dynamo-mocker 1.1.0

Mock LLM scheduler and KV manager for testing
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

//! # KV Manager (vLLM Backend)
//! A synchronous implementation of a block manager that handles MoveBlock signals for caching KV blocks.
//!
//! Uses [`HashCache`] for O(1) block lookups with active/inactive pool management.
//!
//! ## Block Operations
//! The KV manager processes four types of MoveBlock signals:
//!
//! ### Use
//! - Checks if block exists in active pool → increment reference count
//! - If in inactive pool → move to active pool
//! - If neither → try evicting from inactive pool to make room
//! - If inactive pool is empty → pre-empt the oldest running request
//!
//! ### Destroy
//! - Removes the block from the active pool
//!
//! ### Deref
//! - Decrements reference count of a block in active pool
//! - If count reaches zero → move block to inactive pool
//!
//! ### Promote
//! - Converts a partial block (uuid) into a full block (global block hash)
//!
//! ## Preemption
//! If a Use operation fails (typically due to insufficient space), a false boolean signal
//! is returned to the scheduler for preemption. Initial KV block allocations for new requests
//! should not fail due to the capacity checking during scheduling.
//!
//! ## NOTE
//! For simplicity (or non-simplicity), reference counting is tracked manually instead of using
//! the more idiomatic built-in Arc reference counter. This can be considered a shadow / mirror
//! implementation of the main block manager.
use crate::cache::HashCache;
use crate::common::kv_cache_trace;
use crate::common::protocols::{KvEventPublishers, MoveBlock, PrefillCost};
use crate::common::sequence::ActiveSequence;
use dynamo_kv_router::protocols::{
    ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData,
    KvCacheStoredBlockData, LocalBlockHash,
};
use dynamo_tokens::blocks::UniqueBlock;
use dynamo_tokens::{BlockHash, SequenceHash};
use rustc_hash::FxHashMap;

pub struct KvManager {
    cache: HashCache,
    block_size: usize,
    kv_event_publishers: KvEventPublishers,
    dp_rank: u32,
    next_event_id: u64,
}

impl KvManager {
    pub fn new(max_capacity: usize, block_size: usize) -> Self {
        Self::new_with_event_sink(max_capacity, block_size, KvEventPublishers::default(), 0)
    }

    pub fn new_with_event_sink(
        max_capacity: usize,
        block_size: usize,
        kv_event_publishers: KvEventPublishers,
        dp_rank: u32,
    ) -> Self {
        debug_assert!(max_capacity > 0, "max_capacity must be > 0");
        if !kv_event_publishers.is_empty() {
            tracing::info!(
                "KvManager initialized with event sink for DP rank {dp_rank} with block_size {block_size}"
            );
        }

        KvManager {
            cache: HashCache::new(max_capacity),
            block_size,
            kv_event_publishers,
            dp_rank,
            next_event_id: 0,
        }
    }

    /// Converts stored/removed blocks into KvCacheEventData and publishes if sink is available.
    fn publish_kv_event(
        &mut self,
        full_blocks: Vec<SequenceHash>,
        local_hashes: &[BlockHash],
        parent_hash: Option<u64>,
        is_store: bool,
        token_ids: Option<Vec<Vec<u32>>>,
    ) {
        if full_blocks.is_empty() {
            return;
        }

        kv_cache_trace::log_vllm_trace(
            if is_store { "allocation" } else { "eviction" },
            self.dp_rank,
            self.block_size,
            self.cache.num_active(),
            self.cache.num_inactive(),
            self.cache.max_capacity(),
        );

        if self.kv_event_publishers.is_empty() {
            return;
        }

        let event_data = if is_store {
            let num_blocks = full_blocks.len();
            let local_hashes_slice = &local_hashes[local_hashes
                .len()
                .checked_sub(num_blocks)
                .expect("local hashes fewer than stored blocks")..];

            KvCacheEventData::Stored(KvCacheStoreData {
                parent_hash: parent_hash.map(ExternalSequenceBlockHash),
                blocks: full_blocks
                    .into_iter()
                    .zip(local_hashes_slice.iter())
                    .map(|(global_hash, local_hash)| KvCacheStoredBlockData {
                        block_hash: ExternalSequenceBlockHash(global_hash),
                        tokens_hash: LocalBlockHash(*local_hash),
                        mm_extra_info: None,
                    })
                    .collect(),
            })
        } else {
            KvCacheEventData::Removed(KvCacheRemoveData {
                block_hashes: full_blocks
                    .into_iter()
                    .map(ExternalSequenceBlockHash)
                    .collect(),
            })
        };
        // Use incremental event ID starting from 0 and incrementing by 1 for each event.
        let event_id = self.next_event_id;
        self.next_event_id += 1;

        let event = KvCacheEvent {
            event_id,
            data: event_data,
            dp_rank: self.dp_rank,
        };

        if let Err(e) = self
            .kv_event_publishers
            .publish(event, token_ids.as_deref())
        {
            tracing::warn!("Failed to publish KV event: {e}");
        }
    }

