1use crate::models::Output;
2use std::collections::HashMap;
3use strsim::normalized_levenshtein;
4
5#[cfg(feature = "native")]
6#[allow(unused_imports)]
7use rayon::prelude::*;
8
9#[derive(Debug, Clone, PartialEq)]
10pub struct DriftMetrics {
11 pub consistency_score: f64, pub agreement_rate: f64, pub drift_score: f64, pub consensus_output: Option<String>,
15 pub consensus_confidence: ConsensusConfidence,
16 pub outliers: Vec<usize>, }
18
19#[derive(Debug, Clone, PartialEq)]
20pub enum ConsensusConfidence {
21 High, Medium, Low, None, }
26
27#[derive(Debug, Clone, PartialEq)]
28pub enum DriftStatus {
29 Stable, Drifting, Critical, }
33
34#[derive(Clone)]
35pub struct DriftCalculator {
36 similarity_threshold: f64,
37}
38
39impl DriftCalculator {
40 pub fn new() -> Self {
41 Self {
42 similarity_threshold: 0.85,
43 }
44 }
45
46 pub fn with_threshold(threshold: f64) -> Self {
47 Self {
48 similarity_threshold: threshold.clamp(0.0, 1.0),
49 }
50 }
51
52 pub fn similarity_threshold(&self) -> f64 {
54 self.similarity_threshold
55 }
56
57 pub fn calculate_drift(&self, outputs: &[String]) -> DriftMetrics {
59 if outputs.is_empty() {
60 return DriftMetrics {
61 consistency_score: 1.0,
62 agreement_rate: 1.0,
63 drift_score: 0.0,
64 consensus_output: None,
65 consensus_confidence: ConsensusConfidence::None,
66 outliers: Vec::new(),
67 };
68 }
69
70 if outputs.len() == 1 {
71 return DriftMetrics {
72 consistency_score: 1.0,
73 agreement_rate: 1.0,
74 drift_score: 0.0,
75 consensus_output: Some(outputs[0].clone()),
76 consensus_confidence: ConsensusConfidence::High,
77 outliers: Vec::new(),
78 };
79 }
80
81 let similarities = self.calculate_pairwise_similarities(outputs);
83
84 let total_pairs = outputs.len() * (outputs.len() - 1) / 2;
86 let avg_similarity = similarities.iter().sum::<f64>() / total_pairs as f64;
87
88 let agreement_rate = self.calculate_agreement_rate(outputs);
90
91 let drift_score = 1.0 - avg_similarity;
93
94 let consensus_output = self.find_consensus(outputs);
96
97 let consensus_confidence = match agreement_rate {
99 rate if rate > 0.8 => ConsensusConfidence::High,
100 rate if rate >= 0.5 => ConsensusConfidence::Medium,
101 rate if rate > 0.0 => ConsensusConfidence::Low,
102 _ => ConsensusConfidence::None,
103 };
104
105 let outliers = self.find_outliers(outputs, &consensus_output);
107
108 DriftMetrics {
109 consistency_score: avg_similarity,
110 agreement_rate,
111 drift_score,
112 consensus_output,
113 consensus_confidence,
114 outliers,
115 }
116 }
117
118 pub fn calculate_drift_from_outputs(&self, outputs: &[Output]) -> DriftMetrics {
120 let strings: Vec<String> = outputs
121 .iter()
122 .map(|output| output.value.to_string())
123 .collect();
124
125 self.calculate_drift(&strings)
126 }
127
128 pub fn get_status(&self, metrics: &DriftMetrics) -> DriftStatus {
130 match metrics.consistency_score {
131 score if score >= 0.85 => DriftStatus::Stable,
132 score if score >= 0.5 => DriftStatus::Drifting,
133 _ => DriftStatus::Critical,
134 }
135 }
136
137 fn semantic_similarity(&self, a: &str, b: &str) -> f64 {
140 if a == b {
141 return 1.0;
142 }
143
144 if let (Ok(num_a), Ok(num_b)) = (a.parse::<f64>(), b.parse::<f64>()) {
146 let diff = (num_a - num_b).abs();
148 let avg = (num_a.abs() + num_b.abs()) / 2.0;
149 if avg == 0.0 {
150 1.0 } else {
152 (1.0 - (diff / avg)).max(0.0)
153 }
154 } else {
155 normalized_levenshtein(a, b)
157 }
158 }
159
160 fn calculate_pairwise_similarities(&self, outputs: &[String]) -> Vec<f64> {
162 let mut similarities = Vec::new();
163
164 for i in 0..outputs.len() {
165 for j in (i + 1)..outputs.len() {
166 let sim = self.semantic_similarity(&outputs[i], &outputs[j]);
167 similarities.push(sim);
168 }
169 }
170
171 similarities
172 }
173
174 fn calculate_agreement_rate(&self, outputs: &[String]) -> f64 {
176 if outputs.len() <= 1 {
177 return 1.0;
178 }
179
180 let mut clusters: Vec<Vec<String>> = Vec::new();
182
183 for output in outputs {
184 let mut found_cluster = false;
185
186 for cluster in &mut clusters {
187 let cluster_repr: &String = cluster.first().unwrap();
188 if self.semantic_similarity(output, cluster_repr) >= self.similarity_threshold {
189 cluster.push(output.clone());
190 found_cluster = true;
191 break;
192 }
193 }
194
195 if !found_cluster {
196 clusters.push(vec![output.clone()]);
197 }
198 }
199
200 let max_cluster_size = clusters.iter().map(|c| c.len()).max().unwrap_or(0);
202 max_cluster_size as f64 / outputs.len() as f64
203 }
204
205 fn find_consensus(&self, outputs: &[String]) -> Option<String> {
207 if outputs.is_empty() {
208 return None;
209 }
210
211 let mut frequency_map: HashMap<String, usize> = HashMap::new();
213 for output in outputs {
214 *frequency_map.entry(output.clone()).or_insert(0) += 1;
215 }
216
217 if let Some((most_frequent, count)) = frequency_map.iter().max_by_key(|(_, &count)| count) {
219 if *count > outputs.len() / 2 {
220 return Some(most_frequent.clone());
221 }
222 }
223
224 let mut best_output = outputs[0].clone();
226 let mut best_avg_similarity = 0.0;
227
228 for candidate in outputs {
229 let similarities: Vec<f64> = outputs
230 .iter()
231 .map(|other| self.semantic_similarity(candidate, other))
232 .collect();
233
234 let avg_similarity = similarities.iter().sum::<f64>() / similarities.len() as f64;
235
236 if avg_similarity > best_avg_similarity {
237 best_avg_similarity = avg_similarity;
238 best_output = candidate.clone();
239 }
240 }
241
242 Some(best_output)
243 }
244
245 fn find_outliers(&self, outputs: &[String], consensus: &Option<String>) -> Vec<usize> {
247 let Some(consensus_output) = consensus else {
248 return Vec::new();
249 };
250
251 outputs
252 .iter()
253 .enumerate()
254 .filter_map(|(i, output)| {
255 let similarity = self.semantic_similarity(output, consensus_output);
256 if similarity < self.similarity_threshold * 0.7 {
257 Some(i)
259 } else {
260 None
261 }
262 })
263 .collect()
264 }
265}
266
267impl Default for DriftCalculator {
268 fn default() -> Self {
269 Self::new()
270 }
271}
272
273pub struct ConsensusEngine {
275 required_runs: usize,
276 agreement_threshold: f64,
277 drift_calculator: DriftCalculator,
278}
279
280impl ConsensusEngine {
281 pub fn new(required_runs: usize, agreement_threshold: f64) -> Self {
282 Self {
283 required_runs,
284 agreement_threshold: agreement_threshold.clamp(0.0, 1.0),
285 drift_calculator: DriftCalculator::new(),
286 }
287 }
288
289 pub fn run_with_consensus<F, T>(&self, f: F) -> ConsensusResult<T>
291 where
292 F: Fn() -> T,
293 T: Clone + PartialEq + ToString,
294 {
295 let outputs: Vec<T> = (0..self.required_runs).map(|_| f()).collect();
296
297 let output_strings: Vec<String> = outputs.iter().map(|output| output.to_string()).collect();
298
299 let metrics = self.drift_calculator.calculate_drift(&output_strings);
300 let meets_threshold = metrics.agreement_rate >= self.agreement_threshold;
301
302 let consensus = if let Some(consensus_str) = &metrics.consensus_output {
304 outputs
305 .iter()
306 .find(|output| output.to_string() == *consensus_str)
307 .cloned()
308 } else {
309 None
310 };
311
312 ConsensusResult {
313 outputs,
314 consensus,
315 metrics,
316 meets_threshold,
317 }
318 }
319}
320
321#[derive(Debug, Clone)]
322pub struct ConsensusResult<T> {
323 pub outputs: Vec<T>,
324 pub consensus: Option<T>,
325 pub metrics: DriftMetrics,
326 pub meets_threshold: bool,
327}
328
329#[cfg(test)]
330mod tests {
331 use super::*;
332 use serde_json::json;
333
334 #[test]
335 fn test_drift_calculator_empty_outputs() {
336 let calculator = DriftCalculator::new();
337 let metrics = calculator.calculate_drift(&[]);
338
339 assert_eq!(metrics.consistency_score, 1.0);
340 assert_eq!(metrics.agreement_rate, 1.0);
341 assert_eq!(metrics.drift_score, 0.0);
342 assert_eq!(metrics.consensus_output, None);
343 assert_eq!(metrics.consensus_confidence, ConsensusConfidence::None);
344 assert!(metrics.outliers.is_empty());
345 }
346
347 #[test]
348 fn test_drift_calculator_single_output() {
349 let calculator = DriftCalculator::new();
350 let outputs = vec!["hello".to_string()];
351 let metrics = calculator.calculate_drift(&outputs);
352
353 assert_eq!(metrics.consistency_score, 1.0);
354 assert_eq!(metrics.agreement_rate, 1.0);
355 assert_eq!(metrics.drift_score, 0.0);
356 assert_eq!(metrics.consensus_output, Some("hello".to_string()));
357 assert_eq!(metrics.consensus_confidence, ConsensusConfidence::High);
358 assert!(metrics.outliers.is_empty());
359 }
360
361 #[test]
362 fn test_drift_calculator_identical_outputs() {
363 let calculator = DriftCalculator::new();
364 let outputs = vec![
365 "hello".to_string(),
366 "hello".to_string(),
367 "hello".to_string(),
368 ];
369 let metrics = calculator.calculate_drift(&outputs);
370
371 assert_eq!(metrics.consistency_score, 1.0);
372 assert_eq!(metrics.agreement_rate, 1.0);
373 assert_eq!(metrics.drift_score, 0.0);
374 assert_eq!(metrics.consensus_output, Some("hello".to_string()));
375 assert_eq!(metrics.consensus_confidence, ConsensusConfidence::High);
376 assert!(metrics.outliers.is_empty());
377 }
378
379 #[test]
380 fn test_drift_calculator_different_outputs() {
381 let calculator = DriftCalculator::new();
382 let outputs = vec![
383 "apple".to_string(),
384 "orange".to_string(),
385 "banana".to_string(),
386 ];
387 let metrics = calculator.calculate_drift(&outputs);
388
389 assert!(metrics.consistency_score < 1.0);
390 assert!(metrics.drift_score > 0.0);
391 assert!(metrics.consensus_output.is_some());
392 }
393
394 #[test]
395 fn test_semantic_similarity() {
396 let calculator = DriftCalculator::new();
397
398 assert_eq!(calculator.semantic_similarity("hello", "hello"), 1.0);
400
401 let sim = calculator.semantic_similarity("hello", "helo");
403 assert!(sim > 0.5 && sim < 1.0);
404
405 let sim = calculator.semantic_similarity("hello", "xyz");
407 assert!(sim < 0.5);
408
409 let sim = calculator.semantic_similarity("100", "101");
411 assert!(sim > 0.8);
412
413 let sim = calculator.semantic_similarity("100", "200");
414 assert!(sim < 0.8);
415 }
416
417 #[test]
418 fn test_drift_status() {
419 let calculator = DriftCalculator::new();
420
421 let high_consistency = DriftMetrics {
422 consistency_score: 0.9,
423 agreement_rate: 0.9,
424 drift_score: 0.1,
425 consensus_output: Some("test".to_string()),
426 consensus_confidence: ConsensusConfidence::High,
427 outliers: Vec::new(),
428 };
429 assert_eq!(
430 calculator.get_status(&high_consistency),
431 DriftStatus::Stable
432 );
433
434 let medium_consistency = DriftMetrics {
435 consistency_score: 0.7,
436 agreement_rate: 0.7,
437 drift_score: 0.3,
438 consensus_output: Some("test".to_string()),
439 consensus_confidence: ConsensusConfidence::Medium,
440 outliers: Vec::new(),
441 };
442 assert_eq!(
443 calculator.get_status(&medium_consistency),
444 DriftStatus::Drifting
445 );
446
447 let low_consistency = DriftMetrics {
448 consistency_score: 0.3,
449 agreement_rate: 0.3,
450 drift_score: 0.7,
451 consensus_output: Some("test".to_string()),
452 consensus_confidence: ConsensusConfidence::Low,
453 outliers: Vec::new(),
454 };
455 assert_eq!(
456 calculator.get_status(&low_consistency),
457 DriftStatus::Critical
458 );
459 }
460
461 #[test]
462 fn test_drift_from_outputs() {
463 let calculator = DriftCalculator::new();
464 let outputs = vec![
465 Output::new("result", json!("hello"), "string"),
466 Output::new("result", json!("hello"), "string"),
467 Output::new("result", json!("hi"), "string"),
468 ];
469
470 let metrics = calculator.calculate_drift_from_outputs(&outputs);
471 assert!(metrics.consistency_score > 0.5);
472 assert!(metrics.consistency_score < 1.0);
473 }
474
475 #[test]
476 fn test_consensus_engine() {
477 let engine = ConsensusEngine::new(5, 0.8);
478
479 let result = engine.run_with_consensus(|| "consistent".to_string());
481
482 assert_eq!(result.outputs.len(), 5);
483 assert!(result.meets_threshold);
484 assert_eq!(result.consensus, Some("consistent".to_string()));
485 assert_eq!(result.metrics.consistency_score, 1.0);
486 }
487
488 #[test]
489 fn test_outlier_detection() {
490 let calculator = DriftCalculator::new();
491 let outputs = vec![
492 "apple".to_string(),
493 "apple".to_string(),
494 "apple".to_string(),
495 "completely_different_output".to_string(),
496 ];
497
498 let metrics = calculator.calculate_drift(&outputs);
499 assert_eq!(metrics.outliers, vec![3]);
500 }
501
502 #[test]
503 fn test_numerical_consensus() {
504 let calculator = DriftCalculator::new();
505 let outputs = vec!["100".to_string(), "101".to_string(), "99".to_string()];
506
507 let metrics = calculator.calculate_drift(&outputs);
508 assert!(metrics.consistency_score > 0.8);
509 assert!(metrics.consensus_output.is_some());
510 }
511
512 #[test]
513 fn test_threshold_configuration() {
514 let calculator = DriftCalculator::with_threshold(0.9);
515 let outputs = vec!["hello".to_string(), "helo".to_string()]; let metrics = calculator.calculate_drift(&outputs);
518 assert!(metrics.agreement_rate < 1.0);
520 }
521}