1use std::sync::Arc;
7use tracing::warn;
8
9use brainwires_core::message::Message;
10use brainwires_core::provider::{ChatOptions, Provider};
11
12use crate::InferenceTimer;
13
14#[derive(Clone, Debug)]
16pub struct ComplexityResult {
17 pub score: f32,
19 pub confidence: f32,
21 pub used_local_llm: bool,
23}
24
25impl ComplexityResult {
26 pub fn default_complexity() -> Self {
28 Self {
29 score: 0.5, confidence: 0.3,
31 used_local_llm: false,
32 }
33 }
34
35 pub fn from_local(score: f32, confidence: f32) -> Self {
37 Self {
38 score: score.clamp(0.0, 1.0),
39 confidence: confidence.clamp(0.0, 1.0),
40 used_local_llm: true,
41 }
42 }
43}
44
45pub struct ComplexityScorer {
47 provider: Arc<dyn Provider>,
48 model_id: String,
49}
50
51impl ComplexityScorer {
52 pub fn new(provider: Arc<dyn Provider>, model_id: impl Into<String>) -> Self {
54 Self {
55 provider,
56 model_id: model_id.into(),
57 }
58 }
59
60 pub async fn score(&self, task_description: &str) -> Option<ComplexityResult> {
65 let timer = InferenceTimer::new("complexity_score", &self.model_id);
66
67 let system_prompt = self.build_scoring_prompt();
68 let user_prompt = format!(
69 "Rate the complexity of this task from 0.0 (trivial) to 1.0 (very complex). Output ONLY a decimal number.\n\nTask: {}",
70 task_description
71 );
72
73 let messages = vec![Message::user(&user_prompt)];
74 let options = ChatOptions::deterministic(10).system(system_prompt);
75
76 match self.provider.chat(&messages, None, &options).await {
77 Ok(response) => {
78 let text = response.message.text_or_summary();
79 if let Some(score) = self.parse_score(&text) {
80 timer.finish(true);
81 Some(ComplexityResult::from_local(score, 0.8))
82 } else {
83 timer.finish(false);
84 None
85 }
86 }
87 Err(e) => {
88 warn!(target: "local_llm", "Complexity scoring failed: {}", e);
89 timer.finish(false);
90 None
91 }
92 }
93 }
94
95 pub fn score_heuristic(&self, task_description: &str) -> ComplexityResult {
98 let desc_lower = task_description.to_lowercase();
99 let mut score: f32 = 0.3; let complex_indicators = [
103 ("multiple", 0.1),
104 ("several", 0.1),
105 ("complex", 0.15),
106 ("difficult", 0.15),
107 ("careful", 0.1),
108 ("ensure", 0.05),
109 ("validate", 0.1),
110 ("analyze", 0.1),
111 ("refactor", 0.15),
112 ("architecture", 0.2),
113 ("design", 0.1),
114 ("optimize", 0.15),
115 ("performance", 0.1),
116 ("security", 0.15),
117 ("concurrent", 0.2),
118 ("async", 0.1),
119 ("parallel", 0.15),
120 ("distributed", 0.2),
121 ];
122
123 let simple_indicators = [
125 ("simple", -0.1),
126 ("trivial", -0.15),
127 ("just", -0.05),
128 ("only", -0.05),
129 ("basic", -0.1),
130 ("single", -0.05),
131 ("one", -0.05),
132 ("quick", -0.1),
133 ("easy", -0.1),
134 ];
135
136 for (keyword, adjustment) in complex_indicators {
137 if desc_lower.contains(keyword) {
138 score += adjustment;
139 }
140 }
141
142 for (keyword, adjustment) in simple_indicators {
143 if desc_lower.contains(keyword) {
144 score += adjustment;
145 }
146 }
147
148 let word_count = task_description.split_whitespace().count();
150 if word_count > 50 {
151 score += 0.15;
152 } else if word_count > 30 {
153 score += 0.1;
154 } else if word_count < 10 {
155 score -= 0.1;
156 }
157
158 ComplexityResult {
159 score: score.clamp(0.0, 1.0),
160 confidence: 0.4, used_local_llm: false,
162 }
163 }
164
165 fn build_scoring_prompt(&self) -> String {
167 r#"You are a task complexity evaluator. Given a task description, output a complexity score.
168
169Scoring guide:
170- 0.0-0.2: Trivial (single step, no decisions)
171- 0.2-0.4: Simple (few steps, straightforward)
172- 0.4-0.6: Moderate (multiple steps, some decisions)
173- 0.6-0.8: Complex (many steps, careful reasoning needed)
174- 0.8-1.0: Very complex (intricate logic, multiple dependencies)
175
176Consider:
177- Number of steps or operations needed
178- Required reasoning depth
179- Ambiguity in requirements
180- Dependencies between parts
181- Potential for errors
182
183Output ONLY a decimal number between 0.0 and 1.0."#
184 .to_string()
185 }
186
187 fn parse_score(&self, output: &str) -> Option<f32> {
189 let cleaned = output.trim();
191
192 if let Ok(score) = cleaned.parse::<f32>() {
194 return Some(score.clamp(0.0, 1.0));
195 }
196
197 let number_pattern = regex::Regex::new(r"(\d+\.?\d*)").ok()?;
199 if let Some(captures) = number_pattern.captures(cleaned)
200 && let Some(m) = captures.get(1)
201 && let Ok(score) = m.as_str().parse::<f32>()
202 {
203 return Some(score.clamp(0.0, 1.0));
204 }
205
206 None
207 }
208}
209
210pub struct ComplexityScorerBuilder {
212 provider: Option<Arc<dyn Provider>>,
213 model_id: String,
214}
215
216impl Default for ComplexityScorerBuilder {
217 fn default() -> Self {
218 Self {
219 provider: None,
220 model_id: "lfm2-350m".to_string(),
221 }
222 }
223}
224
225impl ComplexityScorerBuilder {
226 pub fn new() -> Self {
228 Self::default()
229 }
230
231 pub fn provider(mut self, provider: Arc<dyn Provider>) -> Self {
233 self.provider = Some(provider);
234 self
235 }
236
237 pub fn model_id(mut self, model_id: impl Into<String>) -> Self {
239 self.model_id = model_id.into();
240 self
241 }
242
243 pub fn build(self) -> Option<ComplexityScorer> {
245 self.provider
246 .map(|p| ComplexityScorer::new(p, self.model_id))
247 }
248}
249
250#[cfg(test)]
251mod tests {
252 use super::*;
253
254 #[test]
255 fn test_complexity_result_default() {
256 let result = ComplexityResult::default_complexity();
257 assert_eq!(result.score, 0.5);
258 assert!(!result.used_local_llm);
259 }
260
261 #[test]
262 fn test_complexity_result_clamping() {
263 let result = ComplexityResult::from_local(1.5, 0.9);
264 assert_eq!(result.score, 1.0); let result = ComplexityResult::from_local(-0.5, 0.9);
267 assert_eq!(result.score, 0.0); }
269
270 #[test]
271 fn test_heuristic_scoring() {
272 let _scorer = ComplexityScorerBuilder::default();
274
275 let simple = "read a file";
277 let simple_score = score_heuristic_direct(simple);
278 assert!(simple_score < 0.5);
279
280 let complex = "refactor the architecture to implement a distributed concurrent system with multiple parallel workers";
282 let complex_score = score_heuristic_direct(complex);
283 assert!(complex_score > 0.5);
284 }
285
286 fn score_heuristic_direct(task: &str) -> f32 {
288 let desc_lower = task.to_lowercase();
289 let mut score: f32 = 0.3;
290
291 let complex_indicators = [
292 ("multiple", 0.1),
293 ("complex", 0.15),
294 ("refactor", 0.15),
295 ("architecture", 0.2),
296 ("concurrent", 0.2),
297 ("parallel", 0.15),
298 ("distributed", 0.2),
299 ];
300
301 let simple_indicators = [("simple", -0.1), ("just", -0.05), ("basic", -0.1)];
302
303 for (keyword, adjustment) in complex_indicators {
304 if desc_lower.contains(keyword) {
305 score += adjustment;
306 }
307 }
308
309 for (keyword, adjustment) in simple_indicators {
310 if desc_lower.contains(keyword) {
311 score += adjustment;
312 }
313 }
314
315 score.clamp(0.0, 1.0)
316 }
317
318 #[test]
319 fn test_parse_score() {
320 let _scorer = ComplexityScorerBuilder::default();
321
322 assert_eq!(parse_score_direct("0.5"), Some(0.5));
324 assert_eq!(parse_score_direct("0.85"), Some(0.85));
325 assert_eq!(parse_score_direct("The complexity is 0.7"), Some(0.7));
326 assert_eq!(parse_score_direct("1.5"), Some(1.0)); }
328
329 fn parse_score_direct(output: &str) -> Option<f32> {
330 let cleaned = output.trim();
331 if let Ok(score) = cleaned.parse::<f32>() {
332 return Some(score.clamp(0.0, 1.0));
333 }
334 let number_pattern = regex::Regex::new(r"(\d+\.?\d*)").ok()?;
335 if let Some(captures) = number_pattern.captures(cleaned) {
336 if let Some(m) = captures.get(1) {
337 if let Ok(score) = m.as_str().parse::<f32>() {
338 return Some(score.clamp(0.0, 1.0));
339 }
340 }
341 }
342 None
343 }
344}