    /// Process a MoveBlock instruction synchronously.
    ///
    /// For `MoveBlock::Use`, returns the number of blocks successfully allocated.
    /// On partial failure, blocks 0..N are committed but block N+1 could not be
    /// allocated. Callers should use the return value to track partial progress.
    ///
    /// For other variants, returns the total block count (they always succeed or panic).
    pub fn process(&mut self, event: &MoveBlock) -> usize {
        match event {
            MoveBlock::Use(hashes, local_hashes, token_ids, parent) => {
                let mut blocks_stored = Vec::<u64>::new();
                let mut stored_token_ids: Option<Vec<Vec<u32>>> =
                    token_ids.as_ref().map(|_| Vec::new());

                let mut parent_block: Option<&UniqueBlock> = parent.as_ref();
                let mut allocated = 0;
                for (i, hash) in hashes.iter().enumerate() {
                    // First check if it already exists in active blocks
                    if self.cache.contains_active(hash) {
                        // Block already active, just increment reference count
                        self.cache.increment_ref(hash);
                        parent_block = Some(hash);
                        allocated += 1;
                        continue;
                    }

                    // Then check if it exists in inactive and move it to active if found
                    if self.cache.reactivate(hash) {
                        parent_block = Some(hash);
                        allocated += 1;
                        continue;
                    }

                    // If at max capacity, evict the oldest entry from inactive blocks
                    if self.cache.is_at_capacity() {
                        let Some(evicted) = self.cache.evict_inactive() else {
                            break;
                        };
                        tracing::trace!(
                            "Evicting block from inactive pool: {evicted:?}, dp_rank={}",
                            self.dp_rank
                        );
                        if let UniqueBlock::FullBlock(evicted_full_block) = evicted {
                            self.publish_kv_event(vec![evicted_full_block], &[], None, false, None);
                        }
                    }

                    // Now insert the new block in active blocks with reference count 1
                    self.cache.insert_active(hash.clone(), 1);
                    allocated += 1;
                    // Track blocks for trace/event
                    if let UniqueBlock::FullBlock(stored_full_block) = hash {
                        blocks_stored.push(*stored_full_block);
                        if let Some(ref mut stids) = stored_token_ids {
                            stids.push(token_ids.as_ref().unwrap()[i].clone());
                        }
                    }
                }

                let parent_hash = match parent_block {
                    None => None,
                    Some(UniqueBlock::FullBlock(block)) => Some(*block),
                    Some(UniqueBlock::PartialBlock(_)) => panic!("parent block cannot be partial"),
                };
                self.publish_kv_event(
                    blocks_stored,
                    local_hashes,
                    parent_hash,
                    true,
                    stored_token_ids,
                );
                return allocated;
            }

            MoveBlock::Destroy(hashes) => {
                let mut blocks_destroyed = Vec::<u64>::new();
                // Process blocks in order (already reversed by caller if needed)
                for hash in hashes.iter() {
                    self.cache.remove_active(hash).unwrap();
                    // Track blocks for batch sending
                    if let UniqueBlock::FullBlock(destroyed_full_block) = hash {
                        blocks_destroyed.push(*destroyed_full_block);
                    }
                }

                self.publish_kv_event(blocks_destroyed, &[], None, false, None);
            }

            MoveBlock::Deref(hashes) => {
                // Process blocks in order (already reversed by caller if needed)
                for hash in hashes.iter() {
                    // Decrement reference count and check if we need to move to inactive
                    if let Some(ref_count) = self.cache.get_active_ref_count(hash) {
                        if ref_count == 0 {
                            panic!("Negative reference count would be encountered after Deref.");
                        }

                        // If reference count reaches zero, remove from active and move to inactive
                        if ref_count == 1 {
                            self.cache.deactivate(hash);
                        } else {
                            self.cache.decrement_ref(hash);
                        }
                    }
                }
            }

            MoveBlock::Promote(uuid, hash, parent_hash, local_hash, promote_token_ids) => {
                let uuid_block = UniqueBlock::PartialBlock(*uuid);
                let hash_block = UniqueBlock::FullBlock(*hash);

                assert_eq!(
                    self.cache.remove_active(&uuid_block),
                    Some(1),
                    "uuid_block {uuid_block:?} should exist and be unique with ref_count=1"
                );

                let hash_ref_count = self.cache.get_active_ref_count(&hash_block);
                // Block is new if it's not in active and not in inactive
                let is_new = if hash_ref_count.is_some() {
                    false
                } else {
                    !self.cache.remove_inactive(&hash_block)
                };

                self.cache
                    .insert_active(hash_block, hash_ref_count.unwrap_or(0) + 1);

                if is_new {
                    self.publish_kv_event(
                        vec![*hash],
                        &[*local_hash],
                        *parent_hash,
                        true,
                        promote_token_ids.as_ref().map(|t| vec![t.clone()]),
                    );
                }
            }
        }

        1
    }

