1use crate::llm::LlmClient;
8
9#[derive(Debug, Clone, Copy, PartialEq)]
12pub enum Tier {
13 Trivial, Moderate, Complex, Expert, }
18
19impl Tier {
20 pub fn label(&self) -> &'static str {
21 match self {
22 Tier::Trivial => "C1-C3 Trivial",
23 Tier::Moderate => "C4-C6 Moderate",
24 Tier::Complex => "C7-C8 Complex",
25 Tier::Expert => "C9-C10 Expert",
26 }
27 }
28
29 pub fn from_score(score: f32) -> Self {
30 if score >= 9.0 {
31 Tier::Expert
32 } else if score >= 7.0 {
33 Tier::Complex
34 } else if score >= 4.0 {
35 Tier::Moderate
36 } else {
37 Tier::Trivial
38 }
39 }
40}
41
42#[derive(Debug, Clone, Copy, PartialEq)]
43pub enum ComplexitySource {
44 Rules, Ai, Dual, }
48
49impl std::fmt::Display for ComplexitySource {
50 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51 match self {
52 ComplexitySource::Rules => write!(f, "rules"),
53 ComplexitySource::Ai => write!(f, "ai"),
54 ComplexitySource::Dual => write!(f, "dual"),
55 }
56 }
57}
58
59#[derive(Debug, Clone)]
60pub struct RoutingResult {
61 pub complexity: u32,
62 pub source: ComplexitySource,
63 pub tier: Tier,
64 pub reasoning: String,
65 pub rule_score: f32,
66 pub ai_score: Option<f32>,
67}
68
69pub fn assess_complexity(prompt: &str) -> Tier {
73 let score = rule_score(prompt);
74 Tier::from_score(score as f32)
75}
76
77pub async fn assess_complexity_dual(prompt: &str, llm: &LlmClient) -> RoutingResult {
79 let rules = rule_score(prompt);
80 let ai = ai_complexity_score(prompt, llm).await;
81
82 let (final_score, source, reasoning) = match ai {
83 Some(ai_score) => {
84 let diff = ai_score as i32 - rules as i32;
85 if diff >= 2 {
86 (
88 ai_score,
89 ComplexitySource::Ai,
90 format!(
91 "Rules: C{}, AI: C{} (using AI — semantic complexity)",
92 rules, ai_score
93 ),
94 )
95 } else if diff <= -2 {
96 let avg = (rules as f64 * 0.6 + ai_score as f64 * 0.4).round() as u32;
98 (
99 avg,
100 ComplexitySource::Dual,
101 format!(
102 "Rules: C{}, AI: C{} (weighted avg, rules dominant)",
103 rules, ai_score
104 ),
105 )
106 } else {
107 let avg = ((rules + ai_score) / 2).max(rules);
109 (
110 avg,
111 ComplexitySource::Dual,
112 format!("Rules: C{}, AI: C{} (agreement)", rules, ai_score),
113 )
114 }
115 }
116 None => (
117 rules,
118 ComplexitySource::Rules,
119 format!("Rule-based only (AI unavailable). C{}", rules),
120 ),
121 };
122
123 let final_score = final_score.clamp(1, 10);
124 let tier = Tier::from_score(final_score as f32);
125
126 println!(
127 " Rules=C{} AI={} Final=C{} => {}",
128 rules,
129 ai.map(|s| format!("C{}", s))
130 .unwrap_or_else(|| "N/A".to_string()),
131 final_score,
132 tier.label()
133 );
134
135 RoutingResult {
136 complexity: final_score,
137 source,
138 tier,
139 reasoning,
140 rule_score: rules as f32,
141 ai_score: ai.map(|s| s as f32),
142 }
143}
144
145fn rule_score(prompt: &str) -> u32 {
148 let text = prompt.to_lowercase();
149 let word_count = prompt.split_whitespace().count();
150 let mut score: f64 = 1.0; let steps = text.matches("step ").count();
156 if steps >= 7 {
157 score += 3.0;
158 } else if steps >= 5 {
159 score += 2.0;
160 } else if steps >= 3 {
161 score += 1.0;
162 }
163
164 let file_exts = [
166 ".py", ".ts", ".js", ".tsx", ".jsx", ".json", ".css", ".html", ".go", ".php",
167 ];
168 let file_count: usize = file_exts.iter().filter(|ext| text.contains(*ext)).count();
169 if file_count >= 3 {
170 score += 2.0;
171 } else if file_count >= 2 {
172 score += 1.0;
173 }
174
175 let def_keywords = ["function", "class", "def ", "interface", "struct"];
177 let def_count: usize = def_keywords.iter().map(|k| text.matches(k).count()).sum();
178 if def_count >= 3 {
179 score += 1.0;
180 }
181
182 let trivial = [
185 "simple",
186 "basic",
187 "single",
188 "just",
189 "only",
190 "straightforward",
191 ];
192 let moderate = [
193 "handle",
194 "validate",
195 "check",
196 "multiple",
197 "combine",
198 "integrate",
199 "parse",
200 "convert",
201 "transform",
202 "edge case",
203 "error handling",
204 ];
205 let high = [
206 "refactor",
207 "optimize",
208 "async",
209 "concurrent",
210 "parallel",
211 "nested",
212 "recursive",
213 "complex",
214 "algorithm",
215 "data structure",
216 "database",
217 "api",
218 "service",
219 "module",
220 "component",
221 "cache",
222 "lru",
223 "linked list",
224 "hash map",
225 "tree",
226 "graph",
227 "queue",
228 "stack",
229 "heap",
230 "binary",
231 "sorting",
232 "searching",
233 "o(1)",
234 "o(n)",
235 "o(log",
236 "time complexity",
237 ];
238 let extreme = [
239 "architect",
240 "design system",
241 "framework",
242 "infrastructure",
243 "distributed",
244 "microservice",
245 "migration",
246 "legacy",
247 "security",
248 "authentication",
249 "authorization",
250 "real-time",
251 "multiple files",
252 "full application",
253 "project",
254 ];
255
256 let mut max_tier: f64 = 0.0;
257
258 let extreme_hits = extreme.iter().filter(|k| text.contains(*k)).count();
259 if extreme_hits >= 2 {
260 max_tier = max_tier.max(4.0);
261 } else if extreme_hits == 1 {
262 max_tier = max_tier.max(3.0);
263 }
264
265 let high_hits = high.iter().filter(|k| text.contains(*k)).count();
266 if high_hits >= 3 {
267 max_tier = max_tier.max(3.0);
268 } else if high_hits >= 1 {
269 max_tier = max_tier.max(2.0);
270 }
271
272 let mod_hits = moderate.iter().filter(|k| text.contains(*k)).count();
273 if mod_hits >= 2 {
274 max_tier = max_tier.max(2.0);
275 } else if mod_hits >= 1 {
276 max_tier = max_tier.max(1.0);
277 }
278
279 let trivial_hits = trivial.iter().filter(|k| text.contains(*k)).count();
280 if trivial_hits >= 2 && max_tier <= 1.0 {
281 score -= 0.5;
282 }
283
284 score += max_tier;
285
286 if word_count > 100 {
288 score += 2.0;
289 } else if word_count > 50 {
290 score += 1.0;
291 } else if word_count < 10 {
292 score -= 0.5;
293 }
294
295 let lang = detect_language_hint(&text);
297 match lang {
298 "go" | "rust" => score += 0.5,
299 "typescript" => score += 0.5,
300 _ => {}
301 }
302
303 if text.contains("html") || text.contains("landing page") || text.contains("website") {
305 score = score.max(7.0);
306 }
307
308 (score.round() as u32).clamp(1, 10)
309}
310
311fn detect_language_hint(lower: &str) -> &str {
313 if lower.contains("rust") || lower.contains("cargo") {
314 "rust"
315 } else if lower.contains("golang") || lower.contains(" go ") {
316 "go"
317 } else if lower.contains("typescript") || lower.contains("next.js") {
318 "typescript"
319 } else {
320 "python"
321 }
322}
323
324async fn ai_complexity_score(prompt: &str, llm: &LlmClient) -> Option<u32> {
327 let system = "/no_think\nYou are a task complexity assessor for a coding agent system.\n\
328 Rate the complexity of this programming task on a scale of 1-10:\n\
329 - 1-3: Simple (single function, basic logic, no dependencies)\n\
330 - 4-5: Medium (multiple functions, some validation, basic tests)\n\
331 - 6-7: Moderate (multiple files, external APIs, error handling)\n\
332 - 8-9: Complex (architecture design, multiple systems, advanced patterns)\n\
333 - 10: Very Complex (distributed systems, complex algorithms, extensive testing)\n\n\
334 Respond with ONLY a JSON object:\n\
335 {\"complexity\": <number>, \"reasoning\": \"<1 sentence>\"}";
336
337 let response = llm.generate(" AI-SCORE", system, prompt).await.ok()?;
338
339 if let Some(start) = response.find('{') {
341 if let Some(end) = response.rfind('}') {
342 if let Ok(json) = serde_json::from_str::<serde_json::Value>(&response[start..=end]) {
343 if let Some(c) = json["complexity"].as_u64() {
344 return Some((c as u32).clamp(1, 10));
345 }
346 }
347 }
348 }
349
350 for word in response.split_whitespace() {
352 let cleaned = word.trim_matches(|c: char| !c.is_numeric() && c != '.');
353 if let Ok(n) = cleaned.parse::<f32>() {
354 if (1.0..=10.0).contains(&n) {
355 return Some(n.round() as u32);
356 }
357 }
358 }
359 None
360}
361
362#[cfg(test)]
363mod tests {
364 use super::*;
365
366 #[test]
369 fn test_trivial() {
370 assert_eq!(assess_complexity("print hello world"), Tier::Trivial);
371 }
372
373 #[test]
374 fn test_trivial_simple() {
375 let c = rule_score("Simple basic function that prints a number");
376 assert!(c <= 3, "simple function should be C1-C3, got C{}", c);
377 }
378
379 #[test]
380 fn test_moderate() {
381 let c = rule_score(
382 "Build a REST API for a todo app with database integration and form validation",
383 );
384 assert!(
385 (3..=6).contains(&c),
386 "REST API todo app should be C3-C6, got C{}",
387 c
388 );
389 }
390
391 #[test]
392 fn test_moderate_validation() {
393 let c =
394 rule_score("Write a function that validates email addresses and handles edge cases");
395 assert!(
396 (3..=6).contains(&c),
397 "validation task should be C3-C6, got C{}",
398 c
399 );
400 }
401
402 #[test]
403 fn test_complex() {
404 let c = rule_score("Build a production-ready FastAPI user authentication endpoint with JWT, rate limiting, and security headers");
405 assert!(c >= 5, "auth endpoint should be C5+, got C{}", c);
406 assert!(c <= 8, "auth endpoint should be <=C8, got C{}", c);
407 }
408
409 #[test]
410 fn test_complex_data_structure() {
411 let c = rule_score(
412 "Implement an LRU cache with O(1) get and put using a hash map and linked list",
413 );
414 assert!(c >= 3, "LRU cache should be C3+ (rule-based), got C{}", c);
415 assert!(c <= 7, "LRU cache should be <=C7, got C{}", c);
416 }
417
418 #[test]
419 fn test_expert() {
420 let c = rule_score("Build a distributed consensus algorithm for a microservice infrastructure with real-time replication");
421 assert!(
422 c >= 5,
423 "distributed system should be C5+ (rule-based), got C{}",
424 c
425 );
426 }
427
428 #[test]
429 fn test_extreme_architecture() {
430 let c = rule_score("Design a distributed microservice authentication system with real-time WebSocket notifications and multiple files for the full application");
431 assert!(
432 c >= 5,
433 "distributed system should be C5+ (rule-based), got C{}",
434 c
435 );
436 }
437
438 #[test]
439 fn test_web_project_boost() {
440 let c = rule_score("Create an HTML landing page with sections");
441 assert!(
442 c >= 7,
443 "HTML landing page should get web boost to C7+, got C{}",
444 c
445 );
446 }
447
448 #[test]
449 fn test_keyword_score() {
450 assert!(rule_score("hello world") <= 3);
451 assert!(rule_score("Build a REST API for a todo app with database") >= 3);
452 assert!(
453 rule_score("Build a JWT authentication system with rate limiting and security") >= 4
454 );
455 assert!(
456 rule_score("Build a distributed compiler with multiple files for the full application")
457 >= 5
458 );
459 }
460
461 #[test]
462 fn test_tier_from_score() {
463 assert_eq!(Tier::from_score(2.0), Tier::Trivial);
464 assert_eq!(Tier::from_score(5.0), Tier::Moderate);
465 assert_eq!(Tier::from_score(7.5), Tier::Complex);
466 assert_eq!(Tier::from_score(9.5), Tier::Expert);
467 }
468
469 #[test]
470 fn test_complexity_always_in_range() {
471 assert!(rule_score("") >= 1);
472 assert!(rule_score("") <= 10);
473 assert!(rule_score("x") >= 1);
474 assert!(rule_score(&"word ".repeat(200)) <= 10);
475 }
476
477 #[test]
478 fn test_language_modifier() {
479 let py = rule_score("Build a module");
480 let go = rule_score("Build a golang module");
481 assert!(go >= py, "Go should be >= Python complexity");
482 }
483
484 #[test]
485 fn test_length_factor() {
486 let short = rule_score("add numbers");
487 let long = rule_score(&format!(
488 "Build a system that {}",
489 "handles complex logic and ".repeat(10)
490 ));
491 assert!(long > short, "longer prompt should score higher");
492 }
493}