ricecoder_learning/
pattern_validator.rs

1/// Pattern validation component for verifying pattern correctness
2use crate::error::Result;
3use crate::models::{Decision, LearnedPattern};
4
5/// Validates patterns against historical decisions
6pub struct PatternValidator;
7
8impl PatternValidator {
9    /// Create a new pattern validator
10    pub fn new() -> Self {
11        Self
12    }
13
14    /// Validate a pattern against historical decisions
15    ///
16    /// Tests the pattern against all historical decisions to verify correctness.
17    /// Returns a validation score (0.0 to 1.0) indicating how well the pattern
18    /// matches the historical decisions.
19    pub fn validate_pattern(
20        &self,
21        pattern: &LearnedPattern,
22        decisions: &[Decision],
23    ) -> Result<ValidationResult> {
24        if decisions.is_empty() {
25            return Ok(ValidationResult {
26                pattern_id: pattern.id.clone(),
27                is_valid: false,
28                validation_score: 0.0,
29                matching_decisions: 0,
30                total_decisions: 0,
31                mismatches: Vec::new(),
32                confidence_recommendation: 0.0,
33            });
34        }
35
36        let mut matching_count = 0;
37        let mut mismatches = Vec::new();
38
39        // Count how many decisions match this pattern
40        for decision in decisions {
41            if decision.decision_type == pattern.pattern_type {
42                let matches = self.decision_matches_pattern(decision, pattern);
43                if matches {
44                    matching_count += 1;
45                } else {
46                    mismatches.push(decision.id.clone());
47                }
48            }
49        }
50
51        // Calculate validation score
52        let validation_score = if !decisions.is_empty() {
53            matching_count as f32 / decisions.len() as f32
54        } else {
55            0.0
56        };
57
58        // Determine if pattern is valid (>= 70% match rate)
59        let is_valid = validation_score >= 0.7;
60
61        // Recommend confidence based on validation score
62        let confidence_recommendation = self.calculate_confidence_recommendation(
63            validation_score,
64            pattern.occurrences,
65            matching_count,
66        );
67
68        Ok(ValidationResult {
69            pattern_id: pattern.id.clone(),
70            is_valid,
71            validation_score,
72            matching_decisions: matching_count,
73            total_decisions: decisions.len(),
74            mismatches,
75            confidence_recommendation,
76        })
77    }
78
79    /// Check if a decision matches a pattern
80    fn decision_matches_pattern(&self, decision: &Decision, pattern: &LearnedPattern) -> bool {
81        // Check if decision type matches
82        if decision.decision_type != pattern.pattern_type {
83            return false;
84        }
85
86        // Check if decision matches any of the pattern examples
87        for example in &pattern.examples {
88            if decision.input == example.input && decision.output == example.output {
89                return true;
90            }
91        }
92
93        false
94    }
95
96    /// Calculate a confidence recommendation based on validation results
97    fn calculate_confidence_recommendation(
98        &self,
99        validation_score: f32,
100        occurrences: usize,
101        matching_count: usize,
102    ) -> f32 {
103        // Base confidence on validation score
104        let validation_factor = validation_score;
105
106        // Increase confidence with more occurrences (up to a point)
107        let occurrence_factor = (occurrences as f32 / 10.0).min(1.0);
108
109        // Increase confidence with more matches
110        let match_factor = (matching_count as f32 / 5.0).min(1.0);
111
112        // Combined confidence recommendation
113        let confidence = (validation_factor * 0.5) + (occurrence_factor * 0.25) + (match_factor * 0.25);
114
115        confidence.min(1.0).max(0.0)
116    }
117
118    /// Validate multiple patterns
119    pub fn validate_patterns(
120        &self,
121        patterns: &[LearnedPattern],
122        decisions: &[Decision],
123    ) -> Result<Vec<ValidationResult>> {
124        let mut results = Vec::new();
125
126        for pattern in patterns {
127            let result = self.validate_pattern(pattern, decisions)?;
128            results.push(result);
129        }
130
131        Ok(results)
132    }
133
134    /// Get validation statistics for a set of patterns
135    pub fn get_validation_statistics(
136        &self,
137        validation_results: &[ValidationResult],
138    ) -> ValidationStatistics {
139        if validation_results.is_empty() {
140            return ValidationStatistics {
141                total_patterns: 0,
142                valid_patterns: 0,
143                invalid_patterns: 0,
144                average_validation_score: 0.0,
145                average_confidence_recommendation: 0.0,
146                total_mismatches: 0,
147            };
148        }
149
150        let valid_count = validation_results.iter().filter(|r| r.is_valid).count();
151        let invalid_count = validation_results.len() - valid_count;
152        let avg_score: f32 = validation_results.iter().map(|r| r.validation_score).sum::<f32>()
153            / validation_results.len() as f32;
154        let avg_confidence: f32 =
155            validation_results.iter().map(|r| r.confidence_recommendation).sum::<f32>()
156                / validation_results.len() as f32;
157        let total_mismatches: usize = validation_results.iter().map(|r| r.mismatches.len()).sum();
158
159        ValidationStatistics {
160            total_patterns: validation_results.len(),
161            valid_patterns: valid_count,
162            invalid_patterns: invalid_count,
163            average_validation_score: avg_score,
164            average_confidence_recommendation: avg_confidence,
165            total_mismatches,
166        }
167    }
168}
169
170impl Default for PatternValidator {
171    fn default() -> Self {
172        Self::new()
173    }
174}
175
176/// Result of pattern validation
177#[derive(Debug, Clone)]
178pub struct ValidationResult {
179    /// ID of the validated pattern
180    pub pattern_id: String,
181    /// Whether the pattern is valid (>= 70% match rate)
182    pub is_valid: bool,
183    /// Validation score (0.0 to 1.0)
184    pub validation_score: f32,
185    /// Number of decisions matching this pattern
186    pub matching_decisions: usize,
187    /// Total number of decisions analyzed
188    pub total_decisions: usize,
189    /// IDs of decisions that don't match the pattern
190    pub mismatches: Vec<String>,
191    /// Recommended confidence score based on validation
192    pub confidence_recommendation: f32,
193}
194
195/// Statistics about pattern validation results
196#[derive(Debug, Clone)]
197pub struct ValidationStatistics {
198    /// Total number of patterns validated
199    pub total_patterns: usize,
200    /// Number of valid patterns
201    pub valid_patterns: usize,
202    /// Number of invalid patterns
203    pub invalid_patterns: usize,
204    /// Average validation score across all patterns
205    pub average_validation_score: f32,
206    /// Average confidence recommendation across all patterns
207    pub average_confidence_recommendation: f32,
208    /// Total number of mismatches across all patterns
209    pub total_mismatches: usize,
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215    use crate::models::{DecisionContext, PatternExample};
216    use std::path::PathBuf;
217
218    fn create_test_decision(
219        decision_type: &str,
220        input: serde_json::Value,
221        output: serde_json::Value,
222    ) -> Decision {
223        let context = DecisionContext {
224            project_path: PathBuf::from("/project"),
225            file_path: PathBuf::from("/project/src/main.rs"),
226            line_number: 10,
227            agent_type: "test_agent".to_string(),
228        };
229
230        Decision::new(context, decision_type.to_string(), input, output)
231    }
232
233    #[test]
234    fn test_pattern_validator_creation() {
235        let validator = PatternValidator::new();
236        assert_eq!(std::mem::size_of_val(&validator), 0); // Zero-sized type
237    }
238
239    #[test]
240    fn test_validate_pattern_empty_decisions() {
241        let validator = PatternValidator::new();
242        let pattern = LearnedPattern::new("test".to_string(), "Test pattern".to_string());
243
244        let result = validator.validate_pattern(&pattern, &[]).unwrap();
245        assert!(!result.is_valid);
246        assert_eq!(result.validation_score, 0.0);
247        assert_eq!(result.matching_decisions, 0);
248    }
249
250    #[test]
251    fn test_validate_pattern_matching() {
252        let validator = PatternValidator::new();
253
254        let decision1 = create_test_decision(
255            "code_generation",
256            serde_json::json!({"input": "test"}),
257            serde_json::json!({"output": "result"}),
258        );
259
260        let decision2 = create_test_decision(
261            "code_generation",
262            serde_json::json!({"input": "test"}),
263            serde_json::json!({"output": "result"}),
264        );
265
266        let mut pattern = LearnedPattern::new("code_generation".to_string(), "Test pattern".to_string());
267        pattern.examples.push(PatternExample {
268            input: serde_json::json!({"input": "test"}),
269            output: serde_json::json!({"output": "result"}),
270            context: serde_json::json!({}),
271        });
272
273        let result = validator
274            .validate_pattern(&pattern, &[decision1, decision2])
275            .unwrap();
276
277        assert!(result.is_valid);
278        assert!(result.validation_score > 0.0);
279        assert_eq!(result.matching_decisions, 2);
280    }
281
282    #[test]
283    fn test_validate_pattern_no_matches() {
284        let validator = PatternValidator::new();
285
286        let decision = create_test_decision(
287            "code_generation",
288            serde_json::json!({"input": "test"}),
289            serde_json::json!({"output": "result"}),
290        );
291
292        let mut pattern = LearnedPattern::new("code_generation".to_string(), "Test pattern".to_string());
293        pattern.examples.push(PatternExample {
294            input: serde_json::json!({"input": "different"}),
295            output: serde_json::json!({"output": "different"}),
296            context: serde_json::json!({}),
297        });
298
299        let result = validator.validate_pattern(&pattern, &[decision]).unwrap();
300
301        assert!(!result.is_valid);
302        assert_eq!(result.validation_score, 0.0);
303        assert_eq!(result.matching_decisions, 0);
304    }
305
306    #[test]
307    fn test_validate_pattern_partial_matches() {
308        let validator = PatternValidator::new();
309
310        let decision1 = create_test_decision(
311            "code_generation",
312            serde_json::json!({"input": "test"}),
313            serde_json::json!({"output": "result"}),
314        );
315
316        let decision2 = create_test_decision(
317            "code_generation",
318            serde_json::json!({"input": "different"}),
319            serde_json::json!({"output": "different"}),
320        );
321
322        let decision3 = create_test_decision(
323            "code_generation",
324            serde_json::json!({"input": "test"}),
325            serde_json::json!({"output": "result"}),
326        );
327
328        let mut pattern = LearnedPattern::new("code_generation".to_string(), "Test pattern".to_string());
329        pattern.examples.push(PatternExample {
330            input: serde_json::json!({"input": "test"}),
331            output: serde_json::json!({"output": "result"}),
332            context: serde_json::json!({}),
333        });
334
335        let result = validator
336            .validate_pattern(&pattern, &[decision1, decision2, decision3])
337            .unwrap();
338
339        // 2 out of 3 match = 66.7%, which is < 70%, so invalid
340        assert!(!result.is_valid);
341        assert!(result.validation_score > 0.6 && result.validation_score < 0.7);
342        assert_eq!(result.matching_decisions, 2);
343    }
344
345    #[test]
346    fn test_validate_pattern_high_match_rate() {
347        let validator = PatternValidator::new();
348
349        let decision1 = create_test_decision(
350            "code_generation",
351            serde_json::json!({"input": "test"}),
352            serde_json::json!({"output": "result"}),
353        );
354
355        let decision2 = create_test_decision(
356            "code_generation",
357            serde_json::json!({"input": "test"}),
358            serde_json::json!({"output": "result"}),
359        );
360
361        let decision3 = create_test_decision(
362            "code_generation",
363            serde_json::json!({"input": "test"}),
364            serde_json::json!({"output": "result"}),
365        );
366
367        let decision4 = create_test_decision(
368            "code_generation",
369            serde_json::json!({"input": "different"}),
370            serde_json::json!({"output": "different"}),
371        );
372
373        let mut pattern = LearnedPattern::new("code_generation".to_string(), "Test pattern".to_string());
374        pattern.examples.push(PatternExample {
375            input: serde_json::json!({"input": "test"}),
376            output: serde_json::json!({"output": "result"}),
377            context: serde_json::json!({}),
378        });
379
380        let result = validator
381            .validate_pattern(&pattern, &[decision1, decision2, decision3, decision4])
382            .unwrap();
383
384        // 3 out of 4 match = 75%, which is >= 70%, so valid
385        assert!(result.is_valid);
386        assert!(result.validation_score > 0.7);
387        assert_eq!(result.matching_decisions, 3);
388    }
389
390    #[test]
391    fn test_validate_patterns_multiple() {
392        let validator = PatternValidator::new();
393
394        let decision1 = create_test_decision(
395            "code_generation",
396            serde_json::json!({"input": "test"}),
397            serde_json::json!({"output": "result"}),
398        );
399
400        let decision2 = create_test_decision(
401            "refactoring",
402            serde_json::json!({"input": "test"}),
403            serde_json::json!({"output": "result"}),
404        );
405
406        let mut pattern1 = LearnedPattern::new("code_generation".to_string(), "Pattern 1".to_string());
407        pattern1.examples.push(PatternExample {
408            input: serde_json::json!({"input": "test"}),
409            output: serde_json::json!({"output": "result"}),
410            context: serde_json::json!({}),
411        });
412
413        let mut pattern2 = LearnedPattern::new("refactoring".to_string(), "Pattern 2".to_string());
414        pattern2.examples.push(PatternExample {
415            input: serde_json::json!({"input": "test"}),
416            output: serde_json::json!({"output": "result"}),
417            context: serde_json::json!({}),
418        });
419
420        let results = validator
421            .validate_patterns(&[pattern1, pattern2], &[decision1, decision2])
422            .unwrap();
423
424        assert_eq!(results.len(), 2);
425    }
426
427    #[test]
428    fn test_get_validation_statistics() {
429        let validator = PatternValidator::new();
430
431        let result1 = ValidationResult {
432            pattern_id: "pattern1".to_string(),
433            is_valid: true,
434            validation_score: 0.8,
435            matching_decisions: 4,
436            total_decisions: 5,
437            mismatches: vec!["decision1".to_string()],
438            confidence_recommendation: 0.75,
439        };
440
441        let result2 = ValidationResult {
442            pattern_id: "pattern2".to_string(),
443            is_valid: false,
444            validation_score: 0.5,
445            matching_decisions: 2,
446            total_decisions: 4,
447            mismatches: vec!["decision2".to_string(), "decision3".to_string()],
448            confidence_recommendation: 0.5,
449        };
450
451        let stats = validator.get_validation_statistics(&[result1, result2]);
452
453        assert_eq!(stats.total_patterns, 2);
454        assert_eq!(stats.valid_patterns, 1);
455        assert_eq!(stats.invalid_patterns, 1);
456        assert!(stats.average_validation_score > 0.6 && stats.average_validation_score < 0.7);
457        assert_eq!(stats.total_mismatches, 3);
458    }
459
460    #[test]
461    fn test_get_validation_statistics_empty() {
462        let validator = PatternValidator::new();
463        let stats = validator.get_validation_statistics(&[]);
464
465        assert_eq!(stats.total_patterns, 0);
466        assert_eq!(stats.valid_patterns, 0);
467        assert_eq!(stats.invalid_patterns, 0);
468        assert_eq!(stats.average_validation_score, 0.0);
469    }
470
471    #[test]
472    fn test_confidence_recommendation_calculation() {
473        let validator = PatternValidator::new();
474
475        // Test with high validation score and many occurrences
476        let confidence1 = validator.calculate_confidence_recommendation(0.9, 10, 5);
477        assert!(confidence1 > 0.7);
478
479        // Test with low validation score
480        let confidence2 = validator.calculate_confidence_recommendation(0.3, 1, 1);
481        assert!(confidence2 < 0.5);
482
483        // Test with zero values
484        let confidence3 = validator.calculate_confidence_recommendation(0.0, 0, 0);
485        assert_eq!(confidence3, 0.0);
486    }
487
488    #[test]
489    fn test_validate_pattern_different_type() {
490        let validator = PatternValidator::new();
491
492        let decision = create_test_decision(
493            "code_generation",
494            serde_json::json!({"input": "test"}),
495            serde_json::json!({"output": "result"}),
496        );
497
498        let mut pattern = LearnedPattern::new("refactoring".to_string(), "Test pattern".to_string());
499        pattern.examples.push(PatternExample {
500            input: serde_json::json!({"input": "test"}),
501            output: serde_json::json!({"output": "result"}),
502            context: serde_json::json!({}),
503        });
504
505        let result = validator.validate_pattern(&pattern, &[decision]).unwrap();
506
507        // Pattern type doesn't match decision type, so no matches
508        assert!(!result.is_valid);
509        assert_eq!(result.matching_decisions, 0);
510    }
511}