oxillama_runtime/sequence_pool.rs
1//! SSM runtime bridge — polymorphic sequence-state pool.
2//!
3//! # Overview
4//!
5//! OxiLLaMa supports two categories of model architectures:
6//!
7//! 1. **Attention-based** (LLaMA, Qwen3, Mistral, Gemma, Phi, …): per-sequence
8//! state is the KV cache (a contiguous K/V buffer per layer).
9//! 2. **SSM-based** (Mamba-2, …): per-sequence state is a set of per-layer
10//! recurrent hidden vectors; there is no KV cache.
11//!
12//! The [`SequencePool`] enum abstracts over both kinds via the
13//! [`oxillama_arch::common::sequence_state::SequenceState`] trait. The engine
14//! picks the right pool variant at load time by examining the loaded
15//! architecture; both variants expose the same `alloc` / `release` / `slot`
16//! interface so the rest of the engine stays arch-agnostic.
17//!
18//! ## Design notes
19//!
20//! - Slots are identified by a `usize` index (same as `Sequence::slot_id`).
21//! - A slot is "live" when it holds a `Box<dyn SequenceState>`.
22//! - On `release` the state is **reset** (zeroed) and returned to the free pool.
23//! - Neither variant interacts with the KV cache from `kv_cache/mod.rs`; the
24//! KV-based pool manages its own separate per-slot state.
25//! - The SSM state pool owns the `Box<dyn SequenceState>` objects outright;
26//! the KV-based pool keeps a `KvCachePool` from which page indices are lent.
27//!
28//! ## Thread safety
29//!
30//! `SequencePool` is **not** `Send` + `Sync` by itself; it is intended to be
31//! owned by a single-threaded engine or wrapped in a `Mutex` by the caller.
32
33use crate::kv_pool::KvCachePool;
34use oxillama_arch::common::sequence_state::SequenceState;
35use thiserror::Error;
36
37// ─── Error ────────────────────────────────────────────────────────────────────
38
39/// Errors produced by pool operations.
40#[derive(Debug, Error)]
41pub enum PoolError {
42 /// The pool has no free slots; the request must wait or be rejected.
43 #[error("sequence pool exhausted: no free slots available")]
44 Exhausted,
45 /// A slot index passed to a pool operation does not identify a live slot.
46 #[error("invalid slot index {0}: slot is not live or out of range")]
47 InvalidSlot(usize),
48}
49
50/// Convenience alias.
51pub type PoolResult<T> = Result<T, PoolError>;
52
53// ─── SequenceSlot ─────────────────────────────────────────────────────────────
54
55/// A live sequence slot in the [`SsmStatePool`].
56///
57/// Each slot carries:
58/// - `state`: the arch-specific [`SequenceState`] (SSM hidden state or an
59/// attention position counter wrapped by the arch crate).
60/// - `position`: current token position in the sequence (mirrors
61/// `state.step_position()`, but accessible without a vtable call).
62/// - `request_id`: the logical request ID associated with this slot (matches
63/// [`Sequence::id`](crate::scheduler::Sequence::id)); `0` = unassigned.
64pub struct SequenceSlot {
65 /// Arch-specific sequence state (SSM hidden vectors, or attention counter).
66 pub state: Box<dyn SequenceState>,
67 /// Current token position (0-indexed).
68 pub position: usize,
69 /// Request ID bound to this slot (0 = none).
70 pub request_id: u64,
71}
72
73impl SequenceSlot {
74 /// Create a new slot with the given state and a zero request ID.
75 pub fn new(state: Box<dyn SequenceState>) -> Self {
76 Self {
77 position: state.step_position(),
78 state,
79 request_id: 0,
80 }
81 }
82
83 /// Advance the internal position counter by one (after each forward step).
84 ///
85 /// Also calls `state.advance()` so both the slot's cached `position` and
86 /// the underlying trait object stay in sync.
87 pub fn step(&mut self) {
88 self.state.advance();
89 self.position = self.state.step_position();
90 }
91
92 /// Reset the slot to position 0 and clear the state.
93 ///
94 /// The slot's `request_id` is reset to 0 as well so that stale IDs are
95 /// not accidentally read after re-allocation.
96 pub fn reset(&mut self) {
97 self.state.reset();
98 self.position = 0;
99 self.request_id = 0;
100 }
101}
102
103// ─── SsmStatePool ─────────────────────────────────────────────────────────────
104
105/// A free-list pool of [`SequenceSlot`]s for SSM-based models.
106///
107/// Pre-allocates `capacity` slots at construction time. Slots are identified
108/// by their index into the internal `slots` vector.
109///
110/// Free slots are tracked by a `free_list: Vec<usize>`. `alloc` pops from the
111/// list; `release` resets the slot and pushes back.
112pub struct SsmStatePool {
113 /// All slots, indexed by slot ID. `Some` = live (allocated); `None` = never
114 /// initialised (only possible during construction before the pool is fully
115 /// initialised — in practice every index 0..capacity is always `Some`).
116 slots: Vec<Option<SequenceSlot>>,
117 /// Indices of slots currently on the free list.
118 free_list: Vec<usize>,
119 /// Total capacity (never changes after construction).
120 capacity: usize,
121}
122
123impl SsmStatePool {
124 /// Create a pool by calling `ForwardPass::allocate_sequence_state` for each slot.
125 ///
126 /// This is the preferred construction path because it delegates state
127 /// allocation to the architecture implementation, rather than hard-coding
128 /// the state type at the call site. The runtime calls this once at model
129 /// load time, after `ForwardPass` is available.
130 ///
131 /// ```ignore
132 /// let pool = SsmStatePool::from_forward_pass(fwd_pass.as_ref(), capacity, max_ctx);
133 /// ```
134 pub fn from_forward_pass(
135 forward_pass: &dyn oxillama_arch::traits::ForwardPass,
136 capacity: usize,
137 max_context_length: usize,
138 ) -> Self {
139 Self::new(capacity, |_| {
140 forward_pass.allocate_sequence_state(max_context_length)
141 })
142 }
143
144 /// Create a new pool using a factory closure to produce each slot's state.
145 ///
146 /// The closure is called once per slot with the slot index. Use it to
147 /// initialise arch-specific state (e.g. `Mamba2SequenceState::new(...)`).
148 ///
149 /// ```ignore
150 /// let pool = SsmStatePool::new(8, |_| {
151 /// Box::new(Mamba2SequenceState::new(24, 16, 256, 4096))
152 /// });
153 /// ```
154 pub fn new<F>(capacity: usize, mut make_state: F) -> Self
155 where
156 F: FnMut(usize) -> Box<dyn SequenceState>,
157 {
158 let mut slots = Vec::with_capacity(capacity);
159 let mut free_list = Vec::with_capacity(capacity);
160
161 for i in 0..capacity {
162 let state = make_state(i);
163 slots.push(Some(SequenceSlot::new(state)));
164 free_list.push(i);
165 }
166
167 Self {
168 slots,
169 free_list,
170 capacity,
171 }
172 }
173
174 /// Allocate a free slot and bind it to `request_id`.
175 ///
176 /// Returns `Ok(slot_idx)` on success.
177 ///
178 /// # Errors
179 ///
180 /// Returns [`PoolError::Exhausted`] when no free slots are available.
181 pub fn alloc(&mut self, request_id: u64) -> PoolResult<usize> {
182 let idx = self.free_list.pop().ok_or(PoolError::Exhausted)?;
183 if let Some(slot) = self.slots[idx].as_mut() {
184 slot.request_id = request_id;
185 }
186 Ok(idx)
187 }
188
189 /// Release slot `idx` back to the free list.
190 ///
191 /// The slot's state is reset (zeroed) and its `request_id` cleared.
192 ///
193 /// # Errors
194 ///
195 /// Returns [`PoolError::InvalidSlot`] if `idx` is out of range or already free.
196 pub fn release(&mut self, idx: usize) -> PoolResult<()> {
197 if idx >= self.slots.len() {
198 return Err(PoolError::InvalidSlot(idx));
199 }
200 // Check that the slot is not already on the free list.
201 if self.free_list.contains(&idx) {
202 return Err(PoolError::InvalidSlot(idx));
203 }
204 if let Some(slot) = self.slots[idx].as_mut() {
205 slot.reset();
206 }
207 self.free_list.push(idx);
208 Ok(())
209 }
210
211 /// Get a shared reference to slot `idx`.
212 ///
213 /// Returns `None` if the slot has never been initialised (should not happen
214 /// for a correctly-constructed pool) or the index is out of range.
215 pub fn slot(&self, idx: usize) -> Option<&SequenceSlot> {
216 self.slots.get(idx)?.as_ref()
217 }
218
219 /// Get a mutable reference to slot `idx`.
220 ///
221 /// Returns `None` if the slot is uninitialised or out of range.
222 pub fn slot_mut(&mut self, idx: usize) -> Option<&mut SequenceSlot> {
223 self.slots.get_mut(idx)?.as_mut()
224 }
225
226 /// Total pool capacity (never changes).
227 pub fn capacity(&self) -> usize {
228 self.capacity
229 }
230
231 /// Number of currently free (unallocated) slots.
232 pub fn free_count(&self) -> usize {
233 self.free_list.len()
234 }
235
236 /// Number of currently allocated (live) slots.
237 pub fn used_count(&self) -> usize {
238 self.capacity.saturating_sub(self.free_list.len())
239 }
240}
241
242// ─── SequencePool ─────────────────────────────────────────────────────────────
243
244/// Dispatch-enum over the two pool backends.
245///
246/// At model-load time the engine inspects the loaded architecture and
247/// constructs either a `KvBased` pool (for any transformer) or an `Ssm` pool
248/// (for Mamba-2 and similar). Both variants expose the same interface through
249/// [`SequencePool`]'s methods.
250///
251/// # KV-based pooling
252///
253/// The KV-based variant stores state in a [`KvCachePool`] of page-sized slabs.
254/// Slots are identified by page indices returned by `KvCachePool::alloc`.
255///
256/// # SSM pooling
257///
258/// The SSM variant stores the full per-layer recurrent state in an
259/// [`SsmStatePool`]. The `alloc_ssm` / `release_ssm` helpers delegate to it.
260pub enum SequencePool {
261 /// Attention-transformer pool (KV cache pages).
262 KvBased(KvCachePool),
263 /// SSM pool (per-layer recurrent hidden states).
264 Ssm(SsmStatePool),
265}
266
267impl SequencePool {
268 /// Allocate a slot from the KV-based pool.
269 ///
270 /// Returns the page index on success.
271 ///
272 /// # Errors
273 ///
274 /// `PoolError::Exhausted` if the pool is full.
275 /// `PoolError::InvalidSlot(usize::MAX)` if called on an `Ssm` variant.
276 pub fn alloc_kv(&mut self) -> PoolResult<usize> {
277 match self {
278 SequencePool::KvBased(pool) => pool.alloc().ok_or(PoolError::Exhausted),
279 SequencePool::Ssm(_) => Err(PoolError::InvalidSlot(usize::MAX)),
280 }
281 }
282
283 /// Free a page in the KV-based pool.
284 ///
285 /// # Errors
286 ///
287 /// `PoolError::InvalidSlot` if called on an `Ssm` variant.
288 pub fn free_kv(&mut self, page_idx: usize) -> PoolResult<()> {
289 match self {
290 SequencePool::KvBased(pool) => {
291 pool.free(page_idx);
292 Ok(())
293 }
294 SequencePool::Ssm(_) => Err(PoolError::InvalidSlot(page_idx)),
295 }
296 }
297
298 /// Allocate an SSM slot bound to `request_id`.
299 ///
300 /// # Errors
301 ///
302 /// `PoolError::Exhausted` if the pool is full.
303 /// `PoolError::InvalidSlot(usize::MAX)` if called on a `KvBased` variant.
304 pub fn alloc_ssm(&mut self, request_id: u64) -> PoolResult<usize> {
305 match self {
306 SequencePool::Ssm(pool) => pool.alloc(request_id),
307 SequencePool::KvBased(_) => Err(PoolError::InvalidSlot(usize::MAX)),
308 }
309 }
310
311 /// Release an SSM slot by index.
312 ///
313 /// # Errors
314 ///
315 /// `PoolError::InvalidSlot` if `idx` is invalid or already free, or if
316 /// called on a `KvBased` variant.
317 pub fn release_ssm(&mut self, idx: usize) -> PoolResult<()> {
318 match self {
319 SequencePool::Ssm(pool) => pool.release(idx),
320 SequencePool::KvBased(_) => Err(PoolError::InvalidSlot(idx)),
321 }
322 }
323
324 /// Get an immutable reference to an SSM slot.
325 ///
326 /// Returns `None` for KV-based pools or out-of-range indices.
327 pub fn ssm_slot(&self, idx: usize) -> Option<&SequenceSlot> {
328 match self {
329 SequencePool::Ssm(pool) => pool.slot(idx),
330 SequencePool::KvBased(_) => None,
331 }
332 }
333
334 /// Get a mutable reference to an SSM slot.
335 ///
336 /// Returns `None` for KV-based pools or out-of-range indices.
337 pub fn ssm_slot_mut(&mut self, idx: usize) -> Option<&mut SequenceSlot> {
338 match self {
339 SequencePool::Ssm(pool) => pool.slot_mut(idx),
340 SequencePool::KvBased(_) => None,
341 }
342 }
343
344 /// Returns `true` if this pool uses the KV-cache backend.
345 pub fn is_kv_based(&self) -> bool {
346 matches!(self, SequencePool::KvBased(_))
347 }
348
349 /// Returns `true` if this pool uses the SSM backend.
350 pub fn is_ssm(&self) -> bool {
351 matches!(self, SequencePool::Ssm(_))
352 }
353}
354
355// ─── Tests ────────────────────────────────────────────────────────────────────
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360 use oxillama_arch::common::sequence_state::{AttentionSequenceState, Mamba2SequenceState};
361
362 // ── SequenceSlot ──────────────────────────────────────────────────────────
363
364 /// Slot step() advances both the cached position and the inner state.
365 #[test]
366 fn sequence_slot_position_advances() {
367 let state = Box::new(AttentionSequenceState::new(512));
368 let mut slot = SequenceSlot::new(state);
369
370 assert_eq!(slot.position, 0, "initial position must be 0");
371 assert_eq!(slot.state.step_position(), 0);
372
373 slot.step();
374 assert_eq!(slot.position, 1, "position after one step must be 1");
375 assert_eq!(slot.state.step_position(), 1);
376
377 slot.step();
378 slot.step();
379 assert_eq!(slot.position, 3);
380 assert_eq!(slot.state.step_position(), 3);
381 }
382
383 /// Slot reset() clears position and request_id.
384 #[test]
385 fn sequence_slot_reset_clears_position() {
386 let state = Box::new(AttentionSequenceState::new(64));
387 let mut slot = SequenceSlot::new(state);
388 slot.request_id = 42;
389 slot.step();
390 slot.step();
391 assert_eq!(slot.position, 2);
392
393 slot.reset();
394 assert_eq!(slot.position, 0, "position must be 0 after reset");
395 assert_eq!(slot.state.step_position(), 0);
396 assert_eq!(slot.request_id, 0, "request_id must be cleared by reset");
397 }
398
399 // ── SsmStatePool ──────────────────────────────────────────────────────────
400
401 /// Allocate and release cycles must recycle slot indices.
402 #[test]
403 fn sequence_pool_allocate_release() {
404 let mut pool = SsmStatePool::new(4, |_| {
405 Box::new(AttentionSequenceState::new(256)) as Box<dyn SequenceState>
406 });
407
408 assert_eq!(pool.capacity(), 4);
409 assert_eq!(pool.free_count(), 4);
410 assert_eq!(pool.used_count(), 0);
411
412 let idx_a = pool.alloc(1).expect("first alloc must succeed");
413 let idx_b = pool.alloc(2).expect("second alloc must succeed");
414 assert_ne!(idx_a, idx_b);
415 assert_eq!(pool.used_count(), 2);
416 assert_eq!(pool.free_count(), 2);
417
418 // Release one and re-allocate — must reuse the freed slot.
419 pool.release(idx_a).expect("release must succeed");
420 assert_eq!(pool.free_count(), 3);
421
422 let idx_c = pool.alloc(3).expect("alloc after release");
423 assert_eq!(idx_c, idx_a, "freed slot must be reused");
424 assert_eq!(pool.used_count(), 2);
425 }
426
427 /// Exhausted pool must return PoolError::Exhausted.
428 #[test]
429 fn ssm_pool_exhaustion_returns_error() {
430 let mut pool = SsmStatePool::new(2, |_| {
431 Box::new(AttentionSequenceState::new(64)) as Box<dyn SequenceState>
432 });
433
434 pool.alloc(10).expect("first");
435 pool.alloc(11).expect("second");
436 let err = pool.alloc(12);
437 assert!(
438 matches!(err, Err(PoolError::Exhausted)),
439 "exhausted pool must return Exhausted, got {err:?}"
440 );
441 }
442
443 /// Releasing an already-free slot must return InvalidSlot.
444 #[test]
445 fn ssm_pool_double_release_errors() {
446 let mut pool = SsmStatePool::new(2, |_| {
447 Box::new(AttentionSequenceState::new(64)) as Box<dyn SequenceState>
448 });
449
450 let idx = pool.alloc(1).expect("alloc");
451 pool.release(idx).expect("first release");
452 let err = pool.release(idx);
453 assert!(
454 matches!(err, Err(PoolError::InvalidSlot(_))),
455 "double-release must return InvalidSlot, got {err:?}"
456 );
457 }
458
459 /// Release resets the underlying state (zeroes position and h vectors).
460 #[test]
461 fn ssm_pool_release_resets_state() {
462 let n_layers = 3;
463 let d_state = 4;
464 let d_inner = 8;
465 let mut pool = SsmStatePool::new(2, |_| {
466 Box::new(Mamba2SequenceState::new(n_layers, d_state, d_inner, 256))
467 as Box<dyn SequenceState>
468 });
469
470 let idx = pool.alloc(99).expect("alloc");
471
472 // Advance the slot's position.
473 if let Some(slot) = pool.slot_mut(idx) {
474 slot.step();
475 slot.step();
476 assert_eq!(slot.position, 2, "position must be 2 before release");
477 }
478
479 pool.release(idx).expect("release");
480
481 // After re-allocation the slot must be fresh (position 0).
482 let idx2 = pool.alloc(100).expect("re-alloc");
483 assert_eq!(idx2, idx, "must reuse the released slot");
484 let slot = pool.slot(idx2).expect("slot must exist");
485 assert_eq!(
486 slot.position, 0,
487 "position must be 0 after re-alloc following release"
488 );
489 assert_eq!(
490 slot.state.step_position(),
491 0,
492 "state.step_position() must be 0 after release"
493 );
494 assert_eq!(slot.request_id, 100, "request_id must be updated on alloc");
495 }
496
497 // ── SequencePool enum ─────────────────────────────────────────────────────
498
499 /// KvBased pool's alloc_kv / free_kv round-trip works.
500 #[test]
501 fn sequence_pool_kv_based_alloc_free() {
502 let kv_pool = KvCachePool::new(16, 4);
503 let mut pool = SequencePool::KvBased(kv_pool);
504
505 assert!(pool.is_kv_based());
506 assert!(!pool.is_ssm());
507
508 let idx = pool.alloc_kv().expect("alloc_kv must succeed");
509 // Should be a valid page index (0..3).
510 assert!(idx < 4, "page index must be in range 0..4, got {idx}");
511
512 pool.free_kv(idx).expect("free_kv must succeed");
513 }
514
515 /// Calling alloc_ssm on a KvBased pool must return an error.
516 #[test]
517 fn sequence_pool_kv_rejects_ssm_ops() {
518 let kv_pool = KvCachePool::new(16, 4);
519 let mut pool = SequencePool::KvBased(kv_pool);
520 let err = pool.alloc_ssm(1);
521 assert!(
522 matches!(err, Err(PoolError::InvalidSlot(_))),
523 "alloc_ssm on KvBased must fail, got {err:?}"
524 );
525 }
526
527 /// Ssm pool's alloc_ssm / release_ssm round-trip works.
528 #[test]
529 fn sequence_pool_ssm_alloc_release() {
530 let inner = SsmStatePool::new(4, |_| {
531 Box::new(AttentionSequenceState::new(256)) as Box<dyn SequenceState>
532 });
533 let mut pool = SequencePool::Ssm(inner);
534
535 assert!(pool.is_ssm());
536 assert!(!pool.is_kv_based());
537
538 let idx = pool.alloc_ssm(7).expect("alloc_ssm");
539 let slot = pool.ssm_slot(idx).expect("slot must exist after alloc");
540 assert_eq!(slot.request_id, 7);
541
542 pool.release_ssm(idx).expect("release_ssm");
543 // After release the slot is on the free list; ssm_slot still returns it
544 // (it's physically present), but request_id should have been cleared.
545 let slot = pool.ssm_slot(idx).expect("slot still accessible");
546 assert_eq!(slot.request_id, 0, "request_id must be 0 after release");
547 }
548
549 /// Calling alloc_kv on an Ssm pool must return an error.
550 #[test]
551 fn sequence_pool_ssm_rejects_kv_ops() {
552 let inner = SsmStatePool::new(2, |_| {
553 Box::new(AttentionSequenceState::new(64)) as Box<dyn SequenceState>
554 });
555 let mut pool = SequencePool::Ssm(inner);
556 let err = pool.alloc_kv();
557 assert!(
558 matches!(err, Err(PoolError::InvalidSlot(_))),
559 "alloc_kv on Ssm must fail, got {err:?}"
560 );
561 }
562
563 /// Two independent SSM requests must not share state (isolation test).
564 #[test]
565 fn mixed_pool_isolation() {
566 let n_layers = 2;
567 let d_state = 2;
568 let d_inner = 4;
569 let inner = SsmStatePool::new(4, |_| {
570 Box::new(Mamba2SequenceState::new(n_layers, d_state, d_inner, 128))
571 as Box<dyn SequenceState>
572 });
573 let mut pool = SequencePool::Ssm(inner);
574
575 let idx_a = pool.alloc_ssm(1).expect("alloc A");
576 let idx_b = pool.alloc_ssm(2).expect("alloc B");
577 assert_ne!(idx_a, idx_b, "two requests must occupy different slots");
578
579 // Advance slot A twice.
580 if let Some(slot_a) = pool.ssm_slot_mut(idx_a) {
581 slot_a.step();
582 slot_a.step();
583 }
584
585 // Slot B must remain at position 0.
586 let slot_b = pool.ssm_slot(idx_b).expect("slot B must exist");
587 assert_eq!(
588 slot_b.position, 0,
589 "slot B position must not be affected by slot A's steps"
590 );
591 }
592
593 /// Out-of-range slot index must return PoolError::InvalidSlot.
594 #[test]
595 fn ssm_pool_out_of_range_slot_errors() {
596 let mut pool = SsmStatePool::new(2, |_| {
597 Box::new(AttentionSequenceState::new(64)) as Box<dyn SequenceState>
598 });
599 pool.alloc(1).expect("alloc to make slot 0 live");
600
601 let err = pool.release(99); // way out of range
602 assert!(
603 matches!(err, Err(PoolError::InvalidSlot(99))),
604 "out-of-range release must return InvalidSlot(99), got {err:?}"
605 );
606 }
607
608 /// `slot_reset_on_eos_for_ssm`: when a slot is released (simulating EOS),
609 /// the underlying SSM state must be all-zero on next allocation.
610 #[test]
611 fn slot_reset_on_eos_for_ssm() {
612 let n_layers = 2;
613 let d_state = 4;
614 let d_inner = 8;
615 let inner = SsmStatePool::new(2, |_| {
616 Box::new(Mamba2SequenceState::new(n_layers, d_state, d_inner, 256))
617 as Box<dyn SequenceState>
618 });
619 let mut pool = SequencePool::Ssm(inner);
620
621 // Allocate, advance a few steps, then simulate EOS by releasing.
622 let idx = pool.alloc_ssm(5).expect("alloc");
623 if let Some(slot) = pool.ssm_slot_mut(idx) {
624 for _ in 0..10 {
625 slot.step();
626 }
627 assert_eq!(slot.position, 10, "must have 10 steps before release");
628 }
629 pool.release_ssm(idx).expect("release on EOS");
630
631 // Re-allocate: state must be fresh.
632 let idx2 = pool.alloc_ssm(6).expect("re-alloc");
633 let slot = pool.ssm_slot(idx2).expect("slot must exist");
634 assert_eq!(
635 slot.position, 0,
636 "position must be 0 on fresh re-alloc (EOS reset)"
637 );
638 assert_eq!(slot.state.step_position(), 0);
639 }
640}