1use crate::core::provider_bandit::ProviderBandit;
17
18#[derive(Debug, Clone)]
20pub struct PreloadPrediction {
21 pub provider_id: String,
22 pub action: String,
23 pub confidence: f64,
24 pub reason: String,
25}
26
27static KEYWORD_MAPPINGS: &[(&[&str], &str, &str)] = &[
29 (
30 &["bug", "error", "crash", "fix", "broken", "issue", "defect"],
31 "github",
32 "issues",
33 ),
34 (
35 &["bug", "error", "crash", "fix", "broken", "issue", "defect"],
36 "jira",
37 "issues",
38 ),
39 (
40 &["pr", "pull", "merge", "review", "branch"],
41 "github",
42 "pull_requests",
43 ),
44 (
45 &[
46 "database",
47 "table",
48 "schema",
49 "column",
50 "migration",
51 "sql",
52 "db",
53 ],
54 "postgres",
55 "schemas",
56 ),
57 (
58 &["sprint", "story", "epic", "velocity", "backlog"],
59 "jira",
60 "sprints",
61 ),
62 (
63 &["wiki", "doc", "documentation", "guide", "howto"],
64 "github",
65 "issues",
66 ),
67];
68
69pub fn predict_preloads(
71 task_description: &str,
72 available_providers: &[String],
73 bandit: &mut ProviderBandit,
74 max_predictions: usize,
75) -> Vec<PreloadPrediction> {
76 let task_lower = task_description.to_lowercase();
77 let task_words: Vec<&str> = task_lower.split_whitespace().collect();
78
79 let mut predictions: Vec<PreloadPrediction> = Vec::new();
80
81 for &(keywords, provider, action) in KEYWORD_MAPPINGS {
82 if !available_providers.iter().any(|p| p == provider) {
83 continue;
84 }
85
86 let matching_keywords: Vec<&&str> = keywords
87 .iter()
88 .filter(|kw| task_words.iter().any(|tw| tw.contains(*kw)))
89 .collect();
90
91 if matching_keywords.is_empty() {
92 continue;
93 }
94
95 let keyword_confidence = matching_keywords.len() as f64 / keywords.len() as f64;
96
97 let task_type = infer_task_type(&task_lower);
98 let bandit_score = bandit.estimated_probability(&task_type, provider);
99
100 let combined = 0.6 * keyword_confidence + 0.4 * bandit_score;
101
102 if !predictions
103 .iter()
104 .any(|p| p.provider_id == provider && p.action == action)
105 {
106 predictions.push(PreloadPrediction {
107 provider_id: provider.to_string(),
108 action: action.to_string(),
109 confidence: combined,
110 reason: format!(
111 "keywords: {}",
112 matching_keywords
113 .iter()
114 .map(|k| **k)
115 .collect::<Vec<_>>()
116 .join(", ")
117 ),
118 });
119 }
120 }
121
122 predictions.sort_by(|a, b| {
123 b.confidence
124 .partial_cmp(&a.confidence)
125 .unwrap_or(std::cmp::Ordering::Equal)
126 });
127 predictions.truncate(max_predictions);
128 predictions
129}
130
131fn infer_task_type(task: &str) -> String {
133 if task.contains("bug")
134 || task.contains("fix")
135 || task.contains("error")
136 || task.contains("crash")
137 {
138 "bugfix".into()
139 } else if task.contains("feature") || task.contains("add") || task.contains("implement") {
140 "feature".into()
141 } else if task.contains("refactor") || task.contains("clean") || task.contains("improve") {
142 "refactor".into()
143 } else {
144 "general".into()
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151
152 #[test]
153 fn predict_bug_fix_suggests_issues() {
154 let mut bandit = ProviderBandit::new();
155 let providers = vec!["github".into(), "jira".into()];
156
157 let predictions = predict_preloads(
158 "Fix the authentication bug in the login flow",
159 &providers,
160 &mut bandit,
161 5,
162 );
163
164 assert!(!predictions.is_empty());
165 assert!(predictions
166 .iter()
167 .any(|p| p.provider_id == "github" && p.action == "issues"));
168 }
169
170 #[test]
171 fn predict_db_task_suggests_schemas() {
172 let mut bandit = ProviderBandit::new();
173 let providers = vec!["postgres".into()];
174
175 let predictions = predict_preloads(
176 "Add a new column to the users database table",
177 &providers,
178 &mut bandit,
179 5,
180 );
181
182 assert!(predictions
183 .iter()
184 .any(|p| p.provider_id == "postgres" && p.action == "schemas"));
185 }
186
187 #[test]
188 fn predict_pr_review_suggests_pull_requests() {
189 let mut bandit = ProviderBandit::new();
190 let providers = vec!["github".into()];
191
192 let predictions = predict_preloads(
193 "Review the open pull requests and merge the approved ones",
194 &providers,
195 &mut bandit,
196 5,
197 );
198
199 assert!(predictions
200 .iter()
201 .any(|p| p.provider_id == "github" && p.action == "pull_requests"));
202 }
203
204 #[test]
205 fn predict_empty_task_returns_empty() {
206 let mut bandit = ProviderBandit::new();
207 let predictions = predict_preloads("", &["github".into()], &mut bandit, 5);
208 assert!(predictions.is_empty());
209 }
210
211 #[test]
212 fn predict_unavailable_provider_skipped() {
213 let mut bandit = ProviderBandit::new();
214 let predictions = predict_preloads(
215 "Fix the database schema migration",
216 &["github".into()], &mut bandit,
218 5,
219 );
220
221 assert!(!predictions.iter().any(|p| p.provider_id == "postgres"));
222 }
223
224 #[test]
225 fn predict_respects_max_predictions() {
226 let mut bandit = ProviderBandit::new();
227 let providers = vec!["github".into(), "jira".into(), "postgres".into()];
228
229 let predictions = predict_preloads(
230 "Fix the bug in database schema and review pull requests",
231 &providers,
232 &mut bandit,
233 2,
234 );
235
236 assert!(predictions.len() <= 2);
237 }
238
239 #[test]
240 fn predict_bandit_trained_boosts_confidence() {
241 let mut bandit = ProviderBandit::new();
242 for _ in 0..20 {
243 bandit.update("bugfix", "github", true);
244 bandit.update("bugfix", "jira", false);
245 }
246
247 let providers = vec!["github".into(), "jira".into()];
248 let predictions = predict_preloads(
249 "Fix the crash bug in authentication",
250 &providers,
251 &mut bandit,
252 5,
253 );
254
255 let gh = predictions
256 .iter()
257 .find(|p| p.provider_id == "github" && p.action == "issues");
258 let jira = predictions
259 .iter()
260 .find(|p| p.provider_id == "jira" && p.action == "issues");
261
262 if let (Some(gh), Some(jira)) = (gh, jira) {
263 assert!(
264 gh.confidence > jira.confidence,
265 "Trained bandit should boost github over jira"
266 );
267 }
268 }
269
270 #[test]
271 fn infer_task_type_correctness() {
272 assert_eq!(infer_task_type("fix the crash bug"), "bugfix");
273 assert_eq!(infer_task_type("add new feature"), "feature");
274 assert_eq!(infer_task_type("refactor the auth module"), "refactor");
275 assert_eq!(infer_task_type("update documentation"), "general");
276 }
277}