dynamo_llm/mocker/
sequence.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::mocker::protocols::MoveBlock;
5use crate::tokens::blocks::UniqueBlock;
6use crate::tokens::{TokenBlockSequence, Tokens};
7use derive_getters::Getters;
8use rand::random;
9use uuid;
10
11/// Create unique blocks from a TokenBlockSequence
12fn create_unique_blocks_from_sequence(
13    tokens: &TokenBlockSequence,
14    uuid: Option<uuid::Uuid>,
15    block_size: usize,
16    enable_prefix_caching: bool,
17) -> Vec<UniqueBlock> {
18    let mut unique_blocks: Vec<UniqueBlock> = tokens
19        .blocks()
20        .iter()
21        .map(|block| {
22            if enable_prefix_caching {
23                UniqueBlock::FullBlock(block.sequence_hash())
24            } else {
25                UniqueBlock::FullBlock(random::<u64>())
26            }
27        })
28        .collect();
29
30    // Only push the partial block if tokens count isn't a multiple of block_size
31    if !tokens.total_tokens().is_multiple_of(block_size) {
32        unique_blocks.push(match uuid {
33            Some(uuid) => UniqueBlock::PartialBlock(uuid),
34            None => UniqueBlock::default(),
35        });
36    }
37    unique_blocks
38}
39
40/// A sequence that is actively being built, with the ability to add tokens and commit to hashes
41/// TODO: reuse tokens
42#[derive(Debug, Getters)]
43pub struct ActiveSequence {
44    unique_blocks: Vec<UniqueBlock>,
45
46    tokens: TokenBlockSequence,
47
48    #[getter(copy)]
49    block_size: usize,
50
51    #[getter(copy)]
52    max_output_tokens: usize,
53
54    #[getter(copy)]
55    generated_tokens: usize,
56
57    #[getter(copy)]
58    already_generated_tokens: usize,
59
60    #[getter(copy)]
61    num_input_tokens: usize,
62
63    creation_signal: Option<MoveBlock>,
64
65    #[getter(copy)]
66    enable_prefix_caching: bool,
67}
68
69impl ActiveSequence {
70    /// Create a new ActiveSequence instance with the provided tokens
71    pub fn new(
72        tokens: Vec<u32>,
73        max_output_tokens: usize,
74        block_size: Option<usize>,
75        enable_prefix_caching: bool,
76    ) -> Self {
77        let block_size = block_size.unwrap_or(64);
78        assert!(block_size > 1, "block_size must be greater than 1");
79        let num_input_tokens = tokens.len();
80
81        let tokens = Tokens::from(tokens).into_sequence(block_size as u32, Some(1337));
82        let unique_blocks =
83            create_unique_blocks_from_sequence(&tokens, None, block_size, enable_prefix_caching);
84        let creation_signal = Some(MoveBlock::Use(unique_blocks.clone()));
85
86        Self {
87            unique_blocks,
88            tokens,
89            block_size,
90            max_output_tokens,
91            generated_tokens: 0,
92            already_generated_tokens: 0,
93            num_input_tokens,
94            creation_signal,
95            enable_prefix_caching,
96        }
97    }
98
99    pub fn extra_tokens(&self) -> u32 {
100        (self.len() % self.block_size) as u32
101    }
102
103    pub fn len(&self) -> usize {
104        self.tokens.total_tokens()
105    }
106
107    pub fn is_empty(&self) -> bool {
108        self.tokens.total_tokens() == 0
109    }
110
111    pub fn take_creation_signal(&mut self) -> Option<MoveBlock> {
112        self.creation_signal.take()
113    }
114
115    pub fn block_hashes(&self) -> Vec<u64> {
116        self.tokens
117            .blocks()
118            .iter()
119            .map(|block| block.block_hash())
120            .collect()
121    }
122
123    /// Create a new ActiveSequence instance and return the creation signal
124    pub fn new_with_signal(
125        tokens: Vec<u32>,
126        max_output_tokens: usize,
127        block_size: Option<usize>,
128        enable_prefix_caching: bool,
129    ) -> (Self, Option<MoveBlock>) {
130        let mut sequence = Self::new(tokens, max_output_tokens, block_size, enable_prefix_caching);
131        let signal = sequence.take_creation_signal();
132        (sequence, signal)
133    }
134
135    /// Get the parent hash from the second-to-last block if it exists and is a FullBlock
136    fn get_parent_hash(&self) -> Option<u64> {
137        if self.unique_blocks.len() < 2 {
138            return None;
139        }
140        match &self.unique_blocks[self.unique_blocks.len() - 2] {
141            UniqueBlock::FullBlock(hash) => Some(*hash),
142            _ => panic!("Cannot have a partial block as parent"),
143        }
144    }
145
146    /// Push a token to the sequence
147    pub fn push(&mut self, token: u32) -> Option<Vec<MoveBlock>> {
148        self.tokens.append(token).expect("Token push failed.");
149        self.generated_tokens += 1;
150
151        if self.len() % self.block_size != 1 {
152            return None;
153        }
154
155        // Add a partial block for the first token in a new partial sequence
156        // Send Use signal (to allocate space for this new generation block)
157        let mut signals = Vec::new();
158
159        // Replace last partial block with full block if it exists
160        if let Some(UniqueBlock::PartialBlock(uuid)) = self.unique_blocks.last().cloned() {
161            let last_block_hash = if self.enable_prefix_caching {
162                self.tokens.last_complete_block().unwrap().sequence_hash()
163            } else {
164                random::<u64>()
165            };
166            self.unique_blocks.pop();
167            self.unique_blocks
168                .push(UniqueBlock::FullBlock(last_block_hash));
169            signals.push(MoveBlock::Promote(
170                uuid,
171                last_block_hash,
172                self.get_parent_hash(),
173            ));
174        }
175
176        let new_partial_block = UniqueBlock::default();
177        self.unique_blocks.push(new_partial_block.clone());
178        signals.push(MoveBlock::Use(vec![new_partial_block]));
179        Some(signals)
180    }
181
182    /// Generate a random token, push it to the sequence, and increment generation count.
183    ///
184    /// This function:
185    /// - Generates a random token and adds it to the current sequence
186    /// - Acquires a new partial block if needed or promotes an existing partial block to a full block
187    /// - Returns appropriate signals for the KvManager to process
188    ///
189    /// # Panics
190    ///
191    /// Calling this function when max_output_tokens has already been reached will cause a panic.
192    /// Always check `generated_tokens < max_output_tokens` before calling this method.
193    pub fn generate(&mut self) -> Vec<MoveBlock> {
194        // Assert that we haven't reached the maximum output tokens
195        assert!(
196            self.generated_tokens < self.max_output_tokens,
197            "Cannot generate more tokens: reached max_output_tokens limit"
198        );
199
200        // Generate a random token
201        let token = random::<u32>();
202
203        // Collect signals
204        let mut signals = Vec::new();
205
206        // Push the token to the sequence and collect any signals
207        if let Some(move_blocks) = self.push(token) {
208            signals.extend(move_blocks);
209        }
210
211        // Check if we've reached the limit after pushing
212        if self.generated_tokens != self.max_output_tokens {
213            return signals;
214        }
215
216        // Free all blocks when we reach max tokens
217        signals.extend(self.free_signal());
218        signals
219    }
220
221    /// Free all blocks, generating appropriate signals for each block type
222    pub fn free_signal(&self) -> Vec<MoveBlock> {
223        self.unique_blocks
224            .iter()
225            .rev()
226            .map(|block| match block {
227                UniqueBlock::PartialBlock(uuid) => {
228                    MoveBlock::Destroy(vec![UniqueBlock::PartialBlock(*uuid)])
229                }
230                UniqueBlock::FullBlock(hash) => {
231                    MoveBlock::Deref(vec![UniqueBlock::FullBlock(*hash)])
232                }
233            })
234            .collect()
235    }
236
237    /// Reset the sequence to its initial state and return the free signals from freeing current blocks
238    pub fn reset_with_signal(&mut self) -> Vec<MoveBlock> {
239        let free_signal = self.free_signal();
240
241        self.tokens.truncate(self.num_input_tokens).unwrap();
242        self.unique_blocks = create_unique_blocks_from_sequence(
243            &self.tokens,
244            None,
245            self.block_size,
246            self.enable_prefix_caching,
247        );
248        self.already_generated_tokens = self.generated_tokens.max(self.already_generated_tokens);
249        self.generated_tokens = 0;
250        self.creation_signal = Some(MoveBlock::Use(self.unique_blocks.clone()));
251
252        free_signal
253    }
254
255    /// Pops last token in the sequence.
256    pub fn pop(&mut self) {
257        self.tokens.pop();
258        self.generated_tokens = self.generated_tokens.saturating_sub(1);
259
260        // Reverts to the last full block
261        if self.tokens.total_tokens().is_multiple_of(self.block_size) {
262            self.unique_blocks.pop();
263        }
264    }
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270
271    #[test]
272    fn test_active_sequence_push() {
273        // Create a sequence with block size 16 initialized with tokens [0..15]
274        let initial_tokens: Vec<u32> = (0..15).collect();
275        let (mut seq1, signal1) =
276            ActiveSequence::new_with_signal(initial_tokens, 100, Some(16), true);
277        assert_eq!(seq1.num_input_tokens(), 15);
278        assert_eq!(seq1.len(), 15);
279
280        // Check that we got a Use signal
281        assert!(signal1.is_some());
282        match &signal1 {
283            Some(MoveBlock::Use(blocks)) => {
284                assert_eq!(blocks.len(), 1);
285            }
286            _ => panic!("Expected Use signal"),
287        }
288
289        // Push token 15 which should complete the block (no signals yet)
290        let signal_15 = seq1.push(15);
291        assert!(
292            signal_15.is_none(),
293            "Completing a block should not trigger signals"
294        );
295
296        // Push token 16 which should trigger both Promote and Use signals
297        let signal_16 = seq1.push(16);
298        assert!(signal_16.is_some());
299        let signal_16 = signal_16.unwrap();
300        assert_eq!(signal_16.len(), 2);
301
302        // First signal should be Promote for the previous block
303        match &signal_16[0] {
304            MoveBlock::Promote(_, _, parent_hash) => {
305                assert_eq!(*parent_hash, None);
306            }
307            _ => panic!("Expected Promote signal as second signal"),
308        }
309
310        // Second signal should be Use for new partial block
311        match &signal_16[1] {
312            MoveBlock::Use(blocks) => {
313                assert_eq!(blocks.len(), 1);
314                assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_)));
315            }
316            _ => panic!("Expected Use signal as first signal"),
317        }
318
319        // Verify state after pushing tokens
320        assert_eq!(seq1.unique_blocks().len(), 2); // One full block and one partial block
321        assert_eq!(seq1.len(), 17);
322        assert_eq!(seq1.len() % seq1.block_size(), 1);
323
324        // Create another sequence with block size 16 initialized with tokens [0..17]
325        let extended_tokens: Vec<u32> = (0..16).collect();
326        let (mut seq2, _) = ActiveSequence::new_with_signal(extended_tokens, 100, Some(16), true);
327        seq2.push(16);
328        seq2.pop();
329        seq2.push(16);
330
331        // Simplified assertions
332        assert_eq!(
333            seq1.unique_blocks()[0],
334            seq2.unique_blocks()[0],
335            "First blocks should be the same"
336        );
337
338        assert_ne!(
339            seq1.unique_blocks()[1],
340            seq2.unique_blocks()[1],
341            "Second blocks should be different"
342        );
343
344        // Reset partial block on seq1 and push back token 16
345        seq1.push(17);
346        seq1.pop();
347        seq1.pop();
348        seq1.push(16);
349
350        // Now push tokens 17..32 to both sequences
351        for token in 17..33 {
352            seq1.push(token);
353            seq2.push(token);
354        }
355
356        // Both sequences should now have 2 blocks:
357        // 1. FullBlock for tokens 0-15
358        // 2. FullBlock for tokens 16-31
359        // 3. No partial block since there are no remaining tokens
360        assert_eq!(
361            seq1.unique_blocks().len(),
362            3,
363            "seq1 should have exactly 3 blocks"
364        );
365        assert_eq!(
366            seq2.unique_blocks().len(),
367            3,
368            "seq2 should have exactly 3 blocks"
369        );
370        assert_eq!(
371            seq1.len() % seq1.block_size(),
372            1,
373            "seq1 should have 1 partial token"
374        );
375        assert_eq!(
376            seq2.len() % seq2.block_size(),
377            1,
378            "seq2 should have 1 partial token"
379        );
380
381        // Verify that both sequences have identical blocks up to the second position
382        assert_eq!(
383            &seq1.unique_blocks()[0..2],
384            &seq2.unique_blocks()[0..2],
385            "First two blocks should be identical"
386        );
387
388        // Push tokens 34..47 to seq1
389        for token in 33..48 {
390            seq1.push(token);
391        }
392
393        // Push token 48 and get the signal - this completes the block and triggers signals
394        let signal = seq1.push(48);
395        let signal = signal.unwrap();
396
397        // Check that signal[0] is promote
398        match &signal[0] {
399            MoveBlock::Promote(_, _, parent_hash) => {
400                // Check that the parent_hash matches unique_blocks[1], which should be a full block
401                if let UniqueBlock::FullBlock(expected_hash) = seq1.unique_blocks()[1] {
402                    assert_eq!(
403                        *parent_hash,
404                        Some(expected_hash),
405                        "Parent hash should match unique_blocks[1]"
406                    );
407                } else {
408                    panic!("unique_blocks[1] should be a full block");
409                }
410            }
411            _ => panic!("Expected Promote signal as first signal"),
412        }
413
414        // Reset seq1 and check that it equals the original clone
415        let free_signals = seq1.reset_with_signal();
416
417        // 49 - 15 generated tokens
418        assert_eq!(seq1.already_generated_tokens, 34);
419
420        // Verify the reset signals include proper cleanup events
421        assert!(!free_signals.is_empty());
422    }
423
424    #[test]
425    fn test_active_sequence_generate_signals() {
426        // Create a sequence with block size 16, max_output_tokens 4, initialized with tokens [0..14)
427        let initial_tokens: Vec<u32> = (0..14).collect();
428        let (mut seq, signal) = ActiveSequence::new_with_signal(initial_tokens, 5, Some(16), true);
429
430        // Initial signal - should have received a Use signal for the partial block
431        assert!(signal.is_some());
432        match signal {
433            Some(MoveBlock::Use(blocks)) => {
434                assert_eq!(blocks.len(), 1);
435                assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_)));
436            }
437            _ => panic!("Expected Use signal for the initial partial block"),
438        }
439
440        // Generate first two tokens - should not trigger new signals
441        seq.generate();
442        let signals_first = seq.generate();
443        assert_eq!(signals_first.len(), 0);
444
445        // Generate third token - this fills the block and should trigger both Promote and Use signals
446        let signals_second = seq.generate();
447        assert_eq!(signals_second.len(), 2);
448
449        // First signal should be Promote
450        match &signals_second[0] {
451            MoveBlock::Promote(_, _, parent_hash) => {
452                assert_eq!(*parent_hash, None);
453            }
454            _ => panic!("Expected Promote signal as first signal after second token"),
455        }
456
457        // Second signal should be Use for new partial block
458        match &signals_second[1] {
459            MoveBlock::Use(blocks) => {
460                assert_eq!(blocks.len(), 1);
461                assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_)));
462            }
463            _ => panic!("Expected Use signal as second signal after second token"),
464        }
465
466        // Generate fourth token - should not trigger new signals as it's adding to partial block
467        let signals_third = seq.generate();
468        assert_eq!(signals_third.len(), 0);
469
470        // Generate last token - we reach max_output_tokens, should trigger Destroy and Deref signals
471        let signals_last = seq.generate();
472        assert_eq!(signals_last.len(), 2);
473
474        // First signal should be Destroy for the partial block
475        match &signals_last[0] {
476            MoveBlock::Destroy(blocks) => {
477                assert_eq!(blocks.len(), 1);
478                assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_)));
479            }
480            _ => panic!("Expected Destroy signal for partial block after fourth token"),
481        }
482
483        // Second signal should be Deref for the full block
484        match &signals_last[1] {
485            MoveBlock::Deref(blocks) => {
486                assert_eq!(blocks.len(), 1);
487                assert!(matches!(blocks[0], UniqueBlock::FullBlock(_)));
488            }
489            _ => panic!("Expected Deref signal for full block after fourth token"),
490        }
491    }
492}