simular/domains/ml/
prediction.rs1use serde::{Deserialize, Serialize};
2
3use crate::engine::rng::SimRng;
4use crate::error::SimResult;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct PredictionState {
13 pub input: Vec<f64>,
15 pub output: Vec<f64>,
17 pub uncertainty: Option<f64>,
19 pub latency_us: u64,
21 pub sequence: u64,
23}
24
25#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
27pub struct InferenceConfig {
28 pub batch_size: usize,
30 pub temperature: f64,
32 pub top_k: usize,
34 pub uncertainty: bool,
36 pub base_latency_us: u64,
38}
39
40impl Default for InferenceConfig {
41 fn default() -> Self {
42 Self {
43 batch_size: 32,
44 temperature: 1.0,
45 top_k: 0,
46 uncertainty: false,
47 base_latency_us: 1000,
48 }
49 }
50}
51
52pub struct PredictionSimulation {
54 config: InferenceConfig,
56 rng: SimRng,
58 sequence: u64,
60 history: Vec<PredictionState>,
62}
63
64impl PredictionSimulation {
65 #[must_use]
67 pub fn new(seed: u64) -> Self {
68 Self {
69 config: InferenceConfig::default(),
70 rng: SimRng::new(seed),
71 sequence: 0,
72 history: Vec::new(),
73 }
74 }
75
76 #[must_use]
78 pub fn with_config(seed: u64, config: InferenceConfig) -> Self {
79 Self {
80 config,
81 rng: SimRng::new(seed),
82 sequence: 0,
83 history: Vec::new(),
84 }
85 }
86
87 #[must_use]
89 pub fn config(&self) -> &InferenceConfig {
90 &self.config
91 }
92
93 pub fn predict<F>(&mut self, input: &[f64], model_fn: F) -> SimResult<PredictionState>
101 where
102 F: FnOnce(&[f64]) -> Vec<f64>,
103 {
104 let mut output = model_fn(input);
106
107 if (self.config.temperature - 1.0).abs() > 1e-10 {
109 output = self.apply_temperature(&output, self.config.temperature);
110 }
111
112 if self.config.top_k > 0 {
114 output = self.sample_top_k(&output, self.config.top_k);
115 }
116
117 let uncertainty = if self.config.uncertainty {
119 Some(self.compute_uncertainty(&output))
120 } else {
121 None
122 };
123
124 let latency_noise = (self.rng.gen_f64() * 0.2 - 0.1) * self.config.base_latency_us as f64;
126 let latency_us = (self.config.base_latency_us as f64 + latency_noise).max(1.0) as u64;
127
128 let state = PredictionState {
129 input: input.to_vec(),
130 output,
131 uncertainty,
132 latency_us,
133 sequence: self.sequence,
134 };
135
136 self.sequence += 1;
137 self.history.push(state.clone());
138
139 Ok(state)
140 }
141
142 pub fn predict_batch<F>(
148 &mut self,
149 inputs: &[Vec<f64>],
150 model_fn: F,
151 ) -> SimResult<Vec<PredictionState>>
152 where
153 F: Fn(&[f64]) -> Vec<f64>,
154 {
155 inputs
156 .iter()
157 .map(|input| self.predict(input, &model_fn))
158 .collect()
159 }
160
161 #[allow(clippy::unused_self)]
163 fn apply_temperature(&self, logits: &[f64], temperature: f64) -> Vec<f64> {
164 if temperature <= 0.0 {
165 return logits.to_vec();
166 }
167 logits.iter().map(|x| x / temperature).collect()
168 }
169
170 #[allow(clippy::unused_self)]
172 fn sample_top_k(&self, values: &[f64], k: usize) -> Vec<f64> {
173 if k >= values.len() {
174 return values.to_vec();
175 }
176
177 let mut sorted: Vec<f64> = values.to_vec();
179 sorted.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
180 let threshold = sorted.get(k - 1).copied().unwrap_or(f64::NEG_INFINITY);
181
182 values
184 .iter()
185 .map(|&v| if v >= threshold { v } else { 0.0 })
186 .collect()
187 }
188
189 #[allow(clippy::unused_self)]
191 fn compute_uncertainty(&self, output: &[f64]) -> f64 {
192 if output.is_empty() {
193 return 0.0;
194 }
195 let mean = output.iter().sum::<f64>() / output.len() as f64;
196 let variance = output.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / output.len() as f64;
197 variance.sqrt()
198 }
199
200 #[must_use]
202 pub fn history(&self) -> &[PredictionState] {
203 &self.history
204 }
205
206 pub fn reset(&mut self, seed: u64) {
208 self.rng = SimRng::new(seed);
209 self.sequence = 0;
210 self.history.clear();
211 }
212}