use crate::kv_pool::KvCachePool;
use oxillama_arch::common::sequence_state::SequenceState;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum PoolError {
#[error("sequence pool exhausted: no free slots available")]
Exhausted,
#[error("invalid slot index {0}: slot is not live or out of range")]
InvalidSlot(usize),
}
pub type PoolResult<T> = Result<T, PoolError>;
pub struct SequenceSlot {
pub state: Box<dyn SequenceState>,
pub position: usize,
pub request_id: u64,
}
impl SequenceSlot {
pub fn new(state: Box<dyn SequenceState>) -> Self {
Self {
position: state.step_position(),
state,
request_id: 0,
}
}
pub fn step(&mut self) {
self.state.advance();
self.position = self.state.step_position();
}
pub fn reset(&mut self) {
self.state.reset();
self.position = 0;
self.request_id = 0;
}
}
pub struct SsmStatePool {
slots: Vec<Option<SequenceSlot>>,
free_list: Vec<usize>,
capacity: usize,
}
impl SsmStatePool {
pub fn from_forward_pass(
forward_pass: &dyn oxillama_arch::traits::ForwardPass,
capacity: usize,
max_context_length: usize,
) -> Self {
Self::new(capacity, |_| {
forward_pass.allocate_sequence_state(max_context_length)
})
}
pub fn new<F>(capacity: usize, mut make_state: F) -> Self
where
F: FnMut(usize) -> Box<dyn SequenceState>,
{
let mut slots = Vec::with_capacity(capacity);
let mut free_list = Vec::with_capacity(capacity);
for i in 0..capacity {
let state = make_state(i);
slots.push(Some(SequenceSlot::new(state)));
free_list.push(i);
}
Self {
slots,
free_list,
capacity,
}
}
pub fn alloc(&mut self, request_id: u64) -> PoolResult<usize> {
let idx = self.free_list.pop().ok_or(PoolError::Exhausted)?;
if let Some(slot) = self.slots[idx].as_mut() {
slot.request_id = request_id;
}
Ok(idx)
}
pub fn release(&mut self, idx: usize) -> PoolResult<()> {
if idx >= self.slots.len() {
return Err(PoolError::InvalidSlot(idx));
}
if self.free_list.contains(&idx) {
return Err(PoolError::InvalidSlot(idx));
}
if let Some(slot) = self.slots[idx].as_mut() {
slot.reset();
}
self.free_list.push(idx);
Ok(())
}
pub fn slot(&self, idx: usize) -> Option<&SequenceSlot> {
self.slots.get(idx)?.as_ref()
}
pub fn slot_mut(&mut self, idx: usize) -> Option<&mut SequenceSlot> {
self.slots.get_mut(idx)?.as_mut()
}
pub fn capacity(&self) -> usize {
self.capacity
}
pub fn free_count(&self) -> usize {
self.free_list.len()
}
pub fn used_count(&self) -> usize {
self.capacity.saturating_sub(self.free_list.len())
}
}
pub enum SequencePool {
KvBased(KvCachePool),
Ssm(SsmStatePool),
}
impl SequencePool {
pub fn alloc_kv(&mut self) -> PoolResult<usize> {
match self {
SequencePool::KvBased(pool) => pool.alloc().ok_or(PoolError::Exhausted),
SequencePool::Ssm(_) => Err(PoolError::InvalidSlot(usize::MAX)),
}
}
pub fn free_kv(&mut self, page_idx: usize) -> PoolResult<()> {
match self {
SequencePool::KvBased(pool) => {
pool.free(page_idx);
Ok(())
}
SequencePool::Ssm(_) => Err(PoolError::InvalidSlot(page_idx)),
}
}
pub fn alloc_ssm(&mut self, request_id: u64) -> PoolResult<usize> {
match self {
SequencePool::Ssm(pool) => pool.alloc(request_id),
SequencePool::KvBased(_) => Err(PoolError::InvalidSlot(usize::MAX)),
}
}
pub fn release_ssm(&mut self, idx: usize) -> PoolResult<()> {
match self {
SequencePool::Ssm(pool) => pool.release(idx),
SequencePool::KvBased(_) => Err(PoolError::InvalidSlot(idx)),
}
}
pub fn ssm_slot(&self, idx: usize) -> Option<&SequenceSlot> {
match self {
SequencePool::Ssm(pool) => pool.slot(idx),
SequencePool::KvBased(_) => None,
}
}
pub fn ssm_slot_mut(&mut self, idx: usize) -> Option<&mut SequenceSlot> {
match self {
SequencePool::Ssm(pool) => pool.slot_mut(idx),
SequencePool::KvBased(_) => None,
}
}
pub fn is_kv_based(&self) -> bool {
matches!(self, SequencePool::KvBased(_))
}
pub fn is_ssm(&self) -> bool {
matches!(self, SequencePool::Ssm(_))
}
}
#[cfg(test)]
mod tests {
use super::*;
use oxillama_arch::common::sequence_state::{AttentionSequenceState, Mamba2SequenceState};
#[test]
fn sequence_slot_position_advances() {
let state = Box::new(AttentionSequenceState::new(512));
let mut slot = SequenceSlot::new(state);
assert_eq!(slot.position, 0, "initial position must be 0");
assert_eq!(slot.state.step_position(), 0);
slot.step();
assert_eq!(slot.position, 1, "position after one step must be 1");
assert_eq!(slot.state.step_position(), 1);
slot.step();
slot.step();
assert_eq!(slot.position, 3);
assert_eq!(slot.state.step_position(), 3);
}
#[test]
fn sequence_slot_reset_clears_position() {
let state = Box::new(AttentionSequenceState::new(64));
let mut slot = SequenceSlot::new(state);
slot.request_id = 42;
slot.step();
slot.step();
assert_eq!(slot.position, 2);
slot.reset();
assert_eq!(slot.position, 0, "position must be 0 after reset");
assert_eq!(slot.state.step_position(), 0);
assert_eq!(slot.request_id, 0, "request_id must be cleared by reset");
}
#[test]
fn sequence_pool_allocate_release() {
let mut pool = SsmStatePool::new(4, |_| {
Box::new(AttentionSequenceState::new(256)) as Box<dyn SequenceState>
});
assert_eq!(pool.capacity(), 4);
assert_eq!(pool.free_count(), 4);
assert_eq!(pool.used_count(), 0);
let idx_a = pool.alloc(1).expect("first alloc must succeed");
let idx_b = pool.alloc(2).expect("second alloc must succeed");
assert_ne!(idx_a, idx_b);
assert_eq!(pool.used_count(), 2);
assert_eq!(pool.free_count(), 2);
pool.release(idx_a).expect("release must succeed");
assert_eq!(pool.free_count(), 3);
let idx_c = pool.alloc(3).expect("alloc after release");
assert_eq!(idx_c, idx_a, "freed slot must be reused");
assert_eq!(pool.used_count(), 2);
}
#[test]
fn ssm_pool_exhaustion_returns_error() {
let mut pool = SsmStatePool::new(2, |_| {
Box::new(AttentionSequenceState::new(64)) as Box<dyn SequenceState>
});
pool.alloc(10).expect("first");
pool.alloc(11).expect("second");
let err = pool.alloc(12);
assert!(
matches!(err, Err(PoolError::Exhausted)),
"exhausted pool must return Exhausted, got {err:?}"
);
}
#[test]
fn ssm_pool_double_release_errors() {
let mut pool = SsmStatePool::new(2, |_| {
Box::new(AttentionSequenceState::new(64)) as Box<dyn SequenceState>
});
let idx = pool.alloc(1).expect("alloc");
pool.release(idx).expect("first release");
let err = pool.release(idx);
assert!(
matches!(err, Err(PoolError::InvalidSlot(_))),
"double-release must return InvalidSlot, got {err:?}"
);
}
#[test]
fn ssm_pool_release_resets_state() {
let n_layers = 3;
let d_state = 4;
let d_inner = 8;
let mut pool = SsmStatePool::new(2, |_| {
Box::new(Mamba2SequenceState::new(n_layers, d_state, d_inner, 256))
as Box<dyn SequenceState>
});
let idx = pool.alloc(99).expect("alloc");
if let Some(slot) = pool.slot_mut(idx) {
slot.step();
slot.step();
assert_eq!(slot.position, 2, "position must be 2 before release");
}
pool.release(idx).expect("release");
let idx2 = pool.alloc(100).expect("re-alloc");
assert_eq!(idx2, idx, "must reuse the released slot");
let slot = pool.slot(idx2).expect("slot must exist");
assert_eq!(
slot.position, 0,
"position must be 0 after re-alloc following release"
);
assert_eq!(
slot.state.step_position(),
0,
"state.step_position() must be 0 after release"
);
assert_eq!(slot.request_id, 100, "request_id must be updated on alloc");
}
#[test]
fn sequence_pool_kv_based_alloc_free() {
let kv_pool = KvCachePool::new(16, 4);
let mut pool = SequencePool::KvBased(kv_pool);
assert!(pool.is_kv_based());
assert!(!pool.is_ssm());
let idx = pool.alloc_kv().expect("alloc_kv must succeed");
assert!(idx < 4, "page index must be in range 0..4, got {idx}");
pool.free_kv(idx).expect("free_kv must succeed");
}
#[test]
fn sequence_pool_kv_rejects_ssm_ops() {
let kv_pool = KvCachePool::new(16, 4);
let mut pool = SequencePool::KvBased(kv_pool);
let err = pool.alloc_ssm(1);
assert!(
matches!(err, Err(PoolError::InvalidSlot(_))),
"alloc_ssm on KvBased must fail, got {err:?}"
);
}
#[test]
fn sequence_pool_ssm_alloc_release() {
let inner = SsmStatePool::new(4, |_| {
Box::new(AttentionSequenceState::new(256)) as Box<dyn SequenceState>
});
let mut pool = SequencePool::Ssm(inner);
assert!(pool.is_ssm());
assert!(!pool.is_kv_based());
let idx = pool.alloc_ssm(7).expect("alloc_ssm");
let slot = pool.ssm_slot(idx).expect("slot must exist after alloc");
assert_eq!(slot.request_id, 7);
pool.release_ssm(idx).expect("release_ssm");
let slot = pool.ssm_slot(idx).expect("slot still accessible");
assert_eq!(slot.request_id, 0, "request_id must be 0 after release");
}
#[test]
fn sequence_pool_ssm_rejects_kv_ops() {
let inner = SsmStatePool::new(2, |_| {
Box::new(AttentionSequenceState::new(64)) as Box<dyn SequenceState>
});
let mut pool = SequencePool::Ssm(inner);
let err = pool.alloc_kv();
assert!(
matches!(err, Err(PoolError::InvalidSlot(_))),
"alloc_kv on Ssm must fail, got {err:?}"
);
}
#[test]
fn mixed_pool_isolation() {
let n_layers = 2;
let d_state = 2;
let d_inner = 4;
let inner = SsmStatePool::new(4, |_| {
Box::new(Mamba2SequenceState::new(n_layers, d_state, d_inner, 128))
as Box<dyn SequenceState>
});
let mut pool = SequencePool::Ssm(inner);
let idx_a = pool.alloc_ssm(1).expect("alloc A");
let idx_b = pool.alloc_ssm(2).expect("alloc B");
assert_ne!(idx_a, idx_b, "two requests must occupy different slots");
if let Some(slot_a) = pool.ssm_slot_mut(idx_a) {
slot_a.step();
slot_a.step();
}
let slot_b = pool.ssm_slot(idx_b).expect("slot B must exist");
assert_eq!(
slot_b.position, 0,
"slot B position must not be affected by slot A's steps"
);
}
#[test]
fn ssm_pool_out_of_range_slot_errors() {
let mut pool = SsmStatePool::new(2, |_| {
Box::new(AttentionSequenceState::new(64)) as Box<dyn SequenceState>
});
pool.alloc(1).expect("alloc to make slot 0 live");
let err = pool.release(99); assert!(
matches!(err, Err(PoolError::InvalidSlot(99))),
"out-of-range release must return InvalidSlot(99), got {err:?}"
);
}
#[test]
fn slot_reset_on_eos_for_ssm() {
let n_layers = 2;
let d_state = 4;
let d_inner = 8;
let inner = SsmStatePool::new(2, |_| {
Box::new(Mamba2SequenceState::new(n_layers, d_state, d_inner, 256))
as Box<dyn SequenceState>
});
let mut pool = SequencePool::Ssm(inner);
let idx = pool.alloc_ssm(5).expect("alloc");
if let Some(slot) = pool.ssm_slot_mut(idx) {
for _ in 0..10 {
slot.step();
}
assert_eq!(slot.position, 10, "must have 10 steps before release");
}
pool.release_ssm(idx).expect("release on EOS");
let idx2 = pool.alloc_ssm(6).expect("re-alloc");
let slot = pool.ssm_slot(idx2).expect("slot must exist");
assert_eq!(
slot.position, 0,
"position must be 0 on fresh re-alloc (EOS reset)"
);
assert_eq!(slot.state.step_position(), 0);
}
}