syntaxdot_tokenizers/
xlm_roberta.rs

1use std::path::Path;
2
3use sentencepiece::SentencePieceProcessor;
4use udgraph::graph::{Node, Sentence};
5
6use super::{SentenceWithPieces, Tokenize};
7use crate::TokenizerError;
8
9const FAIRSEQ_BOS_ID: i64 = 0;
10const FAIRSEQ_EOS_ID: i64 = 2;
11const FAIRSEQ_OFFSET: i64 = 1;
12const FAIRSEQ_UNK: i64 = 3;
13
14/// Tokenizer for Roberta models.
15///
16/// Roberta uses the sentencepiece tokenizer. However, we cannot use
17/// it in the intended way: we would have to detokenize sentences and
18/// it is not guaranteed that each token has a unique piece, which is
19/// required in sequence labeling. So instead, we use the tokenizer as
20/// a subword tokenizer.
21pub struct XlmRobertaTokenizer {
22    spp: SentencePieceProcessor,
23}
24
25impl XlmRobertaTokenizer {
26    pub fn new(spp: SentencePieceProcessor) -> Self {
27        XlmRobertaTokenizer { spp }
28    }
29
30    pub fn open<P>(model: P) -> Result<Self, TokenizerError>
31    where
32        P: AsRef<Path>,
33    {
34        let spp = SentencePieceProcessor::open(model)?;
35        Ok(Self::new(spp))
36    }
37}
38
39impl From<SentencePieceProcessor> for XlmRobertaTokenizer {
40    fn from(spp: SentencePieceProcessor) -> Self {
41        XlmRobertaTokenizer::new(spp)
42    }
43}
44
45impl Tokenize for XlmRobertaTokenizer {
46    fn tokenize(&self, sentence: Sentence) -> SentenceWithPieces {
47        // An average of three pieces per token ought to be enough for
48        // everyone ;).
49        let mut pieces = Vec::with_capacity((sentence.len() - 1) * 3);
50        let mut token_offsets = Vec::with_capacity(sentence.len());
51
52        pieces.push(FAIRSEQ_BOS_ID);
53
54        for token in sentence.iter().filter_map(Node::token) {
55            token_offsets.push(pieces.len());
56
57            let token_pieces = self
58                .spp
59                .encode(token.form())
60                .expect("The sentencepiece tokenizer failed");
61
62            if !token_pieces.is_empty() {
63                pieces.extend(token_pieces.into_iter().map(|piece| {
64                    let piece_id = piece.id as i64;
65                    if piece_id == self.spp.unk_id() as i64 {
66                        FAIRSEQ_UNK
67                    } else {
68                        piece_id + FAIRSEQ_OFFSET
69                    }
70                }));
71            } else {
72                // Use the unknown token id if sentencepiece does not
73                // give an output for the token. This should not
74                // happen under normal circumstances, since
75                // sentencepiece does return this id for unknown
76                // tokens. However, the input may be corrupt and use
77                // some form of non-tab whitespace as a form, for which
78                // sentencepiece does not return any identifier.
79                pieces.push(self.spp.unk_id() as i64 + FAIRSEQ_OFFSET);
80            }
81        }
82
83        pieces.push(FAIRSEQ_EOS_ID);
84
85        SentenceWithPieces {
86            pieces: pieces.into(),
87            sentence,
88            token_offsets,
89        }
90    }
91}
92
93#[cfg(feature = "model-tests")]
94#[cfg(test)]
95mod tests {
96    use std::iter::FromIterator;
97
98    use ndarray::array;
99    use sentencepiece::SentencePieceProcessor;
100    use udgraph::graph::Sentence;
101    use udgraph::token::Token;
102
103    use super::XlmRobertaTokenizer;
104    use crate::Tokenize;
105
106    fn sentence_from_forms(forms: &[&str]) -> Sentence {
107        Sentence::from_iter(forms.iter().map(|&f| Token::new(f)))
108    }
109
110    fn xlm_roberta_tokenizer() -> XlmRobertaTokenizer {
111        let spp = SentencePieceProcessor::open(env!("XLM_ROBERTA_BASE_SENTENCEPIECE")).unwrap();
112        XlmRobertaTokenizer::from(spp)
113    }
114
115    #[test]
116    fn tokenizer_gives_expected_output() {
117        let tokenizer = xlm_roberta_tokenizer();
118        let sent = sentence_from_forms(&["Veruntreute", "die", "AWO", "Spendengeld", "?"]);
119        let pieces = tokenizer.tokenize(sent);
120        assert_eq!(
121            pieces.pieces,
122            array![0, 310, 23451, 107, 6743, 68, 62, 43789, 207126, 49004, 705, 2]
123        );
124    }
125
126    #[test]
127    fn handles_missing_sentence_pieces() {
128        let tokenizer = xlm_roberta_tokenizer();
129        let sent = sentence_from_forms(&["die", " ", "AWO"]);
130        let pieces = tokenizer.tokenize(sent);
131        assert_eq!(pieces.pieces, array![0, 68, 1, 62, 43789, 2]);
132    }
133}