Skip to main content

brainwires_reasoning/
router.rs

1//! Router - Semantic Query Classification
2//!
3//! Uses a provider to classify queries into tool categories,
4//! replacing keyword-based pattern matching with semantic understanding.
5
6use std::sync::Arc;
7use tracing::warn;
8
9use brainwires_core::message::Message;
10use brainwires_core::provider::{ChatOptions, Provider};
11use brainwires_tools::ToolCategory;
12
13use crate::InferenceTimer;
14
15/// Result of route classification
16#[derive(Clone, Debug)]
17pub struct RouteResult {
18    /// Classified tool categories
19    pub categories: Vec<ToolCategory>,
20    /// Confidence score (0.0 - 1.0)
21    pub confidence: f32,
22    /// Whether LLM was used (vs fallback)
23    pub used_local_llm: bool,
24}
25
26impl RouteResult {
27    /// Create a result from pattern-based fallback
28    pub fn from_fallback(categories: Vec<ToolCategory>) -> Self {
29        Self {
30            categories,
31            confidence: 0.5, // Lower confidence for fallback
32            used_local_llm: false,
33        }
34    }
35
36    /// Create a result from LLM classification
37    pub fn from_local(categories: Vec<ToolCategory>, confidence: f32) -> Self {
38        Self {
39            categories,
40            confidence,
41            used_local_llm: true,
42        }
43    }
44}
45
46/// Router for semantic query classification
47pub struct LocalRouter {
48    provider: Arc<dyn Provider>,
49    model_id: String,
50}
51
52impl LocalRouter {
53    /// Create a new router with the given provider
54    pub fn new(provider: Arc<dyn Provider>, model_id: impl Into<String>) -> Self {
55        Self {
56            provider,
57            model_id: model_id.into(),
58        }
59    }
60
61    /// Classify a query into tool categories using the provider
62    ///
63    /// Returns None if classification fails, allowing fallback to pattern matching.
64    pub async fn classify(&self, query: &str) -> Option<RouteResult> {
65        let timer = InferenceTimer::new("route_classify", &self.model_id);
66
67        let system_prompt = self.build_classification_prompt();
68        let user_prompt = format!(
69            "Classify this query into tool categories. Output ONLY the category names, comma-separated.\n\nQuery: {}",
70            query
71        );
72
73        let messages = vec![Message::user(&user_prompt)];
74        let options = ChatOptions::deterministic(50).system(system_prompt);
75
76        match self.provider.chat(&messages, None, &options).await {
77            Ok(response) => {
78                let text = response.message.text_or_summary();
79                let categories = self.parse_categories(&text);
80
81                if categories.is_empty() {
82                    timer.finish(false);
83                    return None;
84                }
85
86                timer.finish(true);
87                Some(RouteResult::from_local(categories, 0.85))
88            }
89            Err(e) => {
90                warn!(target: "local_llm", "Route classification failed: {}", e);
91                timer.finish(false);
92                None
93            }
94        }
95    }
96
97    /// Build the system prompt for classification
98    fn build_classification_prompt(&self) -> String {
99        r#"You are a tool category classifier. Given a user query, output the relevant tool categories.
100
101Available categories:
102- FileOps: File operations (read, write, edit, create, delete, list files/directories)
103- Search: Text search (grep, find patterns, locate text)
104- SemanticSearch: Semantic/concept search (codebase queries, embeddings, RAG)
105- Git: Git operations (commit, diff, branch, merge, status, log)
106- TaskManager: Task tracking (todos, progress, subtasks)
107- AgentPool: Multi-agent operations (spawn, parallel, background)
108- Web: HTTP/API operations (fetch, request, download)
109- WebSearch: Internet search (google, browse, scrape)
110- Bash: Shell commands (run, execute, npm, cargo, pip, docker)
111- Planning: Design/architecture (plan, strategy, roadmap)
112- Context: Memory/recall (remember, previous, earlier)
113- Orchestrator: Script automation (workflow, batch)
114- CodeExecution: Code execution (run code, python, javascript)
115
116Rules:
1171. Output ONLY category names, comma-separated
1182. Include multiple categories if query spans multiple domains
1193. Always include FileOps if file operations might be needed
1204. Be conservative - only include clearly relevant categories"#.to_string()
121    }
122
123    /// Parse LLM output into tool categories
124    fn parse_categories(&self, output: &str) -> Vec<ToolCategory> {
125        let mut categories = Vec::new();
126        let output_lower = output.to_lowercase();
127
128        // Parse each potential category
129        let category_mappings = [
130            ("fileops", ToolCategory::FileOps),
131            ("file", ToolCategory::FileOps),
132            ("search", ToolCategory::Search),
133            ("semanticsearch", ToolCategory::SemanticSearch),
134            ("semantic", ToolCategory::SemanticSearch),
135            ("git", ToolCategory::Git),
136            ("taskmanager", ToolCategory::TaskManager),
137            ("task", ToolCategory::TaskManager),
138            ("agentpool", ToolCategory::AgentPool),
139            ("agent", ToolCategory::AgentPool),
140            ("web", ToolCategory::Web),
141            ("websearch", ToolCategory::WebSearch),
142            ("bash", ToolCategory::Bash),
143            ("shell", ToolCategory::Bash),
144            ("planning", ToolCategory::Planning),
145            ("plan", ToolCategory::Planning),
146            ("context", ToolCategory::Context),
147            ("orchestrator", ToolCategory::Orchestrator),
148            ("codeexecution", ToolCategory::CodeExecution),
149            ("code", ToolCategory::CodeExecution),
150        ];
151
152        for (keyword, category) in category_mappings {
153            if output_lower.contains(keyword) && !categories.contains(&category) {
154                categories.push(category);
155            }
156        }
157
158        categories
159    }
160}
161
162/// Builder for LocalRouter
163pub struct LocalRouterBuilder {
164    provider: Option<Arc<dyn Provider>>,
165    model_id: String,
166}
167
168impl Default for LocalRouterBuilder {
169    fn default() -> Self {
170        Self {
171            provider: None,
172            model_id: "lfm2-350m".to_string(),
173        }
174    }
175}
176
177impl LocalRouterBuilder {
178    /// Create a new builder with default settings.
179    pub fn new() -> Self {
180        Self::default()
181    }
182
183    /// Set the provider to use for query routing.
184    pub fn provider(mut self, provider: Arc<dyn Provider>) -> Self {
185        self.provider = Some(provider);
186        self
187    }
188
189    /// Set the model ID to use for inference.
190    pub fn model_id(mut self, model_id: impl Into<String>) -> Self {
191        self.model_id = model_id.into();
192        self
193    }
194
195    /// Build the local router, returning `None` if no provider was set.
196    pub fn build(self) -> Option<LocalRouter> {
197        self.provider.map(|p| LocalRouter::new(p, self.model_id))
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204
205    #[test]
206    fn test_route_result_from_fallback() {
207        let result = RouteResult::from_fallback(vec![ToolCategory::FileOps, ToolCategory::Search]);
208        assert!(!result.used_local_llm);
209        assert_eq!(result.confidence, 0.5);
210        assert_eq!(result.categories.len(), 2);
211    }
212
213    #[test]
214    fn test_route_result_from_local() {
215        let result = RouteResult::from_local(vec![ToolCategory::Git], 0.9);
216        assert!(result.used_local_llm);
217        assert_eq!(result.confidence, 0.9);
218    }
219
220    #[test]
221    fn test_parse_categories() {
222        let _router = LocalRouterBuilder::default();
223
224        // Test the parsing logic directly
225        let output = "FileOps, Git, Bash";
226        let output_lower = output.to_lowercase();
227        let mut categories = Vec::new();
228
229        if output_lower.contains("fileops") || output_lower.contains("file") {
230            categories.push(ToolCategory::FileOps);
231        }
232        if output_lower.contains("git") {
233            categories.push(ToolCategory::Git);
234        }
235        if output_lower.contains("bash") {
236            categories.push(ToolCategory::Bash);
237        }
238
239        assert!(categories.contains(&ToolCategory::FileOps));
240        assert!(categories.contains(&ToolCategory::Git));
241        assert!(categories.contains(&ToolCategory::Bash));
242    }
243}