Skip to main content

brainwires_reasoning/
strategy_selector.rs

1//! Strategy Selector - Decomposition Strategy Selection
2//!
3//! Uses a provider to analyze tasks and recommend the optimal
4//! decomposition strategy for MDAP execution.
5
6use std::sync::Arc;
7use tracing::warn;
8
9use brainwires_core::message::Message;
10use brainwires_core::provider::{ChatOptions, Provider};
11
12use crate::InferenceTimer;
13
14/// Task type classification
15#[derive(Clone, Copy, Debug, PartialEq, Eq)]
16pub enum TaskType {
17    /// Code-related tasks (implementation, refactoring)
18    Code,
19    /// Multi-step planning tasks
20    Planning,
21    /// Research and analysis tasks
22    Analysis,
23    /// Simple single-step tasks
24    Simple,
25    /// Unknown/ambiguous tasks
26    Unknown,
27}
28
29impl TaskType {
30    /// Parse from string
31    #[allow(clippy::should_implement_trait)]
32    pub fn from_str(s: &str) -> Self {
33        let lower = s.to_lowercase();
34        if lower.contains("code") || lower.contains("implement") || lower.contains("refactor") {
35            TaskType::Code
36        } else if lower.contains("plan") || lower.contains("design") || lower.contains("architect")
37        {
38            TaskType::Planning
39        } else if lower.contains("analy")
40            || lower.contains("research")
41            || lower.contains("investigate")
42        {
43            TaskType::Analysis
44        } else if lower.contains("simple") || lower.contains("single") || lower.contains("atomic") {
45            TaskType::Simple
46        } else {
47            TaskType::Unknown
48        }
49    }
50}
51
52/// Recommended decomposition strategy
53#[derive(Clone, Debug)]
54pub enum RecommendedStrategy {
55    /// Binary recursive decomposition.
56    BinaryRecursive {
57        /// Maximum recursion depth for decomposition.
58        max_depth: u32,
59    },
60    /// Sequential step-by-step
61    Sequential,
62    /// Domain-specific for code
63    CodeOperations,
64    /// No decomposition needed
65    None,
66}
67
68impl RecommendedStrategy {
69    /// Get default max_depth for binary recursive
70    pub fn default_depth() -> u32 {
71        10
72    }
73}
74
75/// Result of strategy selection
76#[derive(Clone, Debug)]
77pub struct StrategyResult {
78    /// Recommended strategy
79    pub strategy: RecommendedStrategy,
80    /// Task type classification
81    pub task_type: TaskType,
82    /// Confidence score
83    pub confidence: f32,
84    /// Whether LLM was used
85    pub used_local_llm: bool,
86    /// Reasoning for the selection
87    pub reasoning: Option<String>,
88}
89
90impl StrategyResult {
91    /// Create from LLM selection
92    pub fn from_local(
93        strategy: RecommendedStrategy,
94        task_type: TaskType,
95        confidence: f32,
96        reasoning: Option<String>,
97    ) -> Self {
98        Self {
99            strategy,
100            task_type,
101            confidence,
102            used_local_llm: true,
103            reasoning,
104        }
105    }
106
107    /// Create from heuristic selection
108    pub fn from_heuristic(strategy: RecommendedStrategy, task_type: TaskType) -> Self {
109        Self {
110            strategy,
111            task_type,
112            confidence: 0.5,
113            used_local_llm: false,
114            reasoning: None,
115        }
116    }
117}
118
119/// Strategy selector for MDAP decomposition
120pub struct StrategySelector {
121    provider: Arc<dyn Provider>,
122    model_id: String,
123}
124
125impl StrategySelector {
126    /// Create a new strategy selector
127    pub fn new(provider: Arc<dyn Provider>, model_id: impl Into<String>) -> Self {
128        Self {
129            provider,
130            model_id: model_id.into(),
131        }
132    }
133
134    /// Select the optimal decomposition strategy for a task
135    pub async fn select_strategy(&self, task: &str) -> Option<StrategyResult> {
136        let timer = InferenceTimer::new("select_strategy", &self.model_id);
137
138        let prompt = self.build_selection_prompt(task);
139
140        let messages = vec![Message::user(&prompt)];
141        let options = ChatOptions::deterministic(100);
142
143        match self.provider.chat(&messages, None, &options).await {
144            Ok(response) => {
145                let output = response.message.text_or_summary();
146                let result = self.parse_selection(&output);
147                timer.finish(true);
148                Some(result)
149            }
150            Err(e) => {
151                warn!(target: "local_llm", "Strategy selection failed: {}", e);
152                timer.finish(false);
153                None
154            }
155        }
156    }
157
158    /// Heuristic strategy selection (pattern-based fallback)
159    pub fn select_heuristic(&self, task: &str) -> StrategyResult {
160        let lower = task.to_lowercase();
161        let word_count = task.split_whitespace().count();
162
163        // Detect task type
164        let task_type = self.classify_task_type(&lower);
165
166        // Select strategy based on task type and complexity
167        let strategy = match task_type {
168            TaskType::Simple => RecommendedStrategy::None,
169            TaskType::Code => {
170                if word_count > 30 {
171                    RecommendedStrategy::BinaryRecursive { max_depth: 8 }
172                } else {
173                    RecommendedStrategy::CodeOperations
174                }
175            }
176            TaskType::Planning => {
177                if word_count > 50 {
178                    RecommendedStrategy::BinaryRecursive { max_depth: 10 }
179                } else {
180                    RecommendedStrategy::Sequential
181                }
182            }
183            TaskType::Analysis => RecommendedStrategy::Sequential,
184            TaskType::Unknown => {
185                // Use complexity heuristics
186                if word_count < 10 {
187                    RecommendedStrategy::None
188                } else if word_count < 30 {
189                    RecommendedStrategy::Sequential
190                } else {
191                    RecommendedStrategy::BinaryRecursive { max_depth: 10 }
192                }
193            }
194        };
195
196        StrategyResult::from_heuristic(strategy, task_type)
197    }
198
199    /// Classify task type from content
200    fn classify_task_type(&self, lower: &str) -> TaskType {
201        // Code indicators
202        let code_indicators = [
203            "implement",
204            "code",
205            "function",
206            "class",
207            "method",
208            "refactor",
209            "debug",
210            "fix bug",
211            "write a",
212            "create a function",
213            "add a feature",
214        ];
215
216        // Planning indicators
217        let planning_indicators = [
218            "plan",
219            "design",
220            "architect",
221            "strategy",
222            "roadmap",
223            "outline",
224            "structure",
225            "organize",
226        ];
227
228        // Analysis indicators
229        let analysis_indicators = [
230            "analyze",
231            "research",
232            "investigate",
233            "explain",
234            "understand",
235            "review",
236            "audit",
237            "examine",
238            "study",
239        ];
240
241        // Simple indicators
242        let simple_indicators = ["just", "simply", "only", "quick", "small change"];
243
244        // Check code first (common case)
245        if code_indicators.iter().any(|i| lower.contains(i)) {
246            return TaskType::Code;
247        }
248
249        if planning_indicators.iter().any(|i| lower.contains(i)) {
250            return TaskType::Planning;
251        }
252
253        if analysis_indicators.iter().any(|i| lower.contains(i)) {
254            return TaskType::Analysis;
255        }
256
257        if simple_indicators.iter().any(|i| lower.contains(i)) {
258            return TaskType::Simple;
259        }
260
261        TaskType::Unknown
262    }
263
264    /// Build the selection prompt
265    fn build_selection_prompt(&self, task: &str) -> String {
266        format!(
267            r#"Analyze this task and recommend the best decomposition strategy.
268
269Task: "{}"
270
271Available strategies:
2721. BINARY_RECURSIVE - Best for complex tasks that can be split recursively (many subtasks)
2732. SEQUENTIAL - Best for step-by-step tasks with clear ordering (moderate complexity)
2743. CODE_OPERATIONS - Best for code-specific tasks (implementation, refactoring)
2754. NONE - Best for simple, atomic tasks that don't need decomposition
276
277Also classify the task type:
278- CODE: Implementation, refactoring, debugging
279- PLANNING: Design, architecture, strategy
280- ANALYSIS: Research, investigation, review
281- SIMPLE: Quick, single-step tasks
282
283Output format:
284STRATEGY: <strategy_name>
285TYPE: <task_type>
286REASON: <brief explanation>
287
288Selection:"#,
289            if task.len() > 300 { &task[..300] } else { task }
290        )
291    }
292
293    /// Parse the LLM output
294    fn parse_selection(&self, output: &str) -> StrategyResult {
295        let upper = output.to_uppercase();
296
297        // Parse strategy
298        let strategy = if upper.contains("BINARY_RECURSIVE") || upper.contains("BINARY RECURSIVE") {
299            RecommendedStrategy::BinaryRecursive { max_depth: 10 }
300        } else if upper.contains("SEQUENTIAL") {
301            RecommendedStrategy::Sequential
302        } else if upper.contains("CODE_OPERATIONS") || upper.contains("CODE OPERATIONS") {
303            RecommendedStrategy::CodeOperations
304        } else if upper.contains("NONE") {
305            RecommendedStrategy::None
306        } else {
307            // Default to sequential for ambiguous cases
308            RecommendedStrategy::Sequential
309        };
310
311        // Parse task type
312        let task_type = if upper.contains("TYPE: CODE") || upper.contains("TYPE:CODE") {
313            TaskType::Code
314        } else if upper.contains("TYPE: PLANNING") || upper.contains("TYPE:PLANNING") {
315            TaskType::Planning
316        } else if upper.contains("TYPE: ANALYSIS") || upper.contains("TYPE:ANALYSIS") {
317            TaskType::Analysis
318        } else if upper.contains("TYPE: SIMPLE") || upper.contains("TYPE:SIMPLE") {
319            TaskType::Simple
320        } else {
321            TaskType::Unknown
322        };
323
324        // Extract reasoning
325        let reasoning = if let Some(reason_start) = output.find("REASON:") {
326            let reason_text = &output[reason_start + 7..];
327            let end = reason_text.find('\n').unwrap_or(reason_text.len());
328            Some(reason_text[..end].trim().to_string())
329        } else {
330            None
331        };
332
333        StrategyResult::from_local(strategy, task_type, 0.8, reasoning)
334    }
335}
336
337/// Builder for StrategySelector
338pub struct StrategySelectorBuilder {
339    provider: Option<Arc<dyn Provider>>,
340    model_id: String,
341}
342
343impl Default for StrategySelectorBuilder {
344    fn default() -> Self {
345        Self {
346            provider: None,
347            model_id: "lfm2-1.2b".to_string(), // Larger model for better reasoning
348        }
349    }
350}
351
352impl StrategySelectorBuilder {
353    /// Create a new builder with default settings.
354    pub fn new() -> Self {
355        Self::default()
356    }
357
358    /// Set the provider to use for strategy selection.
359    pub fn provider(mut self, provider: Arc<dyn Provider>) -> Self {
360        self.provider = Some(provider);
361        self
362    }
363
364    /// Set the model ID to use for inference.
365    pub fn model_id(mut self, model_id: impl Into<String>) -> Self {
366        self.model_id = model_id.into();
367        self
368    }
369
370    /// Build the strategy selector, returning `None` if no provider was set.
371    pub fn build(self) -> Option<StrategySelector> {
372        self.provider
373            .map(|p| StrategySelector::new(p, self.model_id))
374    }
375}
376
377#[cfg(test)]
378mod tests {
379    use super::*;
380
381    #[test]
382    fn test_task_type_parsing() {
383        assert_eq!(TaskType::from_str("code"), TaskType::Code);
384        assert_eq!(TaskType::from_str("implement feature"), TaskType::Code);
385        assert_eq!(
386            TaskType::from_str("design architecture"),
387            TaskType::Planning
388        );
389        assert_eq!(TaskType::from_str("analyze the data"), TaskType::Analysis);
390        assert_eq!(TaskType::from_str("simple fix"), TaskType::Simple);
391        assert_eq!(TaskType::from_str("random text"), TaskType::Unknown);
392    }
393
394    #[test]
395    fn test_heuristic_selection_code() {
396        let _selector = StrategySelectorBuilder::default();
397        let result = select_heuristic_direct("Implement a new authentication system with OAuth2");
398        assert_eq!(result.task_type, TaskType::Code);
399    }
400
401    #[test]
402    fn test_heuristic_selection_simple() {
403        let result = select_heuristic_direct("just fix the typo");
404        assert_eq!(result.task_type, TaskType::Simple);
405        assert!(matches!(result.strategy, RecommendedStrategy::None));
406    }
407
408    #[test]
409    fn test_heuristic_selection_planning() {
410        let result =
411            select_heuristic_direct("Design the system architecture for the new microservice");
412        assert_eq!(result.task_type, TaskType::Planning);
413    }
414
415    fn select_heuristic_direct(task: &str) -> StrategyResult {
416        let lower = task.to_lowercase();
417        let word_count = task.split_whitespace().count();
418
419        let task_type = classify_task_type_direct(&lower);
420
421        let strategy = match task_type {
422            TaskType::Simple => RecommendedStrategy::None,
423            TaskType::Code => {
424                if word_count > 30 {
425                    RecommendedStrategy::BinaryRecursive { max_depth: 8 }
426                } else {
427                    RecommendedStrategy::CodeOperations
428                }
429            }
430            TaskType::Planning => {
431                if word_count > 50 {
432                    RecommendedStrategy::BinaryRecursive { max_depth: 10 }
433                } else {
434                    RecommendedStrategy::Sequential
435                }
436            }
437            TaskType::Analysis => RecommendedStrategy::Sequential,
438            TaskType::Unknown => {
439                if word_count < 10 {
440                    RecommendedStrategy::None
441                } else if word_count < 30 {
442                    RecommendedStrategy::Sequential
443                } else {
444                    RecommendedStrategy::BinaryRecursive { max_depth: 10 }
445                }
446            }
447        };
448
449        StrategyResult::from_heuristic(strategy, task_type)
450    }
451
452    fn classify_task_type_direct(lower: &str) -> TaskType {
453        let code_indicators = ["implement", "code", "function", "class", "refactor"];
454        let planning_indicators = ["plan", "design", "architect"];
455        let analysis_indicators = ["analyze", "research", "investigate"];
456        let simple_indicators = ["just", "simply", "only"];
457
458        if code_indicators.iter().any(|i| lower.contains(i)) {
459            return TaskType::Code;
460        }
461        if planning_indicators.iter().any(|i| lower.contains(i)) {
462            return TaskType::Planning;
463        }
464        if analysis_indicators.iter().any(|i| lower.contains(i)) {
465            return TaskType::Analysis;
466        }
467        if simple_indicators.iter().any(|i| lower.contains(i)) {
468            return TaskType::Simple;
469        }
470        TaskType::Unknown
471    }
472
473    #[test]
474    fn test_parse_selection() {
475        let output = r#"STRATEGY: BINARY_RECURSIVE
476TYPE: CODE
477REASON: Task involves multiple implementation steps"#;
478
479        let result = parse_selection_direct(output);
480        assert!(matches!(
481            result.strategy,
482            RecommendedStrategy::BinaryRecursive { .. }
483        ));
484        assert_eq!(result.task_type, TaskType::Code);
485        assert!(result.reasoning.is_some());
486    }
487
488    fn parse_selection_direct(output: &str) -> StrategyResult {
489        let upper = output.to_uppercase();
490
491        let strategy = if upper.contains("BINARY_RECURSIVE") {
492            RecommendedStrategy::BinaryRecursive { max_depth: 10 }
493        } else if upper.contains("SEQUENTIAL") {
494            RecommendedStrategy::Sequential
495        } else if upper.contains("CODE_OPERATIONS") {
496            RecommendedStrategy::CodeOperations
497        } else if upper.contains("NONE") {
498            RecommendedStrategy::None
499        } else {
500            RecommendedStrategy::Sequential
501        };
502
503        let task_type = if upper.contains("TYPE: CODE") {
504            TaskType::Code
505        } else if upper.contains("TYPE: PLANNING") {
506            TaskType::Planning
507        } else if upper.contains("TYPE: ANALYSIS") {
508            TaskType::Analysis
509        } else if upper.contains("TYPE: SIMPLE") {
510            TaskType::Simple
511        } else {
512            TaskType::Unknown
513        };
514
515        let reasoning = if let Some(reason_start) = output.find("REASON:") {
516            let reason_text = &output[reason_start + 7..];
517            let end = reason_text.find('\n').unwrap_or(reason_text.len());
518            Some(reason_text[..end].trim().to_string())
519        } else {
520            None
521        };
522
523        StrategyResult::from_local(strategy, task_type, 0.8, reasoning)
524    }
525}