Skip to main content

lean_ctx/core/
active_inference.rs

1//! Active Inference Preload — predictive context loading.
2//!
3//! Uses the agent's task description and recent interactions to predict
4//! which providers and resources will be needed next, then preloads them
5//! into the session cache before the agent asks.
6//!
7//! Scientific basis: Active Inference (Friston 2010; Parr, Pezzulo & Friston 2022).
8//! The system acts to reduce expected surprise by preloading context that
9//! minimizes the predicted free energy of future queries.
10//!
11//! Strategy:
12//!   1. Parse task keywords → predict relevant provider actions
13//!   2. Score predictions using the provider bandit
14//!   3. Preload top-k predictions into session cache
15
16use crate::core::provider_bandit::ProviderBandit;
17
18/// A predicted preload action.
19#[derive(Debug, Clone)]
20pub struct PreloadPrediction {
21    pub provider_id: String,
22    pub action: String,
23    pub confidence: f64,
24    pub reason: String,
25}
26
27/// Keyword → provider action mappings.
28static KEYWORD_MAPPINGS: &[(&[&str], &str, &str)] = &[
29    (
30        &["bug", "error", "crash", "fix", "broken", "issue", "defect"],
31        "github",
32        "issues",
33    ),
34    (
35        &["bug", "error", "crash", "fix", "broken", "issue", "defect"],
36        "jira",
37        "issues",
38    ),
39    (
40        &["pr", "pull", "merge", "review", "branch"],
41        "github",
42        "pull_requests",
43    ),
44    (
45        &[
46            "database",
47            "table",
48            "schema",
49            "column",
50            "migration",
51            "sql",
52            "db",
53        ],
54        "postgres",
55        "schemas",
56    ),
57    (
58        &["sprint", "story", "epic", "velocity", "backlog"],
59        "jira",
60        "sprints",
61    ),
62    (
63        &["wiki", "doc", "documentation", "guide", "howto"],
64        "github",
65        "issues",
66    ),
67];
68
69/// Predict which provider actions should be preloaded based on the task.
70pub fn predict_preloads(
71    task_description: &str,
72    available_providers: &[String],
73    bandit: &mut ProviderBandit,
74    max_predictions: usize,
75) -> Vec<PreloadPrediction> {
76    let task_lower = task_description.to_lowercase();
77    let task_words: Vec<&str> = task_lower.split_whitespace().collect();
78
79    let mut predictions: Vec<PreloadPrediction> = Vec::new();
80
81    for &(keywords, provider, action) in KEYWORD_MAPPINGS {
82        if !available_providers.iter().any(|p| p == provider) {
83            continue;
84        }
85
86        let matching_keywords: Vec<&&str> = keywords
87            .iter()
88            .filter(|kw| task_words.iter().any(|tw| tw.contains(*kw)))
89            .collect();
90
91        if matching_keywords.is_empty() {
92            continue;
93        }
94
95        let keyword_confidence = matching_keywords.len() as f64 / keywords.len() as f64;
96
97        let task_type = infer_task_type(&task_lower);
98        let bandit_score = bandit.estimated_probability(&task_type, provider);
99
100        let combined = 0.6 * keyword_confidence + 0.4 * bandit_score;
101
102        if !predictions
103            .iter()
104            .any(|p| p.provider_id == provider && p.action == action)
105        {
106            predictions.push(PreloadPrediction {
107                provider_id: provider.to_string(),
108                action: action.to_string(),
109                confidence: combined,
110                reason: format!(
111                    "keywords: {}",
112                    matching_keywords
113                        .iter()
114                        .map(|k| **k)
115                        .collect::<Vec<_>>()
116                        .join(", ")
117                ),
118            });
119        }
120    }
121
122    predictions.sort_by(|a, b| {
123        b.confidence
124            .partial_cmp(&a.confidence)
125            .unwrap_or(std::cmp::Ordering::Equal)
126    });
127    predictions.truncate(max_predictions);
128    predictions
129}
130
131/// Simple task type inference from keywords.
132fn infer_task_type(task: &str) -> String {
133    if task.contains("bug")
134        || task.contains("fix")
135        || task.contains("error")
136        || task.contains("crash")
137    {
138        "bugfix".into()
139    } else if task.contains("feature") || task.contains("add") || task.contains("implement") {
140        "feature".into()
141    } else if task.contains("refactor") || task.contains("clean") || task.contains("improve") {
142        "refactor".into()
143    } else {
144        "general".into()
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151
152    #[test]
153    fn predict_bug_fix_suggests_issues() {
154        let mut bandit = ProviderBandit::new();
155        let providers = vec!["github".into(), "jira".into()];
156
157        let predictions = predict_preloads(
158            "Fix the authentication bug in the login flow",
159            &providers,
160            &mut bandit,
161            5,
162        );
163
164        assert!(!predictions.is_empty());
165        assert!(predictions
166            .iter()
167            .any(|p| p.provider_id == "github" && p.action == "issues"));
168    }
169
170    #[test]
171    fn predict_db_task_suggests_schemas() {
172        let mut bandit = ProviderBandit::new();
173        let providers = vec!["postgres".into()];
174
175        let predictions = predict_preloads(
176            "Add a new column to the users database table",
177            &providers,
178            &mut bandit,
179            5,
180        );
181
182        assert!(predictions
183            .iter()
184            .any(|p| p.provider_id == "postgres" && p.action == "schemas"));
185    }
186
187    #[test]
188    fn predict_pr_review_suggests_pull_requests() {
189        let mut bandit = ProviderBandit::new();
190        let providers = vec!["github".into()];
191
192        let predictions = predict_preloads(
193            "Review the open pull requests and merge the approved ones",
194            &providers,
195            &mut bandit,
196            5,
197        );
198
199        assert!(predictions
200            .iter()
201            .any(|p| p.provider_id == "github" && p.action == "pull_requests"));
202    }
203
204    #[test]
205    fn predict_empty_task_returns_empty() {
206        let mut bandit = ProviderBandit::new();
207        let predictions = predict_preloads("", &["github".into()], &mut bandit, 5);
208        assert!(predictions.is_empty());
209    }
210
211    #[test]
212    fn predict_unavailable_provider_skipped() {
213        let mut bandit = ProviderBandit::new();
214        let predictions = predict_preloads(
215            "Fix the database schema migration",
216            &["github".into()], // postgres not available
217            &mut bandit,
218            5,
219        );
220
221        assert!(!predictions.iter().any(|p| p.provider_id == "postgres"));
222    }
223
224    #[test]
225    fn predict_respects_max_predictions() {
226        let mut bandit = ProviderBandit::new();
227        let providers = vec!["github".into(), "jira".into(), "postgres".into()];
228
229        let predictions = predict_preloads(
230            "Fix the bug in database schema and review pull requests",
231            &providers,
232            &mut bandit,
233            2,
234        );
235
236        assert!(predictions.len() <= 2);
237    }
238
239    #[test]
240    fn predict_bandit_trained_boosts_confidence() {
241        let mut bandit = ProviderBandit::new();
242        for _ in 0..20 {
243            bandit.update("bugfix", "github", true);
244            bandit.update("bugfix", "jira", false);
245        }
246
247        let providers = vec!["github".into(), "jira".into()];
248        let predictions = predict_preloads(
249            "Fix the crash bug in authentication",
250            &providers,
251            &mut bandit,
252            5,
253        );
254
255        let gh = predictions
256            .iter()
257            .find(|p| p.provider_id == "github" && p.action == "issues");
258        let jira = predictions
259            .iter()
260            .find(|p| p.provider_id == "jira" && p.action == "issues");
261
262        if let (Some(gh), Some(jira)) = (gh, jira) {
263            assert!(
264                gh.confidence > jira.confidence,
265                "Trained bandit should boost github over jira"
266            );
267        }
268    }
269
270    #[test]
271    fn infer_task_type_correctness() {
272        assert_eq!(infer_task_type("fix the crash bug"), "bugfix");
273        assert_eq!(infer_task_type("add new feature"), "feature");
274        assert_eq!(infer_task_type("refactor the auth module"), "refactor");
275        assert_eq!(infer_task_type("update documentation"), "general");
276    }
277}