Skip to main content

simular/domains/ml/
multi_turn.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4use crate::engine::rng::SimRng;
5use crate::error::{SimError, SimResult};
6
7// ============================================================================
8// Multi-Turn Simulation
9// ============================================================================
10
11/// A single turn in multi-turn interaction.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct Turn {
14    /// Turn index.
15    pub index: usize,
16    /// Input query/prompt.
17    pub input: String,
18    /// Model response.
19    pub output: String,
20    /// Ground truth (if available).
21    pub expected: Option<String>,
22    /// Turn metrics.
23    pub metrics: TurnMetrics,
24    /// Context window usage (tokens).
25    pub context_tokens: usize,
26}
27
28/// Metrics for a single turn.
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct TurnMetrics {
31    /// Generation latency in milliseconds.
32    pub latency_ms: f64,
33    /// Input tokens.
34    pub input_tokens: usize,
35    /// Output tokens.
36    pub output_tokens: usize,
37    /// Estimated cost (normalized).
38    pub cost: f64,
39    /// Accuracy vs oracle (if available).
40    pub accuracy: Option<f64>,
41}
42
43impl Default for TurnMetrics {
44    fn default() -> Self {
45        Self {
46            latency_ms: 0.0,
47            input_tokens: 0,
48            output_tokens: 0,
49            cost: 0.0,
50            accuracy: None,
51        }
52    }
53}
54
55/// Multi-turn evaluation results.
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct MultiTurnEvaluation {
58    /// Mean accuracy across runs.
59    pub mean_accuracy: Option<f64>,
60    /// Mean latency across runs.
61    pub mean_latency: Option<f64>,
62    /// Total cost across runs.
63    pub total_cost: f64,
64    /// Confidence interval level.
65    pub confidence_interval: f64,
66    /// Number of runs performed.
67    pub n_runs: usize,
68}
69
70/// Point on Pareto frontier.
71#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct ParetoPoint {
73    /// Model identifier.
74    pub model_id: String,
75    /// Accuracy score.
76    pub accuracy: f64,
77    /// Cost metric.
78    pub cost: f64,
79    /// Latency metric.
80    pub latency: f64,
81    /// Models that dominate this one.
82    pub dominated_by: Vec<String>,
83}
84
85/// Pareto frontier analysis results.
86#[derive(Debug, Clone, Default)]
87pub struct ParetoAnalysis {
88    /// Non-dominated solutions (Pareto frontier).
89    pub frontier: Vec<ParetoPoint>,
90    /// Value scores per model.
91    pub value_scores: HashMap<String, f64>,
92}
93
94/// Multi-turn simulation for conversational/iterative model evaluation.
95///
96/// Implements Pareto frontier analysis across accuracy, cost, and latency.
97pub struct MultiTurnSimulation {
98    /// Conversation history.
99    history: Vec<Turn>,
100    /// Deterministic RNG.
101    rng: SimRng,
102    /// Cost per input token.
103    input_token_cost: f64,
104    /// Cost per output token.
105    output_token_cost: f64,
106    /// Base latency per token (ms).
107    latency_per_token_ms: f64,
108}
109
110impl MultiTurnSimulation {
111    /// Create new multi-turn simulation.
112    #[must_use]
113    pub fn new(seed: u64) -> Self {
114        Self {
115            history: Vec::new(),
116            rng: SimRng::new(seed),
117            input_token_cost: 0.00001,
118            output_token_cost: 0.00003,
119            latency_per_token_ms: 10.0,
120        }
121    }
122
123    /// Set cost parameters.
124    #[must_use]
125    pub fn with_costs(mut self, input_cost: f64, output_cost: f64) -> Self {
126        self.input_token_cost = input_cost;
127        self.output_token_cost = output_cost;
128        self
129    }
130
131    /// Set latency per token.
132    #[must_use]
133    pub fn with_latency_per_token(mut self, latency_ms: f64) -> Self {
134        self.latency_per_token_ms = latency_ms;
135        self
136    }
137
138    /// Execute a single turn using a response generator.
139    ///
140    /// The `generate_fn` takes (input, history) and returns response string.
141    ///
142    /// # Errors
143    ///
144    /// Returns error if turn execution fails.
145    pub fn turn<F>(
146        &mut self,
147        input: &str,
148        expected: Option<&str>,
149        generate_fn: F,
150    ) -> SimResult<Turn>
151    where
152        F: FnOnce(&str, &[Turn]) -> String,
153    {
154        let input_tokens = self.count_tokens(input);
155
156        // Generate response
157        let output = generate_fn(input, &self.history);
158        let output_tokens = self.count_tokens(&output);
159
160        // Compute latency with noise
161        let base_latency = (input_tokens + output_tokens) as f64 * self.latency_per_token_ms;
162        let noise = (self.rng.gen_f64() * 0.2 - 0.1) * base_latency;
163        let latency_ms = (base_latency + noise).max(1.0);
164
165        // Compute cost
166        let cost = input_tokens as f64 * self.input_token_cost
167            + output_tokens as f64 * self.output_token_cost;
168
169        // Compute accuracy if expected is provided
170        let accuracy = expected.map(|exp| self.compute_accuracy(&output, exp));
171
172        let context_tokens = self
173            .history
174            .iter()
175            .map(|t| t.metrics.input_tokens + t.metrics.output_tokens)
176            .sum::<usize>()
177            + input_tokens;
178
179        let turn = Turn {
180            index: self.history.len(),
181            input: input.to_string(),
182            output,
183            expected: expected.map(String::from),
184            metrics: TurnMetrics {
185                latency_ms,
186                input_tokens,
187                output_tokens,
188                cost,
189                accuracy,
190            },
191            context_tokens,
192        };
193
194        self.history.push(turn.clone());
195        Ok(turn)
196    }
197
198    /// Simplified token counting (words * 1.3).
199    #[allow(clippy::unused_self)]
200    fn count_tokens(&self, text: &str) -> usize {
201        let words = text.split_whitespace().count();
202        (words as f64 * 1.3).ceil() as usize
203    }
204
205    /// Compute accuracy between output and expected (Levenshtein similarity).
206    #[allow(clippy::unused_self)]
207    fn compute_accuracy(&self, output: &str, expected: &str) -> f64 {
208        if expected.is_empty() && output.is_empty() {
209            return 1.0;
210        }
211        if expected.is_empty() || output.is_empty() {
212            return 0.0;
213        }
214
215        // Simple word overlap similarity
216        let output_words: std::collections::HashSet<&str> = output.split_whitespace().collect();
217        let expected_words: std::collections::HashSet<&str> = expected.split_whitespace().collect();
218
219        let intersection = output_words.intersection(&expected_words).count();
220        let union = output_words.union(&expected_words).count();
221
222        if union == 0 {
223            return 1.0;
224        }
225
226        intersection as f64 / union as f64
227    }
228
229    /// Run complete multi-turn evaluation with statistical analysis.
230    ///
231    /// Following Princeton methodology: minimum 5 runs, 95% CI.
232    ///
233    /// # Errors
234    ///
235    /// Returns error if fewer than 5 runs are requested or if evaluation fails.
236    pub fn evaluate<F>(
237        &mut self,
238        queries: &[(String, Option<String>)],
239        n_runs: usize,
240        generate_fn: F,
241    ) -> SimResult<MultiTurnEvaluation>
242    where
243        F: Fn(&str, &[Turn]) -> String,
244    {
245        if n_runs < 5 {
246            return Err(SimError::config(
247                "Princeton methodology requires minimum 5 runs".to_string(),
248            ));
249        }
250
251        let mut all_accuracies: Vec<f64> = Vec::new();
252        let mut all_latencies: Vec<f64> = Vec::new();
253        let mut total_cost = 0.0;
254
255        for run in 0..n_runs {
256            // Reset for each run with derived seed
257            let derived_seed = self.rng.gen_u64().wrapping_add(run as u64);
258            self.reset(derived_seed);
259
260            for (query, expected) in queries {
261                let turn = self.turn(query, expected.as_deref(), &generate_fn)?;
262                if let Some(acc) = turn.metrics.accuracy {
263                    all_accuracies.push(acc);
264                }
265                all_latencies.push(turn.metrics.latency_ms);
266                total_cost += turn.metrics.cost;
267            }
268        }
269
270        let mean_accuracy = if all_accuracies.is_empty() {
271            None
272        } else {
273            Some(all_accuracies.iter().sum::<f64>() / all_accuracies.len() as f64)
274        };
275
276        let mean_latency = if all_latencies.is_empty() {
277            None
278        } else {
279            Some(all_latencies.iter().sum::<f64>() / all_latencies.len() as f64)
280        };
281
282        Ok(MultiTurnEvaluation {
283            mean_accuracy,
284            mean_latency,
285            total_cost: total_cost / n_runs as f64,
286            confidence_interval: 0.95,
287            n_runs,
288        })
289    }
290
291    /// Compute Pareto frontier across multiple model evaluations.
292    #[must_use]
293    pub fn pareto_analysis(evaluations: &[(String, MultiTurnEvaluation)]) -> ParetoAnalysis {
294        let mut points: Vec<ParetoPoint> = evaluations
295            .iter()
296            .map(|(id, eval)| ParetoPoint {
297                model_id: id.clone(),
298                accuracy: eval.mean_accuracy.unwrap_or(0.0),
299                cost: eval.total_cost,
300                latency: eval.mean_latency.unwrap_or(f64::MAX),
301                dominated_by: Vec::new(),
302            })
303            .collect();
304
305        // Identify dominated points
306        // First pass: identify dominance relationships
307        let mut dominance: Vec<Vec<String>> = vec![Vec::new(); points.len()];
308        for i in 0..points.len() {
309            for j in 0..points.len() {
310                if i != j && Self::dominates(&points[j], &points[i]) {
311                    dominance[i].push(points[j].model_id.clone());
312                }
313            }
314        }
315        // Second pass: assign dominated_by
316        for (i, dominated_by) in dominance.into_iter().enumerate() {
317            points[i].dominated_by = dominated_by;
318        }
319
320        // Compute value scores
321        let baseline_accuracy = points.iter().map(|p| p.accuracy).fold(0.0_f64, f64::max);
322        let baseline_cost = points.iter().map(|p| p.cost).fold(f64::INFINITY, f64::min);
323        let baseline_latency = points
324            .iter()
325            .map(|p| p.latency)
326            .fold(f64::INFINITY, f64::min);
327
328        let value_scores: HashMap<String, f64> = points
329            .iter()
330            .map(|p| {
331                let accuracy_gap = baseline_accuracy - p.accuracy;
332                let cost_ratio = baseline_cost / p.cost.max(1e-10);
333                let latency_ratio = baseline_latency / p.latency.max(1e-10);
334                let value = (1.0 - accuracy_gap) * cost_ratio * latency_ratio;
335                (p.model_id.clone(), value)
336            })
337            .collect();
338
339        let frontier: Vec<ParetoPoint> = points
340            .into_iter()
341            .filter(|p| p.dominated_by.is_empty())
342            .collect();
343
344        ParetoAnalysis {
345            frontier,
346            value_scores,
347        }
348    }
349
350    /// Check if point a dominates point b (better in all objectives).
351    fn dominates(a: &ParetoPoint, b: &ParetoPoint) -> bool {
352        a.accuracy >= b.accuracy
353            && a.cost <= b.cost
354            && a.latency <= b.latency
355            && (a.accuracy > b.accuracy || a.cost < b.cost || a.latency < b.latency)
356    }
357
358    /// Get conversation history.
359    #[must_use]
360    pub fn history(&self) -> &[Turn] {
361        &self.history
362    }
363
364    /// Reset simulation state.
365    pub fn reset(&mut self, seed: u64) {
366        self.rng = SimRng::new(seed);
367        self.history.clear();
368    }
369}