Skip to main content

bamboo_engine/runtime/
complexity_classifier.rs

1//! Task complexity classifier for dynamic model routing.
2//!
3//! Uses the fast model (via `MiniLoopExecutor`) to classify the current task
4//! complexity at the end of each agent round. The classification result drives
5//! per-round model selection in the agent pipeline.
6
7use bamboo_agent_core::tools::ToolCall;
8use bamboo_agent_core::{AgentError, Session};
9
10use crate::runtime::managers::mini_loop::MiniLoopExecutor;
11
12/// Classification of task complexity for model routing decisions.
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum TaskComplexity {
15    /// File reading, grep, listing, formatting — can use the cheapest model.
16    Simple,
17    /// Standard coding, debugging, writing tests — use the default model.
18    Standard,
19    /// Architecture decisions, multi-file refactoring, planning — use the most capable model.
20    Complex,
21}
22
23/// Classifies task complexity using the fast model via `MiniLoopExecutor`.
24pub struct ComplexityClassifier;
25
26impl ComplexityClassifier {
27    /// Classify task complexity based on the current round's tool calls.
28    ///
29    /// Returns `TaskComplexity::Standard` as a safe default when classification fails.
30    pub async fn classify(
31        executor: &dyn MiniLoopExecutor,
32        session: &Session,
33        tool_calls: &[ToolCall],
34        round: usize,
35    ) -> Result<TaskComplexity, AgentError> {
36        let prompt = build_complexity_prompt(tool_calls, round);
37        let context = extract_recent_context(session);
38        let decision = executor.decide(session, &prompt, &context).await?;
39
40        let answer = decision.answer.to_lowercase();
41        if answer.contains("simple") {
42            Ok(TaskComplexity::Simple)
43        } else if answer.contains("complex") {
44            Ok(TaskComplexity::Complex)
45        } else {
46            Ok(TaskComplexity::Standard)
47        }
48    }
49}
50
51fn build_complexity_prompt(tool_calls: &[ToolCall], round: usize) -> String {
52    let tool_names: Vec<&str> = tool_calls
53        .iter()
54        .map(|tc| tc.function.name.as_str())
55        .collect();
56    format!(
57        "Classify the current task complexity based on the last tool calls.\n\
58         Tool calls this round: {}\n\
59         Round number: {}\n\n\
60         Respond with exactly one word: simple, standard, or complex.\n\
61         - simple: file reading, grep, listing, formatting\n\
62         - standard: editing code, debugging, writing tests\n\
63         - complex: architectural decisions, multi-file refactoring, planning",
64        tool_names.join(", "),
65        round
66    )
67}
68
69/// Extract the last few user/assistant messages as context for classification.
70fn extract_recent_context(session: &Session) -> String {
71    use bamboo_agent_core::Role;
72
73    session
74        .messages
75        .iter()
76        .rev()
77        .filter(|m| matches!(m.role, Role::User | Role::Assistant))
78        .take(4)
79        .map(|m| {
80            let role = match m.role {
81                Role::User => "User",
82                Role::Assistant => "Assistant",
83                _ => "Other",
84            };
85            format!("{}: {}", role, truncate(&m.content, 200))
86        })
87        .collect::<Vec<_>>()
88        .into_iter()
89        .rev()
90        .collect::<Vec<_>>()
91        .join("\n")
92}
93
94fn truncate(s: &str, max_len: usize) -> &str {
95    if s.len() <= max_len {
96        s
97    } else {
98        // Find a safe boundary to avoid panicking on multi-byte chars.
99        let mut end = max_len;
100        while !s.is_char_boundary(end) && end > 0 {
101            end -= 1;
102        }
103        &s[..end]
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110
111    #[test]
112    fn test_truncate_short() {
113        assert_eq!(truncate("hello", 10), "hello");
114    }
115
116    #[test]
117    fn test_truncate_exact() {
118        assert_eq!(truncate("hello", 5), "hello");
119    }
120
121    #[test]
122    fn test_truncate_long() {
123        assert_eq!(truncate("hello world", 5), "hello");
124    }
125
126    #[test]
127    fn test_build_complexity_prompt() {
128        use bamboo_domain::session::tool_types::FunctionCall;
129        let tc = ToolCall {
130            id: "1".to_string(),
131            tool_type: "function".to_string(),
132            function: FunctionCall {
133                name: "Read".to_string(),
134                arguments: "{}".to_string(),
135            },
136        };
137        let prompt = build_complexity_prompt(&[tc], 3);
138        assert!(prompt.contains("Read"));
139        assert!(prompt.contains("3"));
140        assert!(prompt.contains("simple"));
141        assert!(prompt.contains("complex"));
142    }
143}