kizzasi_inference/
engine.rs

1//! Core inference engine
2//!
3//! The InferenceEngine coordinates tokenization, model forward pass,
4//! and optional constraint enforcement.
5
6use crate::context::{ContextConfig, InferenceContext};
7use crate::error::{InferenceError, InferenceResult};
8use crate::sampling::{Sampler, SamplingConfig};
9use kizzasi_model::AutoregressiveModel;
10use scirs2_core::ndarray::Array1;
11
12/// Memory-efficient inference modes
13#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize, Default)]
14pub enum InferenceMode {
15    /// Standard mode: full precision and all states kept
16    #[default]
17    Standard,
18    /// Low memory: aggressive state pruning, limited history
19    LowMemory,
20    /// Streaming mode: minimal state retention, optimized for real-time
21    Streaming,
22    /// Quantized: use reduced precision for states and activations
23    Quantized,
24}
25
26/// Configuration for the inference engine
27#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
28pub struct EngineConfig {
29    /// Input dimension
30    pub input_dim: usize,
31    /// Output dimension
32    pub output_dim: usize,
33    /// Context configuration
34    pub context: ContextConfig,
35    /// Whether to apply constraints
36    pub apply_constraints: bool,
37    /// Sampling configuration
38    pub sampling: SamplingConfig,
39    /// Whether to use embeddings (for discrete outputs)
40    pub use_embeddings: bool,
41    /// Inference mode for memory efficiency
42    pub inference_mode: InferenceMode,
43    /// State pruning threshold (for LowMemory mode)
44    /// States with values below this threshold are zeroed out
45    pub state_prune_threshold: f32,
46    /// Maximum history length (for LowMemory/Streaming modes)
47    pub max_history_length: Option<usize>,
48}
49
50impl Default for EngineConfig {
51    fn default() -> Self {
52        Self {
53            input_dim: 1,
54            output_dim: 1,
55            context: ContextConfig::default(),
56            apply_constraints: true,
57            sampling: SamplingConfig::default(),
58            use_embeddings: false,
59            inference_mode: InferenceMode::Standard,
60            state_prune_threshold: 1e-6,
61            max_history_length: None,
62        }
63    }
64}
65
66impl EngineConfig {
67    /// Create a new engine configuration
68    pub fn new(input_dim: usize, output_dim: usize) -> Self {
69        Self {
70            input_dim,
71            output_dim,
72            ..Default::default()
73        }
74    }
75
76    /// Set context configuration
77    pub fn context(mut self, config: ContextConfig) -> Self {
78        self.context = config;
79        self
80    }
81
82    /// Enable/disable constraint enforcement
83    pub fn apply_constraints(mut self, apply: bool) -> Self {
84        self.apply_constraints = apply;
85        self
86    }
87
88    /// Set sampling configuration
89    pub fn sampling(mut self, sampling: SamplingConfig) -> Self {
90        self.sampling = sampling;
91        self
92    }
93
94    /// Enable embeddings for discrete outputs
95    pub fn use_embeddings(mut self, use_emb: bool) -> Self {
96        self.use_embeddings = use_emb;
97        self
98    }
99
100    /// Set inference mode for memory efficiency
101    pub fn inference_mode(mut self, mode: InferenceMode) -> Self {
102        self.inference_mode = mode;
103        // Auto-configure based on mode
104        match mode {
105            InferenceMode::LowMemory => {
106                self.max_history_length = Some(128);
107                self.state_prune_threshold = 1e-4;
108            }
109            InferenceMode::Streaming => {
110                self.max_history_length = Some(64);
111                self.state_prune_threshold = 1e-3;
112            }
113            InferenceMode::Quantized => {
114                self.state_prune_threshold = 1e-2;
115            }
116            InferenceMode::Standard => {}
117        }
118        self
119    }
120
121    /// Set state pruning threshold
122    pub fn state_prune_threshold(mut self, threshold: f32) -> Self {
123        self.state_prune_threshold = threshold;
124        self
125    }
126
127    /// Set maximum history length
128    pub fn max_history_length(mut self, length: usize) -> Self {
129        self.max_history_length = Some(length);
130        self
131    }
132}
133
134/// The main inference engine for AGSP
135pub struct InferenceEngine {
136    config: EngineConfig,
137    context: InferenceContext,
138    model: Option<Box<dyn AutoregressiveModel>>,
139    sampler: Sampler,
140    initialized: bool,
141}
142
143impl InferenceEngine {
144    /// Create a new inference engine without a model
145    pub fn new(config: EngineConfig) -> Self {
146        let context = InferenceContext::new(config.context.clone());
147        let sampler = Sampler::new(config.sampling.clone());
148        Self {
149            config,
150            context,
151            model: None,
152            sampler,
153            initialized: false,
154        }
155    }
156
157    /// Create a new inference engine with a model
158    pub fn with_model(mut config: EngineConfig, model: Box<dyn AutoregressiveModel>) -> Self {
159        // Update context config to match model
160        config.context.num_layers = model.num_layers();
161        config.context.hidden_dim = model.hidden_dim();
162        config.context.state_dim = model.state_dim();
163
164        let context = InferenceContext::new(config.context.clone());
165        let sampler = Sampler::new(config.sampling.clone());
166        Self {
167            config,
168            context,
169            model: Some(model),
170            sampler,
171            initialized: true,
172        }
173    }
174
175    /// Set the model for this engine
176    ///
177    /// This will update the context configuration to match the model's architecture
178    pub fn set_model(&mut self, model: Box<dyn AutoregressiveModel>) {
179        // Update context config to match model
180        self.config.context.num_layers = model.num_layers();
181        self.config.context.hidden_dim = model.hidden_dim();
182        self.config.context.state_dim = model.state_dim();
183
184        // Recreate context with updated config
185        self.context = InferenceContext::new(self.config.context.clone());
186
187        self.model = Some(model);
188        self.initialized = true;
189    }
190
191    /// Check if a model is loaded
192    pub fn has_model(&self) -> bool {
193        self.model.is_some()
194    }
195
196    /// Perform a single inference step
197    ///
198    /// This is the core autoregressive prediction:
199    /// Given input x_t, predict x_{t+1}
200    pub fn step(&mut self, input: &Array1<f32>) -> InferenceResult<Array1<f32>> {
201        if !self.initialized {
202            return Err(InferenceError::NotInitialized);
203        }
204
205        if input.len() != self.config.input_dim {
206            return Err(InferenceError::DimensionMismatch {
207                expected: self.config.input_dim,
208                got: input.len(),
209            });
210        }
211
212        // Store in context
213        self.context.push(input.clone());
214
215        // Run model forward pass if available
216        let logits = if let Some(ref mut model) = self.model {
217            // Set model states from context
218            let states = self.context.states().to_vec();
219            model
220                .set_states(states)
221                .map_err(|e| InferenceError::ForwardError(e.to_string()))?;
222
223            // Forward pass through model (SignalPredictor trait)
224            let output = model
225                .step(input)
226                .map_err(|e| InferenceError::ForwardError(e.to_string()))?;
227
228            // Update context with new states
229            let mut new_states = model.get_states();
230
231            // Apply memory optimization based on inference mode
232            self.apply_memory_optimization(&mut new_states);
233
234            for (i, state) in new_states.into_iter().enumerate() {
235                self.context.update_state(i, state)?;
236            }
237
238            output
239        } else {
240            // No model - return zeros as fallback
241            Array1::zeros(self.config.output_dim)
242        };
243
244        // Apply sampling if configured
245        let output = if self.config.use_embeddings {
246            // For discrete outputs, sample from logits
247            let sampled_idx = self.sampler.sample(&logits)?;
248            Array1::from_elem(1, sampled_idx)
249        } else {
250            // For continuous outputs, optionally apply temperature scaling
251            if (self.config.sampling.temperature - 1.0).abs() > 1e-6 {
252                logits.mapv(|x| x * self.config.sampling.temperature)
253            } else {
254                logits
255            }
256        };
257
258        Ok(output)
259    }
260
261    /// Perform multi-step rollout
262    ///
263    /// Predicts `steps` future values autoregressively
264    pub fn rollout(
265        &mut self,
266        input: &Array1<f32>,
267        steps: usize,
268    ) -> InferenceResult<Vec<Array1<f32>>> {
269        let mut outputs = Vec::with_capacity(steps);
270        let mut current = input.clone();
271
272        for _ in 0..steps {
273            let output = self.step(&current)?;
274            outputs.push(output.clone());
275            current = output;
276        }
277
278        Ok(outputs)
279    }
280
281    /// Reset the engine state
282    pub fn reset(&mut self) {
283        self.context.reset();
284    }
285
286    /// Get the current step count
287    pub fn step_count(&self) -> usize {
288        self.context.step_count()
289    }
290
291    /// Get the configuration
292    pub fn config(&self) -> &EngineConfig {
293        &self.config
294    }
295
296    /// Get the context
297    pub fn context(&self) -> &InferenceContext {
298        &self.context
299    }
300
301    /// Get mutable access to the sampler
302    pub fn sampler_mut(&mut self) -> &mut Sampler {
303        &mut self.sampler
304    }
305
306    /// Get the sampler
307    pub fn sampler(&self) -> &Sampler {
308        &self.sampler
309    }
310
311    /// Perform batched inference on multiple inputs
312    ///
313    /// This processes multiple inputs in parallel for efficiency.
314    /// Each input is processed independently with its own hidden state.
315    pub fn step_batch(&mut self, inputs: &[Array1<f32>]) -> InferenceResult<Vec<Array1<f32>>> {
316        if !self.initialized {
317            return Err(InferenceError::NotInitialized);
318        }
319
320        let mut outputs = Vec::with_capacity(inputs.len());
321
322        for input in inputs {
323            let output = self.step(input)?;
324            outputs.push(output);
325        }
326
327        Ok(outputs)
328    }
329
330    /// Get model information
331    pub fn model_info(&self) -> Option<ModelInfo> {
332        self.model.as_ref().map(|model| ModelInfo {
333            model_type: model.model_type(),
334            hidden_dim: model.hidden_dim(),
335            state_dim: model.state_dim(),
336            num_layers: model.num_layers(),
337        })
338    }
339
340    /// Apply memory optimization based on inference mode
341    fn apply_memory_optimization(&mut self, states: &mut [kizzasi_core::HiddenState]) {
342        match self.config.inference_mode {
343            InferenceMode::Standard => {
344                // No optimization
345            }
346            InferenceMode::LowMemory | InferenceMode::Streaming => {
347                // Prune small values from states
348                self.prune_states(states);
349                // Trim history if needed
350                if let Some(max_len) = self.config.max_history_length {
351                    if self.context.history().len() > max_len {
352                        self.context.trim_history(max_len);
353                    }
354                }
355            }
356            InferenceMode::Quantized => {
357                // Apply quantization to states
358                self.quantize_states(states);
359                // Also prune
360                self.prune_states(states);
361            }
362        }
363    }
364
365    /// Prune states by zeroing out small values
366    fn prune_states(&self, states: &mut [kizzasi_core::HiddenState]) {
367        let threshold = self.config.state_prune_threshold;
368        for state in states.iter_mut() {
369            let pruned = state
370                .state()
371                .mapv(|x| if x.abs() < threshold { 0.0 } else { x });
372            state.update(pruned);
373        }
374    }
375
376    /// Apply quantization to states (simulate INT8/FP16)
377    fn quantize_states(&self, states: &mut [kizzasi_core::HiddenState]) {
378        // Simple quantization: round to fixed precision
379        // This simulates FP16/INT8 behavior without actually changing types
380        for state in states.iter_mut() {
381            let quantized = state.state().mapv(|x| {
382                // Quantize to ~6 bits of precision (similar to FP16 mantissa)
383                let scale = 64.0;
384                (x * scale).round() / scale
385            });
386            state.update(quantized);
387        }
388    }
389}
390
391/// Information about the loaded model
392#[derive(Debug, Clone)]
393pub struct ModelInfo {
394    pub model_type: kizzasi_model::ModelType,
395    pub hidden_dim: usize,
396    pub state_dim: usize,
397    pub num_layers: usize,
398}
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403
404    #[test]
405    fn test_engine_creation() {
406        let config = EngineConfig::new(3, 3);
407        let engine = InferenceEngine::new(config);
408
409        assert_eq!(engine.step_count(), 0);
410        assert!(!engine.has_model());
411    }
412
413    #[test]
414    fn test_engine_step_no_model() {
415        let config = EngineConfig::new(3, 3);
416        let mut engine = InferenceEngine::new(config);
417
418        let input = Array1::from_vec(vec![0.1, 0.2, 0.3]);
419        let output = engine.step(&input);
420
421        // Should fail without model
422        assert!(output.is_err());
423    }
424
425    #[test]
426    fn test_engine_with_model() {
427        use kizzasi_model::rwkv::{Rwkv, RwkvConfig};
428
429        let model_config = RwkvConfig::new()
430            .input_dim(1)
431            .hidden_dim(64)
432            .intermediate_dim(256)
433            .num_layers(2);
434        let model = Rwkv::new(model_config).unwrap();
435
436        let config = EngineConfig::new(1, 10);
437        let mut engine = InferenceEngine::with_model(config, Box::new(model));
438
439        assert!(engine.has_model());
440
441        let input = Array1::from_vec(vec![0.5]);
442        let output = engine.step(&input);
443
444        if let Err(e) = &output {
445            eprintln!("Error: {:?}", e);
446        }
447        assert!(output.is_ok(), "Expected Ok, got: {:?}", output);
448        assert_eq!(engine.step_count(), 1);
449    }
450
451    #[test]
452    fn test_engine_rollout() {
453        use kizzasi_model::s4::{S4Config, S4D};
454
455        let model_config = S4Config::new()
456            .input_dim(1)
457            .hidden_dim(64)
458            .state_dim(16)
459            .num_layers(2)
460            .diagonal(true);
461        let model = S4D::new(model_config).unwrap();
462
463        let config = EngineConfig::new(1, 10);
464        let mut engine = InferenceEngine::with_model(config, Box::new(model));
465
466        let input = Array1::from_vec(vec![0.5]);
467        let outputs = engine.rollout(&input, 5);
468
469        assert!(outputs.is_ok());
470        assert_eq!(outputs.unwrap().len(), 5);
471        assert_eq!(engine.step_count(), 5);
472    }
473
474    #[test]
475    fn test_engine_batch() {
476        use kizzasi_model::s4::{S4Config, S4D};
477
478        let model_config = S4Config::new()
479            .input_dim(1)
480            .hidden_dim(64)
481            .state_dim(16)
482            .num_layers(2)
483            .diagonal(true);
484        let model = S4D::new(model_config).unwrap();
485
486        let config = EngineConfig::new(1, 10);
487        let mut engine = InferenceEngine::with_model(config, Box::new(model));
488
489        let inputs = vec![
490            Array1::from_vec(vec![0.1]),
491            Array1::from_vec(vec![0.2]),
492            Array1::from_vec(vec![0.3]),
493        ];
494
495        let outputs = engine.step_batch(&inputs);
496        assert!(outputs.is_ok());
497        assert_eq!(outputs.unwrap().len(), 3);
498    }
499
500    #[test]
501    fn test_model_info() {
502        use kizzasi_model::rwkv::{Rwkv, RwkvConfig};
503
504        let model_config = RwkvConfig::new()
505            .input_dim(1)
506            .hidden_dim(128)
507            .intermediate_dim(512)
508            .num_layers(4);
509        let model = Rwkv::new(model_config).unwrap();
510
511        let config = EngineConfig::new(1, 50);
512        let engine = InferenceEngine::with_model(config, Box::new(model));
513
514        let info = engine.model_info();
515        assert!(info.is_some());
516
517        let info = info.unwrap();
518        assert_eq!(info.hidden_dim, 128);
519        assert_eq!(info.num_layers, 4);
520    }
521}