Skip to main content

bob_runtime/
classifier.rs

1//! # ICE Classifier / Router
2//!
3//! Implements the Interpreter-Classifier-Executor cognitive architecture
4//! inspired by OpenAgent's cortex design.
5//!
6//! ## Overview
7//!
8//! Before sending a large prompt to the LLM, the classifier:
9//!
10//! 1. **Interprets** the user's intent from their message
11//! 2. **Classifies** the task type (coding, research, file ops, etc.)
12//! 3. **Routes** only the relevant tools into the LLM context
13//!
14//! This reduces token usage and prevents "tool overload" where having
15//! too many tools degrades LLM performance.
16//!
17//! ## Architecture
18//!
19//! ```text
20//! User Input
21//!     │
22//!     ▼
23//! ┌──────────────┐
24//! │ Interpreter  │  Extract intent + entities
25//! └──────┬───────┘
26//!        │
27//!        ▼
28//! ┌──────────────┐
29//! │ Classifier   │  Map intent → task category
30//! └──────┬───────┘
31//!        │
32//!        ▼
33//! ┌──────────────┐
34//! │ Router       │  Filter tools for category
35//! └──────────────┘
36//! ```
37//!
38//! ## Example
39//!
40//! ```rust,ignore
41//! use bob_runtime::classifier::{Classifier, HeuristicClassifier};
42//!
43//! let classifier = HeuristicClassifier::new();
44//! let routing = classifier.classify("read the file src/main.rs", &tools).await;
45//! // routing.filtered_tools only contains file-related tools
46//! ```
47
48use bob_core::types::ToolDescriptor;
49
50/// Task categories for classification.
51#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
52pub enum TaskCategory {
53    /// Code reading, writing, refactoring.
54    Coding,
55    /// Web search, research, information gathering.
56    Research,
57    /// File system operations (read, write, list).
58    FileOps,
59    /// Shell command execution.
60    Shell,
61    /// Git operations.
62    Git,
63    /// Data analysis, transformation.
64    DataProcessing,
65    /// General conversation, Q&A.
66    General,
67    /// Unknown / cannot classify.
68    Unknown,
69}
70
71impl TaskCategory {
72    /// All known categories.
73    #[must_use]
74    pub fn all() -> &'static [Self] {
75        &[
76            Self::Coding,
77            Self::Research,
78            Self::FileOps,
79            Self::Shell,
80            Self::Git,
81            Self::DataProcessing,
82            Self::General,
83        ]
84    }
85
86    /// Keywords that strongly indicate this category.
87    #[must_use]
88    pub fn keywords(&self) -> &'static [&'static str] {
89        match self {
90            Self::Coding => &[
91                "code",
92                "function",
93                "class",
94                "implement",
95                "refactor",
96                "compile",
97                "build",
98                "test",
99                "debug",
100                "error",
101                "bug",
102                "fix",
103                "syntax",
104                "variable",
105                "type",
106                "trait",
107                "struct",
108                "enum",
109                "module",
110                "import",
111                "export",
112                "api",
113                "endpoint",
114                "handler",
115                "library",
116                "dependency",
117                "crate",
118                "package",
119                ".rs",
120                ".ts",
121                ".py",
122                ".go",
123                ".java",
124                ".cpp",
125                ".js",
126                ".jsx",
127                ".tsx",
128            ],
129            Self::Research => &[
130                "search",
131                "find",
132                "look up",
133                "research",
134                "what is",
135                "how does",
136                "explain",
137                "documentation",
138                "docs",
139                "learn",
140                "understand",
141                "compare",
142                "difference",
143                "pros and cons",
144                "alternative",
145                "best practice",
146                "recommendation",
147            ],
148            Self::FileOps => &[
149                "read file",
150                "write file",
151                "create file",
152                "delete file",
153                "list files",
154                "list directory",
155                "directory",
156                "folder",
157                "copy file",
158                "move file",
159                "rename",
160                "file content",
161                "cat ",
162                "ls ",
163                "mkdir",
164                "touch",
165                "rm ",
166                "cp ",
167                "mv ",
168                "read the file",
169                "write the file",
170                "open file",
171                "open the file",
172                "read src",
173                "read the src",
174                ".rs",
175                ".toml",
176                ".json",
177                ".yaml",
178                ".yml",
179                ".md",
180                ".txt",
181                ".csv",
182                ".log",
183            ],
184            Self::Shell => &[
185                "run", "execute", "command", "shell", "bash", "terminal", "script", "install",
186                "apt", "brew", "npm", "pip", "cargo", "docker", "kubectl", "ssh",
187            ],
188            Self::Git => &[
189                "git",
190                "commit",
191                "push",
192                "pull",
193                "branch",
194                "merge",
195                "rebase",
196                "diff",
197                "log",
198                "status",
199                "checkout",
200                "clone",
201                "repository",
202                "repo",
203                "pr",
204                "pull request",
205            ],
206            Self::DataProcessing => &[
207                "parse",
208                "transform",
209                "convert",
210                "json",
211                "csv",
212                "xml",
213                "sql",
214                "database",
215                "query",
216                "aggregate",
217                "filter",
218                "sort",
219                "map",
220                "reduce",
221                "process data",
222            ],
223            Self::General => &[
224                "hello",
225                "hi",
226                "hey",
227                "help",
228                "thanks",
229                "please",
230                "can you",
231                "could you",
232                "would you",
233                "tell me",
234            ],
235            Self::Unknown => &[],
236        }
237    }
238
239    /// Tool name prefixes that belong to this category.
240    #[must_use]
241    pub fn tool_prefixes(&self) -> &'static [&'static str] {
242        match self {
243            Self::Coding => &["code.", "lint.", "format.", "test.", "build."],
244            Self::Research => &["web.", "search.", "browse.", "fetch."],
245            Self::FileOps => &["file.", "read_file", "write_file", "list_files"],
246            Self::Shell => &["shell.", "exec.", "command."],
247            Self::Git => &["git.", "gh."],
248            Self::DataProcessing => &["data.", "parse.", "transform.", "sql."],
249            Self::General => &[],
250            Self::Unknown => &[],
251        }
252    }
253}
254
255/// Classification result with confidence and routing information.
256#[derive(Debug, Clone)]
257pub struct ClassificationResult {
258    /// The primary task category.
259    pub category: TaskCategory,
260    /// Secondary categories (if multi-category input).
261    pub secondary_categories: Vec<TaskCategory>,
262    /// Confidence score (0.0 - 1.0).
263    pub confidence: f64,
264    /// Extracted keywords that led to classification.
265    pub matched_keywords: Vec<String>,
266    /// Tools filtered for this classification.
267    pub filtered_tools: Vec<ToolDescriptor>,
268}
269
270/// Trait for task classifiers.
271///
272/// Classifiers analyze user input and determine the task category
273/// to enable intelligent tool routing.
274#[async_trait::async_trait]
275pub trait Classifier: Send + Sync {
276    /// Classify user input and return routing information.
277    async fn classify(
278        &self,
279        input: &str,
280        available_tools: &[ToolDescriptor],
281    ) -> ClassificationResult;
282
283    /// Whether this classifier can handle multi-category classification.
284    fn supports_multi_category(&self) -> bool {
285        false
286    }
287}
288
289/// Heuristic keyword-based classifier.
290///
291/// Fast, deterministic classification using keyword matching.
292/// No LLM calls required — pure string matching.
293///
294/// ## Example
295///
296/// ```rust,ignore
297/// use bob_runtime::classifier::HeuristicClassifier;
298///
299/// let classifier = HeuristicClassifier::new();
300/// let result = classifier.classify("read the file src/main.rs", &tools).await;
301/// assert_eq!(result.category, TaskCategory::FileOps);
302/// ```
303#[derive(Debug, Clone, Default)]
304pub struct HeuristicClassifier {
305    /// Minimum confidence threshold to accept classification.
306    min_confidence: f64,
307}
308
309impl HeuristicClassifier {
310    /// Create a new heuristic classifier.
311    #[must_use]
312    pub fn new() -> Self {
313        Self { min_confidence: 0.05 }
314    }
315
316    /// Create with custom minimum confidence threshold.
317    #[must_use]
318    pub fn with_min_confidence(mut self, threshold: f64) -> Self {
319        self.min_confidence = threshold.clamp(0.0, 1.0);
320        self
321    }
322}
323
324#[async_trait::async_trait]
325impl Classifier for HeuristicClassifier {
326    async fn classify(
327        &self,
328        input: &str,
329        available_tools: &[ToolDescriptor],
330    ) -> ClassificationResult {
331        let input_lower = input.to_lowercase();
332
333        // Score each category
334        let mut scores: Vec<(TaskCategory, f64, Vec<String>)> = TaskCategory::all()
335            .iter()
336            .map(|cat| {
337                let mut score = 0.0;
338                let mut matched = Vec::new();
339
340                for keyword in cat.keywords() {
341                    if input_lower.contains(keyword) {
342                        score += 1.0;
343                        matched.push(keyword.to_string());
344                    }
345                }
346
347                // Normalize by keyword count
348                let normalized = if cat.keywords().is_empty() {
349                    0.0
350                } else {
351                    score / cat.keywords().len() as f64
352                };
353
354                (*cat, normalized, matched)
355            })
356            .collect();
357
358        // Sort by score descending
359        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
360
361        let (primary, primary_score, primary_keywords) =
362            scores.first().cloned().unwrap_or((TaskCategory::Unknown, 0.0, Vec::new()));
363
364        // Secondary categories (score > 50% of primary)
365        let secondary_threshold = primary_score * 0.5;
366        let secondary: Vec<TaskCategory> = scores
367            .iter()
368            .skip(1)
369            .filter(|(_, score, _)| *score >= secondary_threshold && *score > 0.0)
370            .map(|(cat, _, _)| *cat)
371            .take(2)
372            .collect();
373
374        // Filter tools based on primary category
375        let filtered_tools =
376            if primary_score >= self.min_confidence && primary != TaskCategory::Unknown {
377                filter_tools_for_category(primary, available_tools)
378            } else {
379                // Unknown or low confidence: return all tools
380                available_tools.to_vec()
381            };
382
383        ClassificationResult {
384            category: if primary_score >= self.min_confidence {
385                primary
386            } else {
387                TaskCategory::Unknown
388            },
389            secondary_categories: secondary,
390            confidence: primary_score,
391            matched_keywords: primary_keywords,
392            filtered_tools,
393        }
394    }
395
396    fn supports_multi_category(&self) -> bool {
397        true
398    }
399}
400
401/// No-op classifier that returns all tools without filtering.
402///
403/// Useful as a fallback or when classification is not desired.
404#[derive(Debug, Clone, Copy, Default)]
405pub struct PassThroughClassifier;
406
407#[async_trait::async_trait]
408impl Classifier for PassThroughClassifier {
409    async fn classify(
410        &self,
411        _input: &str,
412        available_tools: &[ToolDescriptor],
413    ) -> ClassificationResult {
414        ClassificationResult {
415            category: TaskCategory::Unknown,
416            secondary_categories: Vec::new(),
417            confidence: 0.0,
418            matched_keywords: Vec::new(),
419            filtered_tools: available_tools.to_vec(),
420        }
421    }
422}
423
424/// Filter tools that match a task category's prefixes.
425fn filter_tools_for_category(
426    category: TaskCategory,
427    tools: &[ToolDescriptor],
428) -> Vec<ToolDescriptor> {
429    let prefixes = category.tool_prefixes();
430
431    if prefixes.is_empty() {
432        return tools.to_vec();
433    }
434
435    // Also include tools that match secondary categories
436    let matched: Vec<ToolDescriptor> = tools
437        .iter()
438        .filter(|tool| prefixes.iter().any(|prefix| tool.id.starts_with(prefix)))
439        .cloned()
440        .collect();
441
442    // If filtering produces too few tools, fall back to all tools
443    if matched.is_empty() { tools.to_vec() } else { matched }
444}
445
446// ── Tests ────────────────────────────────────────────────────────────
447
448#[cfg(test)]
449mod tests {
450    use super::*;
451
452    fn make_tool(id: &str, desc: &str) -> ToolDescriptor {
453        ToolDescriptor::new(id, desc)
454    }
455
456    fn sample_tools() -> Vec<ToolDescriptor> {
457        vec![
458            make_tool("file.read", "Read file contents"),
459            make_tool("file.write", "Write file contents"),
460            make_tool("file.list", "List directory"),
461            make_tool("shell.exec", "Execute shell command"),
462            make_tool("git.status", "Git status"),
463            make_tool("web.search", "Search the web"),
464            make_tool("code.lint", "Lint code"),
465        ]
466    }
467
468    // ── Heuristic Classifier ────────────────────────────────────────
469
470    #[tokio::test]
471    async fn classify_file_operation() {
472        let classifier = HeuristicClassifier::new();
473        let tools = sample_tools();
474
475        let result = classifier.classify("read the file src/main.rs", &tools).await;
476        assert_eq!(result.category, TaskCategory::FileOps);
477        assert!(!result.filtered_tools.is_empty());
478    }
479
480    #[tokio::test]
481    async fn classify_coding_task() {
482        let classifier = HeuristicClassifier::new();
483        let tools = sample_tools();
484
485        let result = classifier.classify("refactor this function to fix the bug", &tools).await;
486        assert_eq!(result.category, TaskCategory::Coding);
487    }
488
489    #[tokio::test]
490    async fn classify_shell_task() {
491        let classifier = HeuristicClassifier::new();
492        let tools = sample_tools();
493
494        let result = classifier.classify("run the install command", &tools).await;
495        assert_eq!(result.category, TaskCategory::Shell);
496    }
497
498    #[tokio::test]
499    async fn classify_git_task() {
500        let classifier = HeuristicClassifier::new();
501        let tools = sample_tools();
502
503        let result = classifier.classify("show me the git status", &tools).await;
504        assert_eq!(result.category, TaskCategory::Git);
505    }
506
507    #[tokio::test]
508    async fn classify_research_task() {
509        let classifier = HeuristicClassifier::new();
510        let tools = sample_tools();
511
512        let result = classifier.classify("search for documentation on rust traits", &tools).await;
513        assert_eq!(result.category, TaskCategory::Research);
514    }
515
516    #[tokio::test]
517    async fn classify_general_greeting() {
518        let classifier = HeuristicClassifier::new();
519        let tools = sample_tools();
520
521        let result = classifier.classify("hello, can you help me?", &tools).await;
522        assert_eq!(result.category, TaskCategory::General);
523    }
524
525    #[tokio::test]
526    async fn classify_unknown_returns_all_tools() {
527        let classifier = HeuristicClassifier::new();
528        let tools = sample_tools();
529
530        let result = classifier.classify("xyzzy plugh", &tools).await;
531        assert_eq!(result.category, TaskCategory::Unknown);
532        assert_eq!(result.filtered_tools.len(), tools.len());
533    }
534
535    #[tokio::test]
536    async fn filtered_tools_match_category() {
537        let classifier = HeuristicClassifier::new();
538        let tools = sample_tools();
539
540        let result = classifier.classify("read the file config.toml", &tools).await;
541        assert_eq!(result.category, TaskCategory::FileOps);
542        assert!(
543            result.filtered_tools.iter().any(|t| t.id.starts_with("file.")),
544            "should include file-related tools"
545        );
546    }
547
548    // ── PassThrough Classifier ──────────────────────────────────────
549
550    #[tokio::test]
551    async fn passthrough_returns_all_tools() {
552        let classifier = PassThroughClassifier;
553        let tools = sample_tools();
554
555        let result = classifier.classify("anything", &tools).await;
556        assert_eq!(result.filtered_tools.len(), tools.len());
557        assert_eq!(result.category, TaskCategory::Unknown);
558    }
559
560    // ── Multi-category ──────────────────────────────────────────────
561
562    #[test]
563    fn heuristic_supports_multi_category() {
564        let classifier = HeuristicClassifier::new();
565        assert!(classifier.supports_multi_category());
566    }
567}