1use crate::common::protocols::MoveBlock;
5use derive_getters::Getters;
6use dynamo_tokens::blocks::UniqueBlock;
7use dynamo_tokens::{TokenBlockSequence, Tokens};
8use rand::random;
9use validator::Validate;
10
11fn create_unique_blocks_from_sequence(
13 tokens: &TokenBlockSequence,
14 block_size: usize,
15 enable_prefix_caching: bool,
16) -> Vec<UniqueBlock> {
17 let mut unique_blocks: Vec<UniqueBlock> = tokens
18 .blocks()
19 .iter()
20 .map(|block| {
21 if enable_prefix_caching {
22 UniqueBlock::FullBlock(block.sequence_hash())
23 } else {
24 UniqueBlock::FullBlock(random::<u64>())
25 }
26 })
27 .collect();
28
29 if !tokens.total_tokens().is_multiple_of(block_size) {
31 unique_blocks.push(UniqueBlock::default());
32 }
33 unique_blocks
34}
35
36#[derive(Debug, Getters, Validate)]
39pub struct ActiveSequence {
40 unique_blocks: Vec<UniqueBlock>,
41
42 tokens: TokenBlockSequence,
43
44 #[getter(copy)]
45 #[validate(range(min = 2))]
46 block_size: usize,
47
48 #[getter(copy)]
49 max_output_tokens: usize,
50
51 #[getter(copy)]
52 generated_tokens: usize,
53
54 #[getter(copy)]
55 num_input_tokens: usize,
56
57 creation_signal: Option<MoveBlock>,
58
59 #[getter(copy)]
60 enable_prefix_caching: bool,
61
62 #[getter(copy)]
63 emit_token_ids: bool,
64}
65
66impl ActiveSequence {
67 pub fn new(
69 tokens: Vec<u32>,
70 max_output_tokens: usize,
71 block_size: Option<usize>,
72 enable_prefix_caching: bool,
73 emit_token_ids: bool,
74 ) -> Self {
75 let block_size = block_size.unwrap_or(64);
76 let num_input_tokens = tokens.len();
77
78 let block_token_ids: Option<Vec<Vec<u32>>> = if emit_token_ids {
79 let num_complete = tokens.len() / block_size;
80 Some(
81 tokens
82 .chunks(block_size)
83 .take(num_complete)
84 .map(|c| c.to_vec())
85 .collect(),
86 )
87 } else {
88 None
89 };
90
91 let tokens = Tokens::from(tokens).into_sequence(block_size as u32, Some(1337));
92 let unique_blocks =
93 create_unique_blocks_from_sequence(&tokens, block_size, enable_prefix_caching);
94 let block_hashes = tokens.blocks().iter().map(|b| b.block_hash()).collect();
95 let creation_signal = Some(MoveBlock::Use(
96 unique_blocks.clone(),
97 block_hashes,
98 block_token_ids,
99 ));
100
101 let seq = Self {
102 unique_blocks,
103 tokens,
104 block_size,
105 max_output_tokens,
106 generated_tokens: 0,
107 num_input_tokens,
108 creation_signal,
109 enable_prefix_caching,
110 emit_token_ids,
111 };
112 seq.validate().expect("invalid ActiveSequence");
113 seq
114 }
115
116 pub fn extra_tokens(&self) -> u32 {
117 (self.len() % self.block_size) as u32
118 }
119
120 pub fn len(&self) -> usize {
121 self.tokens.total_tokens()
122 }
123
124 pub fn is_empty(&self) -> bool {
125 self.tokens.total_tokens() == 0
126 }
127
128 pub fn take_creation_signal(&mut self) -> Option<MoveBlock> {
129 self.creation_signal.take()
130 }
131
132 pub fn block_hashes(&self) -> Vec<u64> {
133 self.tokens
134 .blocks()
135 .iter()
136 .map(|block| block.block_hash())
137 .collect()
138 }
139
140 pub fn new_with_signal(
142 tokens: Vec<u32>,
143 max_output_tokens: usize,
144 block_size: Option<usize>,
145 enable_prefix_caching: bool,
146 ) -> (Self, Option<MoveBlock>) {
147 let mut sequence = Self::new(
148 tokens,
149 max_output_tokens,
150 block_size,
151 enable_prefix_caching,
152 false,
153 );
154 let signal = sequence.take_creation_signal();
155 (sequence, signal)
156 }
157
158 pub fn push(&mut self, token: u32) -> Option<Vec<MoveBlock>> {
160 self.tokens.append(token).expect("Token push failed.");
161 self.generated_tokens += 1;
162
163 if self.len() % self.block_size != 1 {
164 return None;
165 }
166
167 let mut signals = Vec::new();
170
171 if let Some(UniqueBlock::PartialBlock(uuid)) = self.unique_blocks.last().cloned() {
173 let last_complete = self.tokens.last_complete_block().unwrap();
174 let last_seq_hash = if self.enable_prefix_caching {
175 last_complete.sequence_hash()
176 } else {
177 random::<u64>()
178 };
179 let last_block_hash = last_complete.block_hash();
180 let promote_token_ids = if self.emit_token_ids {
181 Some(last_complete.tokens().to_vec())
182 } else {
183 None
184 };
185 self.unique_blocks.pop();
186
187 let second_to_last_hash = self.unique_blocks.last().map(|block| match block {
189 UniqueBlock::FullBlock(hash) => *hash,
190 UniqueBlock::PartialBlock(_) => panic!("Cannot have a partial block as parent"),
191 });
192
193 self.unique_blocks
194 .push(UniqueBlock::FullBlock(last_seq_hash));
195 signals.push(MoveBlock::Promote(
196 uuid,
197 last_seq_hash,
198 second_to_last_hash,
199 last_block_hash,
200 promote_token_ids,
201 ));
202 }
203
204 let new_partial_block = UniqueBlock::default();
205 self.unique_blocks.push(new_partial_block.clone());
206 signals.push(MoveBlock::Use(vec![new_partial_block], vec![], None));
207 Some(signals)
208 }
209
210 pub fn generate(&mut self) -> Vec<MoveBlock> {
222 assert!(
224 self.generated_tokens < self.max_output_tokens,
225 "Cannot generate more tokens: reached max_output_tokens limit"
226 );
227
228 let token = random::<u32>();
230
231 let mut signals = Vec::new();
233
234 if let Some(move_blocks) = self.push(token) {
236 signals.extend(move_blocks);
237 }
238
239 if self.generated_tokens != self.max_output_tokens {
241 return signals;
242 }
243
244 signals.extend(self.free_signal());
246 signals
247 }
248
249 pub fn free_signal(&self) -> Vec<MoveBlock> {
251 self.unique_blocks
252 .iter()
253 .rev()
254 .map(|block| match block {
255 UniqueBlock::PartialBlock(uuid) => {
256 MoveBlock::Destroy(vec![UniqueBlock::PartialBlock(*uuid)])
257 }
258 UniqueBlock::FullBlock(hash) => {
259 MoveBlock::Deref(vec![UniqueBlock::FullBlock(*hash)])
260 }
261 })
262 .collect()
263 }
264
265 pub fn reset_with_signal(&mut self) -> Vec<MoveBlock> {
268 let free_signal = self.free_signal();
269
270 let block_token_ids = if self.emit_token_ids {
273 Some(
274 self.tokens
275 .blocks()
276 .iter()
277 .map(|b| b.tokens().to_vec())
278 .collect(),
279 )
280 } else {
281 None
282 };
283
284 self.creation_signal = Some(MoveBlock::Use(
285 self.unique_blocks.clone(),
286 self.block_hashes(),
287 block_token_ids,
288 ));
289
290 free_signal
291 }
292
293 pub fn pop(&mut self) {
295 self.tokens.pop();
296 self.generated_tokens = self.generated_tokens.saturating_sub(1);
297
298 if self.tokens.total_tokens().is_multiple_of(self.block_size) {
300 self.unique_blocks.pop();
301 }
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308
309 #[test]
310 fn test_active_sequence_push() {
311 let initial_tokens: Vec<u32> = (0..15).collect();
313 let (mut seq1, signal1) =
314 ActiveSequence::new_with_signal(initial_tokens, 100, Some(16), true);
315 assert_eq!(seq1.num_input_tokens(), 15);
316 assert_eq!(seq1.len(), 15);
317
318 assert!(signal1.is_some());
320 match &signal1 {
321 Some(MoveBlock::Use(blocks, _hashes, ..)) => {
322 assert_eq!(blocks.len(), 1);
323 }
324 _ => panic!("Expected Use signal"),
325 }
326
327 let signal_15 = seq1.push(15);
329 assert!(
330 signal_15.is_none(),
331 "Completing a block should not trigger signals"
332 );
333
334 let signal_16 = seq1.push(16);
336 assert!(signal_16.is_some());
337 let signal_16 = signal_16.unwrap();
338 assert_eq!(signal_16.len(), 2);
339
340 match &signal_16[0] {
342 MoveBlock::Promote(_, _, parent_hash, _hash, ..) => {
343 assert_eq!(*parent_hash, None);
344 }
345 _ => panic!("Expected Promote signal as second signal"),
346 }
347
348 match &signal_16[1] {
350 MoveBlock::Use(blocks, _hashes, ..) => {
351 assert_eq!(blocks.len(), 1);
352 assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_)));
353 }
354 _ => panic!("Expected Use signal as first signal"),
355 }
356
357 assert_eq!(seq1.unique_blocks().len(), 2); assert_eq!(seq1.len(), 17);
360 assert_eq!(seq1.len() % seq1.block_size(), 1);
361
362 let extended_tokens: Vec<u32> = (0..16).collect();
364 let (mut seq2, _) = ActiveSequence::new_with_signal(extended_tokens, 100, Some(16), true);
365 seq2.push(16);
366 seq2.pop();
367 seq2.push(16);
368
369 assert_eq!(
371 seq1.unique_blocks()[0],
372 seq2.unique_blocks()[0],
373 "First blocks should be the same"
374 );
375
376 assert_ne!(
377 seq1.unique_blocks()[1],
378 seq2.unique_blocks()[1],
379 "Second blocks should be different"
380 );
381
382 seq1.push(17);
384 seq1.pop();
385 seq1.pop();
386 seq1.push(16);
387
388 for token in 17..33 {
390 seq1.push(token);
391 seq2.push(token);
392 }
393
394 assert_eq!(
399 seq1.unique_blocks().len(),
400 3,
401 "seq1 should have exactly 3 blocks"
402 );
403 assert_eq!(
404 seq2.unique_blocks().len(),
405 3,
406 "seq2 should have exactly 3 blocks"
407 );
408 assert_eq!(
409 seq1.len() % seq1.block_size(),
410 1,
411 "seq1 should have 1 partial token"
412 );
413 assert_eq!(
414 seq2.len() % seq2.block_size(),
415 1,
416 "seq2 should have 1 partial token"
417 );
418
419 assert_eq!(
421 &seq1.unique_blocks()[0..2],
422 &seq2.unique_blocks()[0..2],
423 "First two blocks should be identical"
424 );
425
426 for token in 33..48 {
428 seq1.push(token);
429 }
430
431 let signal = seq1.push(48);
433 let signal = signal.unwrap();
434
435 match &signal[0] {
437 MoveBlock::Promote(_, _, parent_hash, _hash, ..) => {
438 if let UniqueBlock::FullBlock(expected_hash) = seq1.unique_blocks()[1] {
440 assert_eq!(
441 *parent_hash,
442 Some(expected_hash),
443 "Parent hash should match unique_blocks[1]"
444 );
445 } else {
446 panic!("unique_blocks[1] should be a full block");
447 }
448 }
449 _ => panic!("Expected Promote signal as first signal"),
450 }
451
452 let free_signals = seq1.reset_with_signal();
454
455 assert_eq!(seq1.generated_tokens(), 34);
457
458 assert!(!free_signals.is_empty());
460 }
461
462 #[test]
463 fn test_active_sequence_generate_signals() {
464 let initial_tokens: Vec<u32> = (0..14).collect();
466 let (mut seq, signal) = ActiveSequence::new_with_signal(initial_tokens, 5, Some(16), true);
467
468 assert!(signal.is_some());
470 match signal {
471 Some(MoveBlock::Use(blocks, _hashes, ..)) => {
472 assert_eq!(blocks.len(), 1);
473 assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_)));
474 }
475 _ => panic!("Expected Use signal for the initial partial block"),
476 }
477
478 seq.generate();
480 let signals_first = seq.generate();
481 assert_eq!(signals_first.len(), 0);
482
483 let signals_second = seq.generate();
485 assert_eq!(signals_second.len(), 2);
486
487 match &signals_second[0] {
489 MoveBlock::Promote(_, _, parent_hash, _hash, ..) => {
490 assert_eq!(*parent_hash, None);
491 }
492 _ => panic!("Expected Promote signal as first signal after second token"),
493 }
494
495 match &signals_second[1] {
497 MoveBlock::Use(blocks, _hashes, ..) => {
498 assert_eq!(blocks.len(), 1);
499 assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_)));
500 }
501 _ => panic!("Expected Use signal as second signal after second token"),
502 }
503
504 let signals_third = seq.generate();
506 assert_eq!(signals_third.len(), 0);
507
508 let signals_last = seq.generate();
510 assert_eq!(signals_last.len(), 2);
511
512 match &signals_last[0] {
514 MoveBlock::Destroy(blocks) => {
515 assert_eq!(blocks.len(), 1);
516 assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_)));
517 }
518 _ => panic!("Expected Destroy signal for partial block after fourth token"),
519 }
520
521 match &signals_last[1] {
523 MoveBlock::Deref(blocks) => {
524 assert_eq!(blocks.len(), 1);
525 assert!(matches!(blocks[0], UniqueBlock::FullBlock(_)));
526 }
527 _ => panic!("Expected Deref signal for full block after fourth token"),
528 }
529 }
530}