use std::sync::Arc;
use crate::{
mock,
traits::{Decoder, Encoder},
Tokenizer,
};
#[test]
fn test_mock_tokenizer_encode() {
let tokenizer = mock::MockTokenizer::new();
let encoding = tokenizer.encode("Hello world", false).unwrap();
let token_ids = encoding.token_ids();
assert_eq!(token_ids, &[1, 2]); }
#[test]
fn test_mock_tokenizer_decode() {
let tokenizer = mock::MockTokenizer::new();
let text = tokenizer.decode(&[1, 2], false).unwrap();
assert_eq!(text, "Hello world");
}
#[test]
fn test_mock_tokenizer_decode_skip_special() {
let tokenizer = mock::MockTokenizer::new();
let text = tokenizer.decode(&[1000, 1, 2, 999], false).unwrap();
assert_eq!(text, "<bos> Hello world <eos>");
let text = tokenizer.decode(&[1000, 1, 2, 999], true).unwrap();
assert_eq!(text, "Hello world");
}
#[test]
fn test_tokenizer_wrapper() {
let mock_tokenizer = Arc::new(mock::MockTokenizer::new());
let tokenizer = Tokenizer::from_arc(mock_tokenizer);
let encoding = tokenizer.encode("Hello world", false).unwrap();
assert_eq!(encoding.token_ids(), &[1, 2]);
let text = tokenizer.decode(&[1, 2], false).unwrap();
assert_eq!(text, "Hello world");
assert_eq!(tokenizer.vocab_size(), 14);
assert_eq!(tokenizer.token_to_id("Hello"), Some(1));
assert_eq!(tokenizer.token_to_id("unknown"), None);
assert_eq!(tokenizer.id_to_token(1), Some("Hello".to_string()));
assert_eq!(tokenizer.id_to_token(9999), None);
}
#[test]
fn test_decode_stream_basic() {
let mock_tokenizer = Arc::new(mock::MockTokenizer::new());
let tokenizer = Tokenizer::from_arc(mock_tokenizer);
let initial_tokens = vec![1, 2];
let mut stream = tokenizer.decode_stream(&initial_tokens, false);
let result = stream.step(3).unwrap();
assert_eq!(result, Some(" test".to_string()));
assert_eq!(stream.tokens(), &[1, 2, 3]);
}
#[test]
fn test_decode_stream_multiple_steps() {
let mock_tokenizer = Arc::new(mock::MockTokenizer::new());
let tokenizer = Tokenizer::from_arc(mock_tokenizer);
let initial_tokens = vec![1];
let mut stream = tokenizer.decode_stream(&initial_tokens, false);
let result = stream.step(2).unwrap();
assert_eq!(result, Some(" world".to_string()));
let result = stream.step(3).unwrap();
assert_eq!(result, Some(" test".to_string()));
assert_eq!(stream.tokens(), &[1, 2, 3]);
}
#[test]
fn test_decode_stream_flush() {
let mock_tokenizer = Arc::new(mock::MockTokenizer::new());
let tokenizer = Tokenizer::from_arc(mock_tokenizer);
let initial_tokens = vec![1]; let mut stream = tokenizer.decode_stream(&initial_tokens, false);
let step1 = stream.step(2).unwrap();
assert_eq!(step1, Some(" world".to_string()));
let step2 = stream.step(3).unwrap();
assert_eq!(step2, Some(" test".to_string()));
let flushed = stream.flush().unwrap();
assert_eq!(flushed, None);
}
#[test]
fn test_special_tokens() {
let mock_tokenizer = Arc::new(mock::MockTokenizer::new());
let tokenizer = Tokenizer::from_arc(mock_tokenizer);
let special_tokens = tokenizer.get_special_tokens();
assert_eq!(special_tokens.bos_token, Some("<bos>".to_string()));
assert_eq!(special_tokens.eos_token, Some("<eos>".to_string()));
assert_eq!(special_tokens.unk_token, Some("<unk>".to_string()));
assert!(special_tokens.sep_token.is_none());
assert!(special_tokens.pad_token.is_none());
}
#[test]
fn test_batch_encode() {
let tokenizer = mock::MockTokenizer::new();
let inputs = vec!["Hello", "world", "test"];
let encodings = tokenizer.encode_batch(&inputs, false).unwrap();
assert_eq!(encodings.len(), 3);
assert_eq!(encodings[0].token_ids(), &[1]); assert_eq!(encodings[1].token_ids(), &[2]); assert_eq!(encodings[2].token_ids(), &[3]); }
#[test]
fn test_thread_safety() {
use std::thread;
let mock_tokenizer = Arc::new(mock::MockTokenizer::new());
let tokenizer = Tokenizer::from_arc(mock_tokenizer);
let handles: Vec<_> = (0..10)
.map(|i| {
let tokenizer_clone = tokenizer.clone();
thread::spawn(move || {
let text = "Hello test".to_string();
let encoding = tokenizer_clone.encode(&text, false).unwrap();
let decoded = tokenizer_clone.decode(encoding.token_ids(), false).unwrap();
assert!(decoded.contains("Hello") || decoded.contains("test"));
i
})
})
.collect();
for handle in handles {
handle.join().unwrap();
}
}
#[test]
fn test_decode_stream_multibyte_char_boundary() {
use anyhow::Result;
use crate::{
stream::DecodeStream,
traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait},
};
struct MultiByteTokenizer {
special_tokens: SpecialTokens,
}
impl Encoder for MultiByteTokenizer {
fn encode(&self, _input: &str, _add_special_tokens: bool) -> Result<Encoding> {
Ok(Encoding::Plain(vec![]))
}
fn encode_batch(&self, inputs: &[&str], add_special_tokens: bool) -> Result<Vec<Encoding>> {
inputs
.iter()
.map(|s| self.encode(s, add_special_tokens))
.collect()
}
}
impl Decoder for MultiByteTokenizer {
fn decode(&self, token_ids: &[u32], _skip_special_tokens: bool) -> Result<String> {
Ok(match token_ids {
[1, 2] => "abc".into(),
[1, 2, 3] => "ab\u{1F389}".into(), _ => String::new(),
})
}
}
impl TokenizerTrait for MultiByteTokenizer {
fn vocab_size(&self) -> usize {
10
}
fn get_special_tokens(&self) -> &SpecialTokens {
&self.special_tokens
}
fn token_to_id(&self, _token: &str) -> Option<u32> {
None
}
fn id_to_token(&self, _id: u32) -> Option<String> {
None
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
let tokenizer: Arc<dyn TokenizerTrait> = Arc::new(MultiByteTokenizer {
special_tokens: SpecialTokens::default(),
});
let prompt_tokens = vec![1, 2];
let mut stream = DecodeStream::new(tokenizer, &prompt_tokens, false);
let result = stream.step(3).unwrap();
assert_eq!(result, Some("\u{1F389}".to_string()));
}