use crate::errors::NoosResult;
pub trait NoosTokenizer: Send + Sync {
fn encode(&self, text: &str, add_special_tokens: bool) -> NoosResult<Vec<u32>>;
fn decode(&self, tokens: &[u32]) -> NoosResult<String>;
fn decode_token(&self, token: u32) -> NoosResult<String>;
fn vocab_size(&self) -> usize;
fn eos_token_id(&self) -> u32;
}
#[cfg(test)]
pub(crate) mod tests {
use super::*;
pub(crate) struct MockTokenizer {
vocab_size: usize,
eos_id: u32,
}
impl MockTokenizer {
pub fn new(vocab_size: usize) -> Self {
Self {
vocab_size,
eos_id: (vocab_size - 1) as u32,
}
}
}
impl NoosTokenizer for MockTokenizer {
fn encode(&self, text: &str, _add_special_tokens: bool) -> NoosResult<Vec<u32>> {
Ok(text
.bytes()
.map(|b| (b as u32) % self.vocab_size as u32)
.collect())
}
fn decode(&self, tokens: &[u32]) -> NoosResult<String> {
Ok(tokens
.iter()
.map(|&t| (t as u8 + b'a') as char)
.collect())
}
fn decode_token(&self, token: u32) -> NoosResult<String> {
Ok(((token as u8 + b'a') as char).to_string())
}
fn vocab_size(&self) -> usize {
self.vocab_size
}
fn eos_token_id(&self) -> u32 {
self.eos_id
}
}
#[test]
fn mock_tokenizer_encodes() {
let tokenizer = MockTokenizer::new(256);
let tokens = tokenizer.encode("hello", false).unwrap();
assert_eq!(tokens.len(), 5);
}
#[test]
fn mock_tokenizer_decodes() {
let tokenizer = MockTokenizer::new(256);
let text = tokenizer.decode(&[0, 1, 2]).unwrap();
assert_eq!(text.len(), 3);
}
#[test]
fn mock_tokenizer_eos() {
let tokenizer = MockTokenizer::new(100);
assert_eq!(tokenizer.eos_token_id(), 99);
}
}