1use crate::core::provider_bandit::ProviderBandit;
17
18#[derive(Debug, Clone)]
20pub struct PreloadPrediction {
21 pub provider_id: String,
22 pub action: String,
23 pub confidence: f64,
24 pub reason: String,
25}
26
27static KEYWORD_MAPPINGS: &[(&[&str], &str, &str)] = &[
29 (
30 &["bug", "error", "crash", "fix", "broken", "issue", "defect"],
31 "github",
32 "issues",
33 ),
34 (
35 &["bug", "error", "crash", "fix", "broken", "issue", "defect"],
36 "jira",
37 "issues",
38 ),
39 (
40 &["pr", "pull", "merge", "review", "branch"],
41 "github",
42 "pull_requests",
43 ),
44 (
45 &[
46 "database",
47 "table",
48 "schema",
49 "column",
50 "migration",
51 "sql",
52 "db",
53 ],
54 "postgres",
55 "schemas",
56 ),
57 (
58 &["sprint", "story", "epic", "velocity", "backlog"],
59 "jira",
60 "sprints",
61 ),
62 (
63 &["wiki", "doc", "documentation", "guide", "howto"],
64 "github",
65 "issues",
66 ),
67];
68
69pub fn predict_preloads(
71 task_description: &str,
72 available_providers: &[String],
73 bandit: &mut ProviderBandit,
74 max_predictions: usize,
75) -> Vec<PreloadPrediction> {
76 let task_lower = task_description.to_lowercase();
77 let task_words: Vec<&str> = task_lower.split_whitespace().collect();
78
79 let mut predictions: Vec<PreloadPrediction> = Vec::new();
80
81 for &(keywords, provider, action) in KEYWORD_MAPPINGS {
82 if !available_providers.iter().any(|p| p == provider) {
83 continue;
84 }
85
86 let matching_keywords: Vec<&&str> = keywords
87 .iter()
88 .filter(|kw| task_words.iter().any(|tw| tw.contains(*kw)))
89 .collect();
90
91 if matching_keywords.is_empty() {
92 continue;
93 }
94
95 let keyword_confidence = matching_keywords.len() as f64 / keywords.len() as f64;
96
97 let task_type = infer_task_type(&task_lower);
98 let bandit_score = bandit.estimated_probability(&task_type, provider);
99
100 let combined = 0.6 * keyword_confidence + 0.4 * bandit_score;
101
102 if !predictions
103 .iter()
104 .any(|p| p.provider_id == provider && p.action == action)
105 {
106 predictions.push(PreloadPrediction {
107 provider_id: provider.to_string(),
108 action: action.to_string(),
109 confidence: combined,
110 reason: format!(
111 "keywords: {}",
112 matching_keywords
113 .iter()
114 .map(|k| **k)
115 .collect::<Vec<_>>()
116 .join(", ")
117 ),
118 });
119 }
120 }
121
122 predictions.sort_by(|a, b| {
123 b.confidence
124 .partial_cmp(&a.confidence)
125 .unwrap_or(std::cmp::Ordering::Equal)
126 });
127 predictions.truncate(max_predictions);
128 predictions
129}
130
131#[must_use]
134pub fn infer_task_type(task: &str) -> String {
135 if task.contains("bug")
136 || task.contains("fix")
137 || task.contains("error")
138 || task.contains("crash")
139 {
140 "bugfix".into()
141 } else if task.contains("feature") || task.contains("add") || task.contains("implement") {
142 "feature".into()
143 } else if task.contains("refactor") || task.contains("clean") || task.contains("improve") {
144 "refactor".into()
145 } else {
146 "general".into()
147 }
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153
154 #[test]
155 fn predict_bug_fix_suggests_issues() {
156 let mut bandit = ProviderBandit::new();
157 let providers = vec!["github".into(), "jira".into()];
158
159 let predictions = predict_preloads(
160 "Fix the authentication bug in the login flow",
161 &providers,
162 &mut bandit,
163 5,
164 );
165
166 assert!(!predictions.is_empty());
167 assert!(predictions
168 .iter()
169 .any(|p| p.provider_id == "github" && p.action == "issues"));
170 }
171
172 #[test]
173 fn predict_db_task_suggests_schemas() {
174 let mut bandit = ProviderBandit::new();
175 let providers = vec!["postgres".into()];
176
177 let predictions = predict_preloads(
178 "Add a new column to the users database table",
179 &providers,
180 &mut bandit,
181 5,
182 );
183
184 assert!(predictions
185 .iter()
186 .any(|p| p.provider_id == "postgres" && p.action == "schemas"));
187 }
188
189 #[test]
190 fn predict_pr_review_suggests_pull_requests() {
191 let mut bandit = ProviderBandit::new();
192 let providers = vec!["github".into()];
193
194 let predictions = predict_preloads(
195 "Review the open pull requests and merge the approved ones",
196 &providers,
197 &mut bandit,
198 5,
199 );
200
201 assert!(predictions
202 .iter()
203 .any(|p| p.provider_id == "github" && p.action == "pull_requests"));
204 }
205
206 #[test]
207 fn predict_empty_task_returns_empty() {
208 let mut bandit = ProviderBandit::new();
209 let predictions = predict_preloads("", &["github".into()], &mut bandit, 5);
210 assert!(predictions.is_empty());
211 }
212
213 #[test]
214 fn predict_unavailable_provider_skipped() {
215 let mut bandit = ProviderBandit::new();
216 let predictions = predict_preloads(
217 "Fix the database schema migration",
218 &["github".into()], &mut bandit,
220 5,
221 );
222
223 assert!(!predictions.iter().any(|p| p.provider_id == "postgres"));
224 }
225
226 #[test]
227 fn predict_respects_max_predictions() {
228 let mut bandit = ProviderBandit::new();
229 let providers = vec!["github".into(), "jira".into(), "postgres".into()];
230
231 let predictions = predict_preloads(
232 "Fix the bug in database schema and review pull requests",
233 &providers,
234 &mut bandit,
235 2,
236 );
237
238 assert!(predictions.len() <= 2);
239 }
240
241 #[test]
242 fn predict_bandit_trained_boosts_confidence() {
243 let mut bandit = ProviderBandit::new();
244 for _ in 0..20 {
245 bandit.update("bugfix", "github", true);
246 bandit.update("bugfix", "jira", false);
247 }
248
249 let providers = vec!["github".into(), "jira".into()];
250 let predictions = predict_preloads(
251 "Fix the crash bug in authentication",
252 &providers,
253 &mut bandit,
254 5,
255 );
256
257 let gh = predictions
258 .iter()
259 .find(|p| p.provider_id == "github" && p.action == "issues");
260 let jira = predictions
261 .iter()
262 .find(|p| p.provider_id == "jira" && p.action == "issues");
263
264 if let (Some(gh), Some(jira)) = (gh, jira) {
265 assert!(
266 gh.confidence > jira.confidence,
267 "Trained bandit should boost github over jira"
268 );
269 }
270 }
271
272 #[test]
273 fn infer_task_type_correctness() {
274 assert_eq!(infer_task_type("fix the crash bug"), "bugfix");
275 assert_eq!(infer_task_type("add new feature"), "feature");
276 assert_eq!(infer_task_type("refactor the auth module"), "refactor");
277 assert_eq!(infer_task_type("update documentation"), "general");
278 }
279}