use candle_core::{Result, Tensor};
use std::collections::HashMap;
pub type ModelState = HashMap<String, HashMap<String, Tensor>>;
pub const ATTN_POS_KEY: &str = "pos";
pub const ATTN_LEN_KEY: &str = "l";
pub const ATTN_HEAD_KEY: &str = "head";
pub const ATTN_K_BUF_KEY: &str = "k_buf";
pub const ATTN_V_BUF_KEY: &str = "v_buf";
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct AttentionCursor {
pub pos: usize,
pub len: usize,
pub head: usize,
}
pub fn init_states(_batch_size: usize, _seq_len: usize) -> ModelState {
HashMap::new()
}
pub fn get_or_create_state<'a>(
state: &'a mut ModelState,
module_name: &str,
) -> &'a mut HashMap<String, Tensor> {
state.entry(module_name.to_string()).or_default()
}
fn tensor_to_usize(t: &Tensor) -> Option<usize> {
if let Ok(v) = t.to_scalar::<i64>() {
return Some(v.max(0) as usize);
}
if let Ok(v) = t.to_scalar::<u32>() {
return Some(v as usize);
}
None
}
pub fn read_attention_cursor(module_state: &HashMap<String, Tensor>) -> AttentionCursor {
AttentionCursor {
pos: module_state
.get(ATTN_POS_KEY)
.and_then(tensor_to_usize)
.unwrap_or(0),
len: module_state
.get(ATTN_LEN_KEY)
.and_then(tensor_to_usize)
.unwrap_or(0),
head: module_state
.get(ATTN_HEAD_KEY)
.and_then(tensor_to_usize)
.unwrap_or(0),
}
}
pub fn write_attention_cursor(
module_state: &mut HashMap<String, Tensor>,
cursor: AttentionCursor,
device: &candle_core::Device,
) -> Result<()> {
module_state.insert(
ATTN_POS_KEY.to_string(),
Tensor::new(cursor.pos as u32, device)?,
);
module_state.insert(
ATTN_LEN_KEY.to_string(),
Tensor::new(cursor.len as i64, device)?,
);
module_state.insert(
ATTN_HEAD_KEY.to_string(),
Tensor::new(cursor.head as i64, device)?,
);
Ok(())
}
pub fn get_attention_cursor(state: &ModelState, module_name: &str) -> AttentionCursor {
state
.get(module_name)
.map(read_attention_cursor)
.unwrap_or_default()
}
pub fn increment_steps(state: &mut ModelState, key: &str, increment: usize) {
for (_module_name, module_state) in state.iter_mut() {
if let Some(step_tensor) = module_state.get_mut(key)
&& let Ok(current) = step_tensor.to_scalar::<i64>()
&& let Ok(new_tensor) = Tensor::new(current + increment as i64, step_tensor.device())
{
*step_tensor = new_tensor;
}
}
}
pub fn get_offset(state: &ModelState, module_name: &str) -> usize {
state
.get(module_name)
.and_then(|s| s.get("offset"))
.and_then(|t| t.to_scalar::<i64>().ok())
.unwrap_or(0) as usize
}
pub fn set_offset(state: &mut ModelState, module_name: &str, offset: usize) -> Result<()> {
let module_state = get_or_create_state(state, module_name);
let device = module_state
.values()
.next()
.map(|t| t.device().clone())
.unwrap_or(candle_core::Device::Cpu);
module_state.insert("offset".to_string(), Tensor::new(offset as i64, &device)?);
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_init_states() {
let state = init_states(1, 100);
assert!(state.is_empty());
}
#[test]
fn test_get_or_create_state() {
let mut state = init_states(1, 100);
let module_state = get_or_create_state(&mut state, "test_module");
assert!(module_state.is_empty());
assert!(state.contains_key("test_module"));
}
#[test]
fn test_offset_operations() -> Result<()> {
let mut state = init_states(1, 100);
assert_eq!(get_offset(&state, "test"), 0);
set_offset(&mut state, "test", 42)?;
assert_eq!(get_offset(&state, "test"), 42);
Ok(())
}
}