Skip to main content

oxillama_runtime/
sequence_pool.rs

1//! SSM runtime bridge — polymorphic sequence-state pool.
2//!
3//! # Overview
4//!
5//! OxiLLaMa supports two categories of model architectures:
6//!
7//! 1. **Attention-based** (LLaMA, Qwen3, Mistral, Gemma, Phi, …): per-sequence
8//!    state is the KV cache (a contiguous K/V buffer per layer).
9//! 2. **SSM-based** (Mamba-2, …): per-sequence state is a set of per-layer
10//!    recurrent hidden vectors; there is no KV cache.
11//!
12//! The [`SequencePool`] enum abstracts over both kinds via the
13//! [`oxillama_arch::common::sequence_state::SequenceState`] trait.  The engine
14//! picks the right pool variant at load time by examining the loaded
15//! architecture; both variants expose the same `alloc` / `release` / `slot`
16//! interface so the rest of the engine stays arch-agnostic.
17//!
18//! ## Design notes
19//!
20//! - Slots are identified by a `usize` index (same as `Sequence::slot_id`).
21//! - A slot is "live" when it holds a `Box<dyn SequenceState>`.
22//! - On `release` the state is **reset** (zeroed) and returned to the free pool.
23//! - Neither variant interacts with the KV cache from `kv_cache/mod.rs`; the
24//!   KV-based pool manages its own separate per-slot state.
25//! - The SSM state pool owns the `Box<dyn SequenceState>` objects outright;
26//!   the KV-based pool keeps a `KvCachePool` from which page indices are lent.
27//!
28//! ## Thread safety
29//!
30//! `SequencePool` is **not** `Send` + `Sync` by itself; it is intended to be
31//! owned by a single-threaded engine or wrapped in a `Mutex` by the caller.
32
33use crate::kv_pool::KvCachePool;
34use oxillama_arch::common::sequence_state::SequenceState;
35use thiserror::Error;
36
37// ─── Error ────────────────────────────────────────────────────────────────────
38
39/// Errors produced by pool operations.
40#[derive(Debug, Error)]
41pub enum PoolError {
42    /// The pool has no free slots; the request must wait or be rejected.
43    #[error("sequence pool exhausted: no free slots available")]
44    Exhausted,
45    /// A slot index passed to a pool operation does not identify a live slot.
46    #[error("invalid slot index {0}: slot is not live or out of range")]
47    InvalidSlot(usize),
48}
49
50/// Convenience alias.
51pub type PoolResult<T> = Result<T, PoolError>;
52
53// ─── SequenceSlot ─────────────────────────────────────────────────────────────
54
55/// A live sequence slot in the [`SsmStatePool`].
56///
57/// Each slot carries:
58/// - `state`: the arch-specific [`SequenceState`] (SSM hidden state or an
59///   attention position counter wrapped by the arch crate).
60/// - `position`: current token position in the sequence (mirrors
61///   `state.step_position()`, but accessible without a vtable call).
62/// - `request_id`: the logical request ID associated with this slot (matches
63///   [`Sequence::id`](crate::scheduler::Sequence::id)); `0` = unassigned.
64pub struct SequenceSlot {
65    /// Arch-specific sequence state (SSM hidden vectors, or attention counter).
66    pub state: Box<dyn SequenceState>,
67    /// Current token position (0-indexed).
68    pub position: usize,
69    /// Request ID bound to this slot (0 = none).
70    pub request_id: u64,
71}
72
73impl SequenceSlot {
74    /// Create a new slot with the given state and a zero request ID.
75    pub fn new(state: Box<dyn SequenceState>) -> Self {
76        Self {
77            position: state.step_position(),
78            state,
79            request_id: 0,
80        }
81    }
82
83    /// Advance the internal position counter by one (after each forward step).
84    ///
85    /// Also calls `state.advance()` so both the slot's cached `position` and
86    /// the underlying trait object stay in sync.
87    pub fn step(&mut self) {
88        self.state.advance();
89        self.position = self.state.step_position();
90    }
91
92    /// Reset the slot to position 0 and clear the state.
93    ///
94    /// The slot's `request_id` is reset to 0 as well so that stale IDs are
95    /// not accidentally read after re-allocation.
96    pub fn reset(&mut self) {
97        self.state.reset();
98        self.position = 0;
99        self.request_id = 0;
100    }
101}
102
103// ─── SsmStatePool ─────────────────────────────────────────────────────────────
104
105/// A free-list pool of [`SequenceSlot`]s for SSM-based models.
106///
107/// Pre-allocates `capacity` slots at construction time.  Slots are identified
108/// by their index into the internal `slots` vector.
109///
110/// Free slots are tracked by a `free_list: Vec<usize>`.  `alloc` pops from the
111/// list; `release` resets the slot and pushes back.
112pub struct SsmStatePool {
113    /// All slots, indexed by slot ID.  `Some` = live (allocated); `None` = never
114    /// initialised (only possible during construction before the pool is fully
115    /// initialised — in practice every index 0..capacity is always `Some`).
116    slots: Vec<Option<SequenceSlot>>,
117    /// Indices of slots currently on the free list.
118    free_list: Vec<usize>,
119    /// Total capacity (never changes after construction).
120    capacity: usize,
121}
122
123impl SsmStatePool {
124    /// Create a pool by calling `ForwardPass::allocate_sequence_state` for each slot.
125    ///
126    /// This is the preferred construction path because it delegates state
127    /// allocation to the architecture implementation, rather than hard-coding
128    /// the state type at the call site. The runtime calls this once at model
129    /// load time, after `ForwardPass` is available.
130    ///
131    /// ```ignore
132    /// let pool = SsmStatePool::from_forward_pass(fwd_pass.as_ref(), capacity, max_ctx);
133    /// ```
134    pub fn from_forward_pass(
135        forward_pass: &dyn oxillama_arch::traits::ForwardPass,
136        capacity: usize,
137        max_context_length: usize,
138    ) -> Self {
139        Self::new(capacity, |_| {
140            forward_pass.allocate_sequence_state(max_context_length)
141        })
142    }
143
144    /// Create a new pool using a factory closure to produce each slot's state.
145    ///
146    /// The closure is called once per slot with the slot index.  Use it to
147    /// initialise arch-specific state (e.g. `Mamba2SequenceState::new(...)`).
148    ///
149    /// ```ignore
150    /// let pool = SsmStatePool::new(8, |_| {
151    ///     Box::new(Mamba2SequenceState::new(24, 16, 256, 4096))
152    /// });
153    /// ```
154    pub fn new<F>(capacity: usize, mut make_state: F) -> Self
155    where
156        F: FnMut(usize) -> Box<dyn SequenceState>,
157    {
158        let mut slots = Vec::with_capacity(capacity);
159        let mut free_list = Vec::with_capacity(capacity);
160
161        for i in 0..capacity {
162            let state = make_state(i);
163            slots.push(Some(SequenceSlot::new(state)));
164            free_list.push(i);
165        }
166
167        Self {
168            slots,
169            free_list,
170            capacity,
171        }
172    }
173
174    /// Allocate a free slot and bind it to `request_id`.
175    ///
176    /// Returns `Ok(slot_idx)` on success.
177    ///
178    /// # Errors
179    ///
180    /// Returns [`PoolError::Exhausted`] when no free slots are available.
181    pub fn alloc(&mut self, request_id: u64) -> PoolResult<usize> {
182        let idx = self.free_list.pop().ok_or(PoolError::Exhausted)?;
183        if let Some(slot) = self.slots[idx].as_mut() {
184            slot.request_id = request_id;
185        }
186        Ok(idx)
187    }
188
189    /// Release slot `idx` back to the free list.
190    ///
191    /// The slot's state is reset (zeroed) and its `request_id` cleared.
192    ///
193    /// # Errors
194    ///
195    /// Returns [`PoolError::InvalidSlot`] if `idx` is out of range or already free.
196    pub fn release(&mut self, idx: usize) -> PoolResult<()> {
197        if idx >= self.slots.len() {
198            return Err(PoolError::InvalidSlot(idx));
199        }
200        // Check that the slot is not already on the free list.
201        if self.free_list.contains(&idx) {
202            return Err(PoolError::InvalidSlot(idx));
203        }
204        if let Some(slot) = self.slots[idx].as_mut() {
205            slot.reset();
206        }
207        self.free_list.push(idx);
208        Ok(())
209    }
210
211    /// Get a shared reference to slot `idx`.
212    ///
213    /// Returns `None` if the slot has never been initialised (should not happen
214    /// for a correctly-constructed pool) or the index is out of range.
215    pub fn slot(&self, idx: usize) -> Option<&SequenceSlot> {
216        self.slots.get(idx)?.as_ref()
217    }
218
219    /// Get a mutable reference to slot `idx`.
220    ///
221    /// Returns `None` if the slot is uninitialised or out of range.
222    pub fn slot_mut(&mut self, idx: usize) -> Option<&mut SequenceSlot> {
223        self.slots.get_mut(idx)?.as_mut()
224    }
225
226    /// Total pool capacity (never changes).
227    pub fn capacity(&self) -> usize {
228        self.capacity
229    }
230
231    /// Number of currently free (unallocated) slots.
232    pub fn free_count(&self) -> usize {
233        self.free_list.len()
234    }
235
236    /// Number of currently allocated (live) slots.
237    pub fn used_count(&self) -> usize {
238        self.capacity.saturating_sub(self.free_list.len())
239    }
240}
241
242// ─── SequencePool ─────────────────────────────────────────────────────────────
243
244/// Dispatch-enum over the two pool backends.
245///
246/// At model-load time the engine inspects the loaded architecture and
247/// constructs either a `KvBased` pool (for any transformer) or an `Ssm` pool
248/// (for Mamba-2 and similar).  Both variants expose the same interface through
249/// [`SequencePool`]'s methods.
250///
251/// # KV-based pooling
252///
253/// The KV-based variant stores state in a [`KvCachePool`] of page-sized slabs.
254/// Slots are identified by page indices returned by `KvCachePool::alloc`.
255///
256/// # SSM pooling
257///
258/// The SSM variant stores the full per-layer recurrent state in an
259/// [`SsmStatePool`].  The `alloc_ssm` / `release_ssm` helpers delegate to it.
260pub enum SequencePool {
261    /// Attention-transformer pool (KV cache pages).
262    KvBased(KvCachePool),
263    /// SSM pool (per-layer recurrent hidden states).
264    Ssm(SsmStatePool),
265}
266
267impl SequencePool {
268    /// Allocate a slot from the KV-based pool.
269    ///
270    /// Returns the page index on success.
271    ///
272    /// # Errors
273    ///
274    /// `PoolError::Exhausted` if the pool is full.
275    /// `PoolError::InvalidSlot(usize::MAX)` if called on an `Ssm` variant.
276    pub fn alloc_kv(&mut self) -> PoolResult<usize> {
277        match self {
278            SequencePool::KvBased(pool) => pool.alloc().ok_or(PoolError::Exhausted),
279            SequencePool::Ssm(_) => Err(PoolError::InvalidSlot(usize::MAX)),
280        }
281    }
282
283    /// Free a page in the KV-based pool.
284    ///
285    /// # Errors
286    ///
287    /// `PoolError::InvalidSlot` if called on an `Ssm` variant.
288    pub fn free_kv(&mut self, page_idx: usize) -> PoolResult<()> {
289        match self {
290            SequencePool::KvBased(pool) => {
291                pool.free(page_idx);
292                Ok(())
293            }
294            SequencePool::Ssm(_) => Err(PoolError::InvalidSlot(page_idx)),
295        }
296    }
297
298    /// Allocate an SSM slot bound to `request_id`.
299    ///
300    /// # Errors
301    ///
302    /// `PoolError::Exhausted` if the pool is full.
303    /// `PoolError::InvalidSlot(usize::MAX)` if called on a `KvBased` variant.
304    pub fn alloc_ssm(&mut self, request_id: u64) -> PoolResult<usize> {
305        match self {
306            SequencePool::Ssm(pool) => pool.alloc(request_id),
307            SequencePool::KvBased(_) => Err(PoolError::InvalidSlot(usize::MAX)),
308        }
309    }
310
311    /// Release an SSM slot by index.
312    ///
313    /// # Errors
314    ///
315    /// `PoolError::InvalidSlot` if `idx` is invalid or already free, or if
316    /// called on a `KvBased` variant.
317    pub fn release_ssm(&mut self, idx: usize) -> PoolResult<()> {
318        match self {
319            SequencePool::Ssm(pool) => pool.release(idx),
320            SequencePool::KvBased(_) => Err(PoolError::InvalidSlot(idx)),
321        }
322    }
323
324    /// Get an immutable reference to an SSM slot.
325    ///
326    /// Returns `None` for KV-based pools or out-of-range indices.
327    pub fn ssm_slot(&self, idx: usize) -> Option<&SequenceSlot> {
328        match self {
329            SequencePool::Ssm(pool) => pool.slot(idx),
330            SequencePool::KvBased(_) => None,
331        }
332    }
333
334    /// Get a mutable reference to an SSM slot.
335    ///
336    /// Returns `None` for KV-based pools or out-of-range indices.
337    pub fn ssm_slot_mut(&mut self, idx: usize) -> Option<&mut SequenceSlot> {
338        match self {
339            SequencePool::Ssm(pool) => pool.slot_mut(idx),
340            SequencePool::KvBased(_) => None,
341        }
342    }
343
344    /// Returns `true` if this pool uses the KV-cache backend.
345    pub fn is_kv_based(&self) -> bool {
346        matches!(self, SequencePool::KvBased(_))
347    }
348
349    /// Returns `true` if this pool uses the SSM backend.
350    pub fn is_ssm(&self) -> bool {
351        matches!(self, SequencePool::Ssm(_))
352    }
353}
354
355// ─── Tests ────────────────────────────────────────────────────────────────────
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360    use oxillama_arch::common::sequence_state::{AttentionSequenceState, Mamba2SequenceState};
361
362    // ── SequenceSlot ──────────────────────────────────────────────────────────
363
364    /// Slot step() advances both the cached position and the inner state.
365    #[test]
366    fn sequence_slot_position_advances() {
367        let state = Box::new(AttentionSequenceState::new(512));
368        let mut slot = SequenceSlot::new(state);
369
370        assert_eq!(slot.position, 0, "initial position must be 0");
371        assert_eq!(slot.state.step_position(), 0);
372
373        slot.step();
374        assert_eq!(slot.position, 1, "position after one step must be 1");
375        assert_eq!(slot.state.step_position(), 1);
376
377        slot.step();
378        slot.step();
379        assert_eq!(slot.position, 3);
380        assert_eq!(slot.state.step_position(), 3);
381    }
382
383    /// Slot reset() clears position and request_id.
384    #[test]
385    fn sequence_slot_reset_clears_position() {
386        let state = Box::new(AttentionSequenceState::new(64));
387        let mut slot = SequenceSlot::new(state);
388        slot.request_id = 42;
389        slot.step();
390        slot.step();
391        assert_eq!(slot.position, 2);
392
393        slot.reset();
394        assert_eq!(slot.position, 0, "position must be 0 after reset");
395        assert_eq!(slot.state.step_position(), 0);
396        assert_eq!(slot.request_id, 0, "request_id must be cleared by reset");
397    }
398
399    // ── SsmStatePool ──────────────────────────────────────────────────────────
400
401    /// Allocate and release cycles must recycle slot indices.
402    #[test]
403    fn sequence_pool_allocate_release() {
404        let mut pool = SsmStatePool::new(4, |_| {
405            Box::new(AttentionSequenceState::new(256)) as Box<dyn SequenceState>
406        });
407
408        assert_eq!(pool.capacity(), 4);
409        assert_eq!(pool.free_count(), 4);
410        assert_eq!(pool.used_count(), 0);
411
412        let idx_a = pool.alloc(1).expect("first alloc must succeed");
413        let idx_b = pool.alloc(2).expect("second alloc must succeed");
414        assert_ne!(idx_a, idx_b);
415        assert_eq!(pool.used_count(), 2);
416        assert_eq!(pool.free_count(), 2);
417
418        // Release one and re-allocate — must reuse the freed slot.
419        pool.release(idx_a).expect("release must succeed");
420        assert_eq!(pool.free_count(), 3);
421
422        let idx_c = pool.alloc(3).expect("alloc after release");
423        assert_eq!(idx_c, idx_a, "freed slot must be reused");
424        assert_eq!(pool.used_count(), 2);
425    }
426
427    /// Exhausted pool must return PoolError::Exhausted.
428    #[test]
429    fn ssm_pool_exhaustion_returns_error() {
430        let mut pool = SsmStatePool::new(2, |_| {
431            Box::new(AttentionSequenceState::new(64)) as Box<dyn SequenceState>
432        });
433
434        pool.alloc(10).expect("first");
435        pool.alloc(11).expect("second");
436        let err = pool.alloc(12);
437        assert!(
438            matches!(err, Err(PoolError::Exhausted)),
439            "exhausted pool must return Exhausted, got {err:?}"
440        );
441    }
442
443    /// Releasing an already-free slot must return InvalidSlot.
444    #[test]
445    fn ssm_pool_double_release_errors() {
446        let mut pool = SsmStatePool::new(2, |_| {
447            Box::new(AttentionSequenceState::new(64)) as Box<dyn SequenceState>
448        });
449
450        let idx = pool.alloc(1).expect("alloc");
451        pool.release(idx).expect("first release");
452        let err = pool.release(idx);
453        assert!(
454            matches!(err, Err(PoolError::InvalidSlot(_))),
455            "double-release must return InvalidSlot, got {err:?}"
456        );
457    }
458
459    /// Release resets the underlying state (zeroes position and h vectors).
460    #[test]
461    fn ssm_pool_release_resets_state() {
462        let n_layers = 3;
463        let d_state = 4;
464        let d_inner = 8;
465        let mut pool = SsmStatePool::new(2, |_| {
466            Box::new(Mamba2SequenceState::new(n_layers, d_state, d_inner, 256))
467                as Box<dyn SequenceState>
468        });
469
470        let idx = pool.alloc(99).expect("alloc");
471
472        // Advance the slot's position.
473        if let Some(slot) = pool.slot_mut(idx) {
474            slot.step();
475            slot.step();
476            assert_eq!(slot.position, 2, "position must be 2 before release");
477        }
478
479        pool.release(idx).expect("release");
480
481        // After re-allocation the slot must be fresh (position 0).
482        let idx2 = pool.alloc(100).expect("re-alloc");
483        assert_eq!(idx2, idx, "must reuse the released slot");
484        let slot = pool.slot(idx2).expect("slot must exist");
485        assert_eq!(
486            slot.position, 0,
487            "position must be 0 after re-alloc following release"
488        );
489        assert_eq!(
490            slot.state.step_position(),
491            0,
492            "state.step_position() must be 0 after release"
493        );
494        assert_eq!(slot.request_id, 100, "request_id must be updated on alloc");
495    }
496
497    // ── SequencePool enum ─────────────────────────────────────────────────────
498
499    /// KvBased pool's alloc_kv / free_kv round-trip works.
500    #[test]
501    fn sequence_pool_kv_based_alloc_free() {
502        let kv_pool = KvCachePool::new(16, 4);
503        let mut pool = SequencePool::KvBased(kv_pool);
504
505        assert!(pool.is_kv_based());
506        assert!(!pool.is_ssm());
507
508        let idx = pool.alloc_kv().expect("alloc_kv must succeed");
509        // Should be a valid page index (0..3).
510        assert!(idx < 4, "page index must be in range 0..4, got {idx}");
511
512        pool.free_kv(idx).expect("free_kv must succeed");
513    }
514
515    /// Calling alloc_ssm on a KvBased pool must return an error.
516    #[test]
517    fn sequence_pool_kv_rejects_ssm_ops() {
518        let kv_pool = KvCachePool::new(16, 4);
519        let mut pool = SequencePool::KvBased(kv_pool);
520        let err = pool.alloc_ssm(1);
521        assert!(
522            matches!(err, Err(PoolError::InvalidSlot(_))),
523            "alloc_ssm on KvBased must fail, got {err:?}"
524        );
525    }
526
527    /// Ssm pool's alloc_ssm / release_ssm round-trip works.
528    #[test]
529    fn sequence_pool_ssm_alloc_release() {
530        let inner = SsmStatePool::new(4, |_| {
531            Box::new(AttentionSequenceState::new(256)) as Box<dyn SequenceState>
532        });
533        let mut pool = SequencePool::Ssm(inner);
534
535        assert!(pool.is_ssm());
536        assert!(!pool.is_kv_based());
537
538        let idx = pool.alloc_ssm(7).expect("alloc_ssm");
539        let slot = pool.ssm_slot(idx).expect("slot must exist after alloc");
540        assert_eq!(slot.request_id, 7);
541
542        pool.release_ssm(idx).expect("release_ssm");
543        // After release the slot is on the free list; ssm_slot still returns it
544        // (it's physically present), but request_id should have been cleared.
545        let slot = pool.ssm_slot(idx).expect("slot still accessible");
546        assert_eq!(slot.request_id, 0, "request_id must be 0 after release");
547    }
548
549    /// Calling alloc_kv on an Ssm pool must return an error.
550    #[test]
551    fn sequence_pool_ssm_rejects_kv_ops() {
552        let inner = SsmStatePool::new(2, |_| {
553            Box::new(AttentionSequenceState::new(64)) as Box<dyn SequenceState>
554        });
555        let mut pool = SequencePool::Ssm(inner);
556        let err = pool.alloc_kv();
557        assert!(
558            matches!(err, Err(PoolError::InvalidSlot(_))),
559            "alloc_kv on Ssm must fail, got {err:?}"
560        );
561    }
562
563    /// Two independent SSM requests must not share state (isolation test).
564    #[test]
565    fn mixed_pool_isolation() {
566        let n_layers = 2;
567        let d_state = 2;
568        let d_inner = 4;
569        let inner = SsmStatePool::new(4, |_| {
570            Box::new(Mamba2SequenceState::new(n_layers, d_state, d_inner, 128))
571                as Box<dyn SequenceState>
572        });
573        let mut pool = SequencePool::Ssm(inner);
574
575        let idx_a = pool.alloc_ssm(1).expect("alloc A");
576        let idx_b = pool.alloc_ssm(2).expect("alloc B");
577        assert_ne!(idx_a, idx_b, "two requests must occupy different slots");
578
579        // Advance slot A twice.
580        if let Some(slot_a) = pool.ssm_slot_mut(idx_a) {
581            slot_a.step();
582            slot_a.step();
583        }
584
585        // Slot B must remain at position 0.
586        let slot_b = pool.ssm_slot(idx_b).expect("slot B must exist");
587        assert_eq!(
588            slot_b.position, 0,
589            "slot B position must not be affected by slot A's steps"
590        );
591    }
592
593    /// Out-of-range slot index must return PoolError::InvalidSlot.
594    #[test]
595    fn ssm_pool_out_of_range_slot_errors() {
596        let mut pool = SsmStatePool::new(2, |_| {
597            Box::new(AttentionSequenceState::new(64)) as Box<dyn SequenceState>
598        });
599        pool.alloc(1).expect("alloc to make slot 0 live");
600
601        let err = pool.release(99); // way out of range
602        assert!(
603            matches!(err, Err(PoolError::InvalidSlot(99))),
604            "out-of-range release must return InvalidSlot(99), got {err:?}"
605        );
606    }
607
608    /// `slot_reset_on_eos_for_ssm`: when a slot is released (simulating EOS),
609    /// the underlying SSM state must be all-zero on next allocation.
610    #[test]
611    fn slot_reset_on_eos_for_ssm() {
612        let n_layers = 2;
613        let d_state = 4;
614        let d_inner = 8;
615        let inner = SsmStatePool::new(2, |_| {
616            Box::new(Mamba2SequenceState::new(n_layers, d_state, d_inner, 256))
617                as Box<dyn SequenceState>
618        });
619        let mut pool = SequencePool::Ssm(inner);
620
621        // Allocate, advance a few steps, then simulate EOS by releasing.
622        let idx = pool.alloc_ssm(5).expect("alloc");
623        if let Some(slot) = pool.ssm_slot_mut(idx) {
624            for _ in 0..10 {
625                slot.step();
626            }
627            assert_eq!(slot.position, 10, "must have 10 steps before release");
628        }
629        pool.release_ssm(idx).expect("release on EOS");
630
631        // Re-allocate: state must be fresh.
632        let idx2 = pool.alloc_ssm(6).expect("re-alloc");
633        let slot = pool.ssm_slot(idx2).expect("slot must exist");
634        assert_eq!(
635            slot.position, 0,
636            "position must be 0 on fresh re-alloc (EOS reset)"
637        );
638        assert_eq!(slot.state.step_position(), 0);
639    }
640}