kizzasi_inference/
context.rs

1//! Inference context management
2//!
3//! Manages the sliding window of past inputs and hidden states
4//! for autoregressive prediction.
5
6use crate::error::{InferenceError, InferenceResult};
7use crate::pool::TensorPool;
8use kizzasi_core::HiddenState;
9use scirs2_core::ndarray::Array1;
10use std::collections::VecDeque;
11
12/// Configuration for inference context
13#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
14pub struct ContextConfig {
15    /// Maximum context window size
16    pub max_context: usize,
17    /// Whether to store full history or just hidden states
18    pub store_history: bool,
19    /// Number of model layers (for state management)
20    pub num_layers: usize,
21    /// Hidden state dimension
22    pub hidden_dim: usize,
23    /// State dimension (for SSMs)
24    pub state_dim: usize,
25}
26
27impl Default for ContextConfig {
28    fn default() -> Self {
29        Self {
30            max_context: 8192,
31            store_history: false,
32            num_layers: 4,
33            hidden_dim: 256,
34            state_dim: 16,
35        }
36    }
37}
38
39impl ContextConfig {
40    /// Create a new context configuration
41    pub fn new() -> Self {
42        Self::default()
43    }
44
45    /// Set maximum context size
46    pub fn max_context(mut self, size: usize) -> Self {
47        self.max_context = size;
48        self
49    }
50
51    /// Enable/disable history storage
52    pub fn store_history(mut self, store: bool) -> Self {
53        self.store_history = store;
54        self
55    }
56
57    /// Set number of layers
58    pub fn num_layers(mut self, n: usize) -> Self {
59        self.num_layers = n;
60        self
61    }
62}
63
64/// Manages inference context including history and hidden states
65pub struct InferenceContext {
66    config: ContextConfig,
67    /// History of past inputs (if store_history is true)
68    history: VecDeque<Array1<f32>>,
69    /// Hidden states for each layer
70    states: Vec<HiddenState>,
71    /// Number of steps processed
72    step_count: usize,
73    /// Optional memory pool for efficient allocation
74    pool: Option<TensorPool>,
75}
76
77impl InferenceContext {
78    /// Create a new inference context
79    pub fn new(config: ContextConfig) -> Self {
80        let states = (0..config.num_layers)
81            .map(|_| HiddenState::new(config.hidden_dim, config.state_dim))
82            .collect();
83
84        Self {
85            config,
86            history: VecDeque::new(),
87            states,
88            step_count: 0,
89            pool: None,
90        }
91    }
92
93    /// Create a new inference context with memory pooling enabled
94    pub fn with_pool(config: ContextConfig, pool: TensorPool) -> Self {
95        let states = (0..config.num_layers)
96            .map(|_| HiddenState::new(config.hidden_dim, config.state_dim))
97            .collect();
98
99        Self {
100            config,
101            history: VecDeque::new(),
102            states,
103            step_count: 0,
104            pool: Some(pool),
105        }
106    }
107
108    /// Get reference to the memory pool if available
109    pub fn pool(&self) -> Option<&TensorPool> {
110        self.pool.as_ref()
111    }
112
113    /// Enable memory pooling with specified pool
114    pub fn enable_pooling(&mut self, pool: TensorPool) {
115        self.pool = Some(pool);
116    }
117
118    /// Disable memory pooling
119    pub fn disable_pooling(&mut self) {
120        self.pool = None;
121    }
122
123    /// Reset the context to initial state
124    pub fn reset(&mut self) {
125        self.history.clear();
126        for state in &mut self.states {
127            state.reset();
128        }
129        self.step_count = 0;
130    }
131
132    /// Add an input to the history
133    pub fn push(&mut self, input: Array1<f32>) {
134        if self.config.store_history {
135            if self.history.len() >= self.config.max_context {
136                self.history.pop_front();
137            }
138            self.history.push_back(input);
139        }
140        self.step_count += 1;
141    }
142
143    /// Get the current step count
144    pub fn step_count(&self) -> usize {
145        self.step_count
146    }
147
148    /// Get the history length
149    pub fn history_len(&self) -> usize {
150        self.history.len()
151    }
152
153    /// Get recent history as slice
154    pub fn recent_history(&self, n: usize) -> Vec<&Array1<f32>> {
155        self.history.iter().rev().take(n).collect()
156    }
157
158    /// Get hidden states
159    pub fn states(&self) -> &[HiddenState] {
160        &self.states
161    }
162
163    /// Get mutable hidden states
164    pub fn states_mut(&mut self) -> &mut [HiddenState] {
165        &mut self.states
166    }
167
168    /// Update hidden state for a layer
169    pub fn update_state(&mut self, layer: usize, state: HiddenState) -> InferenceResult<()> {
170        if layer >= self.states.len() {
171            return Err(InferenceError::DimensionMismatch {
172                expected: self.states.len(),
173                got: layer + 1,
174            });
175        }
176        self.states[layer] = state;
177        Ok(())
178    }
179
180    /// Get configuration
181    pub fn config(&self) -> &ContextConfig {
182        &self.config
183    }
184
185    /// Check if context is at capacity
186    pub fn is_full(&self) -> bool {
187        self.step_count >= self.config.max_context
188    }
189
190    /// Get the full history
191    pub fn history(&self) -> &VecDeque<Array1<f32>> {
192        &self.history
193    }
194
195    /// Trim history to specified length (for memory efficiency)
196    pub fn trim_history(&mut self, max_len: usize) {
197        while self.history.len() > max_len {
198            self.history.pop_front();
199        }
200    }
201}
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206
207    #[test]
208    fn test_context_creation() {
209        let config = ContextConfig::new().max_context(100).num_layers(2);
210        let ctx = InferenceContext::new(config);
211
212        assert_eq!(ctx.step_count(), 0);
213        assert_eq!(ctx.states().len(), 2);
214    }
215
216    #[test]
217    fn test_context_push() {
218        let config = ContextConfig::new().store_history(true).max_context(5);
219        let mut ctx = InferenceContext::new(config);
220
221        for i in 0..10 {
222            ctx.push(Array1::from_vec(vec![i as f32]));
223        }
224
225        assert_eq!(ctx.step_count(), 10);
226        assert_eq!(ctx.history_len(), 5); // Max context
227    }
228
229    #[test]
230    fn test_context_reset() {
231        let config = ContextConfig::new().store_history(true);
232        let mut ctx = InferenceContext::new(config);
233
234        ctx.push(Array1::from_vec(vec![1.0]));
235        ctx.push(Array1::from_vec(vec![2.0]));
236
237        assert_eq!(ctx.step_count(), 2);
238
239        ctx.reset();
240
241        assert_eq!(ctx.step_count(), 0);
242        assert_eq!(ctx.history_len(), 0);
243    }
244}