1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
use std::sync::Arc;
use anyhow::Result;
use crate::traits::{TokenIdType, Tokenizer as TokenizerTrait};
/// Maintains state for an ongoing sequence of tokens and their decoded text.
///
/// Mirrors the design of the native `DecodeStream` in the HuggingFace `tokenizers`
/// crate but works through the `dyn Tokenizer` trait so it supports all backends
/// (HuggingFace, Tiktoken, Mock).
///
/// Key design decisions (matching native `DecodeStream`):
/// - **Token draining**: consumed tokens are drained from the buffer after each
/// successful step, keeping memory bounded regardless of generation length.
/// - **Prefix caching**: the decoded prefix string is cached between calls,
/// avoiding a redundant `decode()` on the next step.
pub struct Sequence {
/// The tokenizer used for encoding/decoding
tokenizer: Arc<dyn TokenizerTrait>,
/// Sliding window of token ids needed for correct incremental decoding.
/// Consumed tokens are drained after each successful step.
token_ids: Vec<TokenIdType>,
/// Total number of tokens ever appended (does NOT reset on drain).
/// Used by callers that need the logical sequence length.
total_tokens: usize,
/// Index within `token_ids` that marks the start of the "prefix" window.
/// Everything before this has already been decoded and can be drained.
prefix_index: usize,
/// Cached decoded prefix string from the previous successful step.
/// On the next call we skip one `decode()` by reusing this.
cached_prefix: String,
/// Whether to skip special tokens when decoding
skip_special_tokens: bool,
}
impl std::fmt::Debug for Sequence {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Sequence")
.field("tokenizer", &"Arc<dyn Tokenizer>")
.field(
"token_ids",
&format_args!("{}", {
let token_ids = &self.token_ids;
if token_ids.len() <= 20 {
format!("{token_ids:?}")
} else {
let first_ten = &token_ids[..10];
let last_ten = &token_ids[token_ids.len() - 10..];
format!("{first_ten:?} ... {last_ten:?}")
}
}),
)
.field("prefix_index", &self.prefix_index)
.field("buffer_len", &self.token_ids.len())
.field("total_tokens", &self.total_tokens)
.finish()
}
}
impl Sequence {
/// Create a new empty sequence
pub fn new(tokenizer: Arc<dyn TokenizerTrait>) -> Self {
Self::new_with_options(tokenizer, false)
}
/// Create a new empty sequence with skip_special_tokens option
pub fn new_with_options(tokenizer: Arc<dyn TokenizerTrait>, skip_special_tokens: bool) -> Self {
Self {
tokenizer,
token_ids: Vec::new(),
total_tokens: 0,
prefix_index: 0,
cached_prefix: String::new(),
skip_special_tokens,
}
}
/// Create a sequence with initial tokens
pub fn with_tokens(tokenizer: Arc<dyn TokenizerTrait>, token_ids: Vec<TokenIdType>) -> Self {
Self::with_tokens_and_options(tokenizer, token_ids, false)
}
/// Create a sequence with initial tokens and skip_special_tokens option
pub fn with_tokens_and_options(
tokenizer: Arc<dyn TokenizerTrait>,
token_ids: Vec<TokenIdType>,
skip_special_tokens: bool,
) -> Self {
let len = token_ids.len();
Self {
tokenizer,
token_ids,
total_tokens: len,
prefix_index: 0,
cached_prefix: String::new(),
skip_special_tokens,
}
}
/// Check if the sequence is empty
#[inline]
pub fn is_empty(&self) -> bool {
self.total_tokens == 0
}
/// Get the total number of tokens appended (logical length, not buffer size)
#[inline]
pub fn len(&self) -> usize {
self.total_tokens
}
/// Clear the sequence
pub fn clear(&mut self) {
self.token_ids.clear();
self.total_tokens = 0;
self.prefix_index = 0;
self.cached_prefix.clear();
}
/// Append text to the sequence by encoding it.
///
/// WARNING: Do not mix `append_text()` and `append_token()` on the same
/// instance. `append_text()` does not invalidate the incremental decode
/// cache (`cached_prefix`/`prefix_index`), so subsequent `append_token()`
/// calls would diff against stale state.
///
/// Set `add_special_tokens` to `true` for embeddings, or `false` for chat completion
/// where the chat template already handles special tokens.
pub fn append_text(&mut self, input: &str, add_special_tokens: bool) -> Result<()> {
let encoding = self.tokenizer.encode(input, add_special_tokens)?;
let ids = encoding.token_ids();
self.token_ids.extend(ids);
self.total_tokens += ids.len();
Ok(())
}
/// Append a single token to the sequence and return newly decoded text.
///
/// Delegates to `Decoder::decode_step` on the tokenizer trait. For HuggingFace
/// tokenizers this uses the native `step_decode_stream`; other backends use the
/// default double-decode fallback. Both paths handle token draining and prefix
/// caching internally.
#[inline]
pub fn append_token(&mut self, token_id: TokenIdType) -> Result<String> {
let result = self.tokenizer.decode_step(
token_id,
&mut self.token_ids,
&mut self.cached_prefix,
&mut self.prefix_index,
self.skip_special_tokens,
)?;
self.total_tokens += 1;
match result {
Some(text) => Ok(text),
None => Ok(String::new()),
}
}
/// Get a reference to the tokenizer
#[inline]
pub fn tokenizer(&self) -> &Arc<dyn TokenizerTrait> {
&self.tokenizer
}
/// Get the current token ids in the buffer (sliding window, not full history)
#[inline]
pub fn token_ids(&self) -> &[TokenIdType] {
&self.token_ids
}
/// Decode the current buffer to text.
///
/// WARNING: after `append_token()` calls, this only decodes the sliding
/// window (retained tokens), not the full sequence history. Use the
/// incremental return values from `append_token()` to build the full text.
pub fn text(&self) -> Result<String> {
self.tokenizer
.decode(&self.token_ids, self.skip_special_tokens)
}
/// Get whether special tokens are skipped during decoding
#[inline]
pub fn skip_special_tokens(&self) -> bool {
self.skip_special_tokens
}
}
#[cfg(test)]
mod tests {
use crate::{mock::MockTokenizer, *};
#[test]
fn test_sequence_new() {
let tokenizer = Arc::new(MockTokenizer::new());
let seq = Sequence::new(tokenizer);
assert!(seq.is_empty());
assert_eq!(seq.len(), 0);
}
#[test]
fn test_sequence_append_text() {
let tokenizer = Arc::new(MockTokenizer::new());
let mut seq = Sequence::new(tokenizer);
seq.append_text("Hello", false).unwrap();
assert!(!seq.is_empty());
let text = seq.text().unwrap();
assert_eq!(text, "Hello");
}
#[test]
fn test_sequence_append_token() {
let tokenizer = Arc::new(MockTokenizer::new());
let mut seq = Sequence::new(tokenizer.clone());
// Start with an empty sequence and append token 1 ("Hello")
let text1 = seq.append_token(1).unwrap();
assert_eq!(text1, "Hello");
// Now append token 2 ("world")
// The mock tokenizer will decode [1, 2] as "Hello world" (with a space)
let text2 = seq.append_token(2).unwrap();
// The incremental text should be " world" (with the space that the mock tokenizer adds)
assert_eq!(text2, " world");
}
#[test]
fn test_sequence_clear() {
let tokenizer = Arc::new(MockTokenizer::new());
let mut seq = Sequence::new(tokenizer);
seq.append_text("Hello world", false).unwrap();
assert!(!seq.is_empty());
seq.clear();
assert!(seq.is_empty());
assert_eq!(seq.len(), 0);
}
#[test]
fn test_sequence_debug() {
let tokenizer = Arc::new(MockTokenizer::new());
let mut seq = Sequence::new(tokenizer);
seq.append_text("Test", false).unwrap();
let debug_str = format!("{seq:?}");
assert!(debug_str.contains("Sequence"));
assert!(debug_str.contains("total_tokens"));
}
#[test]
fn test_sequence_token_drain() {
// Verify that the token buffer stays bounded after many appends
let tokenizer = Arc::new(MockTokenizer::new());
let mut seq = Sequence::new(tokenizer);
// Append many tokens and accumulate decoded output
let mut output = String::new();
let mut all_token_ids = Vec::new();
for i in 0..100 {
let token_id = (i % 5) + 1; // cycle through mock tokens
all_token_ids.push(token_id);
let text = seq.append_token(token_id).unwrap();
output.push_str(&text);
}
// Logical length should reflect all tokens
assert_eq!(seq.len(), 100);
// Buffer should be much smaller than 100 due to draining
assert!(
seq.token_ids().len() < 100,
"Token buffer should be drained, but has {} entries",
seq.token_ids().len()
);
// Accumulated output must match a full decode of all tokens
let expected = seq.tokenizer().decode(&all_token_ids, false).unwrap();
assert_eq!(
output, expected,
"Drained incremental output must match full decode"
);
}
}