use std::path::Path;
use anyhow::{Error, Result};
use sentencepiece::SentencePieceProcessor;
const SOURCE_SPM_FILE: &str = "source.spm";
const TARGET_SPM_FILE: &str = "target.spm";
pub struct Tokenizer {
encoder: SentencePieceProcessor,
decoder: SentencePieceProcessor,
}
impl Tokenizer {
pub fn new<T: AsRef<Path>>(path: T) -> Result<Self> {
Tokenizer::from_file(
path.as_ref().join(SOURCE_SPM_FILE),
path.as_ref().join(TARGET_SPM_FILE),
)
}
pub fn from_file<T: AsRef<Path>, U: AsRef<Path>>(src: T, target: U) -> Result<Self> {
Ok(Self {
encoder: SentencePieceProcessor::open(src)?,
decoder: SentencePieceProcessor::open(target)?,
})
}
}
impl crate::Tokenizer for Tokenizer {
fn encode(&self, input: &str) -> Result<Vec<String>> {
let mut source: Vec<String> = self
.encoder
.encode(input)?
.iter()
.map(|v| v.piece.to_string())
.collect();
source.push("</s>".to_string());
Ok(source)
}
fn decode(&self, tokens: Vec<String>) -> Result<String> {
self.decoder.decode_pieces(&tokens).map_err(Error::new)
}
}