bamboo_engine/runtime/
complexity_classifier.rs1use bamboo_agent_core::tools::ToolCall;
8use bamboo_agent_core::{AgentError, Session};
9
10use crate::runtime::managers::mini_loop::MiniLoopExecutor;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum TaskComplexity {
15 Simple,
17 Standard,
19 Complex,
21}
22
23pub struct ComplexityClassifier;
25
26impl ComplexityClassifier {
27 pub async fn classify(
31 executor: &dyn MiniLoopExecutor,
32 session: &Session,
33 tool_calls: &[ToolCall],
34 round: usize,
35 ) -> Result<TaskComplexity, AgentError> {
36 let prompt = build_complexity_prompt(tool_calls, round);
37 let context = extract_recent_context(session);
38 let decision = executor.decide(session, &prompt, &context).await?;
39
40 let answer = decision.answer.to_lowercase();
41 if answer.contains("simple") {
42 Ok(TaskComplexity::Simple)
43 } else if answer.contains("complex") {
44 Ok(TaskComplexity::Complex)
45 } else {
46 Ok(TaskComplexity::Standard)
47 }
48 }
49}
50
51fn build_complexity_prompt(tool_calls: &[ToolCall], round: usize) -> String {
52 let tool_names: Vec<&str> = tool_calls
53 .iter()
54 .map(|tc| tc.function.name.as_str())
55 .collect();
56 format!(
57 "Classify the current task complexity based on the last tool calls.\n\
58 Tool calls this round: {}\n\
59 Round number: {}\n\n\
60 Respond with exactly one word: simple, standard, or complex.\n\
61 - simple: file reading, grep, listing, formatting\n\
62 - standard: editing code, debugging, writing tests\n\
63 - complex: architectural decisions, multi-file refactoring, planning",
64 tool_names.join(", "),
65 round
66 )
67}
68
69fn extract_recent_context(session: &Session) -> String {
71 use bamboo_agent_core::Role;
72
73 session
74 .messages
75 .iter()
76 .rev()
77 .filter(|m| matches!(m.role, Role::User | Role::Assistant))
78 .take(4)
79 .map(|m| {
80 let role = match m.role {
81 Role::User => "User",
82 Role::Assistant => "Assistant",
83 _ => "Other",
84 };
85 format!("{}: {}", role, truncate(&m.content, 200))
86 })
87 .collect::<Vec<_>>()
88 .into_iter()
89 .rev()
90 .collect::<Vec<_>>()
91 .join("\n")
92}
93
94fn truncate(s: &str, max_len: usize) -> &str {
95 if s.len() <= max_len {
96 s
97 } else {
98 let mut end = max_len;
100 while !s.is_char_boundary(end) && end > 0 {
101 end -= 1;
102 }
103 &s[..end]
104 }
105}
106
107#[cfg(test)]
108mod tests {
109 use super::*;
110
111 #[test]
112 fn test_truncate_short() {
113 assert_eq!(truncate("hello", 10), "hello");
114 }
115
116 #[test]
117 fn test_truncate_exact() {
118 assert_eq!(truncate("hello", 5), "hello");
119 }
120
121 #[test]
122 fn test_truncate_long() {
123 assert_eq!(truncate("hello world", 5), "hello");
124 }
125
126 #[test]
127 fn test_build_complexity_prompt() {
128 use bamboo_domain::session::tool_types::FunctionCall;
129 let tc = ToolCall {
130 id: "1".to_string(),
131 tool_type: "function".to_string(),
132 function: FunctionCall {
133 name: "Read".to_string(),
134 arguments: "{}".to_string(),
135 },
136 };
137 let prompt = build_complexity_prompt(&[tc], 3);
138 assert!(prompt.contains("Read"));
139 assert!(prompt.contains("3"));
140 assert!(prompt.contains("simple"));
141 assert!(prompt.contains("complex"));
142 }
143}