1use crate::context::{ContextConfig, InferenceContext};
7use crate::error::{InferenceError, InferenceResult};
8use crate::sampling::{Sampler, SamplingConfig};
9use kizzasi_model::AutoregressiveModel;
10use scirs2_core::ndarray::Array1;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize, Default)]
14pub enum InferenceMode {
15 #[default]
17 Standard,
18 LowMemory,
20 Streaming,
22 Quantized,
24}
25
26#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
28pub struct EngineConfig {
29 pub input_dim: usize,
31 pub output_dim: usize,
33 pub context: ContextConfig,
35 pub apply_constraints: bool,
37 pub sampling: SamplingConfig,
39 pub use_embeddings: bool,
41 pub inference_mode: InferenceMode,
43 pub state_prune_threshold: f32,
46 pub max_history_length: Option<usize>,
48}
49
50impl Default for EngineConfig {
51 fn default() -> Self {
52 Self {
53 input_dim: 1,
54 output_dim: 1,
55 context: ContextConfig::default(),
56 apply_constraints: true,
57 sampling: SamplingConfig::default(),
58 use_embeddings: false,
59 inference_mode: InferenceMode::Standard,
60 state_prune_threshold: 1e-6,
61 max_history_length: None,
62 }
63 }
64}
65
66impl EngineConfig {
67 pub fn new(input_dim: usize, output_dim: usize) -> Self {
69 Self {
70 input_dim,
71 output_dim,
72 ..Default::default()
73 }
74 }
75
76 pub fn context(mut self, config: ContextConfig) -> Self {
78 self.context = config;
79 self
80 }
81
82 pub fn apply_constraints(mut self, apply: bool) -> Self {
84 self.apply_constraints = apply;
85 self
86 }
87
88 pub fn sampling(mut self, sampling: SamplingConfig) -> Self {
90 self.sampling = sampling;
91 self
92 }
93
94 pub fn use_embeddings(mut self, use_emb: bool) -> Self {
96 self.use_embeddings = use_emb;
97 self
98 }
99
100 pub fn inference_mode(mut self, mode: InferenceMode) -> Self {
102 self.inference_mode = mode;
103 match mode {
105 InferenceMode::LowMemory => {
106 self.max_history_length = Some(128);
107 self.state_prune_threshold = 1e-4;
108 }
109 InferenceMode::Streaming => {
110 self.max_history_length = Some(64);
111 self.state_prune_threshold = 1e-3;
112 }
113 InferenceMode::Quantized => {
114 self.state_prune_threshold = 1e-2;
115 }
116 InferenceMode::Standard => {}
117 }
118 self
119 }
120
121 pub fn state_prune_threshold(mut self, threshold: f32) -> Self {
123 self.state_prune_threshold = threshold;
124 self
125 }
126
127 pub fn max_history_length(mut self, length: usize) -> Self {
129 self.max_history_length = Some(length);
130 self
131 }
132}
133
134pub struct InferenceEngine {
136 config: EngineConfig,
137 context: InferenceContext,
138 model: Option<Box<dyn AutoregressiveModel>>,
139 sampler: Sampler,
140 initialized: bool,
141}
142
143impl InferenceEngine {
144 pub fn new(config: EngineConfig) -> Self {
146 let context = InferenceContext::new(config.context.clone());
147 let sampler = Sampler::new(config.sampling.clone());
148 Self {
149 config,
150 context,
151 model: None,
152 sampler,
153 initialized: false,
154 }
155 }
156
157 pub fn with_model(mut config: EngineConfig, model: Box<dyn AutoregressiveModel>) -> Self {
159 config.context.num_layers = model.num_layers();
161 config.context.hidden_dim = model.hidden_dim();
162 config.context.state_dim = model.state_dim();
163
164 let context = InferenceContext::new(config.context.clone());
165 let sampler = Sampler::new(config.sampling.clone());
166 Self {
167 config,
168 context,
169 model: Some(model),
170 sampler,
171 initialized: true,
172 }
173 }
174
175 pub fn set_model(&mut self, model: Box<dyn AutoregressiveModel>) {
179 self.config.context.num_layers = model.num_layers();
181 self.config.context.hidden_dim = model.hidden_dim();
182 self.config.context.state_dim = model.state_dim();
183
184 self.context = InferenceContext::new(self.config.context.clone());
186
187 self.model = Some(model);
188 self.initialized = true;
189 }
190
191 pub fn has_model(&self) -> bool {
193 self.model.is_some()
194 }
195
196 pub fn step(&mut self, input: &Array1<f32>) -> InferenceResult<Array1<f32>> {
201 if !self.initialized {
202 return Err(InferenceError::NotInitialized);
203 }
204
205 if input.len() != self.config.input_dim {
206 return Err(InferenceError::DimensionMismatch {
207 expected: self.config.input_dim,
208 got: input.len(),
209 });
210 }
211
212 self.context.push(input.clone());
214
215 let logits = if let Some(ref mut model) = self.model {
217 let states = self.context.states().to_vec();
219 model
220 .set_states(states)
221 .map_err(|e| InferenceError::ForwardError(e.to_string()))?;
222
223 let output = model
225 .step(input)
226 .map_err(|e| InferenceError::ForwardError(e.to_string()))?;
227
228 let mut new_states = model.get_states();
230
231 self.apply_memory_optimization(&mut new_states);
233
234 for (i, state) in new_states.into_iter().enumerate() {
235 self.context.update_state(i, state)?;
236 }
237
238 output
239 } else {
240 Array1::zeros(self.config.output_dim)
242 };
243
244 let output = if self.config.use_embeddings {
246 let sampled_idx = self.sampler.sample(&logits)?;
248 Array1::from_elem(1, sampled_idx)
249 } else {
250 if (self.config.sampling.temperature - 1.0).abs() > 1e-6 {
252 logits.mapv(|x| x * self.config.sampling.temperature)
253 } else {
254 logits
255 }
256 };
257
258 Ok(output)
259 }
260
261 pub fn rollout(
265 &mut self,
266 input: &Array1<f32>,
267 steps: usize,
268 ) -> InferenceResult<Vec<Array1<f32>>> {
269 let mut outputs = Vec::with_capacity(steps);
270 let mut current = input.clone();
271
272 for _ in 0..steps {
273 let output = self.step(¤t)?;
274 outputs.push(output.clone());
275 current = output;
276 }
277
278 Ok(outputs)
279 }
280
281 pub fn reset(&mut self) {
283 self.context.reset();
284 }
285
286 pub fn step_count(&self) -> usize {
288 self.context.step_count()
289 }
290
291 pub fn config(&self) -> &EngineConfig {
293 &self.config
294 }
295
296 pub fn context(&self) -> &InferenceContext {
298 &self.context
299 }
300
301 pub fn sampler_mut(&mut self) -> &mut Sampler {
303 &mut self.sampler
304 }
305
306 pub fn sampler(&self) -> &Sampler {
308 &self.sampler
309 }
310
311 pub fn step_batch(&mut self, inputs: &[Array1<f32>]) -> InferenceResult<Vec<Array1<f32>>> {
316 if !self.initialized {
317 return Err(InferenceError::NotInitialized);
318 }
319
320 let mut outputs = Vec::with_capacity(inputs.len());
321
322 for input in inputs {
323 let output = self.step(input)?;
324 outputs.push(output);
325 }
326
327 Ok(outputs)
328 }
329
330 pub fn model_info(&self) -> Option<ModelInfo> {
332 self.model.as_ref().map(|model| ModelInfo {
333 model_type: model.model_type(),
334 hidden_dim: model.hidden_dim(),
335 state_dim: model.state_dim(),
336 num_layers: model.num_layers(),
337 })
338 }
339
340 fn apply_memory_optimization(&mut self, states: &mut [kizzasi_core::HiddenState]) {
342 match self.config.inference_mode {
343 InferenceMode::Standard => {
344 }
346 InferenceMode::LowMemory | InferenceMode::Streaming => {
347 self.prune_states(states);
349 if let Some(max_len) = self.config.max_history_length {
351 if self.context.history().len() > max_len {
352 self.context.trim_history(max_len);
353 }
354 }
355 }
356 InferenceMode::Quantized => {
357 self.quantize_states(states);
359 self.prune_states(states);
361 }
362 }
363 }
364
365 fn prune_states(&self, states: &mut [kizzasi_core::HiddenState]) {
367 let threshold = self.config.state_prune_threshold;
368 for state in states.iter_mut() {
369 let pruned = state
370 .state()
371 .mapv(|x| if x.abs() < threshold { 0.0 } else { x });
372 state.update(pruned);
373 }
374 }
375
376 fn quantize_states(&self, states: &mut [kizzasi_core::HiddenState]) {
378 for state in states.iter_mut() {
381 let quantized = state.state().mapv(|x| {
382 let scale = 64.0;
384 (x * scale).round() / scale
385 });
386 state.update(quantized);
387 }
388 }
389}
390
391#[derive(Debug, Clone)]
393pub struct ModelInfo {
394 pub model_type: kizzasi_model::ModelType,
395 pub hidden_dim: usize,
396 pub state_dim: usize,
397 pub num_layers: usize,
398}
399
400#[cfg(test)]
401mod tests {
402 use super::*;
403
404 #[test]
405 fn test_engine_creation() {
406 let config = EngineConfig::new(3, 3);
407 let engine = InferenceEngine::new(config);
408
409 assert_eq!(engine.step_count(), 0);
410 assert!(!engine.has_model());
411 }
412
413 #[test]
414 fn test_engine_step_no_model() {
415 let config = EngineConfig::new(3, 3);
416 let mut engine = InferenceEngine::new(config);
417
418 let input = Array1::from_vec(vec![0.1, 0.2, 0.3]);
419 let output = engine.step(&input);
420
421 assert!(output.is_err());
423 }
424
425 #[test]
426 fn test_engine_with_model() {
427 use kizzasi_model::rwkv::{Rwkv, RwkvConfig};
428
429 let model_config = RwkvConfig::new()
430 .input_dim(1)
431 .hidden_dim(64)
432 .intermediate_dim(256)
433 .num_layers(2);
434 let model = Rwkv::new(model_config).unwrap();
435
436 let config = EngineConfig::new(1, 10);
437 let mut engine = InferenceEngine::with_model(config, Box::new(model));
438
439 assert!(engine.has_model());
440
441 let input = Array1::from_vec(vec![0.5]);
442 let output = engine.step(&input);
443
444 if let Err(e) = &output {
445 eprintln!("Error: {:?}", e);
446 }
447 assert!(output.is_ok(), "Expected Ok, got: {:?}", output);
448 assert_eq!(engine.step_count(), 1);
449 }
450
451 #[test]
452 fn test_engine_rollout() {
453 use kizzasi_model::s4::{S4Config, S4D};
454
455 let model_config = S4Config::new()
456 .input_dim(1)
457 .hidden_dim(64)
458 .state_dim(16)
459 .num_layers(2)
460 .diagonal(true);
461 let model = S4D::new(model_config).unwrap();
462
463 let config = EngineConfig::new(1, 10);
464 let mut engine = InferenceEngine::with_model(config, Box::new(model));
465
466 let input = Array1::from_vec(vec![0.5]);
467 let outputs = engine.rollout(&input, 5);
468
469 assert!(outputs.is_ok());
470 assert_eq!(outputs.unwrap().len(), 5);
471 assert_eq!(engine.step_count(), 5);
472 }
473
474 #[test]
475 fn test_engine_batch() {
476 use kizzasi_model::s4::{S4Config, S4D};
477
478 let model_config = S4Config::new()
479 .input_dim(1)
480 .hidden_dim(64)
481 .state_dim(16)
482 .num_layers(2)
483 .diagonal(true);
484 let model = S4D::new(model_config).unwrap();
485
486 let config = EngineConfig::new(1, 10);
487 let mut engine = InferenceEngine::with_model(config, Box::new(model));
488
489 let inputs = vec![
490 Array1::from_vec(vec![0.1]),
491 Array1::from_vec(vec![0.2]),
492 Array1::from_vec(vec![0.3]),
493 ];
494
495 let outputs = engine.step_batch(&inputs);
496 assert!(outputs.is_ok());
497 assert_eq!(outputs.unwrap().len(), 3);
498 }
499
500 #[test]
501 fn test_model_info() {
502 use kizzasi_model::rwkv::{Rwkv, RwkvConfig};
503
504 let model_config = RwkvConfig::new()
505 .input_dim(1)
506 .hidden_dim(128)
507 .intermediate_dim(512)
508 .num_layers(4);
509 let model = Rwkv::new(model_config).unwrap();
510
511 let config = EngineConfig::new(1, 50);
512 let engine = InferenceEngine::with_model(config, Box::new(model));
513
514 let info = engine.model_info();
515 assert!(info.is_some());
516
517 let info = info.unwrap();
518 assert_eq!(info.hidden_dim, 128);
519 assert_eq!(info.num_layers, 4);
520 }
521}