    /// Get the count of blocks that aren't in active or inactive pools
    pub fn probe_new_blocks(&self, blocks: &[UniqueBlock]) -> usize {
        blocks
            .iter()
            .filter(|&block| !self.cache.contains(block))
            .count()
    }

    /// Get the current capacity (active blocks + inactive blocks)
    pub fn current_capacity(&self) -> usize {
        self.cache.current_capacity()
    }

    /// Get the current capacity as a percentage of the maximum capacity
    pub fn current_capacity_perc(&self) -> f64 {
        self.cache.current_capacity() as f64 / self.cache.max_capacity() as f64
    }

    /// Get the number of active blocks
    pub fn num_active_blocks(&self) -> usize {
        self.cache.num_active()
    }

    /// Get the percentage of active blocks relative to maximum capacity
    pub fn get_active_perc(&self) -> f64 {
        self.cache.num_active() as f64 / self.cache.max_capacity() as f64
    }

    /// Get the number of inactive blocks
    pub fn num_inactive_blocks(&self) -> usize {
        self.cache.num_inactive()
    }

    /// Get the keys of inactive blocks
    pub fn get_inactive_blocks(&self) -> Vec<&UniqueBlock> {
        self.cache.inactive_keys().collect()
    }

    /// Get the keys of active blocks
    pub fn get_active_blocks(&self) -> Vec<&UniqueBlock> {
        self.cache.active_keys().collect()
    }

    pub fn max_capacity(&self) -> usize {
        self.cache.max_capacity()
    }

    pub fn block_size(&self) -> usize {
        self.block_size
    }

    pub fn dp_rank(&self) -> u32 {
        self.dp_rank
    }

    /// Direct access to active blocks map (for tests).
    pub fn active_blocks(&self) -> &FxHashMap<UniqueBlock, usize> {
        self.cache.active_blocks()
    }

