Skip to main content

brainwires_reasoning/
complexity.rs

1//! Complexity Scorer - Task Complexity Assessment
2//!
3//! Uses a provider to score task complexity (0.0 - 1.0),
4//! enabling adaptive k adjustment in MDAP voting.
5
6use std::sync::Arc;
7use tracing::warn;
8
9use brainwires_core::message::Message;
10use brainwires_core::provider::{ChatOptions, Provider};
11
12use crate::InferenceTimer;
13
14/// Result of complexity scoring
15#[derive(Clone, Debug)]
16pub struct ComplexityResult {
17    /// Complexity score (0.0 = trivial, 1.0 = very complex)
18    pub score: f32,
19    /// Confidence in the score (0.0 - 1.0)
20    pub confidence: f32,
21    /// Whether LLM was used (vs default)
22    pub used_local_llm: bool,
23}
24
25impl ComplexityResult {
26    /// Create a default complexity result (fallback)
27    pub fn default_complexity() -> Self {
28        Self {
29            score: 0.5, // Medium complexity as default
30            confidence: 0.3,
31            used_local_llm: false,
32        }
33    }
34
35    /// Create a result from LLM scoring
36    pub fn from_local(score: f32, confidence: f32) -> Self {
37        Self {
38            score: score.clamp(0.0, 1.0),
39            confidence: confidence.clamp(0.0, 1.0),
40            used_local_llm: true,
41        }
42    }
43}
44
45/// Complexity scorer for task difficulty assessment
46pub struct ComplexityScorer {
47    provider: Arc<dyn Provider>,
48    model_id: String,
49}
50
51impl ComplexityScorer {
52    /// Create a new complexity scorer
53    pub fn new(provider: Arc<dyn Provider>, model_id: impl Into<String>) -> Self {
54        Self {
55            provider,
56            model_id: model_id.into(),
57        }
58    }
59
60    /// Score the complexity of a task description
61    ///
62    /// Returns a score from 0.0 (trivial) to 1.0 (very complex).
63    /// Returns None if scoring fails, allowing fallback to default.
64    pub async fn score(&self, task_description: &str) -> Option<ComplexityResult> {
65        let timer = InferenceTimer::new("complexity_score", &self.model_id);
66
67        let system_prompt = self.build_scoring_prompt();
68        let user_prompt = format!(
69            "Rate the complexity of this task from 0.0 (trivial) to 1.0 (very complex). Output ONLY a decimal number.\n\nTask: {}",
70            task_description
71        );
72
73        let messages = vec![Message::user(&user_prompt)];
74        let options = ChatOptions::deterministic(10).system(system_prompt);
75
76        match self.provider.chat(&messages, None, &options).await {
77            Ok(response) => {
78                let text = response.message.text_or_summary();
79                if let Some(score) = self.parse_score(&text) {
80                    timer.finish(true);
81                    Some(ComplexityResult::from_local(score, 0.8))
82                } else {
83                    timer.finish(false);
84                    None
85                }
86            }
87            Err(e) => {
88                warn!(target: "local_llm", "Complexity scoring failed: {}", e);
89                timer.finish(false);
90                None
91            }
92        }
93    }
94
95    /// Score complexity synchronously (for use in sync contexts)
96    /// Uses heuristics instead of LLM for speed.
97    pub fn score_heuristic(&self, task_description: &str) -> ComplexityResult {
98        let desc_lower = task_description.to_lowercase();
99        let mut score: f32 = 0.3; // Base score
100
101        // Complexity indicators (increase score)
102        let complex_indicators = [
103            ("multiple", 0.1),
104            ("several", 0.1),
105            ("complex", 0.15),
106            ("difficult", 0.15),
107            ("careful", 0.1),
108            ("ensure", 0.05),
109            ("validate", 0.1),
110            ("analyze", 0.1),
111            ("refactor", 0.15),
112            ("architecture", 0.2),
113            ("design", 0.1),
114            ("optimize", 0.15),
115            ("performance", 0.1),
116            ("security", 0.15),
117            ("concurrent", 0.2),
118            ("async", 0.1),
119            ("parallel", 0.15),
120            ("distributed", 0.2),
121        ];
122
123        // Simplicity indicators (decrease score)
124        let simple_indicators = [
125            ("simple", -0.1),
126            ("trivial", -0.15),
127            ("just", -0.05),
128            ("only", -0.05),
129            ("basic", -0.1),
130            ("single", -0.05),
131            ("one", -0.05),
132            ("quick", -0.1),
133            ("easy", -0.1),
134        ];
135
136        for (keyword, adjustment) in complex_indicators {
137            if desc_lower.contains(keyword) {
138                score += adjustment;
139            }
140        }
141
142        for (keyword, adjustment) in simple_indicators {
143            if desc_lower.contains(keyword) {
144                score += adjustment;
145            }
146        }
147
148        // Length-based adjustment (longer = more complex)
149        let word_count = task_description.split_whitespace().count();
150        if word_count > 50 {
151            score += 0.15;
152        } else if word_count > 30 {
153            score += 0.1;
154        } else if word_count < 10 {
155            score -= 0.1;
156        }
157
158        ComplexityResult {
159            score: score.clamp(0.0, 1.0),
160            confidence: 0.4, // Lower confidence for heuristic
161            used_local_llm: false,
162        }
163    }
164
165    /// Build the system prompt for complexity scoring
166    fn build_scoring_prompt(&self) -> String {
167        r#"You are a task complexity evaluator. Given a task description, output a complexity score.
168
169Scoring guide:
170- 0.0-0.2: Trivial (single step, no decisions)
171- 0.2-0.4: Simple (few steps, straightforward)
172- 0.4-0.6: Moderate (multiple steps, some decisions)
173- 0.6-0.8: Complex (many steps, careful reasoning needed)
174- 0.8-1.0: Very complex (intricate logic, multiple dependencies)
175
176Consider:
177- Number of steps or operations needed
178- Required reasoning depth
179- Ambiguity in requirements
180- Dependencies between parts
181- Potential for errors
182
183Output ONLY a decimal number between 0.0 and 1.0."#
184            .to_string()
185    }
186
187    /// Parse the LLM output to extract a score
188    fn parse_score(&self, output: &str) -> Option<f32> {
189        // Try to find a floating point number in the output
190        let cleaned = output.trim();
191
192        // Direct parse
193        if let Ok(score) = cleaned.parse::<f32>() {
194            return Some(score.clamp(0.0, 1.0));
195        }
196
197        // Look for a number pattern
198        let number_pattern = regex::Regex::new(r"(\d+\.?\d*)").ok()?;
199        if let Some(captures) = number_pattern.captures(cleaned)
200            && let Some(m) = captures.get(1)
201            && let Ok(score) = m.as_str().parse::<f32>()
202        {
203            return Some(score.clamp(0.0, 1.0));
204        }
205
206        None
207    }
208}
209
210/// Builder for ComplexityScorer
211pub struct ComplexityScorerBuilder {
212    provider: Option<Arc<dyn Provider>>,
213    model_id: String,
214}
215
216impl Default for ComplexityScorerBuilder {
217    fn default() -> Self {
218        Self {
219            provider: None,
220            model_id: "lfm2-350m".to_string(),
221        }
222    }
223}
224
225impl ComplexityScorerBuilder {
226    /// Create a new builder with default settings.
227    pub fn new() -> Self {
228        Self::default()
229    }
230
231    /// Set the provider to use for complexity scoring.
232    pub fn provider(mut self, provider: Arc<dyn Provider>) -> Self {
233        self.provider = Some(provider);
234        self
235    }
236
237    /// Set the model ID to use for inference.
238    pub fn model_id(mut self, model_id: impl Into<String>) -> Self {
239        self.model_id = model_id.into();
240        self
241    }
242
243    /// Build the complexity scorer, returning `None` if no provider was set.
244    pub fn build(self) -> Option<ComplexityScorer> {
245        self.provider
246            .map(|p| ComplexityScorer::new(p, self.model_id))
247    }
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253
254    #[test]
255    fn test_complexity_result_default() {
256        let result = ComplexityResult::default_complexity();
257        assert_eq!(result.score, 0.5);
258        assert!(!result.used_local_llm);
259    }
260
261    #[test]
262    fn test_complexity_result_clamping() {
263        let result = ComplexityResult::from_local(1.5, 0.9);
264        assert_eq!(result.score, 1.0); // Clamped
265
266        let result = ComplexityResult::from_local(-0.5, 0.9);
267        assert_eq!(result.score, 0.0); // Clamped
268    }
269
270    #[test]
271    fn test_heuristic_scoring() {
272        // Create a stub scorer for testing heuristics
273        let _scorer = ComplexityScorerBuilder::default();
274
275        // Test with a simple task
276        let simple = "read a file";
277        let simple_score = score_heuristic_direct(simple);
278        assert!(simple_score < 0.5);
279
280        // Test with a complex task
281        let complex = "refactor the architecture to implement a distributed concurrent system with multiple parallel workers";
282        let complex_score = score_heuristic_direct(complex);
283        assert!(complex_score > 0.5);
284    }
285
286    // Helper for testing heuristic scoring
287    fn score_heuristic_direct(task: &str) -> f32 {
288        let desc_lower = task.to_lowercase();
289        let mut score: f32 = 0.3;
290
291        let complex_indicators = [
292            ("multiple", 0.1),
293            ("complex", 0.15),
294            ("refactor", 0.15),
295            ("architecture", 0.2),
296            ("concurrent", 0.2),
297            ("parallel", 0.15),
298            ("distributed", 0.2),
299        ];
300
301        let simple_indicators = [("simple", -0.1), ("just", -0.05), ("basic", -0.1)];
302
303        for (keyword, adjustment) in complex_indicators {
304            if desc_lower.contains(keyword) {
305                score += adjustment;
306            }
307        }
308
309        for (keyword, adjustment) in simple_indicators {
310            if desc_lower.contains(keyword) {
311                score += adjustment;
312            }
313        }
314
315        score.clamp(0.0, 1.0)
316    }
317
318    #[test]
319    fn test_parse_score() {
320        let _scorer = ComplexityScorerBuilder::default();
321
322        // Test parsing logic
323        assert_eq!(parse_score_direct("0.5"), Some(0.5));
324        assert_eq!(parse_score_direct("0.85"), Some(0.85));
325        assert_eq!(parse_score_direct("The complexity is 0.7"), Some(0.7));
326        assert_eq!(parse_score_direct("1.5"), Some(1.0)); // Clamped
327    }
328
329    fn parse_score_direct(output: &str) -> Option<f32> {
330        let cleaned = output.trim();
331        if let Ok(score) = cleaned.parse::<f32>() {
332            return Some(score.clamp(0.0, 1.0));
333        }
334        let number_pattern = regex::Regex::new(r"(\d+\.?\d*)").ok()?;
335        if let Some(captures) = number_pattern.captures(cleaned) {
336            if let Some(m) = captures.get(1) {
337                if let Ok(score) = m.as_str().parse::<f32>() {
338                    return Some(score.clamp(0.0, 1.0));
339                }
340            }
341        }
342        None
343    }
344}