1use crate::error::{LearningError, Result};
7use crate::intent_tracker::{ArchitecturalDecision, DriftDetection};
8use crate::models::LearnedPattern;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct DriftDetectionConfig {
15 pub confidence_threshold: f32,
17 pub min_occurrences_for_pattern: usize,
19 pub strict_mode: bool,
21}
22
23impl Default for DriftDetectionConfig {
24 fn default() -> Self {
25 Self {
26 confidence_threshold: 0.7,
27 min_occurrences_for_pattern: 3,
28 strict_mode: false,
29 }
30 }
31}
32
33pub struct DriftDetector {
35 config: DriftDetectionConfig,
37 patterns: HashMap<String, LearnedPattern>,
39 drifts: Vec<DriftDetection>,
41}
42
43impl DriftDetector {
44 pub fn new() -> Self {
46 Self::with_config(DriftDetectionConfig::default())
47 }
48
49 pub fn with_config(config: DriftDetectionConfig) -> Self {
51 Self {
52 config,
53 patterns: HashMap::new(),
54 drifts: Vec::new(),
55 }
56 }
57
58 pub fn register_pattern(&mut self, pattern: LearnedPattern) -> Result<()> {
60 if pattern.occurrences < self.config.min_occurrences_for_pattern {
61 return Err(LearningError::PatternExtractionFailed(
62 format!(
63 "Pattern must have at least {} occurrences",
64 self.config.min_occurrences_for_pattern
65 ),
66 ));
67 }
68
69 self.patterns.insert(pattern.id.clone(), pattern);
70 Ok(())
71 }
72
73 pub fn check_deviation(
75 &mut self,
76 decision: &ArchitecturalDecision,
77 pattern_type: &str,
78 ) -> Result<Option<DriftDetection>> {
79 let pattern = self
81 .patterns
82 .values()
83 .find(|p| p.pattern_type.contains(pattern_type));
84
85 if let Some(pattern) = pattern {
86 if decision.confidence < pattern.confidence * self.config.confidence_threshold {
88 let drift_type = if self.config.strict_mode {
89 "violation"
90 } else {
91 "deviation"
92 };
93
94 let drift = DriftDetection::new(
95 decision.id.clone(),
96 drift_type.to_string(),
97 "medium".to_string(),
98 format!(
99 "Decision confidence ({:.2}) is below pattern confidence ({:.2})",
100 decision.confidence, pattern.confidence
101 ),
102 );
103
104 self.drifts.push(drift.clone());
105 Ok(Some(drift))
106 } else {
107 Ok(None)
108 }
109 } else {
110 Ok(None)
111 }
112 }
113
114 pub fn detect_inconsistency(
116 &mut self,
117 decision_id: &str,
118 expected_behavior: &str,
119 actual_behavior: &str,
120 ) -> Result<DriftDetection> {
121 if expected_behavior == actual_behavior {
122 return Err(LearningError::PatternExtractionFailed(
123 "No inconsistency detected".to_string(),
124 ));
125 }
126
127 let drift = DriftDetection::new(
128 decision_id.to_string(),
129 "inconsistency".to_string(),
130 "low".to_string(),
131 format!(
132 "Expected: {}, Actual: {}",
133 expected_behavior, actual_behavior
134 ),
135 );
136
137 self.drifts.push(drift.clone());
138 Ok(drift)
139 }
140
141 pub fn detect_violation(
143 &mut self,
144 decision_id: &str,
145 violation_description: &str,
146 ) -> Result<DriftDetection> {
147 let drift = DriftDetection::new(
148 decision_id.to_string(),
149 "violation".to_string(),
150 "high".to_string(),
151 violation_description.to_string(),
152 );
153
154 self.drifts.push(drift.clone());
155 Ok(drift)
156 }
157
158 pub fn get_drifts(&self) -> Vec<DriftDetection> {
160 self.drifts.clone()
161 }
162
163 pub fn get_drifts_by_severity(&self, severity: &str) -> Vec<DriftDetection> {
165 self.drifts
166 .iter()
167 .filter(|d| d.severity == severity)
168 .cloned()
169 .collect()
170 }
171
172 pub fn get_drifts_for_decision(&self, decision_id: &str) -> Vec<DriftDetection> {
174 self.drifts
175 .iter()
176 .filter(|d| d.decision_id == decision_id)
177 .cloned()
178 .collect()
179 }
180
181 pub fn clear_drifts(&mut self) {
183 self.drifts.clear();
184 }
185
186 pub fn get_statistics(&self) -> DriftStatistics {
188 let total_drifts = self.drifts.len();
189 let high_severity = self
190 .drifts
191 .iter()
192 .filter(|d| d.severity == "high")
193 .count();
194 let medium_severity = self
195 .drifts
196 .iter()
197 .filter(|d| d.severity == "medium")
198 .count();
199 let low_severity = self
200 .drifts
201 .iter()
202 .filter(|d| d.severity == "low")
203 .count();
204
205 let violations = self
206 .drifts
207 .iter()
208 .filter(|d| d.drift_type == "violation")
209 .count();
210 let deviations = self
211 .drifts
212 .iter()
213 .filter(|d| d.drift_type == "deviation")
214 .count();
215 let inconsistencies = self
216 .drifts
217 .iter()
218 .filter(|d| d.drift_type == "inconsistency")
219 .count();
220
221 DriftStatistics {
222 total_drifts,
223 high_severity,
224 medium_severity,
225 low_severity,
226 violations,
227 deviations,
228 inconsistencies,
229 }
230 }
231}
232
233impl Default for DriftDetector {
234 fn default() -> Self {
235 Self::new()
236 }
237}
238
239#[derive(Debug, Clone, Serialize, Deserialize)]
241pub struct DriftStatistics {
242 pub total_drifts: usize,
244 pub high_severity: usize,
246 pub medium_severity: usize,
248 pub low_severity: usize,
250 pub violations: usize,
252 pub deviations: usize,
254 pub inconsistencies: usize,
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261 use crate::intent_tracker::ArchitecturalDecision;
262 use crate::models::LearnedPattern;
263
264 #[test]
265 fn test_drift_detector_creation() {
266 let detector = DriftDetector::new();
267 assert_eq!(detector.get_drifts().len(), 0);
268 }
269
270 #[test]
271 fn test_register_pattern() {
272 let mut detector = DriftDetector::new();
273 let mut pattern = LearnedPattern::new(
274 "layering".to_string(),
275 "Layered architecture pattern".to_string(),
276 );
277 pattern.occurrences = 5;
278 pattern.confidence = 0.9;
279
280 assert!(detector.register_pattern(pattern).is_ok());
281 }
282
283 #[test]
284 fn test_register_pattern_insufficient_occurrences() {
285 let mut detector = DriftDetector::new();
286 let pattern = LearnedPattern::new(
287 "layering".to_string(),
288 "Layered architecture pattern".to_string(),
289 );
290
291 assert!(detector.register_pattern(pattern).is_err());
292 }
293
294 #[test]
295 fn test_detect_inconsistency() {
296 let mut detector = DriftDetector::new();
297 let drift = detector
298 .detect_inconsistency(
299 "decision_1",
300 "async_pattern",
301 "sync_pattern",
302 )
303 .expect("Failed to detect inconsistency");
304
305 assert_eq!(drift.drift_type, "inconsistency");
306 assert_eq!(drift.severity, "low");
307 assert_eq!(detector.get_drifts().len(), 1);
308 }
309
310 #[test]
311 fn test_detect_violation() {
312 let mut detector = DriftDetector::new();
313 let drift = detector
314 .detect_violation(
315 "decision_1",
316 "Direct layer bypass detected",
317 )
318 .expect("Failed to detect violation");
319
320 assert_eq!(drift.drift_type, "violation");
321 assert_eq!(drift.severity, "high");
322 assert_eq!(detector.get_drifts().len(), 1);
323 }
324
325 #[test]
326 fn test_get_drifts_by_severity() {
327 let mut detector = DriftDetector::new();
328 detector
329 .detect_violation("decision_1", "Violation")
330 .expect("Failed to detect violation");
331 detector
332 .detect_inconsistency("decision_2", "expected", "actual")
333 .expect("Failed to detect inconsistency");
334
335 let high_severity = detector.get_drifts_by_severity("high");
336 let low_severity = detector.get_drifts_by_severity("low");
337
338 assert_eq!(high_severity.len(), 1);
339 assert_eq!(low_severity.len(), 1);
340 }
341
342 #[test]
343 fn test_get_drifts_for_decision() {
344 let mut detector = DriftDetector::new();
345 detector
346 .detect_violation("decision_1", "Violation 1")
347 .expect("Failed to detect violation");
348 detector
349 .detect_violation("decision_1", "Violation 2")
350 .expect("Failed to detect violation");
351 detector
352 .detect_violation("decision_2", "Violation 3")
353 .expect("Failed to detect violation");
354
355 let drifts_for_decision_1 = detector.get_drifts_for_decision("decision_1");
356 assert_eq!(drifts_for_decision_1.len(), 2);
357 }
358
359 #[test]
360 fn test_clear_drifts() {
361 let mut detector = DriftDetector::new();
362 detector
363 .detect_violation("decision_1", "Violation")
364 .expect("Failed to detect violation");
365
366 assert_eq!(detector.get_drifts().len(), 1);
367 detector.clear_drifts();
368 assert_eq!(detector.get_drifts().len(), 0);
369 }
370
371 #[test]
372 fn test_get_statistics() {
373 let mut detector = DriftDetector::new();
374 detector
375 .detect_violation("decision_1", "Violation")
376 .expect("Failed to detect violation");
377 detector
378 .detect_inconsistency("decision_2", "expected", "actual")
379 .expect("Failed to detect inconsistency");
380
381 let stats = detector.get_statistics();
382 assert_eq!(stats.total_drifts, 2);
383 assert_eq!(stats.high_severity, 1);
384 assert_eq!(stats.low_severity, 1);
385 assert_eq!(stats.violations, 1);
386 assert_eq!(stats.inconsistencies, 1);
387 }
388
389 #[test]
390 fn test_check_deviation() {
391 let mut detector = DriftDetector::with_config(DriftDetectionConfig {
392 confidence_threshold: 0.7,
393 min_occurrences_for_pattern: 1,
394 strict_mode: false,
395 });
396
397 let mut pattern = LearnedPattern::new(
398 "layering".to_string(),
399 "Layered architecture pattern".to_string(),
400 );
401 pattern.occurrences = 5;
402 pattern.confidence = 0.9;
403
404 detector
405 .register_pattern(pattern)
406 .expect("Failed to register pattern");
407
408 let decision = ArchitecturalDecision::new(
409 "layering".to_string(),
410 "Layered architecture".to_string(),
411 "Separation of concerns".to_string(),
412 "0.1.0".to_string(),
413 );
414
415 let drift = detector
416 .check_deviation(&decision, "layering")
417 .expect("Failed to check deviation");
418
419 assert!(drift.is_some());
421 }
422}