1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4use crate::engine::rng::SimRng;
5use crate::error::{SimError, SimResult};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct Turn {
14 pub index: usize,
16 pub input: String,
18 pub output: String,
20 pub expected: Option<String>,
22 pub metrics: TurnMetrics,
24 pub context_tokens: usize,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct TurnMetrics {
31 pub latency_ms: f64,
33 pub input_tokens: usize,
35 pub output_tokens: usize,
37 pub cost: f64,
39 pub accuracy: Option<f64>,
41}
42
43impl Default for TurnMetrics {
44 fn default() -> Self {
45 Self {
46 latency_ms: 0.0,
47 input_tokens: 0,
48 output_tokens: 0,
49 cost: 0.0,
50 accuracy: None,
51 }
52 }
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct MultiTurnEvaluation {
58 pub mean_accuracy: Option<f64>,
60 pub mean_latency: Option<f64>,
62 pub total_cost: f64,
64 pub confidence_interval: f64,
66 pub n_runs: usize,
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct ParetoPoint {
73 pub model_id: String,
75 pub accuracy: f64,
77 pub cost: f64,
79 pub latency: f64,
81 pub dominated_by: Vec<String>,
83}
84
85#[derive(Debug, Clone, Default)]
87pub struct ParetoAnalysis {
88 pub frontier: Vec<ParetoPoint>,
90 pub value_scores: HashMap<String, f64>,
92}
93
94pub struct MultiTurnSimulation {
98 history: Vec<Turn>,
100 rng: SimRng,
102 input_token_cost: f64,
104 output_token_cost: f64,
106 latency_per_token_ms: f64,
108}
109
110impl MultiTurnSimulation {
111 #[must_use]
113 pub fn new(seed: u64) -> Self {
114 Self {
115 history: Vec::new(),
116 rng: SimRng::new(seed),
117 input_token_cost: 0.00001,
118 output_token_cost: 0.00003,
119 latency_per_token_ms: 10.0,
120 }
121 }
122
123 #[must_use]
125 pub fn with_costs(mut self, input_cost: f64, output_cost: f64) -> Self {
126 self.input_token_cost = input_cost;
127 self.output_token_cost = output_cost;
128 self
129 }
130
131 #[must_use]
133 pub fn with_latency_per_token(mut self, latency_ms: f64) -> Self {
134 self.latency_per_token_ms = latency_ms;
135 self
136 }
137
138 pub fn turn<F>(
146 &mut self,
147 input: &str,
148 expected: Option<&str>,
149 generate_fn: F,
150 ) -> SimResult<Turn>
151 where
152 F: FnOnce(&str, &[Turn]) -> String,
153 {
154 let input_tokens = self.count_tokens(input);
155
156 let output = generate_fn(input, &self.history);
158 let output_tokens = self.count_tokens(&output);
159
160 let base_latency = (input_tokens + output_tokens) as f64 * self.latency_per_token_ms;
162 let noise = (self.rng.gen_f64() * 0.2 - 0.1) * base_latency;
163 let latency_ms = (base_latency + noise).max(1.0);
164
165 let cost = input_tokens as f64 * self.input_token_cost
167 + output_tokens as f64 * self.output_token_cost;
168
169 let accuracy = expected.map(|exp| self.compute_accuracy(&output, exp));
171
172 let context_tokens = self
173 .history
174 .iter()
175 .map(|t| t.metrics.input_tokens + t.metrics.output_tokens)
176 .sum::<usize>()
177 + input_tokens;
178
179 let turn = Turn {
180 index: self.history.len(),
181 input: input.to_string(),
182 output,
183 expected: expected.map(String::from),
184 metrics: TurnMetrics {
185 latency_ms,
186 input_tokens,
187 output_tokens,
188 cost,
189 accuracy,
190 },
191 context_tokens,
192 };
193
194 self.history.push(turn.clone());
195 Ok(turn)
196 }
197
198 #[allow(clippy::unused_self)]
200 fn count_tokens(&self, text: &str) -> usize {
201 let words = text.split_whitespace().count();
202 (words as f64 * 1.3).ceil() as usize
203 }
204
205 #[allow(clippy::unused_self)]
207 fn compute_accuracy(&self, output: &str, expected: &str) -> f64 {
208 if expected.is_empty() && output.is_empty() {
209 return 1.0;
210 }
211 if expected.is_empty() || output.is_empty() {
212 return 0.0;
213 }
214
215 let output_words: std::collections::HashSet<&str> = output.split_whitespace().collect();
217 let expected_words: std::collections::HashSet<&str> = expected.split_whitespace().collect();
218
219 let intersection = output_words.intersection(&expected_words).count();
220 let union = output_words.union(&expected_words).count();
221
222 if union == 0 {
223 return 1.0;
224 }
225
226 intersection as f64 / union as f64
227 }
228
229 pub fn evaluate<F>(
237 &mut self,
238 queries: &[(String, Option<String>)],
239 n_runs: usize,
240 generate_fn: F,
241 ) -> SimResult<MultiTurnEvaluation>
242 where
243 F: Fn(&str, &[Turn]) -> String,
244 {
245 if n_runs < 5 {
246 return Err(SimError::config(
247 "Princeton methodology requires minimum 5 runs".to_string(),
248 ));
249 }
250
251 let mut all_accuracies: Vec<f64> = Vec::new();
252 let mut all_latencies: Vec<f64> = Vec::new();
253 let mut total_cost = 0.0;
254
255 for run in 0..n_runs {
256 let derived_seed = self.rng.gen_u64().wrapping_add(run as u64);
258 self.reset(derived_seed);
259
260 for (query, expected) in queries {
261 let turn = self.turn(query, expected.as_deref(), &generate_fn)?;
262 if let Some(acc) = turn.metrics.accuracy {
263 all_accuracies.push(acc);
264 }
265 all_latencies.push(turn.metrics.latency_ms);
266 total_cost += turn.metrics.cost;
267 }
268 }
269
270 let mean_accuracy = if all_accuracies.is_empty() {
271 None
272 } else {
273 Some(all_accuracies.iter().sum::<f64>() / all_accuracies.len() as f64)
274 };
275
276 let mean_latency = if all_latencies.is_empty() {
277 None
278 } else {
279 Some(all_latencies.iter().sum::<f64>() / all_latencies.len() as f64)
280 };
281
282 Ok(MultiTurnEvaluation {
283 mean_accuracy,
284 mean_latency,
285 total_cost: total_cost / n_runs as f64,
286 confidence_interval: 0.95,
287 n_runs,
288 })
289 }
290
291 #[must_use]
293 pub fn pareto_analysis(evaluations: &[(String, MultiTurnEvaluation)]) -> ParetoAnalysis {
294 let mut points: Vec<ParetoPoint> = evaluations
295 .iter()
296 .map(|(id, eval)| ParetoPoint {
297 model_id: id.clone(),
298 accuracy: eval.mean_accuracy.unwrap_or(0.0),
299 cost: eval.total_cost,
300 latency: eval.mean_latency.unwrap_or(f64::MAX),
301 dominated_by: Vec::new(),
302 })
303 .collect();
304
305 let mut dominance: Vec<Vec<String>> = vec![Vec::new(); points.len()];
308 for i in 0..points.len() {
309 for j in 0..points.len() {
310 if i != j && Self::dominates(&points[j], &points[i]) {
311 dominance[i].push(points[j].model_id.clone());
312 }
313 }
314 }
315 for (i, dominated_by) in dominance.into_iter().enumerate() {
317 points[i].dominated_by = dominated_by;
318 }
319
320 let baseline_accuracy = points.iter().map(|p| p.accuracy).fold(0.0_f64, f64::max);
322 let baseline_cost = points.iter().map(|p| p.cost).fold(f64::INFINITY, f64::min);
323 let baseline_latency = points
324 .iter()
325 .map(|p| p.latency)
326 .fold(f64::INFINITY, f64::min);
327
328 let value_scores: HashMap<String, f64> = points
329 .iter()
330 .map(|p| {
331 let accuracy_gap = baseline_accuracy - p.accuracy;
332 let cost_ratio = baseline_cost / p.cost.max(1e-10);
333 let latency_ratio = baseline_latency / p.latency.max(1e-10);
334 let value = (1.0 - accuracy_gap) * cost_ratio * latency_ratio;
335 (p.model_id.clone(), value)
336 })
337 .collect();
338
339 let frontier: Vec<ParetoPoint> = points
340 .into_iter()
341 .filter(|p| p.dominated_by.is_empty())
342 .collect();
343
344 ParetoAnalysis {
345 frontier,
346 value_scores,
347 }
348 }
349
350 fn dominates(a: &ParetoPoint, b: &ParetoPoint) -> bool {
352 a.accuracy >= b.accuracy
353 && a.cost <= b.cost
354 && a.latency <= b.latency
355 && (a.accuracy > b.accuracy || a.cost < b.cost || a.latency < b.latency)
356 }
357
358 #[must_use]
360 pub fn history(&self) -> &[Turn] {
361 &self.history
362 }
363
364 pub fn reset(&mut self, seed: u64) {
366 self.rng = SimRng::new(seed);
367 self.history.clear();
368 }
369}