Skip to main content

battlecommand_forge/
router.rs

1/// Campbell's Complexity Theory — Dual Assessment Router.
2///
3/// Ported from battleclaw-v2 router.rs. Dual-factor scoring:
4/// 1. Rule-based keyword + structural analysis (fast, deterministic)
5/// 2. AI-assisted scoring via configured complexity model (nuanced)
6/// Blended with disagreement handling.
7use crate::llm::LlmClient;
8
9// ─── Types ───
10
11#[derive(Debug, Clone, Copy, PartialEq)]
12pub enum Tier {
13    Trivial,  // C1-C3
14    Moderate, // C4-C6
15    Complex,  // C7-C8
16    Expert,   // C9-C10
17}
18
19impl Tier {
20    pub fn label(&self) -> &'static str {
21        match self {
22            Tier::Trivial => "C1-C3 Trivial",
23            Tier::Moderate => "C4-C6 Moderate",
24            Tier::Complex => "C7-C8 Complex",
25            Tier::Expert => "C9-C10 Expert",
26        }
27    }
28
29    pub fn from_score(score: f32) -> Self {
30        if score >= 9.0 {
31            Tier::Expert
32        } else if score >= 7.0 {
33            Tier::Complex
34        } else if score >= 4.0 {
35            Tier::Moderate
36        } else {
37            Tier::Trivial
38        }
39    }
40}
41
42#[derive(Debug, Clone, Copy, PartialEq)]
43pub enum ComplexitySource {
44    Rules, // Rule-based only
45    Ai,    // AI assessment only
46    Dual,  // Combined rule + AI
47}
48
49impl std::fmt::Display for ComplexitySource {
50    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51        match self {
52            ComplexitySource::Rules => write!(f, "rules"),
53            ComplexitySource::Ai => write!(f, "ai"),
54            ComplexitySource::Dual => write!(f, "dual"),
55        }
56    }
57}
58
59#[derive(Debug, Clone)]
60pub struct RoutingResult {
61    pub complexity: u32,
62    pub source: ComplexitySource,
63    pub tier: Tier,
64    pub reasoning: String,
65    pub rule_score: f32,
66    pub ai_score: Option<f32>,
67}
68
69// ─── Public API ───
70
71/// Fast rule-based complexity assessment (no LLM needed).
72pub fn assess_complexity(prompt: &str) -> Tier {
73    let score = rule_score(prompt);
74    Tier::from_score(score as f32)
75}
76
77/// Full dual-factor assessment: rules + AI, returns detailed result.
78pub async fn assess_complexity_dual(prompt: &str, llm: &LlmClient) -> RoutingResult {
79    let rules = rule_score(prompt);
80    let ai = ai_complexity_score(prompt, llm).await;
81
82    let (final_score, source, reasoning) = match ai {
83        Some(ai_score) => {
84            let diff = ai_score as i32 - rules as i32;
85            if diff >= 2 {
86                // AI sees more complexity — trust AI
87                (
88                    ai_score,
89                    ComplexitySource::Ai,
90                    format!(
91                        "Rules: C{}, AI: C{} (using AI — semantic complexity)",
92                        rules, ai_score
93                    ),
94                )
95            } else if diff <= -2 {
96                // Rules see more complexity — weighted average favoring rules
97                let avg = (rules as f64 * 0.6 + ai_score as f64 * 0.4).round() as u32;
98                (
99                    avg,
100                    ComplexitySource::Dual,
101                    format!(
102                        "Rules: C{}, AI: C{} (weighted avg, rules dominant)",
103                        rules, ai_score
104                    ),
105                )
106            } else {
107                // Agreement — average, floor at rule score
108                let avg = ((rules + ai_score) / 2).max(rules);
109                (
110                    avg,
111                    ComplexitySource::Dual,
112                    format!("Rules: C{}, AI: C{} (agreement)", rules, ai_score),
113                )
114            }
115        }
116        None => (
117            rules,
118            ComplexitySource::Rules,
119            format!("Rule-based only (AI unavailable). C{}", rules),
120        ),
121    };
122
123    let final_score = final_score.clamp(1, 10);
124    let tier = Tier::from_score(final_score as f32);
125
126    println!(
127        "   Rules=C{} AI={} Final=C{} => {}",
128        rules,
129        ai.map(|s| format!("C{}", s))
130            .unwrap_or_else(|| "N/A".to_string()),
131        final_score,
132        tier.label()
133    );
134
135    RoutingResult {
136        complexity: final_score,
137        source,
138        tier,
139        reasoning,
140        rule_score: rules as f32,
141        ai_score: ai.map(|s| s as f32),
142    }
143}
144
145// ─── Rule-based scoring (ported from v2 assess_complexity_rules) ───
146
147fn rule_score(prompt: &str) -> u32 {
148    let text = prompt.to_lowercase();
149    let word_count = prompt.split_whitespace().count();
150    let mut score: f64 = 1.0; // base score (matched to v2)
151
152    // ── 1. STRUCTURAL COMPLEXITY ──
153
154    // Count explicit steps
155    let steps = text.matches("step ").count();
156    if steps >= 7 {
157        score += 3.0;
158    } else if steps >= 5 {
159        score += 2.0;
160    } else if steps >= 3 {
161        score += 1.0;
162    }
163
164    // Count file extensions mentioned
165    let file_exts = [
166        ".py", ".ts", ".js", ".tsx", ".jsx", ".json", ".css", ".html", ".go", ".php",
167    ];
168    let file_count: usize = file_exts.iter().filter(|ext| text.contains(*ext)).count();
169    if file_count >= 3 {
170        score += 2.0;
171    } else if file_count >= 2 {
172        score += 1.0;
173    }
174
175    // Count function/class definitions expected
176    let def_keywords = ["function", "class", "def ", "interface", "struct"];
177    let def_count: usize = def_keywords.iter().map(|k| text.matches(k).count()).sum();
178    if def_count >= 3 {
179        score += 1.0;
180    }
181
182    // ── 2. SEMANTIC COMPLEXITY (4-tier keywords, matched to v2) ──
183
184    let trivial = [
185        "simple",
186        "basic",
187        "single",
188        "just",
189        "only",
190        "straightforward",
191    ];
192    let moderate = [
193        "handle",
194        "validate",
195        "check",
196        "multiple",
197        "combine",
198        "integrate",
199        "parse",
200        "convert",
201        "transform",
202        "edge case",
203        "error handling",
204    ];
205    let high = [
206        "refactor",
207        "optimize",
208        "async",
209        "concurrent",
210        "parallel",
211        "nested",
212        "recursive",
213        "complex",
214        "algorithm",
215        "data structure",
216        "database",
217        "api",
218        "service",
219        "module",
220        "component",
221        "cache",
222        "lru",
223        "linked list",
224        "hash map",
225        "tree",
226        "graph",
227        "queue",
228        "stack",
229        "heap",
230        "binary",
231        "sorting",
232        "searching",
233        "o(1)",
234        "o(n)",
235        "o(log",
236        "time complexity",
237    ];
238    let extreme = [
239        "architect",
240        "design system",
241        "framework",
242        "infrastructure",
243        "distributed",
244        "microservice",
245        "migration",
246        "legacy",
247        "security",
248        "authentication",
249        "authorization",
250        "real-time",
251        "multiple files",
252        "full application",
253        "project",
254    ];
255
256    let mut max_tier: f64 = 0.0;
257
258    let extreme_hits = extreme.iter().filter(|k| text.contains(*k)).count();
259    if extreme_hits >= 2 {
260        max_tier = max_tier.max(4.0);
261    } else if extreme_hits == 1 {
262        max_tier = max_tier.max(3.0);
263    }
264
265    let high_hits = high.iter().filter(|k| text.contains(*k)).count();
266    if high_hits >= 3 {
267        max_tier = max_tier.max(3.0);
268    } else if high_hits >= 1 {
269        max_tier = max_tier.max(2.0);
270    }
271
272    let mod_hits = moderate.iter().filter(|k| text.contains(*k)).count();
273    if mod_hits >= 2 {
274        max_tier = max_tier.max(2.0);
275    } else if mod_hits >= 1 {
276        max_tier = max_tier.max(1.0);
277    }
278
279    let trivial_hits = trivial.iter().filter(|k| text.contains(*k)).count();
280    if trivial_hits >= 2 && max_tier <= 1.0 {
281        score -= 0.5;
282    }
283
284    score += max_tier;
285
286    // ── 3. LENGTH FACTOR ──
287    if word_count > 100 {
288        score += 2.0;
289    } else if word_count > 50 {
290        score += 1.0;
291    } else if word_count < 10 {
292        score -= 0.5;
293    }
294
295    // ── 4. LANGUAGE MODIFIER ──
296    let lang = detect_language_hint(&text);
297    match lang {
298        "go" | "rust" => score += 0.5,
299        "typescript" => score += 0.5,
300        _ => {}
301    }
302
303    // Web project boost (HTML/CSS usually need more files/context)
304    if text.contains("html") || text.contains("landing page") || text.contains("website") {
305        score = score.max(7.0);
306    }
307
308    (score.round() as u32).clamp(1, 10)
309}
310
311/// Quick language hint from prompt text (for scoring only).
312fn detect_language_hint(lower: &str) -> &str {
313    if lower.contains("rust") || lower.contains("cargo") {
314        "rust"
315    } else if lower.contains("golang") || lower.contains(" go ") {
316        "go"
317    } else if lower.contains("typescript") || lower.contains("next.js") {
318        "typescript"
319    } else {
320        "python"
321    }
322}
323
324// ─── AI complexity scoring ───
325
326async fn ai_complexity_score(prompt: &str, llm: &LlmClient) -> Option<u32> {
327    let system = "/no_think\nYou are a task complexity assessor for a coding agent system.\n\
328        Rate the complexity of this programming task on a scale of 1-10:\n\
329        - 1-3: Simple (single function, basic logic, no dependencies)\n\
330        - 4-5: Medium (multiple functions, some validation, basic tests)\n\
331        - 6-7: Moderate (multiple files, external APIs, error handling)\n\
332        - 8-9: Complex (architecture design, multiple systems, advanced patterns)\n\
333        - 10: Very Complex (distributed systems, complex algorithms, extensive testing)\n\n\
334        Respond with ONLY a JSON object:\n\
335        {\"complexity\": <number>, \"reasoning\": \"<1 sentence>\"}";
336
337    let response = llm.generate("  AI-SCORE", system, prompt).await.ok()?;
338
339    // Try JSON parse first
340    if let Some(start) = response.find('{') {
341        if let Some(end) = response.rfind('}') {
342            if let Ok(json) = serde_json::from_str::<serde_json::Value>(&response[start..=end]) {
343                if let Some(c) = json["complexity"].as_u64() {
344                    return Some((c as u32).clamp(1, 10));
345                }
346            }
347        }
348    }
349
350    // Fallback: parse bare number
351    for word in response.split_whitespace() {
352        let cleaned = word.trim_matches(|c: char| !c.is_numeric() && c != '.');
353        if let Ok(n) = cleaned.parse::<f32>() {
354            if (1.0..=10.0).contains(&n) {
355                return Some(n.round() as u32);
356            }
357        }
358    }
359    None
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365
366    // ─── Rule-based scoring ───
367
368    #[test]
369    fn test_trivial() {
370        assert_eq!(assess_complexity("print hello world"), Tier::Trivial);
371    }
372
373    #[test]
374    fn test_trivial_simple() {
375        let c = rule_score("Simple basic function that prints a number");
376        assert!(c <= 3, "simple function should be C1-C3, got C{}", c);
377    }
378
379    #[test]
380    fn test_moderate() {
381        let c = rule_score(
382            "Build a REST API for a todo app with database integration and form validation",
383        );
384        assert!(
385            (3..=6).contains(&c),
386            "REST API todo app should be C3-C6, got C{}",
387            c
388        );
389    }
390
391    #[test]
392    fn test_moderate_validation() {
393        let c =
394            rule_score("Write a function that validates email addresses and handles edge cases");
395        assert!(
396            (3..=6).contains(&c),
397            "validation task should be C3-C6, got C{}",
398            c
399        );
400    }
401
402    #[test]
403    fn test_complex() {
404        let c = rule_score("Build a production-ready FastAPI user authentication endpoint with JWT, rate limiting, and security headers");
405        assert!(c >= 5, "auth endpoint should be C5+, got C{}", c);
406        assert!(c <= 8, "auth endpoint should be <=C8, got C{}", c);
407    }
408
409    #[test]
410    fn test_complex_data_structure() {
411        let c = rule_score(
412            "Implement an LRU cache with O(1) get and put using a hash map and linked list",
413        );
414        assert!(c >= 3, "LRU cache should be C3+ (rule-based), got C{}", c);
415        assert!(c <= 7, "LRU cache should be <=C7, got C{}", c);
416    }
417
418    #[test]
419    fn test_expert() {
420        let c = rule_score("Build a distributed consensus algorithm for a microservice infrastructure with real-time replication");
421        assert!(
422            c >= 5,
423            "distributed system should be C5+ (rule-based), got C{}",
424            c
425        );
426    }
427
428    #[test]
429    fn test_extreme_architecture() {
430        let c = rule_score("Design a distributed microservice authentication system with real-time WebSocket notifications and multiple files for the full application");
431        assert!(
432            c >= 5,
433            "distributed system should be C5+ (rule-based), got C{}",
434            c
435        );
436    }
437
438    #[test]
439    fn test_web_project_boost() {
440        let c = rule_score("Create an HTML landing page with sections");
441        assert!(
442            c >= 7,
443            "HTML landing page should get web boost to C7+, got C{}",
444            c
445        );
446    }
447
448    #[test]
449    fn test_keyword_score() {
450        assert!(rule_score("hello world") <= 3);
451        assert!(rule_score("Build a REST API for a todo app with database") >= 3);
452        assert!(
453            rule_score("Build a JWT authentication system with rate limiting and security") >= 4
454        );
455        assert!(
456            rule_score("Build a distributed compiler with multiple files for the full application")
457                >= 5
458        );
459    }
460
461    #[test]
462    fn test_tier_from_score() {
463        assert_eq!(Tier::from_score(2.0), Tier::Trivial);
464        assert_eq!(Tier::from_score(5.0), Tier::Moderate);
465        assert_eq!(Tier::from_score(7.5), Tier::Complex);
466        assert_eq!(Tier::from_score(9.5), Tier::Expert);
467    }
468
469    #[test]
470    fn test_complexity_always_in_range() {
471        assert!(rule_score("") >= 1);
472        assert!(rule_score("") <= 10);
473        assert!(rule_score("x") >= 1);
474        assert!(rule_score(&"word ".repeat(200)) <= 10);
475    }
476
477    #[test]
478    fn test_language_modifier() {
479        let py = rule_score("Build a module");
480        let go = rule_score("Build a golang module");
481        assert!(go >= py, "Go should be >= Python complexity");
482    }
483
484    #[test]
485    fn test_length_factor() {
486        let short = rule_score("add numbers");
487        let long = rule_score(&format!(
488            "Build a system that {}",
489            "handles complex logic and ".repeat(10)
490        ));
491        assert!(long > short, "longer prompt should score higher");
492    }
493}