syntaxdot_encoders/dependency/
encoder.rs

1use std::fmt;
2
3use itertools::multizip;
4use ndarray::{s, ArrayView1};
5use numberer::Numberer;
6use serde::{Deserialize, Serialize};
7use thiserror::Error;
8use udgraph::graph::{DepTriple, Node, Sentence};
9use udgraph::token::Token;
10use udgraph::Error;
11
12use crate::categorical::{ImmutableNumberer, MutableNumberer, Number};
13
14/// Dependency encoding.
15#[derive(Debug, Eq, PartialEq)]
16pub struct DependencyEncoding {
17    /// The head of each (non-ROOT) token.
18    pub heads: Vec<usize>,
19
20    /// The dependency relation of each (non-ROOT) token.
21    pub relations: Vec<usize>,
22}
23
24#[derive(Clone, Debug, Eq, Error, PartialEq)]
25pub enum EncodeError {
26    /// The token does not have a head.
27    MissingHead { token: usize, sent: Vec<String> },
28
29    /// The token does not have a dependency relation.
30    MissingRelation { token: usize, sent: Vec<String> },
31}
32
33impl EncodeError {
34    /// Construct `EncodeError::MissingHead` from a CoNLL-U graph.
35    ///
36    /// Construct an error. `token` is the node index for which the
37    /// error applies in `sentence`.
38    pub fn missing_head(token: usize, sentence: &Sentence) -> Self {
39        Self::MissingHead {
40            sent: Self::sentence_to_forms(sentence),
41            token: token - 1,
42        }
43    }
44
45    /// Construct `EncodeError::MissingRelation` from a CoNLL-X graph.
46    ///
47    /// Construct an error. `token` is the node index for which the
48    /// error applies in `sentence`.
49    pub fn missing_relation(token: usize, sentence: &Sentence) -> Self {
50        Self::MissingRelation {
51            sent: Self::sentence_to_forms(sentence),
52            token: token - 1,
53        }
54    }
55
56    fn format_bracketed(bracket_idx: usize, tokens: &[String]) -> String {
57        let mut tokens = tokens.to_owned();
58        tokens.insert(bracket_idx + 1, "]".to_string());
59        tokens.insert(bracket_idx, "[".to_string());
60
61        tokens.join(" ")
62    }
63
64    fn sentence_to_forms(sentence: &Sentence) -> Vec<String> {
65        sentence
66            .iter()
67            .filter_map(Node::token)
68            .map(Token::form)
69            .map(ToOwned::to_owned)
70            .collect()
71    }
72}
73
74impl fmt::Display for EncodeError {
75    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
76        use EncodeError::*;
77
78        match self {
79            MissingHead { token, sent } => write!(
80                f,
81                "Token does not have a head:\n\n{}\n",
82                Self::format_bracketed(*token, sent),
83            ),
84            MissingRelation { token, sent } => write!(
85                f,
86                "Token does not have a dependency relation:\n\n{}\n",
87                Self::format_bracketed(*token, sent),
88            ),
89        }
90    }
91}
92
93/// Arc-factored dependency encoder/decoder.
94#[derive(Serialize, Deserialize)]
95pub struct DependencyEncoder<N>
96where
97    N: Number<String>,
98{
99    relations: N,
100}
101
102impl<N> DependencyEncoder<N>
103where
104    N: Number<String>,
105{
106    /// Encode a sentence.
107    ///
108    /// Returns the encoding of the dependency graph.
109    pub fn encode(&self, sentence: &Sentence) -> Result<DependencyEncoding, EncodeError> {
110        let dep_graph = sentence.dep_graph();
111
112        let mut heads = Vec::with_capacity(sentence.len());
113        let mut relations = Vec::with_capacity(sentence.len());
114
115        for token_idx in 1..sentence.len() {
116            let head = dep_graph
117                .head(token_idx)
118                .ok_or_else(|| EncodeError::missing_head(token_idx, sentence))?
119                .head();
120            heads.push(head);
121
122            let relation = dep_graph
123                .head(token_idx)
124                .and_then(|triple| triple.relation().map(ToString::to_string))
125                .ok_or_else(|| EncodeError::missing_relation(token_idx, sentence))?;
126            relations.push(
127                self.relations
128                    .number(relation.to_string())
129                    .expect("Unknown dependency relation"),
130            );
131        }
132
133        Ok(DependencyEncoding { heads, relations })
134    }
135
136    /// Decode a dependency graph from a score matrix.
137    ///
138    /// The following arguments must be provided:
139    ///
140    /// * `pairwise_head_score`: edge (arc) score matrix, `pairwise_head_score[dependent][head]`
141    ///   is the score for attaching `dependent` to `head`.
142    /// * `best_pairwise_relations`: represents per dependent the best dependency relation
143    ///   given a head (`best_pairwise_relations[dependent, head]`).
144    /// * `sentence`: the sentence in which to store the dependency relations.
145    pub fn decode(
146        &self,
147        sent_heads: ArrayView1<i64>,
148        best_pairwise_relations: ArrayView1<i32>,
149        sentence: &mut Sentence,
150    ) -> Result<(), Error> {
151        // Unwrap the heads, skipping the root vertex.
152        let heads = sent_heads.slice(s![1..]);
153
154        let relations = best_pairwise_relations
155            .into_iter()
156            .skip(1)
157            .cloned()
158            .collect::<Vec<_>>();
159
160        for (dep, &head, relation) in multizip((1..sentence.len(), heads, relations)) {
161            let relation = self
162                .relations
163                .value(relation as usize)
164                // We should never predict an unknown relation, that would mean that
165                // the model does not correspond to the label inventory. This cannot
166                // happen, because the model's shape is based on the number of relations
167                // reported by instances of this type.
168                .unwrap_or_else(|| panic!("Predicted an unknown relation: {}", relation));
169            sentence
170                .dep_graph_mut()
171                .add_deprel::<String>(DepTriple::new(head as usize, Some(relation), dep))?;
172        }
173
174        Ok(())
175    }
176
177    pub fn n_relations(&self) -> usize {
178        self.relations.len()
179    }
180}
181
182pub type ImmutableDependencyEncoder = DependencyEncoder<ImmutableNumberer<String>>;
183
184pub type MutableDependencyEncoder = DependencyEncoder<MutableNumberer<String>>;
185
186impl Default for MutableDependencyEncoder {
187    fn default() -> Self {
188        DependencyEncoder {
189            relations: MutableNumberer::new(Numberer::new(0)),
190        }
191    }
192}
193
194impl MutableDependencyEncoder {
195    /// Create a mutable dependency encoder.
196    pub fn new() -> Self {
197        Default::default()
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use std::fs::File;
204    use std::io::BufReader;
205    use std::iter::once;
206
207    use conllu::io::Reader;
208    use udgraph::graph::{DepTriple, Sentence};
209    use udgraph::token::Token;
210
211    use crate::dependency::{DependencyEncoding, EncodeError, MutableDependencyEncoder};
212    use ndarray::Array1;
213
214    static NON_PROJECTIVE_DATA: &str = "testdata/lassy-small-dev.conllu";
215
216    #[test]
217    pub fn encoding_fails_with_missing_head() {
218        let sent: Sentence = vec![
219            Token::new("Ze"),
220            Token::new("koopt"),
221            Token::new("een"),
222            Token::new("auto"),
223        ]
224        .into_iter()
225        .collect();
226
227        let encoder = MutableDependencyEncoder::new();
228
229        assert!(matches!(
230            encoder.encode(&sent),
231            Err(EncodeError::MissingHead { .. })
232        ));
233    }
234
235    #[test]
236    pub fn encoding_fails_with_missing_relation() {
237        let mut sent: Sentence = vec![
238            Token::new("Ze"),
239            Token::new("koopt"),
240            Token::new("een"),
241            Token::new("auto"),
242        ]
243        .into_iter()
244        .collect();
245
246        sent.dep_graph_mut()
247            .add_deprel(DepTriple::new(0, Some("root"), 2))
248            .unwrap();
249        sent.dep_graph_mut()
250            .add_deprel(DepTriple::<&str>::new(2, None, 1))
251            .unwrap();
252        sent.dep_graph_mut()
253            .add_deprel(DepTriple::new(2, Some("obj"), 4))
254            .unwrap();
255        sent.dep_graph_mut()
256            .add_deprel(DepTriple::new(4, Some("det"), 3))
257            .unwrap();
258
259        let encoder = MutableDependencyEncoder::new();
260
261        assert!(matches!(
262            encoder.encode(&sent),
263            Err(EncodeError::MissingRelation { .. })
264        ));
265    }
266
267    #[test]
268    pub fn encoder_encodes_correctly() {
269        let mut sent: Sentence = vec![
270            Token::new("Ze"),
271            Token::new("koopt"),
272            Token::new("een"),
273            Token::new("auto"),
274        ]
275        .into_iter()
276        .collect();
277
278        sent.dep_graph_mut()
279            .add_deprel(DepTriple::new(0, Some("root"), 2))
280            .unwrap();
281        sent.dep_graph_mut()
282            .add_deprel(DepTriple::new(2, Some("nsubj"), 1))
283            .unwrap();
284        sent.dep_graph_mut()
285            .add_deprel(DepTriple::new(2, Some("obj"), 4))
286            .unwrap();
287        sent.dep_graph_mut()
288            .add_deprel(DepTriple::new(4, Some("det"), 3))
289            .unwrap();
290
291        let encoder = MutableDependencyEncoder::new();
292
293        let encoding = encoder.encode(&sent).unwrap();
294
295        assert_eq!(
296            encoding,
297            DependencyEncoding {
298                heads: vec![2, 0, 4, 2],
299                relations: vec![0, 1, 2, 3]
300            }
301        )
302    }
303
304    #[test]
305    pub fn no_changes_in_encode_decode_roundtrip() {
306        let f = File::open(NON_PROJECTIVE_DATA).unwrap();
307        let reader = Reader::new(BufReader::new(f));
308
309        let encoder = MutableDependencyEncoder::new();
310
311        for sentence in reader {
312            let sentence = sentence.unwrap();
313            let encoding = encoder.encode(&sentence).unwrap();
314
315            let heads = once(0)
316                .chain(encoding.heads.into_iter().map(|v| v as i64))
317                .collect::<Array1<_>>();
318            let best_relations = once(-1)
319                .chain(encoding.relations.into_iter().map(|v| v as i32))
320                .collect::<Array1<_>>();
321
322            let mut decoded_sentence = sentence.clone();
323
324            // Test MST decoding.
325            encoder
326                .decode(heads.view(), best_relations.view(), &mut decoded_sentence)
327                .unwrap();
328            assert_eq!(decoded_sentence, sentence);
329        }
330    }
331}