Skip to main content

brainwires_knowledge/knowledge/bks_pks/
inference.rs

1//! Truth inference engine
2//!
3//! Converts patterns and signals into behavioral truths.
4//! Handles deduplication, merging, and confidence scoring.
5
6use super::collector::FailurePattern;
7use super::truth::{BehavioralTruth, TruthCategory, TruthSource};
8use std::collections::HashMap;
9
10/// Engine for inferring truths from patterns
11pub struct TruthInferenceEngine {
12    /// Known command flags that resolve blocking issues
13    known_nonblocking_flags: HashMap<String, String>,
14
15    /// Minimum occurrences to infer a pattern
16    min_occurrences: u32,
17
18    /// Client ID for provenance
19    client_id: Option<String>,
20}
21
22impl TruthInferenceEngine {
23    /// Create a new inference engine
24    pub fn new(min_occurrences: u32, client_id: Option<String>) -> Self {
25        let mut known_flags = HashMap::new();
26
27        // Known fixes for common blocking commands
28        known_flags.insert("pm2 logs".to_string(), "--nostream".to_string());
29        known_flags.insert("docker logs".to_string(), "--follow=false".to_string());
30        known_flags.insert("tail -f".to_string(), "tail -n".to_string());
31        known_flags.insert("watch".to_string(), "-n 1 -e".to_string());
32
33        Self {
34            known_nonblocking_flags: known_flags,
35            min_occurrences,
36            client_id,
37        }
38    }
39
40    /// Infer truth from a failure pattern
41    pub fn infer_from_failure(&self, pattern: &FailurePattern) -> Option<BehavioralTruth> {
42        if pattern.occurrences < self.min_occurrences {
43            return None;
44        }
45
46        // Check if we have a known fix
47        if let Some(fix) = self.find_known_fix(&pattern.pattern) {
48            return Some(self.create_command_fix_truth(pattern, &fix));
49        }
50
51        // Check if this looks like a timeout/blocking issue
52        if self.looks_like_blocking(&pattern.error_pattern) {
53            return Some(self.create_blocking_warning_truth(pattern));
54        }
55
56        // Generic failure pattern
57        Some(self.create_generic_failure_truth(pattern))
58    }
59
60    /// Find a known fix for a command pattern
61    fn find_known_fix(&self, pattern: &str) -> Option<String> {
62        for (cmd, fix) in &self.known_nonblocking_flags {
63            if pattern.to_lowercase().contains(&cmd.to_lowercase()) {
64                return Some(fix.clone());
65            }
66        }
67        None
68    }
69
70    /// Check if error looks like blocking/timeout
71    fn looks_like_blocking(&self, error_pattern: &Option<String>) -> bool {
72        if let Some(error) = error_pattern {
73            let error_lower = error.to_lowercase();
74            error_lower.contains("timeout")
75                || error_lower.contains("block")
76                || error_lower.contains("hang")
77                || error_lower.contains("wait")
78                || error_lower.contains("stuck")
79        } else {
80            false
81        }
82    }
83
84    /// Create a truth for a known command fix
85    fn create_command_fix_truth(&self, pattern: &FailurePattern, fix: &str) -> BehavioralTruth {
86        let rule = format!(
87            "Use '{}' flag with '{}' to avoid blocking",
88            fix, pattern.pattern
89        );
90
91        let rationale = format!(
92            "'{}' without '{}' can block indefinitely. Detected {} failures.",
93            pattern.pattern, fix, pattern.occurrences
94        );
95
96        BehavioralTruth::new(
97            TruthCategory::CommandUsage,
98            pattern.pattern.clone(),
99            rule,
100            rationale,
101            TruthSource::FailurePattern,
102            self.client_id.clone(),
103        )
104    }
105
106    /// Create a truth warning about blocking behavior
107    fn create_blocking_warning_truth(&self, pattern: &FailurePattern) -> BehavioralTruth {
108        let rule = format!(
109            "'{}' may block or timeout - consider using a non-blocking alternative or spawning a monitor",
110            pattern.pattern
111        );
112
113        let rationale = format!(
114            "Detected {} timeout/blocking failures with '{}'",
115            pattern.occurrences, pattern.pattern
116        );
117
118        BehavioralTruth::new(
119            TruthCategory::PatternAvoidance,
120            pattern.pattern.clone(),
121            rule,
122            rationale,
123            TruthSource::FailurePattern,
124            self.client_id.clone(),
125        )
126    }
127
128    /// Create a generic failure truth
129    fn create_generic_failure_truth(&self, pattern: &FailurePattern) -> BehavioralTruth {
130        let error_info = pattern
131            .error_pattern
132            .as_ref()
133            .map(|e| format!(" (error: {})", truncate(e, 50)))
134            .unwrap_or_default();
135
136        let rule = format!(
137            "'{}' frequently fails{} - consider alternatives",
138            pattern.pattern, error_info
139        );
140
141        let rationale = format!(
142            "Detected {} failures with '{}' across {} contexts",
143            pattern.occurrences,
144            pattern.pattern,
145            pattern.contexts.len()
146        );
147
148        BehavioralTruth::new(
149            TruthCategory::PatternAvoidance,
150            pattern.pattern.clone(),
151            rule,
152            rationale,
153            TruthSource::FailurePattern,
154            self.client_id.clone(),
155        )
156    }
157
158    /// Infer category from correction context
159    pub fn infer_category_from_correction(
160        &self,
161        context: &str,
162        wrong: &str,
163        right: &str,
164    ) -> TruthCategory {
165        let combined = format!("{} {} {}", context, wrong, right).to_lowercase();
166
167        // Check for specific patterns
168        if combined.contains("spawn") || combined.contains("agent") || combined.contains("monitor")
169        {
170            TruthCategory::TaskStrategy
171        } else if combined.contains("--") || combined.contains("flag") {
172            TruthCategory::CommandUsage
173        } else if combined.contains("error")
174            || combined.contains("fail")
175            || combined.contains("retry")
176        {
177            TruthCategory::ErrorRecovery
178        } else if combined.contains("context")
179            || combined.contains("token")
180            || combined.contains("parallel")
181        {
182            TruthCategory::ResourceManagement
183        } else if combined.contains("don't")
184            || combined.contains("avoid")
185            || combined.contains("never")
186        {
187            TruthCategory::PatternAvoidance
188        } else {
189            TruthCategory::ToolBehavior
190        }
191    }
192
193    /// Create a truth from a correction
194    pub fn create_correction_truth(
195        &self,
196        context: &str,
197        wrong: &str,
198        right: &str,
199    ) -> BehavioralTruth {
200        let category = self.infer_category_from_correction(context, wrong, right);
201
202        let rule = format!(
203            "Instead of '{}', use '{}'",
204            truncate(wrong, 50),
205            truncate(right, 50)
206        );
207
208        let rationale = format!(
209            "User corrected behavior in context: {}",
210            truncate(context, 100)
211        );
212
213        BehavioralTruth::new(
214            category,
215            context.to_string(),
216            rule,
217            rationale,
218            TruthSource::ConversationCorrection,
219            self.client_id.clone(),
220        )
221    }
222
223    /// Create a truth from explicit teaching
224    pub fn create_explicit_truth(
225        &self,
226        rule: &str,
227        rationale: Option<&str>,
228        category: TruthCategory,
229        context: Option<&str>,
230    ) -> BehavioralTruth {
231        let context_pattern = context
232            .map(|c| c.to_string())
233            .unwrap_or_else(|| extract_context_from_rule(rule));
234
235        let rationale = rationale
236            .map(|r| r.to_string())
237            .unwrap_or_else(|| "Explicitly taught by user".to_string());
238
239        BehavioralTruth::new(
240            category,
241            context_pattern,
242            rule.to_string(),
243            rationale,
244            TruthSource::ExplicitCommand,
245            self.client_id.clone(),
246        )
247    }
248
249    /// Check if two truths are similar enough to merge
250    pub fn should_merge(&self, existing: &BehavioralTruth, new: &BehavioralTruth) -> bool {
251        // Same category
252        if existing.category != new.category {
253            return false;
254        }
255
256        // Similar context patterns
257        let context_similarity = jaccard_similarity(
258            &existing.context_pattern.to_lowercase(),
259            &new.context_pattern.to_lowercase(),
260        );
261
262        // Similar rules
263        let rule_similarity =
264            jaccard_similarity(&existing.rule.to_lowercase(), &new.rule.to_lowercase());
265
266        context_similarity > 0.5 && rule_similarity > 0.3
267    }
268
269    /// Merge a new truth into an existing one
270    pub fn merge_truths(&self, existing: &mut BehavioralTruth, new: &BehavioralTruth) {
271        // Combine reinforcements
272        existing.reinforcements += new.reinforcements;
273        existing.contradictions += new.contradictions;
274
275        // Average confidence with bias toward newer
276        existing.confidence = 0.7 * existing.confidence + 0.3 * new.confidence;
277
278        // Update timestamp
279        if new.last_used > existing.last_used {
280            existing.last_used = new.last_used;
281        }
282
283        existing.version += 1;
284    }
285}
286
287/// Extract context from a rule string
288fn extract_context_from_rule(rule: &str) -> String {
289    // Look for quoted strings first
290    if let Some(start) = rule.find('\'')
291        && let Some(end) = rule[start + 1..].find('\'')
292    {
293        return rule[start + 1..start + 1 + end].to_string();
294    }
295
296    // Look for command-like patterns
297    let words: Vec<&str> = rule.split_whitespace().collect();
298
299    for (i, word) in words.iter().enumerate() {
300        // Skip common words
301        if [
302            "use", "with", "the", "a", "to", "for", "when", "if", "instead", "of",
303        ]
304        .contains(&word.to_lowercase().as_str())
305        {
306            continue;
307        }
308
309        // Found a potential command
310        if word
311            .chars()
312            .all(|c| c.is_alphanumeric() || c == '-' || c == '_')
313        {
314            if i + 1 < words.len() {
315                return format!("{} {}", word, words[i + 1]);
316            }
317            return word.to_string();
318        }
319    }
320
321    // Fallback: first few words
322    words.iter().take(3).cloned().collect::<Vec<_>>().join(" ")
323}
324
325/// Calculate Jaccard similarity between two strings (word-based)
326fn jaccard_similarity(a: &str, b: &str) -> f64 {
327    let words_a: std::collections::HashSet<&str> = a.split_whitespace().collect();
328    let words_b: std::collections::HashSet<&str> = b.split_whitespace().collect();
329
330    if words_a.is_empty() && words_b.is_empty() {
331        return 1.0;
332    }
333
334    let intersection = words_a.intersection(&words_b).count();
335    let union = words_a.union(&words_b).count();
336
337    if union == 0 {
338        0.0
339    } else {
340        intersection as f64 / union as f64
341    }
342}
343
344/// Truncate a string
345fn truncate(s: &str, max_len: usize) -> String {
346    if s.len() <= max_len {
347        s.to_string()
348    } else {
349        format!("{}...", &s[..max_len.saturating_sub(3)])
350    }
351}
352
353#[cfg(test)]
354mod tests {
355    use super::*;
356
357    #[test]
358    fn test_infer_from_known_failure() {
359        let engine = TruthInferenceEngine::new(3, None);
360
361        let pattern = FailurePattern {
362            pattern: "pm2 logs myapp".to_string(),
363            error_pattern: Some("timeout".to_string()),
364            occurrences: 5,
365            timestamps: vec![1, 2, 3, 4, 5],
366            contexts: vec!["test".to_string()],
367        };
368
369        let truth = engine.infer_from_failure(&pattern).unwrap();
370        assert!(truth.rule.contains("--nostream"));
371        assert_eq!(truth.category, TruthCategory::CommandUsage);
372    }
373
374    #[test]
375    fn test_infer_from_blocking_failure() {
376        let engine = TruthInferenceEngine::new(3, None);
377
378        let pattern = FailurePattern {
379            pattern: "some-command".to_string(),
380            error_pattern: Some("connection timeout after 30s".to_string()),
381            occurrences: 3,
382            timestamps: vec![1, 2, 3],
383            contexts: vec!["test".to_string()],
384        };
385
386        let truth = engine.infer_from_failure(&pattern).unwrap();
387        assert!(truth.rule.contains("block") || truth.rule.contains("timeout"));
388    }
389
390    #[test]
391    fn test_category_inference() {
392        let engine = TruthInferenceEngine::new(3, None);
393
394        assert_eq!(
395            engine.infer_category_from_correction("task", "poll inline", "spawn agent"),
396            TruthCategory::TaskStrategy
397        );
398
399        assert_eq!(
400            engine.infer_category_from_correction("pm2", "logs", "--nostream flag"),
401            TruthCategory::CommandUsage
402        );
403    }
404
405    #[test]
406    fn test_jaccard_similarity() {
407        assert_eq!(jaccard_similarity("a b c", "a b c"), 1.0);
408        assert_eq!(jaccard_similarity("a b c", "d e f"), 0.0);
409        assert!((jaccard_similarity("a b c", "a b d") - 0.5).abs() < 0.01);
410    }
411
412    #[test]
413    fn test_extract_context() {
414        assert_eq!(
415            extract_context_from_rule("Use '--nostream' with pm2 logs"),
416            "--nostream"
417        );
418        assert_eq!(
419            extract_context_from_rule("cargo build should use cargo-watch"),
420            "cargo build"
421        );
422    }
423
424    #[test]
425    fn test_should_merge() {
426        let engine = TruthInferenceEngine::new(3, None);
427
428        let truth1 = BehavioralTruth::new(
429            TruthCategory::CommandUsage,
430            "pm2 logs".to_string(),
431            "Use --nostream flag".to_string(),
432            "Avoids blocking".to_string(),
433            TruthSource::ExplicitCommand,
434            None,
435        );
436
437        let truth2 = BehavioralTruth::new(
438            TruthCategory::CommandUsage,
439            "pm2 logs app".to_string(),
440            "Use --nostream flag to avoid blocking".to_string(),
441            "Different rationale".to_string(),
442            TruthSource::FailurePattern,
443            None,
444        );
445
446        assert!(engine.should_merge(&truth1, &truth2));
447
448        let truth3 = BehavioralTruth::new(
449            TruthCategory::TaskStrategy,
450            "something else".to_string(),
451            "Different rule entirely".to_string(),
452            "Different".to_string(),
453            TruthSource::ExplicitCommand,
454            None,
455        );
456
457        assert!(!engine.should_merge(&truth1, &truth3));
458    }
459}