1use bob_core::types::ToolDescriptor;
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
52pub enum TaskCategory {
53 Coding,
55 Research,
57 FileOps,
59 Shell,
61 Git,
63 DataProcessing,
65 General,
67 Unknown,
69}
70
71impl TaskCategory {
72 #[must_use]
74 pub fn all() -> &'static [Self] {
75 &[
76 Self::Coding,
77 Self::Research,
78 Self::FileOps,
79 Self::Shell,
80 Self::Git,
81 Self::DataProcessing,
82 Self::General,
83 ]
84 }
85
86 #[must_use]
88 pub fn keywords(&self) -> &'static [&'static str] {
89 match self {
90 Self::Coding => &[
91 "code",
92 "function",
93 "class",
94 "implement",
95 "refactor",
96 "compile",
97 "build",
98 "test",
99 "debug",
100 "error",
101 "bug",
102 "fix",
103 "syntax",
104 "variable",
105 "type",
106 "trait",
107 "struct",
108 "enum",
109 "module",
110 "import",
111 "export",
112 "api",
113 "endpoint",
114 "handler",
115 "library",
116 "dependency",
117 "crate",
118 "package",
119 ".rs",
120 ".ts",
121 ".py",
122 ".go",
123 ".java",
124 ".cpp",
125 ".js",
126 ".jsx",
127 ".tsx",
128 ],
129 Self::Research => &[
130 "search",
131 "find",
132 "look up",
133 "research",
134 "what is",
135 "how does",
136 "explain",
137 "documentation",
138 "docs",
139 "learn",
140 "understand",
141 "compare",
142 "difference",
143 "pros and cons",
144 "alternative",
145 "best practice",
146 "recommendation",
147 ],
148 Self::FileOps => &[
149 "read file",
150 "write file",
151 "create file",
152 "delete file",
153 "list files",
154 "list directory",
155 "directory",
156 "folder",
157 "copy file",
158 "move file",
159 "rename",
160 "file content",
161 "cat ",
162 "ls ",
163 "mkdir",
164 "touch",
165 "rm ",
166 "cp ",
167 "mv ",
168 "read the file",
169 "write the file",
170 "open file",
171 "open the file",
172 "read src",
173 "read the src",
174 ".rs",
175 ".toml",
176 ".json",
177 ".yaml",
178 ".yml",
179 ".md",
180 ".txt",
181 ".csv",
182 ".log",
183 ],
184 Self::Shell => &[
185 "run", "execute", "command", "shell", "bash", "terminal", "script", "install",
186 "apt", "brew", "npm", "pip", "cargo", "docker", "kubectl", "ssh",
187 ],
188 Self::Git => &[
189 "git",
190 "commit",
191 "push",
192 "pull",
193 "branch",
194 "merge",
195 "rebase",
196 "diff",
197 "log",
198 "status",
199 "checkout",
200 "clone",
201 "repository",
202 "repo",
203 "pr",
204 "pull request",
205 ],
206 Self::DataProcessing => &[
207 "parse",
208 "transform",
209 "convert",
210 "json",
211 "csv",
212 "xml",
213 "sql",
214 "database",
215 "query",
216 "aggregate",
217 "filter",
218 "sort",
219 "map",
220 "reduce",
221 "process data",
222 ],
223 Self::General => &[
224 "hello",
225 "hi",
226 "hey",
227 "help",
228 "thanks",
229 "please",
230 "can you",
231 "could you",
232 "would you",
233 "tell me",
234 ],
235 Self::Unknown => &[],
236 }
237 }
238
239 #[must_use]
241 pub fn tool_prefixes(&self) -> &'static [&'static str] {
242 match self {
243 Self::Coding => &["code.", "lint.", "format.", "test.", "build."],
244 Self::Research => &["web.", "search.", "browse.", "fetch."],
245 Self::FileOps => &["file.", "read_file", "write_file", "list_files"],
246 Self::Shell => &["shell.", "exec.", "command."],
247 Self::Git => &["git.", "gh."],
248 Self::DataProcessing => &["data.", "parse.", "transform.", "sql."],
249 Self::General => &[],
250 Self::Unknown => &[],
251 }
252 }
253}
254
255#[derive(Debug, Clone)]
257pub struct ClassificationResult {
258 pub category: TaskCategory,
260 pub secondary_categories: Vec<TaskCategory>,
262 pub confidence: f64,
264 pub matched_keywords: Vec<String>,
266 pub filtered_tools: Vec<ToolDescriptor>,
268}
269
270#[async_trait::async_trait]
275pub trait Classifier: Send + Sync {
276 async fn classify(
278 &self,
279 input: &str,
280 available_tools: &[ToolDescriptor],
281 ) -> ClassificationResult;
282
283 fn supports_multi_category(&self) -> bool {
285 false
286 }
287}
288
289#[derive(Debug, Clone, Default)]
304pub struct HeuristicClassifier {
305 min_confidence: f64,
307}
308
309impl HeuristicClassifier {
310 #[must_use]
312 pub fn new() -> Self {
313 Self { min_confidence: 0.05 }
314 }
315
316 #[must_use]
318 pub fn with_min_confidence(mut self, threshold: f64) -> Self {
319 self.min_confidence = threshold.clamp(0.0, 1.0);
320 self
321 }
322}
323
324#[async_trait::async_trait]
325impl Classifier for HeuristicClassifier {
326 async fn classify(
327 &self,
328 input: &str,
329 available_tools: &[ToolDescriptor],
330 ) -> ClassificationResult {
331 let input_lower = input.to_lowercase();
332
333 let mut scores: Vec<(TaskCategory, f64, Vec<String>)> = TaskCategory::all()
335 .iter()
336 .map(|cat| {
337 let mut score = 0.0;
338 let mut matched = Vec::new();
339
340 for keyword in cat.keywords() {
341 if input_lower.contains(keyword) {
342 score += 1.0;
343 matched.push(keyword.to_string());
344 }
345 }
346
347 let normalized = if cat.keywords().is_empty() {
349 0.0
350 } else {
351 score / cat.keywords().len() as f64
352 };
353
354 (*cat, normalized, matched)
355 })
356 .collect();
357
358 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
360
361 let (primary, primary_score, primary_keywords) =
362 scores.first().cloned().unwrap_or((TaskCategory::Unknown, 0.0, Vec::new()));
363
364 let secondary_threshold = primary_score * 0.5;
366 let secondary: Vec<TaskCategory> = scores
367 .iter()
368 .skip(1)
369 .filter(|(_, score, _)| *score >= secondary_threshold && *score > 0.0)
370 .map(|(cat, _, _)| *cat)
371 .take(2)
372 .collect();
373
374 let filtered_tools =
376 if primary_score >= self.min_confidence && primary != TaskCategory::Unknown {
377 filter_tools_for_category(primary, available_tools)
378 } else {
379 available_tools.to_vec()
381 };
382
383 ClassificationResult {
384 category: if primary_score >= self.min_confidence {
385 primary
386 } else {
387 TaskCategory::Unknown
388 },
389 secondary_categories: secondary,
390 confidence: primary_score,
391 matched_keywords: primary_keywords,
392 filtered_tools,
393 }
394 }
395
396 fn supports_multi_category(&self) -> bool {
397 true
398 }
399}
400
401#[derive(Debug, Clone, Copy, Default)]
405pub struct PassThroughClassifier;
406
407#[async_trait::async_trait]
408impl Classifier for PassThroughClassifier {
409 async fn classify(
410 &self,
411 _input: &str,
412 available_tools: &[ToolDescriptor],
413 ) -> ClassificationResult {
414 ClassificationResult {
415 category: TaskCategory::Unknown,
416 secondary_categories: Vec::new(),
417 confidence: 0.0,
418 matched_keywords: Vec::new(),
419 filtered_tools: available_tools.to_vec(),
420 }
421 }
422}
423
424fn filter_tools_for_category(
426 category: TaskCategory,
427 tools: &[ToolDescriptor],
428) -> Vec<ToolDescriptor> {
429 let prefixes = category.tool_prefixes();
430
431 if prefixes.is_empty() {
432 return tools.to_vec();
433 }
434
435 let matched: Vec<ToolDescriptor> = tools
437 .iter()
438 .filter(|tool| prefixes.iter().any(|prefix| tool.id.starts_with(prefix)))
439 .cloned()
440 .collect();
441
442 if matched.is_empty() { tools.to_vec() } else { matched }
444}
445
446#[cfg(test)]
449mod tests {
450 use super::*;
451
452 fn make_tool(id: &str, desc: &str) -> ToolDescriptor {
453 ToolDescriptor::new(id, desc)
454 }
455
456 fn sample_tools() -> Vec<ToolDescriptor> {
457 vec![
458 make_tool("file.read", "Read file contents"),
459 make_tool("file.write", "Write file contents"),
460 make_tool("file.list", "List directory"),
461 make_tool("shell.exec", "Execute shell command"),
462 make_tool("git.status", "Git status"),
463 make_tool("web.search", "Search the web"),
464 make_tool("code.lint", "Lint code"),
465 ]
466 }
467
468 #[tokio::test]
471 async fn classify_file_operation() {
472 let classifier = HeuristicClassifier::new();
473 let tools = sample_tools();
474
475 let result = classifier.classify("read the file src/main.rs", &tools).await;
476 assert_eq!(result.category, TaskCategory::FileOps);
477 assert!(!result.filtered_tools.is_empty());
478 }
479
480 #[tokio::test]
481 async fn classify_coding_task() {
482 let classifier = HeuristicClassifier::new();
483 let tools = sample_tools();
484
485 let result = classifier.classify("refactor this function to fix the bug", &tools).await;
486 assert_eq!(result.category, TaskCategory::Coding);
487 }
488
489 #[tokio::test]
490 async fn classify_shell_task() {
491 let classifier = HeuristicClassifier::new();
492 let tools = sample_tools();
493
494 let result = classifier.classify("run the install command", &tools).await;
495 assert_eq!(result.category, TaskCategory::Shell);
496 }
497
498 #[tokio::test]
499 async fn classify_git_task() {
500 let classifier = HeuristicClassifier::new();
501 let tools = sample_tools();
502
503 let result = classifier.classify("show me the git status", &tools).await;
504 assert_eq!(result.category, TaskCategory::Git);
505 }
506
507 #[tokio::test]
508 async fn classify_research_task() {
509 let classifier = HeuristicClassifier::new();
510 let tools = sample_tools();
511
512 let result = classifier.classify("search for documentation on rust traits", &tools).await;
513 assert_eq!(result.category, TaskCategory::Research);
514 }
515
516 #[tokio::test]
517 async fn classify_general_greeting() {
518 let classifier = HeuristicClassifier::new();
519 let tools = sample_tools();
520
521 let result = classifier.classify("hello, can you help me?", &tools).await;
522 assert_eq!(result.category, TaskCategory::General);
523 }
524
525 #[tokio::test]
526 async fn classify_unknown_returns_all_tools() {
527 let classifier = HeuristicClassifier::new();
528 let tools = sample_tools();
529
530 let result = classifier.classify("xyzzy plugh", &tools).await;
531 assert_eq!(result.category, TaskCategory::Unknown);
532 assert_eq!(result.filtered_tools.len(), tools.len());
533 }
534
535 #[tokio::test]
536 async fn filtered_tools_match_category() {
537 let classifier = HeuristicClassifier::new();
538 let tools = sample_tools();
539
540 let result = classifier.classify("read the file config.toml", &tools).await;
541 assert_eq!(result.category, TaskCategory::FileOps);
542 assert!(
543 result.filtered_tools.iter().any(|t| t.id.starts_with("file.")),
544 "should include file-related tools"
545 );
546 }
547
548 #[tokio::test]
551 async fn passthrough_returns_all_tools() {
552 let classifier = PassThroughClassifier;
553 let tools = sample_tools();
554
555 let result = classifier.classify("anything", &tools).await;
556 assert_eq!(result.filtered_tools.len(), tools.len());
557 assert_eq!(result.category, TaskCategory::Unknown);
558 }
559
560 #[test]
563 fn heuristic_supports_multi_category() {
564 let classifier = HeuristicClassifier::new();
565 assert!(classifier.supports_multi_category());
566 }
567}