use std::sync::Arc;
use anyhow::Result;
use crate::traits::{TokenIdType, Tokenizer as TokenizerTrait};
pub struct Sequence {
tokenizer: Arc<dyn TokenizerTrait>,
token_ids: Vec<TokenIdType>,
total_tokens: usize,
prefix_index: usize,
cached_prefix: String,
skip_special_tokens: bool,
}
impl std::fmt::Debug for Sequence {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Sequence")
.field("tokenizer", &"Arc<dyn Tokenizer>")
.field(
"token_ids",
&format_args!("{}", {
let token_ids = &self.token_ids;
if token_ids.len() <= 20 {
format!("{token_ids:?}")
} else {
let first_ten = &token_ids[..10];
let last_ten = &token_ids[token_ids.len() - 10..];
format!("{first_ten:?} ... {last_ten:?}")
}
}),
)
.field("prefix_index", &self.prefix_index)
.field("buffer_len", &self.token_ids.len())
.field("total_tokens", &self.total_tokens)
.finish()
}
}
impl Sequence {
pub fn new(tokenizer: Arc<dyn TokenizerTrait>) -> Self {
Self::new_with_options(tokenizer, false)
}
pub fn new_with_options(tokenizer: Arc<dyn TokenizerTrait>, skip_special_tokens: bool) -> Self {
Self {
tokenizer,
token_ids: Vec::new(),
total_tokens: 0,
prefix_index: 0,
cached_prefix: String::new(),
skip_special_tokens,
}
}
pub fn with_tokens(tokenizer: Arc<dyn TokenizerTrait>, token_ids: Vec<TokenIdType>) -> Self {
Self::with_tokens_and_options(tokenizer, token_ids, false)
}
pub fn with_tokens_and_options(
tokenizer: Arc<dyn TokenizerTrait>,
token_ids: Vec<TokenIdType>,
skip_special_tokens: bool,
) -> Self {
let len = token_ids.len();
Self {
tokenizer,
token_ids,
total_tokens: len,
prefix_index: 0,
cached_prefix: String::new(),
skip_special_tokens,
}
}
#[inline]
pub fn is_empty(&self) -> bool {
self.total_tokens == 0
}
#[inline]
pub fn len(&self) -> usize {
self.total_tokens
}
pub fn clear(&mut self) {
self.token_ids.clear();
self.total_tokens = 0;
self.prefix_index = 0;
self.cached_prefix.clear();
}
pub fn append_text(&mut self, input: &str, add_special_tokens: bool) -> Result<()> {
let encoding = self.tokenizer.encode(input, add_special_tokens)?;
let ids = encoding.token_ids();
self.token_ids.extend(ids);
self.total_tokens += ids.len();
Ok(())
}
#[inline]
pub fn append_token(&mut self, token_id: TokenIdType) -> Result<String> {
let result = self.tokenizer.decode_step(
token_id,
&mut self.token_ids,
&mut self.cached_prefix,
&mut self.prefix_index,
self.skip_special_tokens,
)?;
self.total_tokens += 1;
match result {
Some(text) => Ok(text),
None => Ok(String::new()),
}
}
#[inline]
pub fn tokenizer(&self) -> &Arc<dyn TokenizerTrait> {
&self.tokenizer
}
#[inline]
pub fn token_ids(&self) -> &[TokenIdType] {
&self.token_ids
}
pub fn text(&self) -> Result<String> {
self.tokenizer
.decode(&self.token_ids, self.skip_special_tokens)
}
#[inline]
pub fn skip_special_tokens(&self) -> bool {
self.skip_special_tokens
}
}
#[cfg(test)]
mod tests {
use crate::{mock::MockTokenizer, *};
#[test]
fn test_sequence_new() {
let tokenizer = Arc::new(MockTokenizer::new());
let seq = Sequence::new(tokenizer);
assert!(seq.is_empty());
assert_eq!(seq.len(), 0);
}
#[test]
fn test_sequence_append_text() {
let tokenizer = Arc::new(MockTokenizer::new());
let mut seq = Sequence::new(tokenizer);
seq.append_text("Hello", false).unwrap();
assert!(!seq.is_empty());
let text = seq.text().unwrap();
assert_eq!(text, "Hello");
}
#[test]
fn test_sequence_append_token() {
let tokenizer = Arc::new(MockTokenizer::new());
let mut seq = Sequence::new(tokenizer.clone());
let text1 = seq.append_token(1).unwrap();
assert_eq!(text1, "Hello");
let text2 = seq.append_token(2).unwrap();
assert_eq!(text2, " world");
}
#[test]
fn test_sequence_clear() {
let tokenizer = Arc::new(MockTokenizer::new());
let mut seq = Sequence::new(tokenizer);
seq.append_text("Hello world", false).unwrap();
assert!(!seq.is_empty());
seq.clear();
assert!(seq.is_empty());
assert_eq!(seq.len(), 0);
}
#[test]
fn test_sequence_debug() {
let tokenizer = Arc::new(MockTokenizer::new());
let mut seq = Sequence::new(tokenizer);
seq.append_text("Test", false).unwrap();
let debug_str = format!("{seq:?}");
assert!(debug_str.contains("Sequence"));
assert!(debug_str.contains("total_tokens"));
}
#[test]
fn test_sequence_token_drain() {
let tokenizer = Arc::new(MockTokenizer::new());
let mut seq = Sequence::new(tokenizer);
let mut output = String::new();
let mut all_token_ids = Vec::new();
for i in 0..100 {
let token_id = (i % 5) + 1; all_token_ids.push(token_id);
let text = seq.append_token(token_id).unwrap();
output.push_str(&text);
}
assert_eq!(seq.len(), 100);
assert!(
seq.token_ids().len() < 100,
"Token buffer should be drained, but has {} entries",
seq.token_ids().len()
);
let expected = seq.tokenizer().decode(&all_token_ids, false).unwrap();
assert_eq!(
output, expected,
"Drained incremental output must match full decode"
);
}
}