Skip to main content

dynamo_mocker/common/
sequence.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::common::protocols::MoveBlock;
5use derive_getters::Getters;
6use dynamo_tokens::blocks::UniqueBlock;
7use dynamo_tokens::{TokenBlockSequence, Tokens};
8use rand::random;
9use validator::Validate;
10
11/// Create unique blocks from a TokenBlockSequence
12fn create_unique_blocks_from_sequence(
13    tokens: &TokenBlockSequence,
14    block_size: usize,
15    enable_prefix_caching: bool,
16) -> Vec<UniqueBlock> {
17    let mut unique_blocks: Vec<UniqueBlock> = tokens
18        .blocks()
19        .iter()
20        .map(|block| {
21            if enable_prefix_caching {
22                UniqueBlock::FullBlock(block.sequence_hash())
23            } else {
24                UniqueBlock::FullBlock(random::<u64>())
25            }
26        })
27        .collect();
28
29    // Only push the partial block if tokens count isn't a multiple of block_size
30    if !tokens.total_tokens().is_multiple_of(block_size) {
31        unique_blocks.push(UniqueBlock::default());
32    }
33    unique_blocks
34}
35
36/// A sequence that is actively being built, with the ability to add tokens and commit to hashes
37/// TODO: reuse tokens
38#[derive(Debug, Getters, Validate)]
39pub struct ActiveSequence {
40    unique_blocks: Vec<UniqueBlock>,
41
42    tokens: TokenBlockSequence,
43
44    #[getter(copy)]
45    #[validate(range(min = 2))]
46    block_size: usize,
47
48    #[getter(copy)]
49    max_output_tokens: usize,
50
51    #[getter(copy)]
52    generated_tokens: usize,
53
54    #[getter(copy)]
55    num_input_tokens: usize,
56
57    creation_signal: Option<MoveBlock>,
58
59    #[getter(copy)]
60    enable_prefix_caching: bool,
61
62    #[getter(copy)]
63    emit_token_ids: bool,
64}
65
66impl ActiveSequence {
67    /// Create a new ActiveSequence instance with the provided tokens
68    pub fn new(
69        tokens: Vec<u32>,
70        max_output_tokens: usize,
71        block_size: Option<usize>,
72        enable_prefix_caching: bool,
73        emit_token_ids: bool,
74    ) -> Self {
75        let block_size = block_size.unwrap_or(64);
76        let num_input_tokens = tokens.len();
77
78        let block_token_ids: Option<Vec<Vec<u32>>> = if emit_token_ids {
79            let num_complete = tokens.len() / block_size;
80            Some(
81                tokens
82                    .chunks(block_size)
83                    .take(num_complete)
84                    .map(|c| c.to_vec())
85                    .collect(),
86            )
87        } else {
88            None
89        };
90
91        let tokens = Tokens::from(tokens).into_sequence(block_size as u32, Some(1337));
92        let unique_blocks =
93            create_unique_blocks_from_sequence(&tokens, block_size, enable_prefix_caching);
94        let block_hashes = tokens.blocks().iter().map(|b| b.block_hash()).collect();
95        let creation_signal = Some(MoveBlock::Use(
96            unique_blocks.clone(),
97            block_hashes,
98            block_token_ids,
99        ));
100
101        let seq = Self {
102            unique_blocks,
103            tokens,
104            block_size,
105            max_output_tokens,
106            generated_tokens: 0,
107            num_input_tokens,
108            creation_signal,
109            enable_prefix_caching,
110            emit_token_ids,
111        };
112        seq.validate().expect("invalid ActiveSequence");
113        seq
114    }
115
116    pub fn extra_tokens(&self) -> u32 {
117        (self.len() % self.block_size) as u32
118    }
119
120    pub fn len(&self) -> usize {
121        self.tokens.total_tokens()
122    }
123
124    pub fn is_empty(&self) -> bool {
125        self.tokens.total_tokens() == 0
126    }
127
128    pub fn take_creation_signal(&mut self) -> Option<MoveBlock> {
129        self.creation_signal.take()
130    }
131
132    pub fn block_hashes(&self) -> Vec<u64> {
133        self.tokens
134            .blocks()
135            .iter()
136            .map(|block| block.block_hash())
137            .collect()
138    }
139
140    /// Create a new ActiveSequence instance and return the creation signal
141    pub fn new_with_signal(
142        tokens: Vec<u32>,
143        max_output_tokens: usize,
144        block_size: Option<usize>,
145        enable_prefix_caching: bool,
146    ) -> (Self, Option<MoveBlock>) {
147        let mut sequence = Self::new(
148            tokens,
149            max_output_tokens,
150            block_size,
151            enable_prefix_caching,
152            false,
153        );
154        let signal = sequence.take_creation_signal();
155        (sequence, signal)
156    }
157
158    /// Push a token to the sequence
159    pub fn push(&mut self, token: u32) -> Option<Vec<MoveBlock>> {
160        self.tokens.append(token).expect("Token push failed.");
161        self.generated_tokens += 1;
162
163        if self.len() % self.block_size != 1 {
164            return None;
165        }
166
167        // Add a partial block for the first token in a new partial sequence
168        // Send Use signal (to allocate space for this new generation block)
169        let mut signals = Vec::new();
170
171        // Replace last partial block with full block if it exists
172        if let Some(UniqueBlock::PartialBlock(uuid)) = self.unique_blocks.last().cloned() {
173            let last_complete = self.tokens.last_complete_block().unwrap();
174            let last_seq_hash = if self.enable_prefix_caching {
175                last_complete.sequence_hash()
176            } else {
177                random::<u64>()
178            };
179            let last_block_hash = last_complete.block_hash();
180            let promote_token_ids = if self.emit_token_ids {
181                Some(last_complete.tokens().to_vec())
182            } else {
183                None
184            };
185            self.unique_blocks.pop();
186
187            // After pop, the last element is the parent block
188            let second_to_last_hash = self.unique_blocks.last().map(|block| match block {
189                UniqueBlock::FullBlock(hash) => *hash,
190                UniqueBlock::PartialBlock(_) => panic!("Cannot have a partial block as parent"),
191            });
192
193            self.unique_blocks
194                .push(UniqueBlock::FullBlock(last_seq_hash));
195            signals.push(MoveBlock::Promote(
196                uuid,
197                last_seq_hash,
198                second_to_last_hash,
199                last_block_hash,
200                promote_token_ids,
201            ));
202        }
203
204        let new_partial_block = UniqueBlock::default();
205        self.unique_blocks.push(new_partial_block.clone());
206        signals.push(MoveBlock::Use(vec![new_partial_block], vec![], None));
207        Some(signals)
208    }
209
210    /// Generate a random token, push it to the sequence, and increment generation count.
211    ///
212    /// This function:
213    /// - Generates a random token and adds it to the current sequence
214    /// - Acquires a new partial block if needed or promotes an existing partial block to a full block
215    /// - Returns appropriate signals for the KvManager to process
216    ///
217    /// # Panics
218    ///
219    /// Calling this function when max_output_tokens has already been reached will cause a panic.
220    /// Always check `generated_tokens < max_output_tokens` before calling this method.
221    pub fn generate(&mut self) -> Vec<MoveBlock> {
222        // Assert that we haven't reached the maximum output tokens
223        assert!(
224            self.generated_tokens < self.max_output_tokens,
225            "Cannot generate more tokens: reached max_output_tokens limit"
226        );
227
228        // Generate a random token
229        let token = random::<u32>();
230
231        // Collect signals
232        let mut signals = Vec::new();
233
234        // Push the token to the sequence and collect any signals
235        if let Some(move_blocks) = self.push(token) {
236            signals.extend(move_blocks);
237        }
238
239        // Check if we've reached the limit after pushing
240        if self.generated_tokens != self.max_output_tokens {
241            return signals;
242        }
243
244        // Free all blocks when we reach max tokens
245        signals.extend(self.free_signal());
246        signals
247    }
248
249    /// Free all blocks, generating appropriate signals for each block type
250    pub fn free_signal(&self) -> Vec<MoveBlock> {
251        self.unique_blocks
252            .iter()
253            .rev()
254            .map(|block| match block {
255                UniqueBlock::PartialBlock(uuid) => {
256                    MoveBlock::Destroy(vec![UniqueBlock::PartialBlock(*uuid)])
257                }
258                UniqueBlock::FullBlock(hash) => {
259                    MoveBlock::Deref(vec![UniqueBlock::FullBlock(*hash)])
260                }
261            })
262            .collect()
263    }
264
265    /// Move the request to a preempted state and return the free signals from freeing current blocks
266    /// Upon preemption, the sequence retains the tokens generated during the decode phase (if any).
267    pub fn reset_with_signal(&mut self) -> Vec<MoveBlock> {
268        let free_signal = self.free_signal();
269
270        // Don't reset generated_tokens since we're keeping the tokens in the sequence
271
272        let block_token_ids = if self.emit_token_ids {
273            Some(
274                self.tokens
275                    .blocks()
276                    .iter()
277                    .map(|b| b.tokens().to_vec())
278                    .collect(),
279            )
280        } else {
281            None
282        };
283
284        self.creation_signal = Some(MoveBlock::Use(
285            self.unique_blocks.clone(),
286            self.block_hashes(),
287            block_token_ids,
288        ));
289
290        free_signal
291    }
292
293    /// Pops last token in the sequence.
294    pub fn pop(&mut self) {
295        self.tokens.pop();
296        self.generated_tokens = self.generated_tokens.saturating_sub(1);
297
298        // Reverts to the last full block
299        if self.tokens.total_tokens().is_multiple_of(self.block_size) {
300            self.unique_blocks.pop();
301        }
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308
309    #[test]
310    fn test_active_sequence_push() {
311        // Create a sequence with block size 16 initialized with tokens [0..15]
312        let initial_tokens: Vec<u32> = (0..15).collect();
313        let (mut seq1, signal1) =
314            ActiveSequence::new_with_signal(initial_tokens, 100, Some(16), true);
315        assert_eq!(seq1.num_input_tokens(), 15);
316        assert_eq!(seq1.len(), 15);
317
318        // Check that we got a Use signal
319        assert!(signal1.is_some());
320        match &signal1 {
321            Some(MoveBlock::Use(blocks, _hashes, ..)) => {
322                assert_eq!(blocks.len(), 1);
323            }
324            _ => panic!("Expected Use signal"),
325        }
326
327        // Push token 15 which should complete the block (no signals yet)
328        let signal_15 = seq1.push(15);
329        assert!(
330            signal_15.is_none(),
331            "Completing a block should not trigger signals"
332        );
333
334        // Push token 16 which should trigger both Promote and Use signals
335        let signal_16 = seq1.push(16);
336        assert!(signal_16.is_some());
337        let signal_16 = signal_16.unwrap();
338        assert_eq!(signal_16.len(), 2);
339
340        // First signal should be Promote for the previous block
341        match &signal_16[0] {
342            MoveBlock::Promote(_, _, parent_hash, _hash, ..) => {
343                assert_eq!(*parent_hash, None);
344            }
345            _ => panic!("Expected Promote signal as second signal"),
346        }
347
348        // Second signal should be Use for new partial block
349        match &signal_16[1] {
350            MoveBlock::Use(blocks, _hashes, ..) => {
351                assert_eq!(blocks.len(), 1);
352                assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_)));
353            }
354            _ => panic!("Expected Use signal as first signal"),
355        }
356
357        // Verify state after pushing tokens
358        assert_eq!(seq1.unique_blocks().len(), 2); // One full block and one partial block
359        assert_eq!(seq1.len(), 17);
360        assert_eq!(seq1.len() % seq1.block_size(), 1);
361
362        // Create another sequence with block size 16 initialized with tokens [0..17]
363        let extended_tokens: Vec<u32> = (0..16).collect();
364        let (mut seq2, _) = ActiveSequence::new_with_signal(extended_tokens, 100, Some(16), true);
365        seq2.push(16);
366        seq2.pop();
367        seq2.push(16);
368
369        // Simplified assertions
370        assert_eq!(
371            seq1.unique_blocks()[0],
372            seq2.unique_blocks()[0],
373            "First blocks should be the same"
374        );
375
376        assert_ne!(
377            seq1.unique_blocks()[1],
378            seq2.unique_blocks()[1],
379            "Second blocks should be different"
380        );
381
382        // Reset partial block on seq1 and push back token 16
383        seq1.push(17);
384        seq1.pop();
385        seq1.pop();
386        seq1.push(16);
387
388        // Now push tokens 17..32 to both sequences
389        for token in 17..33 {
390            seq1.push(token);
391            seq2.push(token);
392        }
393
394        // Both sequences should now have 2 blocks:
395        // 1. FullBlock for tokens 0-15
396        // 2. FullBlock for tokens 16-31
397        // 3. No partial block since there are no remaining tokens
398        assert_eq!(
399            seq1.unique_blocks().len(),
400            3,
401            "seq1 should have exactly 3 blocks"
402        );
403        assert_eq!(
404            seq2.unique_blocks().len(),
405            3,
406            "seq2 should have exactly 3 blocks"
407        );
408        assert_eq!(
409            seq1.len() % seq1.block_size(),
410            1,
411            "seq1 should have 1 partial token"
412        );
413        assert_eq!(
414            seq2.len() % seq2.block_size(),
415            1,
416            "seq2 should have 1 partial token"
417        );
418
419        // Verify that both sequences have identical blocks up to the second position
420        assert_eq!(
421            &seq1.unique_blocks()[0..2],
422            &seq2.unique_blocks()[0..2],
423            "First two blocks should be identical"
424        );
425
426        // Push tokens 34..47 to seq1
427        for token in 33..48 {
428            seq1.push(token);
429        }
430
431        // Push token 48 and get the signal - this completes the block and triggers signals
432        let signal = seq1.push(48);
433        let signal = signal.unwrap();
434
435        // Check that signal[0] is promote
436        match &signal[0] {
437            MoveBlock::Promote(_, _, parent_hash, _hash, ..) => {
438                // Check that the parent_hash matches unique_blocks[1], which should be a full block
439                if let UniqueBlock::FullBlock(expected_hash) = seq1.unique_blocks()[1] {
440                    assert_eq!(
441                        *parent_hash,
442                        Some(expected_hash),
443                        "Parent hash should match unique_blocks[1]"
444                    );
445                } else {
446                    panic!("unique_blocks[1] should be a full block");
447                }
448            }
449            _ => panic!("Expected Promote signal as first signal"),
450        }
451
452        // Reset seq1 and check that it equals the original clone
453        let free_signals = seq1.reset_with_signal();
454
455        // 49 - 15 generated tokens
456        assert_eq!(seq1.generated_tokens(), 34);
457
458        // Verify the reset signals include proper cleanup events
459        assert!(!free_signals.is_empty());
460    }
461
462    #[test]
463    fn test_active_sequence_generate_signals() {
464        // Create a sequence with block size 16, max_output_tokens 4, initialized with tokens [0..14)
465        let initial_tokens: Vec<u32> = (0..14).collect();
466        let (mut seq, signal) = ActiveSequence::new_with_signal(initial_tokens, 5, Some(16), true);
467
468        // Initial signal - should have received a Use signal for the partial block
469        assert!(signal.is_some());
470        match signal {
471            Some(MoveBlock::Use(blocks, _hashes, ..)) => {
472                assert_eq!(blocks.len(), 1);
473                assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_)));
474            }
475            _ => panic!("Expected Use signal for the initial partial block"),
476        }
477
478        // Generate first two tokens - should not trigger new signals
479        seq.generate();
480        let signals_first = seq.generate();
481        assert_eq!(signals_first.len(), 0);
482
483        // Generate third token - this fills the block and should trigger both Promote and Use signals
484        let signals_second = seq.generate();
485        assert_eq!(signals_second.len(), 2);
486
487        // First signal should be Promote
488        match &signals_second[0] {
489            MoveBlock::Promote(_, _, parent_hash, _hash, ..) => {
490                assert_eq!(*parent_hash, None);
491            }
492            _ => panic!("Expected Promote signal as first signal after second token"),
493        }
494
495        // Second signal should be Use for new partial block
496        match &signals_second[1] {
497            MoveBlock::Use(blocks, _hashes, ..) => {
498                assert_eq!(blocks.len(), 1);
499                assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_)));
500            }
501            _ => panic!("Expected Use signal as second signal after second token"),
502        }
503
504        // Generate fourth token - should not trigger new signals as it's adding to partial block
505        let signals_third = seq.generate();
506        assert_eq!(signals_third.len(), 0);
507
508        // Generate last token - we reach max_output_tokens, should trigger Destroy and Deref signals
509        let signals_last = seq.generate();
510        assert_eq!(signals_last.len(), 2);
511
512        // First signal should be Destroy for the partial block
513        match &signals_last[0] {
514            MoveBlock::Destroy(blocks) => {
515                assert_eq!(blocks.len(), 1);
516                assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_)));
517            }
518            _ => panic!("Expected Destroy signal for partial block after fourth token"),
519        }
520
521        // Second signal should be Deref for the full block
522        match &signals_last[1] {
523            MoveBlock::Deref(blocks) => {
524                assert_eq!(blocks.len(), 1);
525                assert!(matches!(blocks[0], UniqueBlock::FullBlock(_)));
526            }
527            _ => panic!("Expected Deref signal for full block after fourth token"),
528        }
529    }
530}