1use std::sync::Arc;
7use tracing::warn;
8
9use brainwires_core::message::Message;
10use brainwires_core::provider::{ChatOptions, Provider};
11use brainwires_tools::ToolCategory;
12
13use crate::InferenceTimer;
14
15#[derive(Clone, Debug)]
17pub struct RouteResult {
18 pub categories: Vec<ToolCategory>,
20 pub confidence: f32,
22 pub used_local_llm: bool,
24}
25
26impl RouteResult {
27 pub fn from_fallback(categories: Vec<ToolCategory>) -> Self {
29 Self {
30 categories,
31 confidence: 0.5, used_local_llm: false,
33 }
34 }
35
36 pub fn from_local(categories: Vec<ToolCategory>, confidence: f32) -> Self {
38 Self {
39 categories,
40 confidence,
41 used_local_llm: true,
42 }
43 }
44}
45
46pub struct LocalRouter {
48 provider: Arc<dyn Provider>,
49 model_id: String,
50}
51
52impl LocalRouter {
53 pub fn new(provider: Arc<dyn Provider>, model_id: impl Into<String>) -> Self {
55 Self {
56 provider,
57 model_id: model_id.into(),
58 }
59 }
60
61 pub async fn classify(&self, query: &str) -> Option<RouteResult> {
65 let timer = InferenceTimer::new("route_classify", &self.model_id);
66
67 let system_prompt = self.build_classification_prompt();
68 let user_prompt = format!(
69 "Classify this query into tool categories. Output ONLY the category names, comma-separated.\n\nQuery: {}",
70 query
71 );
72
73 let messages = vec![Message::user(&user_prompt)];
74 let options = ChatOptions::deterministic(50).system(system_prompt);
75
76 match self.provider.chat(&messages, None, &options).await {
77 Ok(response) => {
78 let text = response.message.text_or_summary();
79 let categories = self.parse_categories(&text);
80
81 if categories.is_empty() {
82 timer.finish(false);
83 return None;
84 }
85
86 timer.finish(true);
87 Some(RouteResult::from_local(categories, 0.85))
88 }
89 Err(e) => {
90 warn!(target: "local_llm", "Route classification failed: {}", e);
91 timer.finish(false);
92 None
93 }
94 }
95 }
96
97 fn build_classification_prompt(&self) -> String {
99 r#"You are a tool category classifier. Given a user query, output the relevant tool categories.
100
101Available categories:
102- FileOps: File operations (read, write, edit, create, delete, list files/directories)
103- Search: Text search (grep, find patterns, locate text)
104- SemanticSearch: Semantic/concept search (codebase queries, embeddings, RAG)
105- Git: Git operations (commit, diff, branch, merge, status, log)
106- TaskManager: Task tracking (todos, progress, subtasks)
107- AgentPool: Multi-agent operations (spawn, parallel, background)
108- Web: HTTP/API operations (fetch, request, download)
109- WebSearch: Internet search (google, browse, scrape)
110- Bash: Shell commands (run, execute, npm, cargo, pip, docker)
111- Planning: Design/architecture (plan, strategy, roadmap)
112- Context: Memory/recall (remember, previous, earlier)
113- Orchestrator: Script automation (workflow, batch)
114- CodeExecution: Code execution (run code, python, javascript)
115
116Rules:
1171. Output ONLY category names, comma-separated
1182. Include multiple categories if query spans multiple domains
1193. Always include FileOps if file operations might be needed
1204. Be conservative - only include clearly relevant categories"#.to_string()
121 }
122
123 fn parse_categories(&self, output: &str) -> Vec<ToolCategory> {
125 let mut categories = Vec::new();
126 let output_lower = output.to_lowercase();
127
128 let category_mappings = [
130 ("fileops", ToolCategory::FileOps),
131 ("file", ToolCategory::FileOps),
132 ("search", ToolCategory::Search),
133 ("semanticsearch", ToolCategory::SemanticSearch),
134 ("semantic", ToolCategory::SemanticSearch),
135 ("git", ToolCategory::Git),
136 ("taskmanager", ToolCategory::TaskManager),
137 ("task", ToolCategory::TaskManager),
138 ("agentpool", ToolCategory::AgentPool),
139 ("agent", ToolCategory::AgentPool),
140 ("web", ToolCategory::Web),
141 ("websearch", ToolCategory::WebSearch),
142 ("bash", ToolCategory::Bash),
143 ("shell", ToolCategory::Bash),
144 ("planning", ToolCategory::Planning),
145 ("plan", ToolCategory::Planning),
146 ("context", ToolCategory::Context),
147 ("orchestrator", ToolCategory::Orchestrator),
148 ("codeexecution", ToolCategory::CodeExecution),
149 ("code", ToolCategory::CodeExecution),
150 ];
151
152 for (keyword, category) in category_mappings {
153 if output_lower.contains(keyword) && !categories.contains(&category) {
154 categories.push(category);
155 }
156 }
157
158 categories
159 }
160}
161
162pub struct LocalRouterBuilder {
164 provider: Option<Arc<dyn Provider>>,
165 model_id: String,
166}
167
168impl Default for LocalRouterBuilder {
169 fn default() -> Self {
170 Self {
171 provider: None,
172 model_id: "lfm2-350m".to_string(),
173 }
174 }
175}
176
177impl LocalRouterBuilder {
178 pub fn new() -> Self {
180 Self::default()
181 }
182
183 pub fn provider(mut self, provider: Arc<dyn Provider>) -> Self {
185 self.provider = Some(provider);
186 self
187 }
188
189 pub fn model_id(mut self, model_id: impl Into<String>) -> Self {
191 self.model_id = model_id.into();
192 self
193 }
194
195 pub fn build(self) -> Option<LocalRouter> {
197 self.provider.map(|p| LocalRouter::new(p, self.model_id))
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204
205 #[test]
206 fn test_route_result_from_fallback() {
207 let result = RouteResult::from_fallback(vec![ToolCategory::FileOps, ToolCategory::Search]);
208 assert!(!result.used_local_llm);
209 assert_eq!(result.confidence, 0.5);
210 assert_eq!(result.categories.len(), 2);
211 }
212
213 #[test]
214 fn test_route_result_from_local() {
215 let result = RouteResult::from_local(vec![ToolCategory::Git], 0.9);
216 assert!(result.used_local_llm);
217 assert_eq!(result.confidence, 0.9);
218 }
219
220 #[test]
221 fn test_parse_categories() {
222 let _router = LocalRouterBuilder::default();
223
224 let output = "FileOps, Git, Bash";
226 let output_lower = output.to_lowercase();
227 let mut categories = Vec::new();
228
229 if output_lower.contains("fileops") || output_lower.contains("file") {
230 categories.push(ToolCategory::FileOps);
231 }
232 if output_lower.contains("git") {
233 categories.push(ToolCategory::Git);
234 }
235 if output_lower.contains("bash") {
236 categories.push(ToolCategory::Bash);
237 }
238
239 assert!(categories.contains(&ToolCategory::FileOps));
240 assert!(categories.contains(&ToolCategory::Git));
241 assert!(categories.contains(&ToolCategory::Bash));
242 }
243}