syntaxdot_encoders/dependency/
encoder.rs1use std::fmt;
2
3use itertools::multizip;
4use ndarray::{s, ArrayView1};
5use numberer::Numberer;
6use serde::{Deserialize, Serialize};
7use thiserror::Error;
8use udgraph::graph::{DepTriple, Node, Sentence};
9use udgraph::token::Token;
10use udgraph::Error;
11
12use crate::categorical::{ImmutableNumberer, MutableNumberer, Number};
13
14#[derive(Debug, Eq, PartialEq)]
16pub struct DependencyEncoding {
17 pub heads: Vec<usize>,
19
20 pub relations: Vec<usize>,
22}
23
24#[derive(Clone, Debug, Eq, Error, PartialEq)]
25pub enum EncodeError {
26 MissingHead { token: usize, sent: Vec<String> },
28
29 MissingRelation { token: usize, sent: Vec<String> },
31}
32
33impl EncodeError {
34 pub fn missing_head(token: usize, sentence: &Sentence) -> Self {
39 Self::MissingHead {
40 sent: Self::sentence_to_forms(sentence),
41 token: token - 1,
42 }
43 }
44
45 pub fn missing_relation(token: usize, sentence: &Sentence) -> Self {
50 Self::MissingRelation {
51 sent: Self::sentence_to_forms(sentence),
52 token: token - 1,
53 }
54 }
55
56 fn format_bracketed(bracket_idx: usize, tokens: &[String]) -> String {
57 let mut tokens = tokens.to_owned();
58 tokens.insert(bracket_idx + 1, "]".to_string());
59 tokens.insert(bracket_idx, "[".to_string());
60
61 tokens.join(" ")
62 }
63
64 fn sentence_to_forms(sentence: &Sentence) -> Vec<String> {
65 sentence
66 .iter()
67 .filter_map(Node::token)
68 .map(Token::form)
69 .map(ToOwned::to_owned)
70 .collect()
71 }
72}
73
74impl fmt::Display for EncodeError {
75 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
76 use EncodeError::*;
77
78 match self {
79 MissingHead { token, sent } => write!(
80 f,
81 "Token does not have a head:\n\n{}\n",
82 Self::format_bracketed(*token, sent),
83 ),
84 MissingRelation { token, sent } => write!(
85 f,
86 "Token does not have a dependency relation:\n\n{}\n",
87 Self::format_bracketed(*token, sent),
88 ),
89 }
90 }
91}
92
93#[derive(Serialize, Deserialize)]
95pub struct DependencyEncoder<N>
96where
97 N: Number<String>,
98{
99 relations: N,
100}
101
102impl<N> DependencyEncoder<N>
103where
104 N: Number<String>,
105{
106 pub fn encode(&self, sentence: &Sentence) -> Result<DependencyEncoding, EncodeError> {
110 let dep_graph = sentence.dep_graph();
111
112 let mut heads = Vec::with_capacity(sentence.len());
113 let mut relations = Vec::with_capacity(sentence.len());
114
115 for token_idx in 1..sentence.len() {
116 let head = dep_graph
117 .head(token_idx)
118 .ok_or_else(|| EncodeError::missing_head(token_idx, sentence))?
119 .head();
120 heads.push(head);
121
122 let relation = dep_graph
123 .head(token_idx)
124 .and_then(|triple| triple.relation().map(ToString::to_string))
125 .ok_or_else(|| EncodeError::missing_relation(token_idx, sentence))?;
126 relations.push(
127 self.relations
128 .number(relation.to_string())
129 .expect("Unknown dependency relation"),
130 );
131 }
132
133 Ok(DependencyEncoding { heads, relations })
134 }
135
136 pub fn decode(
146 &self,
147 sent_heads: ArrayView1<i64>,
148 best_pairwise_relations: ArrayView1<i32>,
149 sentence: &mut Sentence,
150 ) -> Result<(), Error> {
151 let heads = sent_heads.slice(s![1..]);
153
154 let relations = best_pairwise_relations
155 .into_iter()
156 .skip(1)
157 .cloned()
158 .collect::<Vec<_>>();
159
160 for (dep, &head, relation) in multizip((1..sentence.len(), heads, relations)) {
161 let relation = self
162 .relations
163 .value(relation as usize)
164 .unwrap_or_else(|| panic!("Predicted an unknown relation: {}", relation));
169 sentence
170 .dep_graph_mut()
171 .add_deprel::<String>(DepTriple::new(head as usize, Some(relation), dep))?;
172 }
173
174 Ok(())
175 }
176
177 pub fn n_relations(&self) -> usize {
178 self.relations.len()
179 }
180}
181
182pub type ImmutableDependencyEncoder = DependencyEncoder<ImmutableNumberer<String>>;
183
184pub type MutableDependencyEncoder = DependencyEncoder<MutableNumberer<String>>;
185
186impl Default for MutableDependencyEncoder {
187 fn default() -> Self {
188 DependencyEncoder {
189 relations: MutableNumberer::new(Numberer::new(0)),
190 }
191 }
192}
193
194impl MutableDependencyEncoder {
195 pub fn new() -> Self {
197 Default::default()
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use std::fs::File;
204 use std::io::BufReader;
205 use std::iter::once;
206
207 use conllu::io::Reader;
208 use udgraph::graph::{DepTriple, Sentence};
209 use udgraph::token::Token;
210
211 use crate::dependency::{DependencyEncoding, EncodeError, MutableDependencyEncoder};
212 use ndarray::Array1;
213
214 static NON_PROJECTIVE_DATA: &str = "testdata/lassy-small-dev.conllu";
215
216 #[test]
217 pub fn encoding_fails_with_missing_head() {
218 let sent: Sentence = vec![
219 Token::new("Ze"),
220 Token::new("koopt"),
221 Token::new("een"),
222 Token::new("auto"),
223 ]
224 .into_iter()
225 .collect();
226
227 let encoder = MutableDependencyEncoder::new();
228
229 assert!(matches!(
230 encoder.encode(&sent),
231 Err(EncodeError::MissingHead { .. })
232 ));
233 }
234
235 #[test]
236 pub fn encoding_fails_with_missing_relation() {
237 let mut sent: Sentence = vec![
238 Token::new("Ze"),
239 Token::new("koopt"),
240 Token::new("een"),
241 Token::new("auto"),
242 ]
243 .into_iter()
244 .collect();
245
246 sent.dep_graph_mut()
247 .add_deprel(DepTriple::new(0, Some("root"), 2))
248 .unwrap();
249 sent.dep_graph_mut()
250 .add_deprel(DepTriple::<&str>::new(2, None, 1))
251 .unwrap();
252 sent.dep_graph_mut()
253 .add_deprel(DepTriple::new(2, Some("obj"), 4))
254 .unwrap();
255 sent.dep_graph_mut()
256 .add_deprel(DepTriple::new(4, Some("det"), 3))
257 .unwrap();
258
259 let encoder = MutableDependencyEncoder::new();
260
261 assert!(matches!(
262 encoder.encode(&sent),
263 Err(EncodeError::MissingRelation { .. })
264 ));
265 }
266
267 #[test]
268 pub fn encoder_encodes_correctly() {
269 let mut sent: Sentence = vec![
270 Token::new("Ze"),
271 Token::new("koopt"),
272 Token::new("een"),
273 Token::new("auto"),
274 ]
275 .into_iter()
276 .collect();
277
278 sent.dep_graph_mut()
279 .add_deprel(DepTriple::new(0, Some("root"), 2))
280 .unwrap();
281 sent.dep_graph_mut()
282 .add_deprel(DepTriple::new(2, Some("nsubj"), 1))
283 .unwrap();
284 sent.dep_graph_mut()
285 .add_deprel(DepTriple::new(2, Some("obj"), 4))
286 .unwrap();
287 sent.dep_graph_mut()
288 .add_deprel(DepTriple::new(4, Some("det"), 3))
289 .unwrap();
290
291 let encoder = MutableDependencyEncoder::new();
292
293 let encoding = encoder.encode(&sent).unwrap();
294
295 assert_eq!(
296 encoding,
297 DependencyEncoding {
298 heads: vec![2, 0, 4, 2],
299 relations: vec![0, 1, 2, 3]
300 }
301 )
302 }
303
304 #[test]
305 pub fn no_changes_in_encode_decode_roundtrip() {
306 let f = File::open(NON_PROJECTIVE_DATA).unwrap();
307 let reader = Reader::new(BufReader::new(f));
308
309 let encoder = MutableDependencyEncoder::new();
310
311 for sentence in reader {
312 let sentence = sentence.unwrap();
313 let encoding = encoder.encode(&sentence).unwrap();
314
315 let heads = once(0)
316 .chain(encoding.heads.into_iter().map(|v| v as i64))
317 .collect::<Array1<_>>();
318 let best_relations = once(-1)
319 .chain(encoding.relations.into_iter().map(|v| v as i32))
320 .collect::<Array1<_>>();
321
322 let mut decoded_sentence = sentence.clone();
323
324 encoder
326 .decode(heads.view(), best_relations.view(), &mut decoded_sentence)
327 .unwrap();
328 assert_eq!(decoded_sentence, sentence);
329 }
330 }
331}