Skip to main content

briefcase_core/
drift.rs

1use crate::models::Output;
2use std::collections::HashMap;
3use strsim::normalized_levenshtein;
4
5#[cfg(feature = "native")]
6#[allow(unused_imports)]
7use rayon::prelude::*;
8
9#[derive(Debug, Clone, PartialEq)]
10pub struct DriftMetrics {
11    pub consistency_score: f64, // 0.0 - 1.0, higher = more consistent
12    pub agreement_rate: f64,    // Percentage of outputs that match
13    pub drift_score: f64,       // 0.0 - 1.0, higher = more drift
14    pub consensus_output: Option<String>,
15    pub consensus_confidence: ConsensusConfidence,
16    pub outliers: Vec<usize>, // Indices of outlier outputs
17}
18
19#[derive(Debug, Clone, PartialEq)]
20pub enum ConsensusConfidence {
21    High,   // >80% agreement
22    Medium, // 50-80% agreement
23    Low,    // <50% agreement
24    None,   // No consensus possible
25}
26
27#[derive(Debug, Clone, PartialEq)]
28pub enum DriftStatus {
29    Stable,   // consistency_score >= 0.85
30    Drifting, // consistency_score 0.5-0.85
31    Critical, // consistency_score < 0.5
32}
33
34#[derive(Clone)]
35pub struct DriftCalculator {
36    similarity_threshold: f64,
37}
38
39impl DriftCalculator {
40    pub fn new() -> Self {
41        Self {
42            similarity_threshold: 0.85,
43        }
44    }
45
46    pub fn with_threshold(threshold: f64) -> Self {
47        Self {
48            similarity_threshold: threshold.clamp(0.0, 1.0),
49        }
50    }
51
52    /// Get the similarity threshold
53    pub fn similarity_threshold(&self) -> f64 {
54        self.similarity_threshold
55    }
56
57    /// Calculate drift metrics from a list of string outputs
58    pub fn calculate_drift(&self, outputs: &[String]) -> DriftMetrics {
59        if outputs.is_empty() {
60            return DriftMetrics {
61                consistency_score: 1.0,
62                agreement_rate: 1.0,
63                drift_score: 0.0,
64                consensus_output: None,
65                consensus_confidence: ConsensusConfidence::None,
66                outliers: Vec::new(),
67            };
68        }
69
70        if outputs.len() == 1 {
71            return DriftMetrics {
72                consistency_score: 1.0,
73                agreement_rate: 1.0,
74                drift_score: 0.0,
75                consensus_output: Some(outputs[0].clone()),
76                consensus_confidence: ConsensusConfidence::High,
77                outliers: Vec::new(),
78            };
79        }
80
81        // Calculate pairwise similarities
82        let similarities = self.calculate_pairwise_similarities(outputs);
83
84        // Calculate average similarity (consistency score)
85        let total_pairs = outputs.len() * (outputs.len() - 1) / 2;
86        let avg_similarity = similarities.iter().sum::<f64>() / total_pairs as f64;
87
88        // Calculate agreement rate (exact or near-exact matches)
89        let agreement_rate = self.calculate_agreement_rate(outputs);
90
91        // Drift score is inverse of consistency
92        let drift_score = 1.0 - avg_similarity;
93
94        // Find consensus output
95        let consensus_output = self.find_consensus(outputs);
96
97        // Determine consensus confidence
98        let consensus_confidence = match agreement_rate {
99            rate if rate > 0.8 => ConsensusConfidence::High,
100            rate if rate >= 0.5 => ConsensusConfidence::Medium,
101            rate if rate > 0.0 => ConsensusConfidence::Low,
102            _ => ConsensusConfidence::None,
103        };
104
105        // Find outliers (outputs significantly different from consensus)
106        let outliers = self.find_outliers(outputs, &consensus_output);
107
108        DriftMetrics {
109            consistency_score: avg_similarity,
110            agreement_rate,
111            drift_score,
112            consensus_output,
113            consensus_confidence,
114            outliers,
115        }
116    }
117
118    /// Calculate drift from Output structs (uses value field)
119    pub fn calculate_drift_from_outputs(&self, outputs: &[Output]) -> DriftMetrics {
120        let strings: Vec<String> = outputs
121            .iter()
122            .map(|output| output.value.to_string())
123            .collect();
124
125        self.calculate_drift(&strings)
126    }
127
128    /// Determine drift status from metrics
129    pub fn get_status(&self, metrics: &DriftMetrics) -> DriftStatus {
130        match metrics.consistency_score {
131            score if score >= 0.85 => DriftStatus::Stable,
132            score if score >= 0.5 => DriftStatus::Drifting,
133            _ => DriftStatus::Critical,
134        }
135    }
136
137    /// Calculate semantic similarity between two strings
138    /// Uses Levenshtein distance normalized by length
139    fn semantic_similarity(&self, a: &str, b: &str) -> f64 {
140        if a == b {
141            return 1.0;
142        }
143
144        // Try to parse as numbers for numeric comparison
145        if let (Ok(num_a), Ok(num_b)) = (a.parse::<f64>(), b.parse::<f64>()) {
146            // For numbers, use relative difference
147            let diff = (num_a - num_b).abs();
148            let avg = (num_a.abs() + num_b.abs()) / 2.0;
149            if avg == 0.0 {
150                1.0 // Both are zero
151            } else {
152                (1.0 - (diff / avg)).max(0.0)
153            }
154        } else {
155            // For text, use normalized Levenshtein distance
156            normalized_levenshtein(a, b)
157        }
158    }
159
160    /// Calculate pairwise similarities for all combinations
161    fn calculate_pairwise_similarities(&self, outputs: &[String]) -> Vec<f64> {
162        let mut similarities = Vec::new();
163
164        for i in 0..outputs.len() {
165            for j in (i + 1)..outputs.len() {
166                let sim = self.semantic_similarity(&outputs[i], &outputs[j]);
167                similarities.push(sim);
168            }
169        }
170
171        similarities
172    }
173
174    /// Calculate agreement rate (percentage of outputs that are similar enough)
175    fn calculate_agreement_rate(&self, outputs: &[String]) -> f64 {
176        if outputs.len() <= 1 {
177            return 1.0;
178        }
179
180        // Count unique clusters of similar outputs
181        let mut clusters: Vec<Vec<String>> = Vec::new();
182
183        for output in outputs {
184            let mut found_cluster = false;
185
186            for cluster in &mut clusters {
187                let cluster_repr: &String = cluster.first().unwrap();
188                if self.semantic_similarity(output, cluster_repr) >= self.similarity_threshold {
189                    cluster.push(output.clone());
190                    found_cluster = true;
191                    break;
192                }
193            }
194
195            if !found_cluster {
196                clusters.push(vec![output.clone()]);
197            }
198        }
199
200        // Find the largest cluster
201        let max_cluster_size = clusters.iter().map(|c| c.len()).max().unwrap_or(0);
202        max_cluster_size as f64 / outputs.len() as f64
203    }
204
205    /// Find the consensus output (most common or centroid)
206    fn find_consensus(&self, outputs: &[String]) -> Option<String> {
207        if outputs.is_empty() {
208            return None;
209        }
210
211        // Try frequency-based consensus first
212        let mut frequency_map: HashMap<String, usize> = HashMap::new();
213        for output in outputs {
214            *frequency_map.entry(output.clone()).or_insert(0) += 1;
215        }
216
217        // If there's a clear most frequent output
218        if let Some((most_frequent, count)) = frequency_map.iter().max_by_key(|(_, &count)| count) {
219            if *count > outputs.len() / 2 {
220                return Some(most_frequent.clone());
221            }
222        }
223
224        // Otherwise, find the output with highest average similarity to all others
225        let mut best_output = outputs[0].clone();
226        let mut best_avg_similarity = 0.0;
227
228        for candidate in outputs {
229            let similarities: Vec<f64> = outputs
230                .iter()
231                .map(|other| self.semantic_similarity(candidate, other))
232                .collect();
233
234            let avg_similarity = similarities.iter().sum::<f64>() / similarities.len() as f64;
235
236            if avg_similarity > best_avg_similarity {
237                best_avg_similarity = avg_similarity;
238                best_output = candidate.clone();
239            }
240        }
241
242        Some(best_output)
243    }
244
245    /// Find outliers (outputs significantly different from consensus)
246    fn find_outliers(&self, outputs: &[String], consensus: &Option<String>) -> Vec<usize> {
247        let Some(consensus_output) = consensus else {
248            return Vec::new();
249        };
250
251        outputs
252            .iter()
253            .enumerate()
254            .filter_map(|(i, output)| {
255                let similarity = self.semantic_similarity(output, consensus_output);
256                if similarity < self.similarity_threshold * 0.7 {
257                    // More strict for outliers
258                    Some(i)
259                } else {
260                    None
261                }
262            })
263            .collect()
264    }
265}
266
267impl Default for DriftCalculator {
268    fn default() -> Self {
269        Self::new()
270    }
271}
272
273/// Consensus engine for N-of-M agreement
274pub struct ConsensusEngine {
275    required_runs: usize,
276    agreement_threshold: f64,
277    drift_calculator: DriftCalculator,
278}
279
280impl ConsensusEngine {
281    pub fn new(required_runs: usize, agreement_threshold: f64) -> Self {
282        Self {
283            required_runs,
284            agreement_threshold: agreement_threshold.clamp(0.0, 1.0),
285            drift_calculator: DriftCalculator::new(),
286        }
287    }
288
289    /// Run function N times and return consensus result
290    pub fn run_with_consensus<F, T>(&self, f: F) -> ConsensusResult<T>
291    where
292        F: Fn() -> T,
293        T: Clone + PartialEq + ToString,
294    {
295        let outputs: Vec<T> = (0..self.required_runs).map(|_| f()).collect();
296
297        let output_strings: Vec<String> = outputs.iter().map(|output| output.to_string()).collect();
298
299        let metrics = self.drift_calculator.calculate_drift(&output_strings);
300        let meets_threshold = metrics.agreement_rate >= self.agreement_threshold;
301
302        // Find consensus by matching against the consensus string
303        let consensus = if let Some(consensus_str) = &metrics.consensus_output {
304            outputs
305                .iter()
306                .find(|output| output.to_string() == *consensus_str)
307                .cloned()
308        } else {
309            None
310        };
311
312        ConsensusResult {
313            outputs,
314            consensus,
315            metrics,
316            meets_threshold,
317        }
318    }
319}
320
321#[derive(Debug, Clone)]
322pub struct ConsensusResult<T> {
323    pub outputs: Vec<T>,
324    pub consensus: Option<T>,
325    pub metrics: DriftMetrics,
326    pub meets_threshold: bool,
327}
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332    use serde_json::json;
333
334    #[test]
335    fn test_drift_calculator_empty_outputs() {
336        let calculator = DriftCalculator::new();
337        let metrics = calculator.calculate_drift(&[]);
338
339        assert_eq!(metrics.consistency_score, 1.0);
340        assert_eq!(metrics.agreement_rate, 1.0);
341        assert_eq!(metrics.drift_score, 0.0);
342        assert_eq!(metrics.consensus_output, None);
343        assert_eq!(metrics.consensus_confidence, ConsensusConfidence::None);
344        assert!(metrics.outliers.is_empty());
345    }
346
347    #[test]
348    fn test_drift_calculator_single_output() {
349        let calculator = DriftCalculator::new();
350        let outputs = vec!["hello".to_string()];
351        let metrics = calculator.calculate_drift(&outputs);
352
353        assert_eq!(metrics.consistency_score, 1.0);
354        assert_eq!(metrics.agreement_rate, 1.0);
355        assert_eq!(metrics.drift_score, 0.0);
356        assert_eq!(metrics.consensus_output, Some("hello".to_string()));
357        assert_eq!(metrics.consensus_confidence, ConsensusConfidence::High);
358        assert!(metrics.outliers.is_empty());
359    }
360
361    #[test]
362    fn test_drift_calculator_identical_outputs() {
363        let calculator = DriftCalculator::new();
364        let outputs = vec![
365            "hello".to_string(),
366            "hello".to_string(),
367            "hello".to_string(),
368        ];
369        let metrics = calculator.calculate_drift(&outputs);
370
371        assert_eq!(metrics.consistency_score, 1.0);
372        assert_eq!(metrics.agreement_rate, 1.0);
373        assert_eq!(metrics.drift_score, 0.0);
374        assert_eq!(metrics.consensus_output, Some("hello".to_string()));
375        assert_eq!(metrics.consensus_confidence, ConsensusConfidence::High);
376        assert!(metrics.outliers.is_empty());
377    }
378
379    #[test]
380    fn test_drift_calculator_different_outputs() {
381        let calculator = DriftCalculator::new();
382        let outputs = vec![
383            "apple".to_string(),
384            "orange".to_string(),
385            "banana".to_string(),
386        ];
387        let metrics = calculator.calculate_drift(&outputs);
388
389        assert!(metrics.consistency_score < 1.0);
390        assert!(metrics.drift_score > 0.0);
391        assert!(metrics.consensus_output.is_some());
392    }
393
394    #[test]
395    fn test_semantic_similarity() {
396        let calculator = DriftCalculator::new();
397
398        // Identical strings
399        assert_eq!(calculator.semantic_similarity("hello", "hello"), 1.0);
400
401        // Similar strings
402        let sim = calculator.semantic_similarity("hello", "helo");
403        assert!(sim > 0.5 && sim < 1.0);
404
405        // Completely different strings
406        let sim = calculator.semantic_similarity("hello", "xyz");
407        assert!(sim < 0.5);
408
409        // Numbers
410        let sim = calculator.semantic_similarity("100", "101");
411        assert!(sim > 0.8);
412
413        let sim = calculator.semantic_similarity("100", "200");
414        assert!(sim < 0.8);
415    }
416
417    #[test]
418    fn test_drift_status() {
419        let calculator = DriftCalculator::new();
420
421        let high_consistency = DriftMetrics {
422            consistency_score: 0.9,
423            agreement_rate: 0.9,
424            drift_score: 0.1,
425            consensus_output: Some("test".to_string()),
426            consensus_confidence: ConsensusConfidence::High,
427            outliers: Vec::new(),
428        };
429        assert_eq!(
430            calculator.get_status(&high_consistency),
431            DriftStatus::Stable
432        );
433
434        let medium_consistency = DriftMetrics {
435            consistency_score: 0.7,
436            agreement_rate: 0.7,
437            drift_score: 0.3,
438            consensus_output: Some("test".to_string()),
439            consensus_confidence: ConsensusConfidence::Medium,
440            outliers: Vec::new(),
441        };
442        assert_eq!(
443            calculator.get_status(&medium_consistency),
444            DriftStatus::Drifting
445        );
446
447        let low_consistency = DriftMetrics {
448            consistency_score: 0.3,
449            agreement_rate: 0.3,
450            drift_score: 0.7,
451            consensus_output: Some("test".to_string()),
452            consensus_confidence: ConsensusConfidence::Low,
453            outliers: Vec::new(),
454        };
455        assert_eq!(
456            calculator.get_status(&low_consistency),
457            DriftStatus::Critical
458        );
459    }
460
461    #[test]
462    fn test_drift_from_outputs() {
463        let calculator = DriftCalculator::new();
464        let outputs = vec![
465            Output::new("result", json!("hello"), "string"),
466            Output::new("result", json!("hello"), "string"),
467            Output::new("result", json!("hi"), "string"),
468        ];
469
470        let metrics = calculator.calculate_drift_from_outputs(&outputs);
471        assert!(metrics.consistency_score > 0.5);
472        assert!(metrics.consistency_score < 1.0);
473    }
474
475    #[test]
476    fn test_consensus_engine() {
477        let engine = ConsensusEngine::new(5, 0.8);
478
479        // Function that always returns the same value
480        let result = engine.run_with_consensus(|| "consistent".to_string());
481
482        assert_eq!(result.outputs.len(), 5);
483        assert!(result.meets_threshold);
484        assert_eq!(result.consensus, Some("consistent".to_string()));
485        assert_eq!(result.metrics.consistency_score, 1.0);
486    }
487
488    #[test]
489    fn test_outlier_detection() {
490        let calculator = DriftCalculator::new();
491        let outputs = vec![
492            "apple".to_string(),
493            "apple".to_string(),
494            "apple".to_string(),
495            "completely_different_output".to_string(),
496        ];
497
498        let metrics = calculator.calculate_drift(&outputs);
499        assert_eq!(metrics.outliers, vec![3]);
500    }
501
502    #[test]
503    fn test_numerical_consensus() {
504        let calculator = DriftCalculator::new();
505        let outputs = vec!["100".to_string(), "101".to_string(), "99".to_string()];
506
507        let metrics = calculator.calculate_drift(&outputs);
508        assert!(metrics.consistency_score > 0.8);
509        assert!(metrics.consensus_output.is_some());
510    }
511
512    #[test]
513    fn test_threshold_configuration() {
514        let calculator = DriftCalculator::with_threshold(0.9);
515        let outputs = vec!["hello".to_string(), "helo".to_string()]; // Typo
516
517        let metrics = calculator.calculate_drift(&outputs);
518        // With higher threshold, agreement rate should be lower
519        assert!(metrics.agreement_rate < 1.0);
520    }
521}