1use std::sync::Arc;
7use tracing::warn;
8
9use brainwires_core::message::Message;
10use brainwires_core::provider::{ChatOptions, Provider};
11
12use crate::InferenceTimer;
13
14#[derive(Clone, Copy, Debug, PartialEq, Eq)]
16pub enum TaskType {
17 Code,
19 Planning,
21 Analysis,
23 Simple,
25 Unknown,
27}
28
29impl TaskType {
30 #[allow(clippy::should_implement_trait)]
32 pub fn from_str(s: &str) -> Self {
33 let lower = s.to_lowercase();
34 if lower.contains("code") || lower.contains("implement") || lower.contains("refactor") {
35 TaskType::Code
36 } else if lower.contains("plan") || lower.contains("design") || lower.contains("architect")
37 {
38 TaskType::Planning
39 } else if lower.contains("analy")
40 || lower.contains("research")
41 || lower.contains("investigate")
42 {
43 TaskType::Analysis
44 } else if lower.contains("simple") || lower.contains("single") || lower.contains("atomic") {
45 TaskType::Simple
46 } else {
47 TaskType::Unknown
48 }
49 }
50}
51
52#[derive(Clone, Debug)]
54pub enum RecommendedStrategy {
55 BinaryRecursive {
57 max_depth: u32,
59 },
60 Sequential,
62 CodeOperations,
64 None,
66}
67
68impl RecommendedStrategy {
69 pub fn default_depth() -> u32 {
71 10
72 }
73}
74
75#[derive(Clone, Debug)]
77pub struct StrategyResult {
78 pub strategy: RecommendedStrategy,
80 pub task_type: TaskType,
82 pub confidence: f32,
84 pub used_local_llm: bool,
86 pub reasoning: Option<String>,
88}
89
90impl StrategyResult {
91 pub fn from_local(
93 strategy: RecommendedStrategy,
94 task_type: TaskType,
95 confidence: f32,
96 reasoning: Option<String>,
97 ) -> Self {
98 Self {
99 strategy,
100 task_type,
101 confidence,
102 used_local_llm: true,
103 reasoning,
104 }
105 }
106
107 pub fn from_heuristic(strategy: RecommendedStrategy, task_type: TaskType) -> Self {
109 Self {
110 strategy,
111 task_type,
112 confidence: 0.5,
113 used_local_llm: false,
114 reasoning: None,
115 }
116 }
117}
118
119pub struct StrategySelector {
121 provider: Arc<dyn Provider>,
122 model_id: String,
123}
124
125impl StrategySelector {
126 pub fn new(provider: Arc<dyn Provider>, model_id: impl Into<String>) -> Self {
128 Self {
129 provider,
130 model_id: model_id.into(),
131 }
132 }
133
134 pub async fn select_strategy(&self, task: &str) -> Option<StrategyResult> {
136 let timer = InferenceTimer::new("select_strategy", &self.model_id);
137
138 let prompt = self.build_selection_prompt(task);
139
140 let messages = vec![Message::user(&prompt)];
141 let options = ChatOptions::deterministic(100);
142
143 match self.provider.chat(&messages, None, &options).await {
144 Ok(response) => {
145 let output = response.message.text_or_summary();
146 let result = self.parse_selection(&output);
147 timer.finish(true);
148 Some(result)
149 }
150 Err(e) => {
151 warn!(target: "local_llm", "Strategy selection failed: {}", e);
152 timer.finish(false);
153 None
154 }
155 }
156 }
157
158 pub fn select_heuristic(&self, task: &str) -> StrategyResult {
160 let lower = task.to_lowercase();
161 let word_count = task.split_whitespace().count();
162
163 let task_type = self.classify_task_type(&lower);
165
166 let strategy = match task_type {
168 TaskType::Simple => RecommendedStrategy::None,
169 TaskType::Code => {
170 if word_count > 30 {
171 RecommendedStrategy::BinaryRecursive { max_depth: 8 }
172 } else {
173 RecommendedStrategy::CodeOperations
174 }
175 }
176 TaskType::Planning => {
177 if word_count > 50 {
178 RecommendedStrategy::BinaryRecursive { max_depth: 10 }
179 } else {
180 RecommendedStrategy::Sequential
181 }
182 }
183 TaskType::Analysis => RecommendedStrategy::Sequential,
184 TaskType::Unknown => {
185 if word_count < 10 {
187 RecommendedStrategy::None
188 } else if word_count < 30 {
189 RecommendedStrategy::Sequential
190 } else {
191 RecommendedStrategy::BinaryRecursive { max_depth: 10 }
192 }
193 }
194 };
195
196 StrategyResult::from_heuristic(strategy, task_type)
197 }
198
199 fn classify_task_type(&self, lower: &str) -> TaskType {
201 let code_indicators = [
203 "implement",
204 "code",
205 "function",
206 "class",
207 "method",
208 "refactor",
209 "debug",
210 "fix bug",
211 "write a",
212 "create a function",
213 "add a feature",
214 ];
215
216 let planning_indicators = [
218 "plan",
219 "design",
220 "architect",
221 "strategy",
222 "roadmap",
223 "outline",
224 "structure",
225 "organize",
226 ];
227
228 let analysis_indicators = [
230 "analyze",
231 "research",
232 "investigate",
233 "explain",
234 "understand",
235 "review",
236 "audit",
237 "examine",
238 "study",
239 ];
240
241 let simple_indicators = ["just", "simply", "only", "quick", "small change"];
243
244 if code_indicators.iter().any(|i| lower.contains(i)) {
246 return TaskType::Code;
247 }
248
249 if planning_indicators.iter().any(|i| lower.contains(i)) {
250 return TaskType::Planning;
251 }
252
253 if analysis_indicators.iter().any(|i| lower.contains(i)) {
254 return TaskType::Analysis;
255 }
256
257 if simple_indicators.iter().any(|i| lower.contains(i)) {
258 return TaskType::Simple;
259 }
260
261 TaskType::Unknown
262 }
263
264 fn build_selection_prompt(&self, task: &str) -> String {
266 format!(
267 r#"Analyze this task and recommend the best decomposition strategy.
268
269Task: "{}"
270
271Available strategies:
2721. BINARY_RECURSIVE - Best for complex tasks that can be split recursively (many subtasks)
2732. SEQUENTIAL - Best for step-by-step tasks with clear ordering (moderate complexity)
2743. CODE_OPERATIONS - Best for code-specific tasks (implementation, refactoring)
2754. NONE - Best for simple, atomic tasks that don't need decomposition
276
277Also classify the task type:
278- CODE: Implementation, refactoring, debugging
279- PLANNING: Design, architecture, strategy
280- ANALYSIS: Research, investigation, review
281- SIMPLE: Quick, single-step tasks
282
283Output format:
284STRATEGY: <strategy_name>
285TYPE: <task_type>
286REASON: <brief explanation>
287
288Selection:"#,
289 if task.len() > 300 { &task[..300] } else { task }
290 )
291 }
292
293 fn parse_selection(&self, output: &str) -> StrategyResult {
295 let upper = output.to_uppercase();
296
297 let strategy = if upper.contains("BINARY_RECURSIVE") || upper.contains("BINARY RECURSIVE") {
299 RecommendedStrategy::BinaryRecursive { max_depth: 10 }
300 } else if upper.contains("SEQUENTIAL") {
301 RecommendedStrategy::Sequential
302 } else if upper.contains("CODE_OPERATIONS") || upper.contains("CODE OPERATIONS") {
303 RecommendedStrategy::CodeOperations
304 } else if upper.contains("NONE") {
305 RecommendedStrategy::None
306 } else {
307 RecommendedStrategy::Sequential
309 };
310
311 let task_type = if upper.contains("TYPE: CODE") || upper.contains("TYPE:CODE") {
313 TaskType::Code
314 } else if upper.contains("TYPE: PLANNING") || upper.contains("TYPE:PLANNING") {
315 TaskType::Planning
316 } else if upper.contains("TYPE: ANALYSIS") || upper.contains("TYPE:ANALYSIS") {
317 TaskType::Analysis
318 } else if upper.contains("TYPE: SIMPLE") || upper.contains("TYPE:SIMPLE") {
319 TaskType::Simple
320 } else {
321 TaskType::Unknown
322 };
323
324 let reasoning = if let Some(reason_start) = output.find("REASON:") {
326 let reason_text = &output[reason_start + 7..];
327 let end = reason_text.find('\n').unwrap_or(reason_text.len());
328 Some(reason_text[..end].trim().to_string())
329 } else {
330 None
331 };
332
333 StrategyResult::from_local(strategy, task_type, 0.8, reasoning)
334 }
335}
336
337pub struct StrategySelectorBuilder {
339 provider: Option<Arc<dyn Provider>>,
340 model_id: String,
341}
342
343impl Default for StrategySelectorBuilder {
344 fn default() -> Self {
345 Self {
346 provider: None,
347 model_id: "lfm2-1.2b".to_string(), }
349 }
350}
351
352impl StrategySelectorBuilder {
353 pub fn new() -> Self {
355 Self::default()
356 }
357
358 pub fn provider(mut self, provider: Arc<dyn Provider>) -> Self {
360 self.provider = Some(provider);
361 self
362 }
363
364 pub fn model_id(mut self, model_id: impl Into<String>) -> Self {
366 self.model_id = model_id.into();
367 self
368 }
369
370 pub fn build(self) -> Option<StrategySelector> {
372 self.provider
373 .map(|p| StrategySelector::new(p, self.model_id))
374 }
375}
376
377#[cfg(test)]
378mod tests {
379 use super::*;
380
381 #[test]
382 fn test_task_type_parsing() {
383 assert_eq!(TaskType::from_str("code"), TaskType::Code);
384 assert_eq!(TaskType::from_str("implement feature"), TaskType::Code);
385 assert_eq!(
386 TaskType::from_str("design architecture"),
387 TaskType::Planning
388 );
389 assert_eq!(TaskType::from_str("analyze the data"), TaskType::Analysis);
390 assert_eq!(TaskType::from_str("simple fix"), TaskType::Simple);
391 assert_eq!(TaskType::from_str("random text"), TaskType::Unknown);
392 }
393
394 #[test]
395 fn test_heuristic_selection_code() {
396 let _selector = StrategySelectorBuilder::default();
397 let result = select_heuristic_direct("Implement a new authentication system with OAuth2");
398 assert_eq!(result.task_type, TaskType::Code);
399 }
400
401 #[test]
402 fn test_heuristic_selection_simple() {
403 let result = select_heuristic_direct("just fix the typo");
404 assert_eq!(result.task_type, TaskType::Simple);
405 assert!(matches!(result.strategy, RecommendedStrategy::None));
406 }
407
408 #[test]
409 fn test_heuristic_selection_planning() {
410 let result =
411 select_heuristic_direct("Design the system architecture for the new microservice");
412 assert_eq!(result.task_type, TaskType::Planning);
413 }
414
415 fn select_heuristic_direct(task: &str) -> StrategyResult {
416 let lower = task.to_lowercase();
417 let word_count = task.split_whitespace().count();
418
419 let task_type = classify_task_type_direct(&lower);
420
421 let strategy = match task_type {
422 TaskType::Simple => RecommendedStrategy::None,
423 TaskType::Code => {
424 if word_count > 30 {
425 RecommendedStrategy::BinaryRecursive { max_depth: 8 }
426 } else {
427 RecommendedStrategy::CodeOperations
428 }
429 }
430 TaskType::Planning => {
431 if word_count > 50 {
432 RecommendedStrategy::BinaryRecursive { max_depth: 10 }
433 } else {
434 RecommendedStrategy::Sequential
435 }
436 }
437 TaskType::Analysis => RecommendedStrategy::Sequential,
438 TaskType::Unknown => {
439 if word_count < 10 {
440 RecommendedStrategy::None
441 } else if word_count < 30 {
442 RecommendedStrategy::Sequential
443 } else {
444 RecommendedStrategy::BinaryRecursive { max_depth: 10 }
445 }
446 }
447 };
448
449 StrategyResult::from_heuristic(strategy, task_type)
450 }
451
452 fn classify_task_type_direct(lower: &str) -> TaskType {
453 let code_indicators = ["implement", "code", "function", "class", "refactor"];
454 let planning_indicators = ["plan", "design", "architect"];
455 let analysis_indicators = ["analyze", "research", "investigate"];
456 let simple_indicators = ["just", "simply", "only"];
457
458 if code_indicators.iter().any(|i| lower.contains(i)) {
459 return TaskType::Code;
460 }
461 if planning_indicators.iter().any(|i| lower.contains(i)) {
462 return TaskType::Planning;
463 }
464 if analysis_indicators.iter().any(|i| lower.contains(i)) {
465 return TaskType::Analysis;
466 }
467 if simple_indicators.iter().any(|i| lower.contains(i)) {
468 return TaskType::Simple;
469 }
470 TaskType::Unknown
471 }
472
473 #[test]
474 fn test_parse_selection() {
475 let output = r#"STRATEGY: BINARY_RECURSIVE
476TYPE: CODE
477REASON: Task involves multiple implementation steps"#;
478
479 let result = parse_selection_direct(output);
480 assert!(matches!(
481 result.strategy,
482 RecommendedStrategy::BinaryRecursive { .. }
483 ));
484 assert_eq!(result.task_type, TaskType::Code);
485 assert!(result.reasoning.is_some());
486 }
487
488 fn parse_selection_direct(output: &str) -> StrategyResult {
489 let upper = output.to_uppercase();
490
491 let strategy = if upper.contains("BINARY_RECURSIVE") {
492 RecommendedStrategy::BinaryRecursive { max_depth: 10 }
493 } else if upper.contains("SEQUENTIAL") {
494 RecommendedStrategy::Sequential
495 } else if upper.contains("CODE_OPERATIONS") {
496 RecommendedStrategy::CodeOperations
497 } else if upper.contains("NONE") {
498 RecommendedStrategy::None
499 } else {
500 RecommendedStrategy::Sequential
501 };
502
503 let task_type = if upper.contains("TYPE: CODE") {
504 TaskType::Code
505 } else if upper.contains("TYPE: PLANNING") {
506 TaskType::Planning
507 } else if upper.contains("TYPE: ANALYSIS") {
508 TaskType::Analysis
509 } else if upper.contains("TYPE: SIMPLE") {
510 TaskType::Simple
511 } else {
512 TaskType::Unknown
513 };
514
515 let reasoning = if let Some(reason_start) = output.find("REASON:") {
516 let reason_text = &output[reason_start + 7..];
517 let end = reason_text.find('\n').unwrap_or(reason_text.len());
518 Some(reason_text[..end].trim().to_string())
519 } else {
520 None
521 };
522
523 StrategyResult::from_local(strategy, task_type, 0.8, reasoning)
524 }
525}