use conllu::graph::{Node, Sentence};
use sentencepiece::SentencePieceProcessor;
use crate::input::{SentenceWithPieces, Tokenize};
const FAIRSEQ_BOS_ID: i64 = 0;
const FAIRSEQ_EOS_ID: i64 = 2;
const FAIRSEQ_OFFSET: i64 = 1;
pub struct XlmRobertaTokenizer {
spp: SentencePieceProcessor,
}
impl XlmRobertaTokenizer {
pub fn new(spp: SentencePieceProcessor) -> Self {
XlmRobertaTokenizer { spp }
}
}
impl From<SentencePieceProcessor> for XlmRobertaTokenizer {
fn from(spp: SentencePieceProcessor) -> Self {
XlmRobertaTokenizer::new(spp)
}
}
impl Tokenize for XlmRobertaTokenizer {
fn tokenize(&self, sentence: Sentence) -> SentenceWithPieces {
let mut pieces = Vec::with_capacity((sentence.len() - 1) * 3);
let mut token_offsets = Vec::with_capacity(sentence.len());
pieces.push(FAIRSEQ_BOS_ID);
for token in sentence.iter().filter_map(Node::token) {
token_offsets.push(pieces.len());
let token_pieces = self
.spp
.encode(token.form())
.expect("The sentencepiece tokenizer failed");
if !token_pieces.is_empty() {
pieces.extend(
token_pieces
.into_iter()
.map(|piece| piece.id as i64 + FAIRSEQ_OFFSET),
);
} else {
pieces.push(self.spp.unknown_id() as i64 + FAIRSEQ_OFFSET);
}
}
pieces.push(FAIRSEQ_EOS_ID);
SentenceWithPieces {
pieces: pieces.into(),
sentence,
token_offsets,
}
}
}
#[cfg(feature = "model-tests")]
#[cfg(test)]
mod tests {
use std::iter::FromIterator;
use conllu::graph::Sentence;
use conllu::token::Token;
use ndarray::array;
use sentencepiece::SentencePieceProcessor;
use crate::input::{Tokenize, XlmRobertaTokenizer};
fn sentence_from_forms(forms: &[&str]) -> Sentence {
Sentence::from_iter(forms.iter().map(|&f| Token::new(f)))
}
fn xlm_roberta_tokenizer() -> XlmRobertaTokenizer {
let spp = SentencePieceProcessor::load(env!("XLM_ROBERTA_BASE_SENTENCEPIECE")).unwrap();
XlmRobertaTokenizer::from(spp)
}
#[test]
fn tokenizer_gives_expected_output() {
let tokenizer = xlm_roberta_tokenizer();
let sent = sentence_from_forms(&["Veruntreute", "die", "AWO", "Spendengeld", "?"]);
let pieces = tokenizer.tokenize(sent);
assert_eq!(
pieces.pieces,
array![0, 310, 23451, 107, 6743, 68, 62, 43789, 207126, 49004, 705, 2]
);
}
#[test]
fn handles_missing_sentence_pieces() {
let tokenizer = xlm_roberta_tokenizer();
let sent = sentence_from_forms(&["die", " ", "AWO"]);
let pieces = tokenizer.tokenize(sent);
assert_eq!(pieces.pieces, array![0, 68, 1, 62, 43789, 2]);
}
}