kizzasi_inference/
context.rs1use crate::error::{InferenceError, InferenceResult};
7use crate::pool::TensorPool;
8use kizzasi_core::HiddenState;
9use scirs2_core::ndarray::Array1;
10use std::collections::VecDeque;
11
12#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
14pub struct ContextConfig {
15 pub max_context: usize,
17 pub store_history: bool,
19 pub num_layers: usize,
21 pub hidden_dim: usize,
23 pub state_dim: usize,
25}
26
27impl Default for ContextConfig {
28 fn default() -> Self {
29 Self {
30 max_context: 8192,
31 store_history: false,
32 num_layers: 4,
33 hidden_dim: 256,
34 state_dim: 16,
35 }
36 }
37}
38
39impl ContextConfig {
40 pub fn new() -> Self {
42 Self::default()
43 }
44
45 pub fn max_context(mut self, size: usize) -> Self {
47 self.max_context = size;
48 self
49 }
50
51 pub fn store_history(mut self, store: bool) -> Self {
53 self.store_history = store;
54 self
55 }
56
57 pub fn num_layers(mut self, n: usize) -> Self {
59 self.num_layers = n;
60 self
61 }
62}
63
64pub struct InferenceContext {
66 config: ContextConfig,
67 history: VecDeque<Array1<f32>>,
69 states: Vec<HiddenState>,
71 step_count: usize,
73 pool: Option<TensorPool>,
75}
76
77impl InferenceContext {
78 pub fn new(config: ContextConfig) -> Self {
80 let states = (0..config.num_layers)
81 .map(|_| HiddenState::new(config.hidden_dim, config.state_dim))
82 .collect();
83
84 Self {
85 config,
86 history: VecDeque::new(),
87 states,
88 step_count: 0,
89 pool: None,
90 }
91 }
92
93 pub fn with_pool(config: ContextConfig, pool: TensorPool) -> Self {
95 let states = (0..config.num_layers)
96 .map(|_| HiddenState::new(config.hidden_dim, config.state_dim))
97 .collect();
98
99 Self {
100 config,
101 history: VecDeque::new(),
102 states,
103 step_count: 0,
104 pool: Some(pool),
105 }
106 }
107
108 pub fn pool(&self) -> Option<&TensorPool> {
110 self.pool.as_ref()
111 }
112
113 pub fn enable_pooling(&mut self, pool: TensorPool) {
115 self.pool = Some(pool);
116 }
117
118 pub fn disable_pooling(&mut self) {
120 self.pool = None;
121 }
122
123 pub fn reset(&mut self) {
125 self.history.clear();
126 for state in &mut self.states {
127 state.reset();
128 }
129 self.step_count = 0;
130 }
131
132 pub fn push(&mut self, input: Array1<f32>) {
134 if self.config.store_history {
135 if self.history.len() >= self.config.max_context {
136 self.history.pop_front();
137 }
138 self.history.push_back(input);
139 }
140 self.step_count += 1;
141 }
142
143 pub fn step_count(&self) -> usize {
145 self.step_count
146 }
147
148 pub fn history_len(&self) -> usize {
150 self.history.len()
151 }
152
153 pub fn recent_history(&self, n: usize) -> Vec<&Array1<f32>> {
155 self.history.iter().rev().take(n).collect()
156 }
157
158 pub fn states(&self) -> &[HiddenState] {
160 &self.states
161 }
162
163 pub fn states_mut(&mut self) -> &mut [HiddenState] {
165 &mut self.states
166 }
167
168 pub fn update_state(&mut self, layer: usize, state: HiddenState) -> InferenceResult<()> {
170 if layer >= self.states.len() {
171 return Err(InferenceError::DimensionMismatch {
172 expected: self.states.len(),
173 got: layer + 1,
174 });
175 }
176 self.states[layer] = state;
177 Ok(())
178 }
179
180 pub fn config(&self) -> &ContextConfig {
182 &self.config
183 }
184
185 pub fn is_full(&self) -> bool {
187 self.step_count >= self.config.max_context
188 }
189
190 pub fn history(&self) -> &VecDeque<Array1<f32>> {
192 &self.history
193 }
194
195 pub fn trim_history(&mut self, max_len: usize) {
197 while self.history.len() > max_len {
198 self.history.pop_front();
199 }
200 }
201}
202
203#[cfg(test)]
204mod tests {
205 use super::*;
206
207 #[test]
208 fn test_context_creation() {
209 let config = ContextConfig::new().max_context(100).num_layers(2);
210 let ctx = InferenceContext::new(config);
211
212 assert_eq!(ctx.step_count(), 0);
213 assert_eq!(ctx.states().len(), 2);
214 }
215
216 #[test]
217 fn test_context_push() {
218 let config = ContextConfig::new().store_history(true).max_context(5);
219 let mut ctx = InferenceContext::new(config);
220
221 for i in 0..10 {
222 ctx.push(Array1::from_vec(vec![i as f32]));
223 }
224
225 assert_eq!(ctx.step_count(), 10);
226 assert_eq!(ctx.history_len(), 5); }
228
229 #[test]
230 fn test_context_reset() {
231 let config = ContextConfig::new().store_history(true);
232 let mut ctx = InferenceContext::new(config);
233
234 ctx.push(Array1::from_vec(vec![1.0]));
235 ctx.push(Array1::from_vec(vec![2.0]));
236
237 assert_eq!(ctx.step_count(), 2);
238
239 ctx.reset();
240
241 assert_eq!(ctx.step_count(), 0);
242 assert_eq!(ctx.history_len(), 0);
243 }
244}