use std::path::Path;
use sentencepiece::SentencePieceProcessor;
use udgraph::graph::{Node, Sentence};
use super::{SentenceWithPieces, Tokenize};
use crate::TokenizerError;
pub struct AlbertTokenizer {
spp: SentencePieceProcessor,
}
impl AlbertTokenizer {
pub fn new(spp: SentencePieceProcessor) -> Self {
AlbertTokenizer { spp }
}
pub fn open<P>(model: P) -> Result<Self, TokenizerError>
where
P: AsRef<Path>,
{
let spp = SentencePieceProcessor::open(model)?;
Ok(Self::new(spp))
}
}
impl From<SentencePieceProcessor> for AlbertTokenizer {
fn from(spp: SentencePieceProcessor) -> Self {
AlbertTokenizer::new(spp)
}
}
impl Tokenize for AlbertTokenizer {
fn tokenize(&self, sentence: Sentence) -> SentenceWithPieces {
let mut pieces = Vec::with_capacity((sentence.len() + 1) * 3);
let mut token_offsets = Vec::with_capacity(sentence.len());
pieces.push(
self.spp
.piece_to_id("[CLS]")
.expect("ALBERT model does not have a [CLS] token")
.expect("ALBERT model does not have a [CLS] token") as i64,
);
for token in sentence.iter().filter_map(Node::token) {
token_offsets.push(pieces.len());
let token_pieces = self
.spp
.encode(token.form())
.expect("The sentencepiece tokenizer failed");
if !token_pieces.is_empty() {
pieces.extend(token_pieces.into_iter().map(|piece| piece.id as i64));
} else {
pieces.push(self.spp.unk_id() as i64);
}
}
pieces.push(
self.spp
.piece_to_id("[SEP]")
.expect("ALBERT model does not have a [SEP] token")
.expect("ALBERT model does not have a [SEP] token") as i64,
);
SentenceWithPieces {
pieces: pieces.into(),
sentence,
token_offsets,
}
}
}
#[cfg(feature = "model-tests")]
#[cfg(test)]
mod tests {
use std::iter::FromIterator;
use ndarray::array;
use sentencepiece::SentencePieceProcessor;
use udgraph::graph::Sentence;
use udgraph::token::Token;
use super::AlbertTokenizer;
use crate::Tokenize;
fn sentence_from_forms(forms: &[&str]) -> Sentence {
Sentence::from_iter(forms.iter().map(|&f| Token::new(f)))
}
fn albert_tokenizer() -> AlbertTokenizer {
let spp = SentencePieceProcessor::open(env!("ALBERT_BASE_V2_SENTENCEPIECE")).unwrap();
AlbertTokenizer::new(spp)
}
#[test]
fn tokenizer_gives_expected_output() {
let tokenizer = albert_tokenizer();
let sent = sentence_from_forms(&["pierre", "vinken", "will", "join", "the", "board", "."]);
let pieces = tokenizer.tokenize(sent);
assert_eq!(
pieces.pieces,
array![2, 5399, 9730, 2853, 129, 1865, 14, 686, 13, 9, 3]
);
}
#[test]
fn handles_missing_sentence_pieces() {
let tokenizer = albert_tokenizer();
let sent = sentence_from_forms(&["pierre", " ", "vinken"]);
let pieces = tokenizer.tokenize(sent);
assert_eq!(pieces.pieces, array![2, 5399, 1, 9730, 2853, 3]);
}
}