use candle_core::{Device, IndexOp, Result, Tensor};
use super::KvCache;
use crate::layers_masker::PastKvLenCache;
#[derive(Debug)]
pub struct RecurrentStatePool {
pub conv_state: Tensor,
pub recurrent_state: Tensor,
seqlen_offsets: Vec<usize>,
free_slots: Vec<usize>,
capacity: usize,
conv_dim: usize,
conv_width: usize,
state_dims: Vec<usize>,
dtype: candle_core::DType,
device: Device,
}
const INITIAL_POOL_CAPACITY: usize = 4;
impl RecurrentStatePool {
pub fn new(
conv_dim: usize,
conv_width: usize,
state_dims: Vec<usize>,
dtype: candle_core::DType,
device: &Device,
) -> Result<Self> {
let capacity = INITIAL_POOL_CAPACITY;
let conv_state = Tensor::zeros((capacity, conv_dim, conv_width), dtype, device)?;
let mut recurrent_shape = vec![capacity];
recurrent_shape.extend_from_slice(&state_dims);
let recurrent_state = Tensor::zeros(recurrent_shape, dtype, device)?;
let free_slots: Vec<usize> = (0..capacity).rev().collect();
let seqlen_offsets = vec![0; capacity];
Ok(Self {
conv_state,
recurrent_state,
seqlen_offsets,
free_slots,
capacity,
conv_dim,
conv_width,
state_dims,
dtype,
device: device.clone(),
})
}
fn grow(&mut self) -> Result<()> {
let new_capacity = self.capacity * 2;
let new_conv = Tensor::zeros(
(new_capacity, self.conv_dim, self.conv_width),
self.dtype,
&self.device,
)?;
new_conv.slice_set(&self.conv_state, 0, 0)?;
let mut recurrent_shape = vec![new_capacity];
recurrent_shape.extend_from_slice(&self.state_dims);
let new_recurrent = Tensor::zeros(recurrent_shape, self.dtype, &self.device)?;
new_recurrent.slice_set(&self.recurrent_state, 0, 0)?;
self.free_slots.extend((self.capacity..new_capacity).rev());
self.seqlen_offsets.resize(new_capacity, 0);
self.conv_state = new_conv;
self.recurrent_state = new_recurrent;
self.capacity = new_capacity;
tracing::info!("Recurrent state pool grew to capacity {new_capacity}");
Ok(())
}
pub fn allocate(&mut self) -> Option<usize> {
if self.free_slots.is_empty() {
if let Err(e) = self.grow() {
tracing::error!("Failed to grow recurrent state pool: {e}");
return None;
}
}
let slot_idx = self.free_slots.pop()?;
if self.reset_slot(slot_idx).is_err() {
tracing::warn!("Failed to reset recurrent state slot {slot_idx}, state may be stale");
}
Some(slot_idx)
}
pub fn free(&mut self, slot_idx: usize) {
debug_assert!(slot_idx < self.capacity);
self.seqlen_offsets[slot_idx] = 0;
self.free_slots.push(slot_idx);
}
pub fn get_seqlen_offset(&self, slot_idx: usize) -> usize {
self.seqlen_offsets[slot_idx]
}
pub fn set_seqlen_offset(&mut self, slot_idx: usize, offset: usize) {
self.seqlen_offsets[slot_idx] = offset;
}
pub fn increment_seqlen_offset(&mut self, slot_idx: usize, delta: usize) {
self.seqlen_offsets[slot_idx] += delta;
}
pub fn gather_conv_state(&self, state_indices: &Tensor) -> Result<Tensor> {
self.conv_state.index_select(state_indices, 0)
}
pub fn gather_recurrent_state(&self, state_indices: &Tensor) -> Result<Tensor> {
self.recurrent_state.index_select(state_indices, 0)
}
pub fn scatter_conv_state(&mut self, state_indices: &Tensor, values: &Tensor) -> Result<()> {
let indices: Vec<u32> = state_indices.to_vec1()?;
for (batch_idx, &slot_idx) in indices.iter().enumerate() {
let value = values.i(batch_idx)?.unsqueeze(0)?.contiguous()?;
self.conv_state.slice_set(&value, 0, slot_idx as usize)?;
}
Ok(())
}
pub fn scatter_recurrent_state(
&mut self,
state_indices: &Tensor,
values: &Tensor,
) -> Result<()> {
let indices: Vec<u32> = state_indices.to_vec1()?;
for (batch_idx, &slot_idx) in indices.iter().enumerate() {
let value = values.i(batch_idx)?.unsqueeze(0)?.contiguous()?;
self.recurrent_state
.slice_set(&value, 0, slot_idx as usize)?;
}
Ok(())
}
pub fn reset_slot(&mut self, slot_idx: usize) -> Result<()> {
let zero_conv = Tensor::zeros(
(1, self.conv_dim, self.conv_width),
self.dtype,
&self.device,
)?;
let mut recurrent_shape = vec![1usize];
recurrent_shape.extend_from_slice(&self.state_dims);
let zero_recurrent = Tensor::zeros(recurrent_shape, self.dtype, &self.device)?;
self.conv_state.slice_set(&zero_conv, 0, slot_idx)?;
self.recurrent_state
.slice_set(&zero_recurrent, 0, slot_idx)?;
self.seqlen_offsets[slot_idx] = 0;
Ok(())
}
pub fn reset(&mut self) -> Result<()> {
self.conv_state = self.conv_state.zeros_like()?;
self.recurrent_state = self.recurrent_state.zeros_like()?;
self.seqlen_offsets.fill(0);
self.free_slots = (0..self.capacity).rev().collect();
Ok(())
}
pub fn capacity(&self) -> usize {
self.capacity
}
pub fn num_free_slots(&self) -> usize {
self.free_slots.len()
}
pub fn device(&self) -> &Device {
&self.device
}
pub fn dtype(&self) -> candle_core::DType {
self.dtype
}
}
impl Clone for RecurrentStatePool {
fn clone(&self) -> Self {
Self {
conv_state: self.conv_state.clone(),
recurrent_state: self.recurrent_state.clone(),
seqlen_offsets: self.seqlen_offsets.clone(),
free_slots: self.free_slots.clone(),
capacity: self.capacity,
conv_dim: self.conv_dim,
conv_width: self.conv_width,
state_dims: self.state_dims.clone(),
dtype: self.dtype,
device: self.device.clone(),
}
}
}
#[derive(Clone, Debug)]
pub enum HybridLayerCache {
Attention(KvCache),
Recurrent(RecurrentStatePool),
}
impl HybridLayerCache {
pub fn reset(&mut self) {
match self {
Self::Attention(kv) => kv.reset(),
Self::Recurrent(pool) => {
let _ = pool.reset();
}
}
}
pub fn as_kv_cache(&self) -> Option<&KvCache> {
match self {
Self::Attention(kv) => Some(kv),
Self::Recurrent(_) => None,
}
}
pub fn as_kv_cache_mut(&mut self) -> Option<&mut KvCache> {
match self {
Self::Attention(kv) => Some(kv),
Self::Recurrent(_) => None,
}
}
pub fn as_recurrent_pool(&self) -> Option<&RecurrentStatePool> {
match self {
Self::Attention(_) => None,
Self::Recurrent(pool) => Some(pool),
}
}
pub fn as_recurrent_pool_mut(&mut self) -> Option<&mut RecurrentStatePool> {
match self {
Self::Attention(_) => None,
Self::Recurrent(pool) => Some(pool),
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum HybridLayerType {
Attention,
Recurrent,
}
#[derive(Clone, Debug)]
pub struct RecurrentLayerConfig {
pub conv_dim: usize,
pub conv_width: usize,
pub state_dims: Vec<usize>,
}
#[derive(Clone, Debug)]
pub struct HybridCacheConfig {
pub layer_types: Vec<HybridLayerType>,
pub max_seq_len: usize,
pub recurrent: RecurrentLayerConfig,
}
#[derive(Clone, Debug)]
pub struct HybridCache {
pub caches: Vec<HybridLayerCache>,
config: HybridCacheConfig,
state_indices: Option<Tensor>,
}
impl HybridCache {
pub const CACHE_GROW_SIZE: usize = 512;
pub fn new(
config: HybridCacheConfig,
dtype: candle_core::DType,
device: &Device,
) -> Result<Self> {
let mut caches = Vec::with_capacity(config.layer_types.len());
for layer_type in &config.layer_types {
let cache = match layer_type {
HybridLayerType::Attention => HybridLayerCache::Attention(KvCache::new_normal(
2,
config.max_seq_len,
Self::CACHE_GROW_SIZE,
)),
HybridLayerType::Recurrent => HybridLayerCache::Recurrent(RecurrentStatePool::new(
config.recurrent.conv_dim,
config.recurrent.conv_width,
config.recurrent.state_dims.clone(),
dtype,
device,
)?),
};
caches.push(cache);
}
Ok(Self {
caches,
config,
state_indices: None,
})
}
pub fn allocate_seq(&mut self) -> Option<usize> {
let recurrent_layers: Vec<usize> = self
.caches
.iter()
.enumerate()
.filter_map(|(idx, cache)| match cache {
HybridLayerCache::Recurrent(_) => Some(idx),
HybridLayerCache::Attention(_) => None,
})
.collect();
let mut expected_slot = None;
let mut allocated_slots = Vec::new();
for &layer_idx in &recurrent_layers {
let slot_idx = {
let HybridLayerCache::Recurrent(pool) = &mut self.caches[layer_idx] else {
unreachable!("recurrent_layers only contains recurrent entries");
};
match pool.allocate() {
Some(idx) => idx,
None => {
for (&rollback_layer_idx, &rollback_slot_idx) in
recurrent_layers.iter().zip(allocated_slots.iter())
{
if let HybridLayerCache::Recurrent(pool) =
&mut self.caches[rollback_layer_idx]
{
pool.free(rollback_slot_idx);
}
}
return None;
}
}
};
if let Some(expected) = expected_slot {
if slot_idx != expected {
tracing::warn!(
"Hybrid recurrent pool slot mismatch: expected {expected}, got {slot_idx}. Rolling back allocation."
);
if let HybridLayerCache::Recurrent(pool) = &mut self.caches[layer_idx] {
pool.free(slot_idx);
}
for (&rollback_layer_idx, &rollback_slot_idx) in
recurrent_layers.iter().zip(allocated_slots.iter())
{
if let HybridLayerCache::Recurrent(pool) =
&mut self.caches[rollback_layer_idx]
{
pool.free(rollback_slot_idx);
}
}
return None;
}
} else {
expected_slot = Some(slot_idx);
}
allocated_slots.push(slot_idx);
}
expected_slot
}
pub fn free_seq(&mut self, slot_idx: usize) {
for cache in &mut self.caches {
if let HybridLayerCache::Recurrent(pool) = cache {
pool.free(slot_idx);
}
}
}
pub fn reset_seq(&mut self, slot_idx: usize) -> Result<()> {
for cache in &mut self.caches {
if let HybridLayerCache::Recurrent(pool) = cache {
pool.reset_slot(slot_idx)?;
}
}
Ok(())
}
pub fn reset(&mut self) {
for cache in &mut self.caches {
cache.reset();
}
}
pub fn num_layers(&self) -> usize {
self.caches.len()
}
pub fn layer_types(&self) -> &[HybridLayerType] {
&self.config.layer_types
}
pub fn config(&self) -> &HybridCacheConfig {
&self.config
}
pub fn get_mut(&mut self, layer: usize) -> Option<&mut HybridLayerCache> {
self.caches.get_mut(layer)
}
pub fn get(&self, layer: usize) -> Option<&HybridLayerCache> {
self.caches.get(layer)
}
pub fn set_state_indices(&mut self, indices: Option<Tensor>) {
self.state_indices = indices;
}
pub fn state_indices(&self) -> Option<&Tensor> {
self.state_indices.as_ref()
}
}
impl PastKvLenCache for HybridCache {
fn get_past_kv_len(&self) -> Result<usize> {
for cache in &self.caches {
if let HybridLayerCache::Attention(kv) = cache {
return Ok(kv.current_seq_len());
}
}
Ok(0)
}
}
impl HybridCache {
pub fn truncate_attention_to(&mut self, len: usize) -> Result<()> {
for cache in &mut self.caches {
if let HybridLayerCache::Attention(kv) = cache {
kv.set_len(len)?;
}
}
Ok(())
}
}
#[derive(Clone, Debug)]
pub struct RecurrentStateSnapshot {
pub conv_state: Tensor,
pub recurrent_state: Tensor,
pub seqlen_offset: usize,
}
impl HybridCache {
#[allow(clippy::cast_possible_truncation)]
pub fn snapshot_recurrent_state(&self, slot_idx: usize) -> Result<Vec<RecurrentStateSnapshot>> {
let mut snapshots = Vec::new();
for cache in &self.caches {
if let HybridLayerCache::Recurrent(pool) = cache {
let idx_tensor = Tensor::from_vec(vec![slot_idx as u32], (1,), pool.device())?;
let conv = pool.gather_conv_state(&idx_tensor)?;
let recurrent = pool.gather_recurrent_state(&idx_tensor)?;
snapshots.push(RecurrentStateSnapshot {
conv_state: conv,
recurrent_state: recurrent,
seqlen_offset: pool.get_seqlen_offset(slot_idx),
});
}
}
Ok(snapshots)
}
#[allow(clippy::cast_possible_truncation)]
pub fn restore_recurrent_state(
&mut self,
slot_idx: usize,
snapshots: &[RecurrentStateSnapshot],
) -> Result<()> {
let mut snap_iter = snapshots.iter();
for cache in &mut self.caches {
if let HybridLayerCache::Recurrent(pool) = cache {
if let Some(snap) = snap_iter.next() {
let conv = snap.conv_state.to_device(pool.device())?;
let recurrent = snap.recurrent_state.to_device(pool.device())?;
let idx_tensor = Tensor::from_vec(vec![slot_idx as u32], (1,), pool.device())?;
pool.scatter_conv_state(&idx_tensor, &conv)?;
pool.scatter_recurrent_state(&idx_tensor, &recurrent)?;
pool.set_seqlen_offset(slot_idx, snap.seqlen_offset);
}
}
}
Ok(())
}
}