    /// Check if a sequence can be scheduled and calculate cost if possible
    pub fn get_prefill_cost(&self, sequence: &ActiveSequence) -> PrefillCost {
        let seq_blocks = sequence.unique_blocks();

        // Find the longest prefix that exists in cache
        // We must stop at the first cache miss since KV states are computed sequentially
        let mut overlap_blocks = 0;
        for block in seq_blocks {
            if !self.cache.contains(block) {
                // First cache miss - can't use anything after this point
                break;
            }
            overlap_blocks += 1;
        }

        let new_blocks = seq_blocks.len() - overlap_blocks;
        // Clamp cached_tokens to handle partial blocks (last block may have < block_size tokens)
        let cached_tokens = (overlap_blocks * self.block_size).min(sequence.num_input_tokens());
        let new_tokens = sequence.num_input_tokens() - cached_tokens;

        PrefillCost {
            new_blocks,
            new_tokens,
            cached_tokens,
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::sync::Arc;

    use crate::common::protocols::KvCacheEventSink;

    #[test]
    fn test_failure_on_max_capacity() {
        // Create a KvManager with 10 blocks capacity
        let mut manager = KvManager::new(10, 16);

        // Helper function to use multiple blocks that returns the count allocated
        fn use_blocks(manager: &mut KvManager, ids: Vec<u64>) -> usize {
            let blocks: Vec<_> = ids.iter().map(|&id| UniqueBlock::FullBlock(id)).collect();
            let hashes: Vec<_> = ids.into_iter().collect();
            manager.process(&MoveBlock::Use(blocks, hashes, None, None))
        }

        // First use 10 blocks (0 to 9) in a batch
        let response = use_blocks(&mut manager, (0..10).collect());
        assert_eq!(response, 10, "Expected all 10 blocks allocated");

        // Verify we are at capacity
        assert_eq!(manager.current_capacity(), 10);

        // The 11th block should return 0, not panic
        let response = use_blocks(&mut manager, vec![10]);
        assert_eq!(
            response, 0,
            "Expected 0 blocks allocated when exceeding max capacity"
        );
    }

    #[test]
    fn test_block_lifecycle_stringent() {
        // Create a KvManager with 10 blocks capacity (no KV event publisher for tests)
        let mut manager = KvManager::new(10, 16);

        // Helper function to use multiple blocks
        fn use_blocks(manager: &mut KvManager, ids: Vec<u64>) {
            let blocks: Vec<_> = ids.iter().map(|&id| UniqueBlock::FullBlock(id)).collect();
            let hashes: Vec<_> = ids.into_iter().collect();
            manager.process(&MoveBlock::Use(blocks, hashes, None, None));
        }

        // Helper function to destroy multiple blocks
        fn destroy_blocks(manager: &mut KvManager, ids: Vec<u64>) {
            let blocks = ids.into_iter().map(UniqueBlock::FullBlock).collect();
            manager.process(&MoveBlock::Destroy(blocks));
        }

        // Helper function to deref multiple blocks
        fn deref_blocks(manager: &mut KvManager, ids: Vec<u64>) {
            let blocks = ids.into_iter().map(UniqueBlock::FullBlock).collect();
            manager.process(&MoveBlock::Deref(blocks));
        }

        // Helper function to check if active blocks contain expected blocks with expected ref counts
        fn assert_active_blocks(manager: &KvManager, expected_blocks: &[(u64, usize)]) {
            assert_eq!(
                manager.active_blocks().len(),
                expected_blocks.len(),
                "Active blocks count doesn't match expected"
            );

            for &(id, ref_count) in expected_blocks {
                let block = UniqueBlock::FullBlock(id);
                assert!(
                    manager.active_blocks().contains_key(&block),
                    "Block {id} not found in active blocks",
                );
                assert_eq!(
                    manager.active_blocks().get(&block),
                    Some(&ref_count),
                    "Block {id} has wrong reference count",
                );
            }
        }

        // Helper function to check if inactive blocks contain expected blocks
        fn assert_inactive_blocks(
            manager: &KvManager,
            expected_size: usize,
            expected_blocks: &[u64],
        ) {
            let inactive_blocks = manager.get_inactive_blocks();
            let inactive_blocks_count = manager.num_inactive_blocks();

            assert_eq!(
                inactive_blocks_count, expected_size,
                "Inactive blocks count doesn't match expected"
            );

            for &id in expected_blocks {
                let block = UniqueBlock::FullBlock(id);
                assert!(
                    inactive_blocks.iter().any(|&b| *b == block),
                    "Block {id} not found in inactive blocks",
                );
            }
        }

        // First use blocks 0, 1, 2, 3, 4 in a batch
        use_blocks(&mut manager, (0..5).collect());

        // Then use blocks 0, 1, 5, 6 in a batch
        use_blocks(&mut manager, vec![0, 1, 5, 6]);

        // Check that the blocks 0 and 1 are in active blocks, both with reference counts of 2
        assert_active_blocks(
            &manager,
            &[(0, 2), (1, 2), (2, 1), (3, 1), (4, 1), (5, 1), (6, 1)],
        );

        // Now destroy block 4
        destroy_blocks(&mut manager, vec![4]);

        // And deref blocks 3, 2, 1, 0 in this order as a batch
        deref_blocks(&mut manager, vec![0, 1, 2, 3]);

        // Check that the inactive_blocks is size 2 (via num_objects) and contains 3 and 2
        assert_inactive_blocks(&manager, 2, &[3, 2]);
        assert_active_blocks(&manager, &[(0, 1), (1, 1), (5, 1), (6, 1)]);

        // Now destroy block 6
        destroy_blocks(&mut manager, vec![6]);

        // And deref blocks 5, 1, 0 as a batch
        deref_blocks(&mut manager, vec![0, 1, 5]);

        // Check that the inactive_blocks is size 5, and contains 0, 1, 2, 3, 5
        assert_inactive_blocks(&manager, 5, &[0, 1, 2, 3, 5]);
        assert_active_blocks(&manager, &[]);

        // Now use 0, 1, 2, 7, 8, 9 as a batch
        use_blocks(&mut manager, vec![0, 1, 2, 7, 8, 9]);

        // Check that the inactive_blocks is size 2, and contains 3 and 5
        assert_inactive_blocks(&manager, 2, &[3, 5]);
        assert_active_blocks(&manager, &[(0, 1), (1, 1), (2, 1), (7, 1), (8, 1), (9, 1)]);

        // Test the new_blocks method - only block 4 should be new out of [0,1,2,3,4]
        let blocks_to_check: Vec<UniqueBlock> = vec![0, 1, 2, 3, 4]
            .into_iter()
            .map(UniqueBlock::FullBlock)
            .collect();
        assert_eq!(manager.probe_new_blocks(&blocks_to_check), 1);

        // Now use blocks 10, 11, 12 as a batch
        use_blocks(&mut manager, vec![10, 11, 12]);

        // Check that the inactive_blocks is size 1 and contains only 5
        assert_inactive_blocks(&manager, 1, &[5]);

        use_blocks(&mut manager, vec![13]);
    }

    #[test]
    fn test_chunked_prefill_parent_hash() {
        use std::sync::Mutex;

        use crate::common::sequence::ActiveSequence;

        #[derive(Default)]
        struct CapturingSink {
            events: Mutex<Vec<KvCacheEvent>>,
        }

        impl KvCacheEventSink for CapturingSink {
            fn publish(&self, event: KvCacheEvent) -> anyhow::Result<()> {
                self.events.lock().unwrap().push(event);
                Ok(())
            }
        }

        let block_size = 64;
        let tokens: Vec<u32> = (0..512).collect(); // 8 blocks
        let mut seq = ActiveSequence::new(tokens, 100, Some(block_size), true, false);

        let sink = Arc::new(CapturingSink::default());
        let mut manager = KvManager::new_with_event_sink(
            256,
            block_size,
            KvEventPublishers::new(Some(sink.clone() as _), None),
            0,
        );

        // Chunk 1: allocate blocks 0-3
        let signal = seq.prepare_allocation(256).unwrap();
        manager.process(&signal);
        seq.commit_allocation(256);

        // Chunk 2: allocate blocks 4-7
        let signal = seq.prepare_allocation(512).unwrap();
        manager.process(&signal);
        seq.commit_allocation(512);

        let events = sink.events.lock().unwrap();
        assert_eq!(events.len(), 2, "expected two store events");

        // First event: parent_hash should be None (starts from root)
        let KvCacheEventData::Stored(ref store1) = events[0].data else {
            panic!("expected store event");
        };
        assert!(
            store1.parent_hash.is_none(),
            "first chunk should have no parent"
        );

        // Second event: parent_hash should be the seq_hash of block 3
        // (the last block from the first chunk)
        let KvCacheEventData::Stored(ref store2) = events[1].data else {
            panic!("expected store event");
        };
        let expected_parent = seq.unique_blocks()[3].clone();
        let UniqueBlock::FullBlock(expected_hash) = expected_parent else {
            panic!("expected full block");
        };
        assert_eq!(
            store2.parent_hash,
            Some(ExternalSequenceBlockHash(expected_hash)),
            "second chunk's parent should be block 3's seq_hash"
        );
    }

    #[test]
    fn test_repreempt_after_partial_recompute_only_frees_reallocated_blocks() {
        let mut seq = ActiveSequence::new((0..6).collect(), 16, Some(4), true, false);
        let mut manager = KvManager::new(16, 4);

        let signal = seq.take_creation_signal().unwrap();
        assert_eq!(manager.process(&signal), 2);

        for _ in 0..3 {
            let signals = seq.generate();
            for signal in &signals {
                manager.process(signal);
            }
            if seq.generated_tokens() < seq.max_output_tokens() {
                seq.commit_allocation(seq.len());
            }
        }
        assert_eq!(manager.num_active_blocks(), 3);

        let first_reset = seq.reset_with_signal();
        for signal in &first_reset {
            manager.process(signal);
        }
        assert_eq!(manager.num_active_blocks(), 0);

        let prompt_only = seq.prepare_allocation(seq.num_input_tokens()).unwrap();
        assert_eq!(manager.process(&prompt_only), 2);
        seq.commit_allocation(seq.num_input_tokens());
        assert_eq!(manager.num_active_blocks(), 2);

        let second_reset = seq.reset_with_signal();
        for signal in &second_reset {
            manager.process(signal);
        }

        assert_eq!(manager.num_active_blocks(), 0);
    }
}