syntaxdot_tokenizers/
xlm_roberta.rs1use std::path::Path;
2
3use sentencepiece::SentencePieceProcessor;
4use udgraph::graph::{Node, Sentence};
5
6use super::{SentenceWithPieces, Tokenize};
7use crate::TokenizerError;
8
9const FAIRSEQ_BOS_ID: i64 = 0;
10const FAIRSEQ_EOS_ID: i64 = 2;
11const FAIRSEQ_OFFSET: i64 = 1;
12const FAIRSEQ_UNK: i64 = 3;
13
14pub struct XlmRobertaTokenizer {
22 spp: SentencePieceProcessor,
23}
24
25impl XlmRobertaTokenizer {
26 pub fn new(spp: SentencePieceProcessor) -> Self {
27 XlmRobertaTokenizer { spp }
28 }
29
30 pub fn open<P>(model: P) -> Result<Self, TokenizerError>
31 where
32 P: AsRef<Path>,
33 {
34 let spp = SentencePieceProcessor::open(model)?;
35 Ok(Self::new(spp))
36 }
37}
38
39impl From<SentencePieceProcessor> for XlmRobertaTokenizer {
40 fn from(spp: SentencePieceProcessor) -> Self {
41 XlmRobertaTokenizer::new(spp)
42 }
43}
44
45impl Tokenize for XlmRobertaTokenizer {
46 fn tokenize(&self, sentence: Sentence) -> SentenceWithPieces {
47 let mut pieces = Vec::with_capacity((sentence.len() - 1) * 3);
50 let mut token_offsets = Vec::with_capacity(sentence.len());
51
52 pieces.push(FAIRSEQ_BOS_ID);
53
54 for token in sentence.iter().filter_map(Node::token) {
55 token_offsets.push(pieces.len());
56
57 let token_pieces = self
58 .spp
59 .encode(token.form())
60 .expect("The sentencepiece tokenizer failed");
61
62 if !token_pieces.is_empty() {
63 pieces.extend(token_pieces.into_iter().map(|piece| {
64 let piece_id = piece.id as i64;
65 if piece_id == self.spp.unk_id() as i64 {
66 FAIRSEQ_UNK
67 } else {
68 piece_id + FAIRSEQ_OFFSET
69 }
70 }));
71 } else {
72 pieces.push(self.spp.unk_id() as i64 + FAIRSEQ_OFFSET);
80 }
81 }
82
83 pieces.push(FAIRSEQ_EOS_ID);
84
85 SentenceWithPieces {
86 pieces: pieces.into(),
87 sentence,
88 token_offsets,
89 }
90 }
91}
92
93#[cfg(feature = "model-tests")]
94#[cfg(test)]
95mod tests {
96 use std::iter::FromIterator;
97
98 use ndarray::array;
99 use sentencepiece::SentencePieceProcessor;
100 use udgraph::graph::Sentence;
101 use udgraph::token::Token;
102
103 use super::XlmRobertaTokenizer;
104 use crate::Tokenize;
105
106 fn sentence_from_forms(forms: &[&str]) -> Sentence {
107 Sentence::from_iter(forms.iter().map(|&f| Token::new(f)))
108 }
109
110 fn xlm_roberta_tokenizer() -> XlmRobertaTokenizer {
111 let spp = SentencePieceProcessor::open(env!("XLM_ROBERTA_BASE_SENTENCEPIECE")).unwrap();
112 XlmRobertaTokenizer::from(spp)
113 }
114
115 #[test]
116 fn tokenizer_gives_expected_output() {
117 let tokenizer = xlm_roberta_tokenizer();
118 let sent = sentence_from_forms(&["Veruntreute", "die", "AWO", "Spendengeld", "?"]);
119 let pieces = tokenizer.tokenize(sent);
120 assert_eq!(
121 pieces.pieces,
122 array![0, 310, 23451, 107, 6743, 68, 62, 43789, 207126, 49004, 705, 2]
123 );
124 }
125
126 #[test]
127 fn handles_missing_sentence_pieces() {
128 let tokenizer = xlm_roberta_tokenizer();
129 let sent = sentence_from_forms(&["die", " ", "AWO"]);
130 let pieces = tokenizer.tokenize(sent);
131 assert_eq!(pieces.pieces, array![0, 68, 1, 62, 43789, 2]);
132 }
133}