1use super::collector::FailurePattern;
7use super::truth::{BehavioralTruth, TruthCategory, TruthSource};
8use std::collections::HashMap;
9
10pub struct TruthInferenceEngine {
12 known_nonblocking_flags: HashMap<String, String>,
14
15 min_occurrences: u32,
17
18 client_id: Option<String>,
20}
21
22impl TruthInferenceEngine {
23 pub fn new(min_occurrences: u32, client_id: Option<String>) -> Self {
25 let mut known_flags = HashMap::new();
26
27 known_flags.insert("pm2 logs".to_string(), "--nostream".to_string());
29 known_flags.insert("docker logs".to_string(), "--follow=false".to_string());
30 known_flags.insert("tail -f".to_string(), "tail -n".to_string());
31 known_flags.insert("watch".to_string(), "-n 1 -e".to_string());
32
33 Self {
34 known_nonblocking_flags: known_flags,
35 min_occurrences,
36 client_id,
37 }
38 }
39
40 pub fn infer_from_failure(&self, pattern: &FailurePattern) -> Option<BehavioralTruth> {
42 if pattern.occurrences < self.min_occurrences {
43 return None;
44 }
45
46 if let Some(fix) = self.find_known_fix(&pattern.pattern) {
48 return Some(self.create_command_fix_truth(pattern, &fix));
49 }
50
51 if self.looks_like_blocking(&pattern.error_pattern) {
53 return Some(self.create_blocking_warning_truth(pattern));
54 }
55
56 Some(self.create_generic_failure_truth(pattern))
58 }
59
60 fn find_known_fix(&self, pattern: &str) -> Option<String> {
62 for (cmd, fix) in &self.known_nonblocking_flags {
63 if pattern.to_lowercase().contains(&cmd.to_lowercase()) {
64 return Some(fix.clone());
65 }
66 }
67 None
68 }
69
70 fn looks_like_blocking(&self, error_pattern: &Option<String>) -> bool {
72 if let Some(error) = error_pattern {
73 let error_lower = error.to_lowercase();
74 error_lower.contains("timeout")
75 || error_lower.contains("block")
76 || error_lower.contains("hang")
77 || error_lower.contains("wait")
78 || error_lower.contains("stuck")
79 } else {
80 false
81 }
82 }
83
84 fn create_command_fix_truth(&self, pattern: &FailurePattern, fix: &str) -> BehavioralTruth {
86 let rule = format!(
87 "Use '{}' flag with '{}' to avoid blocking",
88 fix, pattern.pattern
89 );
90
91 let rationale = format!(
92 "'{}' without '{}' can block indefinitely. Detected {} failures.",
93 pattern.pattern, fix, pattern.occurrences
94 );
95
96 BehavioralTruth::new(
97 TruthCategory::CommandUsage,
98 pattern.pattern.clone(),
99 rule,
100 rationale,
101 TruthSource::FailurePattern,
102 self.client_id.clone(),
103 )
104 }
105
106 fn create_blocking_warning_truth(&self, pattern: &FailurePattern) -> BehavioralTruth {
108 let rule = format!(
109 "'{}' may block or timeout - consider using a non-blocking alternative or spawning a monitor",
110 pattern.pattern
111 );
112
113 let rationale = format!(
114 "Detected {} timeout/blocking failures with '{}'",
115 pattern.occurrences, pattern.pattern
116 );
117
118 BehavioralTruth::new(
119 TruthCategory::PatternAvoidance,
120 pattern.pattern.clone(),
121 rule,
122 rationale,
123 TruthSource::FailurePattern,
124 self.client_id.clone(),
125 )
126 }
127
128 fn create_generic_failure_truth(&self, pattern: &FailurePattern) -> BehavioralTruth {
130 let error_info = pattern
131 .error_pattern
132 .as_ref()
133 .map(|e| format!(" (error: {})", truncate(e, 50)))
134 .unwrap_or_default();
135
136 let rule = format!(
137 "'{}' frequently fails{} - consider alternatives",
138 pattern.pattern, error_info
139 );
140
141 let rationale = format!(
142 "Detected {} failures with '{}' across {} contexts",
143 pattern.occurrences,
144 pattern.pattern,
145 pattern.contexts.len()
146 );
147
148 BehavioralTruth::new(
149 TruthCategory::PatternAvoidance,
150 pattern.pattern.clone(),
151 rule,
152 rationale,
153 TruthSource::FailurePattern,
154 self.client_id.clone(),
155 )
156 }
157
158 pub fn infer_category_from_correction(
160 &self,
161 context: &str,
162 wrong: &str,
163 right: &str,
164 ) -> TruthCategory {
165 let combined = format!("{} {} {}", context, wrong, right).to_lowercase();
166
167 if combined.contains("spawn") || combined.contains("agent") || combined.contains("monitor")
169 {
170 TruthCategory::TaskStrategy
171 } else if combined.contains("--") || combined.contains("flag") {
172 TruthCategory::CommandUsage
173 } else if combined.contains("error")
174 || combined.contains("fail")
175 || combined.contains("retry")
176 {
177 TruthCategory::ErrorRecovery
178 } else if combined.contains("context")
179 || combined.contains("token")
180 || combined.contains("parallel")
181 {
182 TruthCategory::ResourceManagement
183 } else if combined.contains("don't")
184 || combined.contains("avoid")
185 || combined.contains("never")
186 {
187 TruthCategory::PatternAvoidance
188 } else {
189 TruthCategory::ToolBehavior
190 }
191 }
192
193 pub fn create_correction_truth(
195 &self,
196 context: &str,
197 wrong: &str,
198 right: &str,
199 ) -> BehavioralTruth {
200 let category = self.infer_category_from_correction(context, wrong, right);
201
202 let rule = format!(
203 "Instead of '{}', use '{}'",
204 truncate(wrong, 50),
205 truncate(right, 50)
206 );
207
208 let rationale = format!(
209 "User corrected behavior in context: {}",
210 truncate(context, 100)
211 );
212
213 BehavioralTruth::new(
214 category,
215 context.to_string(),
216 rule,
217 rationale,
218 TruthSource::ConversationCorrection,
219 self.client_id.clone(),
220 )
221 }
222
223 pub fn create_explicit_truth(
225 &self,
226 rule: &str,
227 rationale: Option<&str>,
228 category: TruthCategory,
229 context: Option<&str>,
230 ) -> BehavioralTruth {
231 let context_pattern = context
232 .map(|c| c.to_string())
233 .unwrap_or_else(|| extract_context_from_rule(rule));
234
235 let rationale = rationale
236 .map(|r| r.to_string())
237 .unwrap_or_else(|| "Explicitly taught by user".to_string());
238
239 BehavioralTruth::new(
240 category,
241 context_pattern,
242 rule.to_string(),
243 rationale,
244 TruthSource::ExplicitCommand,
245 self.client_id.clone(),
246 )
247 }
248
249 pub fn should_merge(&self, existing: &BehavioralTruth, new: &BehavioralTruth) -> bool {
251 if existing.category != new.category {
253 return false;
254 }
255
256 let context_similarity = jaccard_similarity(
258 &existing.context_pattern.to_lowercase(),
259 &new.context_pattern.to_lowercase(),
260 );
261
262 let rule_similarity =
264 jaccard_similarity(&existing.rule.to_lowercase(), &new.rule.to_lowercase());
265
266 context_similarity > 0.5 && rule_similarity > 0.3
267 }
268
269 pub fn merge_truths(&self, existing: &mut BehavioralTruth, new: &BehavioralTruth) {
271 existing.reinforcements += new.reinforcements;
273 existing.contradictions += new.contradictions;
274
275 existing.confidence = 0.7 * existing.confidence + 0.3 * new.confidence;
277
278 if new.last_used > existing.last_used {
280 existing.last_used = new.last_used;
281 }
282
283 existing.version += 1;
284 }
285}
286
287fn extract_context_from_rule(rule: &str) -> String {
289 if let Some(start) = rule.find('\'')
291 && let Some(end) = rule[start + 1..].find('\'')
292 {
293 return rule[start + 1..start + 1 + end].to_string();
294 }
295
296 let words: Vec<&str> = rule.split_whitespace().collect();
298
299 for (i, word) in words.iter().enumerate() {
300 if [
302 "use", "with", "the", "a", "to", "for", "when", "if", "instead", "of",
303 ]
304 .contains(&word.to_lowercase().as_str())
305 {
306 continue;
307 }
308
309 if word
311 .chars()
312 .all(|c| c.is_alphanumeric() || c == '-' || c == '_')
313 {
314 if i + 1 < words.len() {
315 return format!("{} {}", word, words[i + 1]);
316 }
317 return word.to_string();
318 }
319 }
320
321 words.iter().take(3).cloned().collect::<Vec<_>>().join(" ")
323}
324
325fn jaccard_similarity(a: &str, b: &str) -> f64 {
327 let words_a: std::collections::HashSet<&str> = a.split_whitespace().collect();
328 let words_b: std::collections::HashSet<&str> = b.split_whitespace().collect();
329
330 if words_a.is_empty() && words_b.is_empty() {
331 return 1.0;
332 }
333
334 let intersection = words_a.intersection(&words_b).count();
335 let union = words_a.union(&words_b).count();
336
337 if union == 0 {
338 0.0
339 } else {
340 intersection as f64 / union as f64
341 }
342}
343
344fn truncate(s: &str, max_len: usize) -> String {
346 if s.len() <= max_len {
347 s.to_string()
348 } else {
349 format!("{}...", &s[..max_len.saturating_sub(3)])
350 }
351}
352
353#[cfg(test)]
354mod tests {
355 use super::*;
356
357 #[test]
358 fn test_infer_from_known_failure() {
359 let engine = TruthInferenceEngine::new(3, None);
360
361 let pattern = FailurePattern {
362 pattern: "pm2 logs myapp".to_string(),
363 error_pattern: Some("timeout".to_string()),
364 occurrences: 5,
365 timestamps: vec![1, 2, 3, 4, 5],
366 contexts: vec!["test".to_string()],
367 };
368
369 let truth = engine.infer_from_failure(&pattern).unwrap();
370 assert!(truth.rule.contains("--nostream"));
371 assert_eq!(truth.category, TruthCategory::CommandUsage);
372 }
373
374 #[test]
375 fn test_infer_from_blocking_failure() {
376 let engine = TruthInferenceEngine::new(3, None);
377
378 let pattern = FailurePattern {
379 pattern: "some-command".to_string(),
380 error_pattern: Some("connection timeout after 30s".to_string()),
381 occurrences: 3,
382 timestamps: vec![1, 2, 3],
383 contexts: vec!["test".to_string()],
384 };
385
386 let truth = engine.infer_from_failure(&pattern).unwrap();
387 assert!(truth.rule.contains("block") || truth.rule.contains("timeout"));
388 }
389
390 #[test]
391 fn test_category_inference() {
392 let engine = TruthInferenceEngine::new(3, None);
393
394 assert_eq!(
395 engine.infer_category_from_correction("task", "poll inline", "spawn agent"),
396 TruthCategory::TaskStrategy
397 );
398
399 assert_eq!(
400 engine.infer_category_from_correction("pm2", "logs", "--nostream flag"),
401 TruthCategory::CommandUsage
402 );
403 }
404
405 #[test]
406 fn test_jaccard_similarity() {
407 assert_eq!(jaccard_similarity("a b c", "a b c"), 1.0);
408 assert_eq!(jaccard_similarity("a b c", "d e f"), 0.0);
409 assert!((jaccard_similarity("a b c", "a b d") - 0.5).abs() < 0.01);
410 }
411
412 #[test]
413 fn test_extract_context() {
414 assert_eq!(
415 extract_context_from_rule("Use '--nostream' with pm2 logs"),
416 "--nostream"
417 );
418 assert_eq!(
419 extract_context_from_rule("cargo build should use cargo-watch"),
420 "cargo build"
421 );
422 }
423
424 #[test]
425 fn test_should_merge() {
426 let engine = TruthInferenceEngine::new(3, None);
427
428 let truth1 = BehavioralTruth::new(
429 TruthCategory::CommandUsage,
430 "pm2 logs".to_string(),
431 "Use --nostream flag".to_string(),
432 "Avoids blocking".to_string(),
433 TruthSource::ExplicitCommand,
434 None,
435 );
436
437 let truth2 = BehavioralTruth::new(
438 TruthCategory::CommandUsage,
439 "pm2 logs app".to_string(),
440 "Use --nostream flag to avoid blocking".to_string(),
441 "Different rationale".to_string(),
442 TruthSource::FailurePattern,
443 None,
444 );
445
446 assert!(engine.should_merge(&truth1, &truth2));
447
448 let truth3 = BehavioralTruth::new(
449 TruthCategory::TaskStrategy,
450 "something else".to_string(),
451 "Different rule entirely".to_string(),
452 "Different".to_string(),
453 TruthSource::ExplicitCommand,
454 None,
455 );
456
457 assert!(!engine.should_merge(&truth1, &truth3));
458 }
459}