Skip to main content

lean_ctx/core/
active_inference.rs

1//! Active Inference Preload — predictive context loading.
2//!
3//! Uses the agent's task description and recent interactions to predict
4//! which providers and resources will be needed next, then preloads them
5//! into the session cache before the agent asks.
6//!
7//! Scientific basis: Active Inference (Friston 2010; Parr, Pezzulo & Friston 2022).
8//! The system acts to reduce expected surprise by preloading context that
9//! minimizes the predicted free energy of future queries.
10//!
11//! Strategy:
12//!   1. Parse task keywords → predict relevant provider actions
13//!   2. Score predictions using the provider bandit
14//!   3. Preload top-k predictions into session cache
15
16use crate::core::provider_bandit::ProviderBandit;
17
18/// A predicted preload action.
19#[derive(Debug, Clone)]
20pub struct PreloadPrediction {
21    pub provider_id: String,
22    pub action: String,
23    pub confidence: f64,
24    pub reason: String,
25}
26
27/// Keyword → provider action mappings.
28static KEYWORD_MAPPINGS: &[(&[&str], &str, &str)] = &[
29    (
30        &["bug", "error", "crash", "fix", "broken", "issue", "defect"],
31        "github",
32        "issues",
33    ),
34    (
35        &["bug", "error", "crash", "fix", "broken", "issue", "defect"],
36        "jira",
37        "issues",
38    ),
39    (
40        &["pr", "pull", "merge", "review", "branch"],
41        "github",
42        "pull_requests",
43    ),
44    (
45        &[
46            "database",
47            "table",
48            "schema",
49            "column",
50            "migration",
51            "sql",
52            "db",
53        ],
54        "postgres",
55        "schemas",
56    ),
57    (
58        &["sprint", "story", "epic", "velocity", "backlog"],
59        "jira",
60        "sprints",
61    ),
62    (
63        &["wiki", "doc", "documentation", "guide", "howto"],
64        "github",
65        "issues",
66    ),
67];
68
69/// Predict which provider actions should be preloaded based on the task.
70pub fn predict_preloads(
71    task_description: &str,
72    available_providers: &[String],
73    bandit: &mut ProviderBandit,
74    max_predictions: usize,
75) -> Vec<PreloadPrediction> {
76    let task_lower = task_description.to_lowercase();
77    let task_words: Vec<&str> = task_lower.split_whitespace().collect();
78
79    let mut predictions: Vec<PreloadPrediction> = Vec::new();
80
81    for &(keywords, provider, action) in KEYWORD_MAPPINGS {
82        if !available_providers.iter().any(|p| p == provider) {
83            continue;
84        }
85
86        let matching_keywords: Vec<&&str> = keywords
87            .iter()
88            .filter(|kw| task_words.iter().any(|tw| tw.contains(*kw)))
89            .collect();
90
91        if matching_keywords.is_empty() {
92            continue;
93        }
94
95        let keyword_confidence = matching_keywords.len() as f64 / keywords.len() as f64;
96
97        let task_type = infer_task_type(&task_lower);
98        let bandit_score = bandit.estimated_probability(&task_type, provider);
99
100        let combined = 0.6 * keyword_confidence + 0.4 * bandit_score;
101
102        if !predictions
103            .iter()
104            .any(|p| p.provider_id == provider && p.action == action)
105        {
106            predictions.push(PreloadPrediction {
107                provider_id: provider.to_string(),
108                action: action.to_string(),
109                confidence: combined,
110                reason: format!(
111                    "keywords: {}",
112                    matching_keywords
113                        .iter()
114                        .map(|k| **k)
115                        .collect::<Vec<_>>()
116                        .join(", ")
117                ),
118            });
119        }
120    }
121
122    predictions.sort_by(|a, b| {
123        b.confidence
124            .partial_cmp(&a.confidence)
125            .unwrap_or(std::cmp::Ordering::Equal)
126    });
127    predictions.truncate(max_predictions);
128    predictions
129}
130
131/// Simple task type inference from keywords. Public so the preload feedback loop
132/// can bucket outcomes by the same task type the prediction was scored under.
133#[must_use]
134pub fn infer_task_type(task: &str) -> String {
135    if task.contains("bug")
136        || task.contains("fix")
137        || task.contains("error")
138        || task.contains("crash")
139    {
140        "bugfix".into()
141    } else if task.contains("feature") || task.contains("add") || task.contains("implement") {
142        "feature".into()
143    } else if task.contains("refactor") || task.contains("clean") || task.contains("improve") {
144        "refactor".into()
145    } else {
146        "general".into()
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153
154    #[test]
155    fn predict_bug_fix_suggests_issues() {
156        let mut bandit = ProviderBandit::new();
157        let providers = vec!["github".into(), "jira".into()];
158
159        let predictions = predict_preloads(
160            "Fix the authentication bug in the login flow",
161            &providers,
162            &mut bandit,
163            5,
164        );
165
166        assert!(!predictions.is_empty());
167        assert!(predictions
168            .iter()
169            .any(|p| p.provider_id == "github" && p.action == "issues"));
170    }
171
172    #[test]
173    fn predict_db_task_suggests_schemas() {
174        let mut bandit = ProviderBandit::new();
175        let providers = vec!["postgres".into()];
176
177        let predictions = predict_preloads(
178            "Add a new column to the users database table",
179            &providers,
180            &mut bandit,
181            5,
182        );
183
184        assert!(predictions
185            .iter()
186            .any(|p| p.provider_id == "postgres" && p.action == "schemas"));
187    }
188
189    #[test]
190    fn predict_pr_review_suggests_pull_requests() {
191        let mut bandit = ProviderBandit::new();
192        let providers = vec!["github".into()];
193
194        let predictions = predict_preloads(
195            "Review the open pull requests and merge the approved ones",
196            &providers,
197            &mut bandit,
198            5,
199        );
200
201        assert!(predictions
202            .iter()
203            .any(|p| p.provider_id == "github" && p.action == "pull_requests"));
204    }
205
206    #[test]
207    fn predict_empty_task_returns_empty() {
208        let mut bandit = ProviderBandit::new();
209        let predictions = predict_preloads("", &["github".into()], &mut bandit, 5);
210        assert!(predictions.is_empty());
211    }
212
213    #[test]
214    fn predict_unavailable_provider_skipped() {
215        let mut bandit = ProviderBandit::new();
216        let predictions = predict_preloads(
217            "Fix the database schema migration",
218            &["github".into()], // postgres not available
219            &mut bandit,
220            5,
221        );
222
223        assert!(!predictions.iter().any(|p| p.provider_id == "postgres"));
224    }
225
226    #[test]
227    fn predict_respects_max_predictions() {
228        let mut bandit = ProviderBandit::new();
229        let providers = vec!["github".into(), "jira".into(), "postgres".into()];
230
231        let predictions = predict_preloads(
232            "Fix the bug in database schema and review pull requests",
233            &providers,
234            &mut bandit,
235            2,
236        );
237
238        assert!(predictions.len() <= 2);
239    }
240
241    #[test]
242    fn predict_bandit_trained_boosts_confidence() {
243        let mut bandit = ProviderBandit::new();
244        for _ in 0..20 {
245            bandit.update("bugfix", "github", true);
246            bandit.update("bugfix", "jira", false);
247        }
248
249        let providers = vec!["github".into(), "jira".into()];
250        let predictions = predict_preloads(
251            "Fix the crash bug in authentication",
252            &providers,
253            &mut bandit,
254            5,
255        );
256
257        let gh = predictions
258            .iter()
259            .find(|p| p.provider_id == "github" && p.action == "issues");
260        let jira = predictions
261            .iter()
262            .find(|p| p.provider_id == "jira" && p.action == "issues");
263
264        if let (Some(gh), Some(jira)) = (gh, jira) {
265            assert!(
266                gh.confidence > jira.confidence,
267                "Trained bandit should boost github over jira"
268            );
269        }
270    }
271
272    #[test]
273    fn infer_task_type_correctness() {
274        assert_eq!(infer_task_type("fix the crash bug"), "bugfix");
275        assert_eq!(infer_task_type("add new feature"), "feature");
276        assert_eq!(infer_task_type("refactor the auth module"), "refactor");
277        assert_eq!(infer_task_type("update documentation"), "general");
278    }
279}