pocket-tts 0.6.2

High-performance CPU-based Text-to-Speech library using Candle
Documentation
//! Voice state management for streaming generation and voice cloning

use candle_core::{Result, Tensor};
use std::collections::HashMap;

/// Model state type for stateful modules
pub type ModelState = HashMap<String, HashMap<String, Tensor>>;

/// Common per-attention state keys.
pub const ATTN_POS_KEY: &str = "pos";
pub const ATTN_LEN_KEY: &str = "l";
pub const ATTN_HEAD_KEY: &str = "head";
pub const ATTN_K_BUF_KEY: &str = "k_buf";
pub const ATTN_V_BUF_KEY: &str = "v_buf";

/// Cursor/scalar metadata for attention cache state.
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct AttentionCursor {
    pub pos: usize,
    pub len: usize,
    pub head: usize,
}

/// Initialize empty model state for all stateful modules
///
/// Creates a nested HashMap structure that will be populated
/// as modules run their forward passes.
pub fn init_states(_batch_size: usize, _seq_len: usize) -> ModelState {
    // Start with empty state - modules will populate as needed
    HashMap::new()
}

/// Get or create a module's state entry
pub fn get_or_create_state<'a>(
    state: &'a mut ModelState,
    module_name: &str,
) -> &'a mut HashMap<String, Tensor> {
    state.entry(module_name.to_string()).or_default()
}

fn tensor_to_usize(t: &Tensor) -> Option<usize> {
    if let Ok(v) = t.to_scalar::<i64>() {
        return Some(v.max(0) as usize);
    }
    if let Ok(v) = t.to_scalar::<u32>() {
        return Some(v as usize);
    }
    None
}

/// Read attention cursor values from a module state map.
pub fn read_attention_cursor(module_state: &HashMap<String, Tensor>) -> AttentionCursor {
    AttentionCursor {
        pos: module_state
            .get(ATTN_POS_KEY)
            .and_then(tensor_to_usize)
            .unwrap_or(0),
        len: module_state
            .get(ATTN_LEN_KEY)
            .and_then(tensor_to_usize)
            .unwrap_or(0),
        head: module_state
            .get(ATTN_HEAD_KEY)
            .and_then(tensor_to_usize)
            .unwrap_or(0),
    }
}

/// Write attention cursor values into a module state map.
pub fn write_attention_cursor(
    module_state: &mut HashMap<String, Tensor>,
    cursor: AttentionCursor,
    device: &candle_core::Device,
) -> Result<()> {
    module_state.insert(
        ATTN_POS_KEY.to_string(),
        Tensor::new(cursor.pos as u32, device)?,
    );
    module_state.insert(
        ATTN_LEN_KEY.to_string(),
        Tensor::new(cursor.len as i64, device)?,
    );
    module_state.insert(
        ATTN_HEAD_KEY.to_string(),
        Tensor::new(cursor.head as i64, device)?,
    );
    Ok(())
}

/// Read attention cursor for a module name from a full model state.
pub fn get_attention_cursor(state: &ModelState, module_name: &str) -> AttentionCursor {
    state
        .get(module_name)
        .map(read_attention_cursor)
        .unwrap_or_default()
}

/// Increment step counters in model state for all modules
///
/// This is used after processing tokens to update position information
/// for streaming generation.
pub fn increment_steps(state: &mut ModelState, key: &str, increment: usize) {
    for (_module_name, module_state) in state.iter_mut() {
        if let Some(step_tensor) = module_state.get_mut(key)
            && let Ok(current) = step_tensor.to_scalar::<i64>()
            && let Ok(new_tensor) = Tensor::new(current + increment as i64, step_tensor.device())
        {
            *step_tensor = new_tensor;
        }
    }
}

/// Get the current step/offset for a module
pub fn get_offset(state: &ModelState, module_name: &str) -> usize {
    state
        .get(module_name)
        .and_then(|s| s.get("offset"))
        .and_then(|t| t.to_scalar::<i64>().ok())
        .unwrap_or(0) as usize
}

/// Set the offset for a module
pub fn set_offset(state: &mut ModelState, module_name: &str, offset: usize) -> Result<()> {
    let module_state = get_or_create_state(state, module_name);
    let device = module_state
        .values()
        .next()
        .map(|t| t.device().clone())
        .unwrap_or(candle_core::Device::Cpu);
    module_state.insert("offset".to_string(), Tensor::new(offset as i64, &device)?);
    Ok(())
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_init_states() {
        let state = init_states(1, 100);
        assert!(state.is_empty());
    }

    #[test]
    fn test_get_or_create_state() {
        let mut state = init_states(1, 100);
        let module_state = get_or_create_state(&mut state, "test_module");
        assert!(module_state.is_empty());
        assert!(state.contains_key("test_module"));
    }

    #[test]
    fn test_offset_operations() -> Result<()> {
        let mut state = init_states(1, 100);

        // Initially offset is 0
        assert_eq!(get_offset(&state, "test"), 0);

        // Set offset
        set_offset(&mut state, "test", 42)?;
        assert_eq!(get_offset(&state, "test"), 42);

        Ok(())
    }
}