1mod heuristic;
6mod similarity;
7mod types;
8
9pub use heuristic::Heuristic;
10pub use types::{Pattern, PatternEffectiveness};
11
12use crate::types::TaskContext;
13use chrono::Duration;
14
15impl Pattern {
16 #[must_use]
18 pub fn is_relevant_to(&self, query_context: &TaskContext) -> bool {
19 if let Some(pattern_context) = self.context() {
20 if pattern_context.domain == query_context.domain {
22 return true;
23 }
24
25 if pattern_context.language == query_context.language
27 && pattern_context.language.is_some()
28 {
29 return true;
30 }
31
32 let common_tags: Vec<_> = pattern_context
34 .tags
35 .iter()
36 .filter(|t| query_context.tags.contains(t))
37 .collect();
38
39 if !common_tags.is_empty() {
40 return true;
41 }
42 }
43
44 false
45 }
46
47 #[must_use]
50 pub fn similarity_key(&self) -> String {
51 match self {
52 Pattern::ToolSequence { tools, context, .. } => {
53 format!("tool_seq:{}:{}", tools.join(","), context.domain)
54 }
55 Pattern::DecisionPoint {
56 condition,
57 action,
58 context,
59 ..
60 } => {
61 format!("decision:{}:{}:{}", condition, action, context.domain)
62 }
63 Pattern::ErrorRecovery {
64 error_type,
65 recovery_steps,
66 context,
67 ..
68 } => {
69 format!(
70 "error_recovery:{}:{}:{}",
71 error_type,
72 recovery_steps.join(","),
73 context.domain
74 )
75 }
76 Pattern::ContextPattern {
77 context_features,
78 recommended_approach,
79 ..
80 } => {
81 format!(
82 "context:{}:{}",
83 context_features.join(","),
84 recommended_approach
85 )
86 }
87 }
88 }
89
90 #[must_use]
93 pub fn similarity_score(&self, other: &Self) -> f32 {
94 if std::mem::discriminant(self) != std::mem::discriminant(other) {
96 return 0.0;
97 }
98
99 match (self, other) {
100 (
101 Pattern::ToolSequence {
102 tools: tools1,
103 context: ctx1,
104 ..
105 },
106 Pattern::ToolSequence {
107 tools: tools2,
108 context: ctx2,
109 ..
110 },
111 ) => {
112 let sequence_similarity = similarity::sequence_similarity(tools1, tools2);
113 let context_similarity = similarity::context_similarity(ctx1, ctx2);
114 sequence_similarity * 0.7 + context_similarity * 0.3
116 }
117 (
118 Pattern::DecisionPoint {
119 condition: cond1,
120 action: act1,
121 context: ctx1,
122 ..
123 },
124 Pattern::DecisionPoint {
125 condition: cond2,
126 action: act2,
127 context: ctx2,
128 ..
129 },
130 ) => {
131 let condition_sim = similarity::string_similarity(cond1, cond2);
132 let action_sim = similarity::string_similarity(act1, act2);
133 let context_sim = similarity::context_similarity(ctx1, ctx2);
134 condition_sim * 0.4 + action_sim * 0.4 + context_sim * 0.2
136 }
137 (
138 Pattern::ErrorRecovery {
139 error_type: err1,
140 recovery_steps: steps1,
141 context: ctx1,
142 ..
143 },
144 Pattern::ErrorRecovery {
145 error_type: err2,
146 recovery_steps: steps2,
147 context: ctx2,
148 ..
149 },
150 ) => {
151 let error_sim = similarity::string_similarity(err1, err2);
152 let steps_sim = similarity::sequence_similarity(steps1, steps2);
153 let context_sim = similarity::context_similarity(ctx1, ctx2);
154 error_sim * 0.4 + steps_sim * 0.4 + context_sim * 0.2
156 }
157 (
158 Pattern::ContextPattern {
159 context_features: feat1,
160 recommended_approach: rec1,
161 ..
162 },
163 Pattern::ContextPattern {
164 context_features: feat2,
165 recommended_approach: rec2,
166 ..
167 },
168 ) => {
169 let features_sim = similarity::sequence_similarity(feat1, feat2);
170 let approach_sim = similarity::string_similarity(rec1, rec2);
171 features_sim * 0.6 + approach_sim * 0.4
173 }
174 _ => 0.0,
175 }
176 }
177
178 #[must_use]
181 pub fn confidence(&self) -> f32 {
182 let success_rate = self.success_rate();
183 let sample_size = self.sample_size() as f32;
184
185 if sample_size == 0.0 {
186 return 0.0;
187 }
188
189 success_rate * sample_size.sqrt()
190 }
191
192 pub fn merge_with(&mut self, other: &Self) {
195 if std::mem::discriminant(self) != std::mem::discriminant(other) {
197 return;
198 }
199
200 match (self, other) {
201 (
202 Pattern::ToolSequence {
203 success_rate: sr1,
204 occurrence_count: oc1,
205 avg_latency: lat1,
206 ..
207 },
208 Pattern::ToolSequence {
209 success_rate: sr2,
210 occurrence_count: oc2,
211 avg_latency: lat2,
212 ..
213 },
214 ) => {
215 let total_count = *oc1 + *oc2;
216 *sr1 = (*sr1 * *oc1 as f32 + *sr2 * *oc2 as f32) / total_count as f32;
218 *lat1 = Duration::milliseconds(
220 (lat1.num_milliseconds() * *oc1 as i64 + lat2.num_milliseconds() * *oc2 as i64)
221 / total_count as i64,
222 );
223 *oc1 = total_count;
224 }
225 (
226 Pattern::DecisionPoint {
227 outcome_stats: stats1,
228 ..
229 },
230 Pattern::DecisionPoint {
231 outcome_stats: stats2,
232 ..
233 },
234 ) => {
235 stats1.success_count += stats2.success_count;
236 stats1.failure_count += stats2.failure_count;
237 stats1.total_count += stats2.total_count;
238 stats1.avg_duration_secs = (stats1.avg_duration_secs
240 * (stats1.total_count - stats2.total_count) as f32
241 + stats2.avg_duration_secs * stats2.total_count as f32)
242 / stats1.total_count as f32;
243 }
244 (
245 Pattern::ErrorRecovery {
246 success_rate: sr1, ..
247 },
248 Pattern::ErrorRecovery {
249 success_rate: sr2, ..
250 },
251 ) => {
252 *sr1 = (*sr1 + *sr2) / 2.0;
254 }
257 (
258 Pattern::ContextPattern {
259 evidence: ev1,
260 success_rate: sr1,
261 ..
262 },
263 Pattern::ContextPattern {
264 evidence: ev2,
265 success_rate: sr2,
266 ..
267 },
268 ) => {
269 let size1 = ev1.len();
270 let size2 = ev2.len();
271 ev1.extend_from_slice(ev2);
273 *sr1 = (*sr1 * size1 as f32 + *sr2 * size2 as f32) / (size1 + size2) as f32;
275 }
276 _ => {}
277 }
278 }
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284 use crate::types::ComplexityLevel;
285 use uuid::Uuid;
286
287 #[test]
288 fn test_pattern_id() {
289 let pattern = Pattern::ToolSequence {
290 id: Uuid::new_v4(),
291 tools: vec!["tool1".to_string(), "tool2".to_string()],
292 context: TaskContext::default(),
293 success_rate: 0.9,
294 avg_latency: Duration::milliseconds(100),
295 occurrence_count: 5,
296 effectiveness: PatternEffectiveness::new(),
297 };
298
299 assert!(pattern.success_rate() > 0.8);
300 assert!(pattern.context().is_some());
301 }
302
303 #[test]
304 fn test_pattern_similarity_key() {
305 let pattern1 = Pattern::ToolSequence {
306 id: Uuid::new_v4(),
307 tools: vec!["read".to_string(), "write".to_string()],
308 context: TaskContext {
309 domain: "web-api".to_string(),
310 ..Default::default()
311 },
312 success_rate: 0.9,
313 avg_latency: Duration::milliseconds(100),
314 occurrence_count: 5,
315 effectiveness: PatternEffectiveness::new(),
316 };
317
318 let pattern2 = Pattern::ToolSequence {
319 id: Uuid::new_v4(),
320 tools: vec!["read".to_string(), "write".to_string()],
321 context: TaskContext {
322 domain: "web-api".to_string(),
323 ..Default::default()
324 },
325 success_rate: 0.8,
326 avg_latency: Duration::milliseconds(120),
327 occurrence_count: 3,
328 effectiveness: PatternEffectiveness::new(),
329 };
330
331 assert_eq!(pattern1.similarity_key(), pattern2.similarity_key());
333 }
334
335 #[test]
336 fn test_pattern_similarity_score() {
337 let pattern1 = Pattern::ToolSequence {
338 id: Uuid::new_v4(),
339 tools: vec!["read".to_string(), "write".to_string()],
340 context: TaskContext {
341 domain: "web-api".to_string(),
342 language: Some("rust".to_string()),
343 ..Default::default()
344 },
345 success_rate: 0.9,
346 avg_latency: Duration::milliseconds(100),
347 occurrence_count: 5,
348 effectiveness: PatternEffectiveness::new(),
349 };
350
351 let pattern2 = Pattern::ToolSequence {
352 id: Uuid::new_v4(),
353 tools: vec!["read".to_string(), "write".to_string()],
354 context: TaskContext {
355 domain: "web-api".to_string(),
356 language: Some("rust".to_string()),
357 ..Default::default()
358 },
359 success_rate: 0.8,
360 avg_latency: Duration::milliseconds(120),
361 occurrence_count: 3,
362 effectiveness: PatternEffectiveness::new(),
363 };
364
365 let similarity = pattern1.similarity_score(&pattern2);
366
367 assert!(similarity > 0.9);
369 }
370
371 #[test]
372 fn test_pattern_confidence() {
373 let pattern = Pattern::ToolSequence {
374 id: Uuid::new_v4(),
375 tools: vec!["tool1".to_string()],
376 context: TaskContext::default(),
377 success_rate: 0.8,
378 avg_latency: Duration::milliseconds(100),
379 occurrence_count: 16, effectiveness: PatternEffectiveness::new(),
381 };
382
383 let confidence = pattern.confidence();
384
385 assert!((confidence - 3.2).abs() < 0.01);
387 }
388
389 #[test]
390 fn test_pattern_merge() {
391 let mut pattern1 = Pattern::ToolSequence {
392 id: Uuid::new_v4(),
393 tools: vec!["read".to_string(), "write".to_string()],
394 context: TaskContext::default(),
395 success_rate: 0.8,
396 avg_latency: Duration::milliseconds(100),
397 occurrence_count: 10,
398 effectiveness: PatternEffectiveness::new(),
399 };
400
401 let pattern2 = Pattern::ToolSequence {
402 id: Uuid::new_v4(),
403 tools: vec!["read".to_string(), "write".to_string()],
404 context: TaskContext::default(),
405 success_rate: 0.9,
406 avg_latency: Duration::milliseconds(200),
407 occurrence_count: 10,
408 effectiveness: PatternEffectiveness::new(),
409 };
410
411 pattern1.merge_with(&pattern2);
412
413 match pattern1 {
415 Pattern::ToolSequence {
416 occurrence_count,
417 success_rate,
418 ..
419 } => {
420 assert_eq!(occurrence_count, 20);
421 assert!((success_rate - 0.85).abs() < 0.01);
423 }
424 _ => panic!("Expected ToolSequence"),
425 }
426 }
427
428 #[test]
429 fn test_pattern_relevance() {
430 let pattern_context = TaskContext {
431 language: Some("rust".to_string()),
432 framework: None,
433 complexity: ComplexityLevel::Moderate,
434 domain: "web-api".to_string(),
435 tags: vec!["async".to_string()],
436 };
437
438 let pattern = Pattern::ToolSequence {
439 id: Uuid::new_v4(),
440 tools: vec![],
441 context: pattern_context.clone(),
442 success_rate: 0.9,
443 avg_latency: Duration::milliseconds(100),
444 occurrence_count: 1,
445 effectiveness: PatternEffectiveness::new(),
446 };
447
448 let query_context = TaskContext {
450 domain: "web-api".to_string(),
451 ..Default::default()
452 };
453 assert!(pattern.is_relevant_to(&query_context));
454
455 let query_context2 = TaskContext {
457 language: Some("rust".to_string()),
458 domain: "cli".to_string(),
459 ..Default::default()
460 };
461 assert!(pattern.is_relevant_to(&query_context2));
462
463 let query_context3 = TaskContext {
465 language: Some("python".to_string()),
466 domain: "data-science".to_string(),
467 ..Default::default()
468 };
469 assert!(!pattern.is_relevant_to(&query_context3));
470 }
471
472 #[test]
473 fn test_heuristic_evidence_update() {
474 let mut heuristic = Heuristic::new(
475 "When refactoring async code".to_string(),
476 "Use tokio::spawn for CPU-intensive tasks".to_string(),
477 0.7,
478 );
479
480 assert_eq!(heuristic.evidence.sample_size, 0);
481
482 heuristic.update_evidence(Uuid::new_v4(), true);
484 assert_eq!(heuristic.evidence.sample_size, 1);
485 assert_eq!(heuristic.evidence.success_rate, 1.0);
486
487 heuristic.update_evidence(Uuid::new_v4(), false);
489 assert_eq!(heuristic.evidence.sample_size, 2);
490 assert_eq!(heuristic.evidence.success_rate, 0.5);
491
492 heuristic.update_evidence(Uuid::new_v4(), true);
494 assert_eq!(heuristic.evidence.sample_size, 3);
495 assert!((heuristic.evidence.success_rate - 0.666).abs() < 0.01);
496 }
497}