1use crate::error::{LearningError, Result};
3use crate::models::{Decision, LearnedPattern, PatternExample};
4use std::collections::HashMap;
5
6pub struct PatternCapturer {
8 min_occurrences: usize,
10 min_confidence: f32,
12}
13
14impl PatternCapturer {
15 pub fn new() -> Self {
17 Self {
18 min_occurrences: 2,
19 min_confidence: 0.5,
20 }
21 }
22
23 pub fn with_settings(min_occurrences: usize, min_confidence: f32) -> Self {
25 Self {
26 min_occurrences,
27 min_confidence,
28 }
29 }
30
31 pub fn extract_patterns(&self, decisions: &[Decision]) -> Result<Vec<LearnedPattern>> {
33 if decisions.is_empty() {
34 return Ok(Vec::new());
35 }
36
37 let mut decisions_by_type: HashMap<String, Vec<&Decision>> = HashMap::new();
39 for decision in decisions {
40 decisions_by_type
41 .entry(decision.decision_type.clone())
42 .or_insert_with(Vec::new)
43 .push(decision);
44 }
45
46 let mut patterns = Vec::new();
47
48 for (decision_type, type_decisions) in decisions_by_type {
50 if type_decisions.len() >= self.min_occurrences {
51 let extracted = self.extract_patterns_for_type(&decision_type, &type_decisions)?;
53 patterns.extend(extracted);
54 }
55 }
56
57 Ok(patterns)
58 }
59
60 fn extract_patterns_for_type(
62 &self,
63 decision_type: &str,
64 decisions: &[&Decision],
65 ) -> Result<Vec<LearnedPattern>> {
66 let mut patterns = Vec::new();
67
68 let mut input_groups: HashMap<String, Vec<&Decision>> = HashMap::new();
70 for decision in decisions {
71 let input_key = serde_json::to_string(&decision.input)
73 .unwrap_or_else(|_| "unknown".to_string());
74 input_groups
75 .entry(input_key)
76 .or_insert_with(Vec::new)
77 .push(decision);
78 }
79
80 for (_input_key, group) in input_groups {
82 if group.len() >= self.min_occurrences {
83 let pattern = self.create_pattern_from_group(decision_type, &group)?;
84 if pattern.confidence >= self.min_confidence {
85 patterns.push(pattern);
86 }
87 }
88 }
89
90 Ok(patterns)
91 }
92
93 fn create_pattern_from_group(
95 &self,
96 decision_type: &str,
97 decisions: &[&Decision],
98 ) -> Result<LearnedPattern> {
99 if decisions.is_empty() {
100 return Err(LearningError::PatternExtractionFailed(
101 "Cannot create pattern from empty group".to_string(),
102 ));
103 }
104
105 let pattern_content = format!(
107 "{}:{}",
108 decision_type,
109 serde_json::to_string(&decisions[0].input).unwrap_or_default()
110 );
111 let pattern_id = format!(
112 "{:x}",
113 md5::compute(pattern_content.as_bytes())
114 );
115
116 let mut pattern = LearnedPattern {
117 id: pattern_id,
118 pattern_type: decision_type.to_string(),
119 description: format!("Pattern for {}", decision_type),
120 examples: Vec::new(),
121 confidence: 0.0,
122 occurrences: 0,
123 created_at: chrono::Utc::now(),
124 last_seen: chrono::Utc::now(),
125 };
126
127 for decision in decisions {
129 let example = PatternExample {
130 input: decision.input.clone(),
131 output: decision.output.clone(),
132 context: serde_json::json!({
133 "agent_type": decision.context.agent_type,
134 "file_path": decision.context.file_path.to_string_lossy(),
135 "line_number": decision.context.line_number,
136 }),
137 };
138 pattern.examples.push(example);
139 }
140
141 pattern.occurrences = decisions.len();
143 pattern.confidence = self.calculate_confidence(decisions)?;
144 pattern.last_seen = decisions.last().unwrap().timestamp;
145
146 Ok(pattern)
147 }
148
149 fn calculate_confidence(&self, decisions: &[&Decision]) -> Result<f32> {
151 if decisions.is_empty() {
152 return Ok(0.0);
153 }
154
155 let output_consistency = if decisions.len() > 1 {
157 let mut output_counts: HashMap<String, usize> = HashMap::new();
159 for decision in decisions {
160 let output_str = serde_json::to_string(&decision.output)
161 .unwrap_or_else(|_| "unknown".to_string());
162 *output_counts.entry(output_str).or_insert(0) += 1;
163 }
164
165 let max_count = output_counts.values().max().copied().unwrap_or(0);
167 max_count as f32 / decisions.len() as f32
168 } else {
169 0.5 };
171
172 let occurrence_factor = (decisions.len() as f32 / 10.0).min(1.0);
174
175 let confidence = (output_consistency * 0.7) + (occurrence_factor * 0.3);
177
178 Ok(confidence.min(1.0).max(0.0))
179 }
180
181 fn outputs_are_similar(&self, output1: &serde_json::Value, output2: &serde_json::Value) -> bool {
183 output1 == output2
186 }
187
188 fn compute_input_hash(&self, input: &serde_json::Value) -> String {
190 format!("{:?}", input)
193 }
194
195 pub fn validate_pattern(
197 &self,
198 pattern: &LearnedPattern,
199 decisions: &[Decision],
200 ) -> Result<f32> {
201 if decisions.is_empty() {
202 return Ok(0.0);
203 }
204
205 let mut matching_count = 0;
206
207 for decision in decisions {
208 if decision.decision_type == pattern.pattern_type {
209 for example in &pattern.examples {
211 if self.decision_matches_example(decision, example) {
212 matching_count += 1;
213 break;
214 }
215 }
216 }
217 }
218
219 let validation_score = matching_count as f32 / decisions.len() as f32;
220 Ok(validation_score)
221 }
222
223 fn decision_matches_example(&self, decision: &Decision, example: &PatternExample) -> bool {
225 decision.input == example.input && decision.output == example.output
227 }
228
229 pub fn update_confidence(
231 &self,
232 pattern: &mut LearnedPattern,
233 validation_score: f32,
234 ) -> Result<()> {
235 let alpha = 0.3; pattern.confidence = (alpha * validation_score) + ((1.0 - alpha) * pattern.confidence);
238
239 Ok(())
240 }
241
242 pub fn extract_patterns_with_analysis(
244 &self,
245 decisions: &[Decision],
246 ) -> Result<Vec<(LearnedPattern, PatternAnalysis)>> {
247 let patterns = self.extract_patterns(decisions)?;
248
249 let mut results = Vec::new();
250 for pattern in patterns {
251 let analysis = self.analyze_pattern(&pattern, decisions)?;
252 results.push((pattern, analysis));
253 }
254
255 Ok(results)
256 }
257
258 fn analyze_pattern(
260 &self,
261 pattern: &LearnedPattern,
262 decisions: &[Decision],
263 ) -> Result<PatternAnalysis> {
264 let validation_score = self.validate_pattern(pattern, decisions)?;
265
266 let mut matching_decisions = 0;
267 for decision in decisions {
268 if decision.decision_type == pattern.pattern_type {
269 for example in &pattern.examples {
270 if self.decision_matches_example(decision, example) {
271 matching_decisions += 1;
272 break;
273 }
274 }
275 }
276 }
277
278 Ok(PatternAnalysis {
279 pattern_id: pattern.id.clone(),
280 validation_score,
281 matching_decisions,
282 total_decisions: decisions.len(),
283 confidence: pattern.confidence,
284 occurrences: pattern.occurrences,
285 })
286 }
287}
288
289impl Default for PatternCapturer {
290 fn default() -> Self {
291 Self::new()
292 }
293}
294
295#[derive(Debug, Clone)]
297pub struct PatternAnalysis {
298 pub pattern_id: String,
300 pub validation_score: f32,
302 pub matching_decisions: usize,
304 pub total_decisions: usize,
306 pub confidence: f32,
308 pub occurrences: usize,
310}
311
312#[cfg(test)]
313mod tests {
314 use super::*;
315 use crate::models::DecisionContext;
316 use std::path::PathBuf;
317
318 fn create_test_decision(
319 decision_type: &str,
320 input: serde_json::Value,
321 output: serde_json::Value,
322 ) -> Decision {
323 let context = DecisionContext {
324 project_path: PathBuf::from("/project"),
325 file_path: PathBuf::from("/project/src/main.rs"),
326 line_number: 10,
327 agent_type: "test_agent".to_string(),
328 };
329
330 Decision::new(context, decision_type.to_string(), input, output)
331 }
332
333 #[test]
334 fn test_pattern_capturer_creation() {
335 let capturer = PatternCapturer::new();
336 assert_eq!(capturer.min_occurrences, 2);
337 assert_eq!(capturer.min_confidence, 0.5);
338 }
339
340 #[test]
341 fn test_pattern_capturer_with_settings() {
342 let capturer = PatternCapturer::with_settings(3, 0.7);
343 assert_eq!(capturer.min_occurrences, 3);
344 assert_eq!(capturer.min_confidence, 0.7);
345 }
346
347 #[test]
348 fn test_extract_patterns_empty() {
349 let capturer = PatternCapturer::new();
350 let patterns = capturer.extract_patterns(&[]).unwrap();
351 assert!(patterns.is_empty());
352 }
353
354 #[test]
355 fn test_extract_patterns_single_decision() {
356 let capturer = PatternCapturer::new();
357
358 let decision = create_test_decision(
359 "code_generation",
360 serde_json::json!({"input": "test"}),
361 serde_json::json!({"output": "result"}),
362 );
363
364 let patterns = capturer.extract_patterns(&[decision]).unwrap();
365 assert!(patterns.is_empty());
367 }
368
369 #[test]
370 fn test_extract_patterns_multiple_decisions() {
371 let capturer = PatternCapturer::new();
372
373 let decision1 = create_test_decision(
374 "code_generation",
375 serde_json::json!({"input": "test"}),
376 serde_json::json!({"output": "result"}),
377 );
378
379 let decision2 = create_test_decision(
380 "code_generation",
381 serde_json::json!({"input": "test"}),
382 serde_json::json!({"output": "result"}),
383 );
384
385 let patterns = capturer.extract_patterns(&[decision1, decision2]).unwrap();
386 assert_eq!(patterns.len(), 1);
387 assert_eq!(patterns[0].pattern_type, "code_generation");
388 assert_eq!(patterns[0].occurrences, 2);
389 }
390
391 #[test]
392 fn test_extract_patterns_different_types() {
393 let capturer = PatternCapturer::new();
394
395 let decision1 = create_test_decision(
396 "code_generation",
397 serde_json::json!({"input": "test"}),
398 serde_json::json!({"output": "result"}),
399 );
400
401 let decision2 = create_test_decision(
402 "code_generation",
403 serde_json::json!({"input": "test"}),
404 serde_json::json!({"output": "result"}),
405 );
406
407 let decision3 = create_test_decision(
408 "refactoring",
409 serde_json::json!({"input": "test"}),
410 serde_json::json!({"output": "result"}),
411 );
412
413 let patterns = capturer
414 .extract_patterns(&[decision1, decision2, decision3])
415 .unwrap();
416 assert_eq!(patterns.len(), 1);
418 assert_eq!(patterns[0].pattern_type, "code_generation");
419 }
420
421 #[test]
422 fn test_calculate_confidence() {
423 let capturer = PatternCapturer::new();
424
425 let decision1 = create_test_decision(
426 "code_generation",
427 serde_json::json!({"input": "test"}),
428 serde_json::json!({"output": "result"}),
429 );
430
431 let decision2 = create_test_decision(
432 "code_generation",
433 serde_json::json!({"input": "test"}),
434 serde_json::json!({"output": "result"}),
435 );
436
437 let confidence = capturer.calculate_confidence(&[&decision1, &decision2]).unwrap();
438 assert!(confidence > 0.0);
439 assert!(confidence <= 1.0);
440 }
441
442 #[test]
443 fn test_validate_pattern() {
444 let capturer = PatternCapturer::new();
445
446 let decision1 = create_test_decision(
447 "code_generation",
448 serde_json::json!({"input": "test"}),
449 serde_json::json!({"output": "result"}),
450 );
451
452 let decision2 = create_test_decision(
453 "code_generation",
454 serde_json::json!({"input": "test"}),
455 serde_json::json!({"output": "result"}),
456 );
457
458 let patterns = capturer.extract_patterns(&[decision1.clone(), decision2]).unwrap();
459 assert_eq!(patterns.len(), 1);
460
461 let validation_score = capturer
462 .validate_pattern(&patterns[0], &[decision1])
463 .unwrap();
464 assert!(validation_score >= 0.0);
465 assert!(validation_score <= 1.0);
466 }
467
468 #[test]
469 fn test_update_confidence() {
470 let capturer = PatternCapturer::new();
471
472 let mut pattern = LearnedPattern::new(
473 "code_generation".to_string(),
474 "Test pattern".to_string(),
475 );
476
477 let initial_confidence = pattern.confidence;
478 capturer.update_confidence(&mut pattern, 0.9).unwrap();
479
480 assert_ne!(pattern.confidence, initial_confidence);
481 assert!(pattern.confidence > initial_confidence);
482 }
483
484 #[test]
485 fn test_extract_patterns_with_analysis() {
486 let capturer = PatternCapturer::new();
487
488 let decision1 = create_test_decision(
489 "code_generation",
490 serde_json::json!({"input": "test"}),
491 serde_json::json!({"output": "result"}),
492 );
493
494 let decision2 = create_test_decision(
495 "code_generation",
496 serde_json::json!({"input": "test"}),
497 serde_json::json!({"output": "result"}),
498 );
499
500 let results = capturer
501 .extract_patterns_with_analysis(&[decision1, decision2])
502 .unwrap();
503
504 assert_eq!(results.len(), 1);
505 let (pattern, analysis) = &results[0];
506 assert_eq!(pattern.pattern_type, "code_generation");
507 assert!(analysis.validation_score >= 0.0);
508 assert!(analysis.validation_score <= 1.0);
509 }
510
511 #[test]
512 fn test_pattern_examples() {
513 let capturer = PatternCapturer::new();
514
515 let decision1 = create_test_decision(
516 "code_generation",
517 serde_json::json!({"input": "test"}),
518 serde_json::json!({"output": "result"}),
519 );
520
521 let decision2 = create_test_decision(
522 "code_generation",
523 serde_json::json!({"input": "test"}),
524 serde_json::json!({"output": "result"}),
525 );
526
527 let patterns = capturer.extract_patterns(&[decision1, decision2]).unwrap();
528 assert_eq!(patterns.len(), 1);
529 assert_eq!(patterns[0].examples.len(), 2);
530 }
531}