Skip to main content

simular/domains/ml/
prediction.rs

1use serde::{Deserialize, Serialize};
2
3use crate::engine::rng::SimRng;
4use crate::error::SimResult;
5
6// ============================================================================
7// Prediction Simulation
8// ============================================================================
9
10/// Prediction state for replay and analysis.
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct PredictionState {
13    /// Input features.
14    pub input: Vec<f64>,
15    /// Model output.
16    pub output: Vec<f64>,
17    /// Uncertainty estimate (if available).
18    pub uncertainty: Option<f64>,
19    /// Inference latency in microseconds (simulated).
20    pub latency_us: u64,
21    /// Sequence number.
22    pub sequence: u64,
23}
24
25/// Inference configuration.
26#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
27pub struct InferenceConfig {
28    /// Batch size for inference.
29    pub batch_size: usize,
30    /// Temperature for probabilistic outputs.
31    pub temperature: f64,
32    /// Top-k sampling (0 = greedy).
33    pub top_k: usize,
34    /// Enable uncertainty quantification.
35    pub uncertainty: bool,
36    /// Simulated latency base (microseconds).
37    pub base_latency_us: u64,
38}
39
40impl Default for InferenceConfig {
41    fn default() -> Self {
42        Self {
43            batch_size: 32,
44            temperature: 1.0,
45            top_k: 0,
46            uncertainty: false,
47            base_latency_us: 1000,
48        }
49    }
50}
51
52/// Simulated inference scenario for reproducible prediction testing.
53pub struct PredictionSimulation {
54    /// Inference configuration.
55    config: InferenceConfig,
56    /// Deterministic RNG for stochastic models.
57    rng: SimRng,
58    /// Prediction sequence counter.
59    sequence: u64,
60    /// Prediction history.
61    history: Vec<PredictionState>,
62}
63
64impl PredictionSimulation {
65    /// Create new prediction simulation with deterministic seed.
66    #[must_use]
67    pub fn new(seed: u64) -> Self {
68        Self {
69            config: InferenceConfig::default(),
70            rng: SimRng::new(seed),
71            sequence: 0,
72            history: Vec::new(),
73        }
74    }
75
76    /// Create with custom configuration.
77    #[must_use]
78    pub fn with_config(seed: u64, config: InferenceConfig) -> Self {
79        Self {
80            config,
81            rng: SimRng::new(seed),
82            sequence: 0,
83            history: Vec::new(),
84        }
85    }
86
87    /// Get inference configuration.
88    #[must_use]
89    pub fn config(&self) -> &InferenceConfig {
90        &self.config
91    }
92
93    /// Simulate single prediction using a model function.
94    ///
95    /// The `model_fn` takes input and returns output vector.
96    ///
97    /// # Errors
98    ///
99    /// Returns error if model prediction fails.
100    pub fn predict<F>(&mut self, input: &[f64], model_fn: F) -> SimResult<PredictionState>
101    where
102        F: FnOnce(&[f64]) -> Vec<f64>,
103    {
104        // Simulate inference
105        let mut output = model_fn(input);
106
107        // Apply temperature scaling if not 1.0
108        if (self.config.temperature - 1.0).abs() > 1e-10 {
109            output = self.apply_temperature(&output, self.config.temperature);
110        }
111
112        // Apply top-k sampling if configured
113        if self.config.top_k > 0 {
114            output = self.sample_top_k(&output, self.config.top_k);
115        }
116
117        // Compute uncertainty if enabled (simplified: variance of output)
118        let uncertainty = if self.config.uncertainty {
119            Some(self.compute_uncertainty(&output))
120        } else {
121            None
122        };
123
124        // Simulate latency with noise
125        let latency_noise = (self.rng.gen_f64() * 0.2 - 0.1) * self.config.base_latency_us as f64;
126        let latency_us = (self.config.base_latency_us as f64 + latency_noise).max(1.0) as u64;
127
128        let state = PredictionState {
129            input: input.to_vec(),
130            output,
131            uncertainty,
132            latency_us,
133            sequence: self.sequence,
134        };
135
136        self.sequence += 1;
137        self.history.push(state.clone());
138
139        Ok(state)
140    }
141
142    /// Simulate batch prediction.
143    ///
144    /// # Errors
145    ///
146    /// Returns error if any prediction fails.
147    pub fn predict_batch<F>(
148        &mut self,
149        inputs: &[Vec<f64>],
150        model_fn: F,
151    ) -> SimResult<Vec<PredictionState>>
152    where
153        F: Fn(&[f64]) -> Vec<f64>,
154    {
155        inputs
156            .iter()
157            .map(|input| self.predict(input, &model_fn))
158            .collect()
159    }
160
161    /// Apply temperature scaling to logits.
162    #[allow(clippy::unused_self)]
163    fn apply_temperature(&self, logits: &[f64], temperature: f64) -> Vec<f64> {
164        if temperature <= 0.0 {
165            return logits.to_vec();
166        }
167        logits.iter().map(|x| x / temperature).collect()
168    }
169
170    /// Sample top-k values, zeroing out the rest.
171    #[allow(clippy::unused_self)]
172    fn sample_top_k(&self, values: &[f64], k: usize) -> Vec<f64> {
173        if k >= values.len() {
174            return values.to_vec();
175        }
176
177        // Find k-th largest value
178        let mut sorted: Vec<f64> = values.to_vec();
179        sorted.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
180        let threshold = sorted.get(k - 1).copied().unwrap_or(f64::NEG_INFINITY);
181
182        // Zero out values below threshold
183        values
184            .iter()
185            .map(|&v| if v >= threshold { v } else { 0.0 })
186            .collect()
187    }
188
189    /// Compute simplified uncertainty estimate.
190    #[allow(clippy::unused_self)]
191    fn compute_uncertainty(&self, output: &[f64]) -> f64 {
192        if output.is_empty() {
193            return 0.0;
194        }
195        let mean = output.iter().sum::<f64>() / output.len() as f64;
196        let variance = output.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / output.len() as f64;
197        variance.sqrt()
198    }
199
200    /// Get prediction history.
201    #[must_use]
202    pub fn history(&self) -> &[PredictionState] {
203        &self.history
204    }
205
206    /// Reset simulation state.
207    pub fn reset(&mut self, seed: u64) {
208        self.rng = SimRng::new(seed);
209        self.sequence = 0;
210        self.history.clear();
211    }
212}