use bamboo_agent_core::tools::ToolCall;
use bamboo_agent_core::{AgentError, Session};
use crate::runtime::managers::mini_loop::MiniLoopExecutor;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TaskComplexity {
Simple,
Standard,
Complex,
}
pub struct ComplexityClassifier;
impl ComplexityClassifier {
pub async fn classify(
executor: &dyn MiniLoopExecutor,
session: &Session,
tool_calls: &[ToolCall],
round: usize,
) -> Result<TaskComplexity, AgentError> {
let prompt = build_complexity_prompt(tool_calls, round);
let context = extract_recent_context(session);
let decision = executor.decide(session, &prompt, &context).await?;
let answer = decision.answer.to_lowercase();
if answer.contains("simple") {
Ok(TaskComplexity::Simple)
} else if answer.contains("complex") {
Ok(TaskComplexity::Complex)
} else {
Ok(TaskComplexity::Standard)
}
}
}
fn build_complexity_prompt(tool_calls: &[ToolCall], round: usize) -> String {
let tool_names: Vec<&str> = tool_calls
.iter()
.map(|tc| tc.function.name.as_str())
.collect();
format!(
"Classify the current task complexity based on the last tool calls.\n\
Tool calls this round: {}\n\
Round number: {}\n\n\
Respond with exactly one word: simple, standard, or complex.\n\
- simple: file reading, grep, listing, formatting\n\
- standard: editing code, debugging, writing tests\n\
- complex: architectural decisions, multi-file refactoring, planning",
tool_names.join(", "),
round
)
}
fn extract_recent_context(session: &Session) -> String {
use bamboo_agent_core::Role;
session
.messages
.iter()
.rev()
.filter(|m| matches!(m.role, Role::User | Role::Assistant))
.take(4)
.map(|m| {
let role = match m.role {
Role::User => "User",
Role::Assistant => "Assistant",
_ => "Other",
};
format!("{}: {}", role, truncate(&m.content, 200))
})
.collect::<Vec<_>>()
.into_iter()
.rev()
.collect::<Vec<_>>()
.join("\n")
}
fn truncate(s: &str, max_len: usize) -> &str {
if s.len() <= max_len {
s
} else {
let mut end = max_len;
while !s.is_char_boundary(end) && end > 0 {
end -= 1;
}
&s[..end]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_truncate_short() {
assert_eq!(truncate("hello", 10), "hello");
}
#[test]
fn test_truncate_exact() {
assert_eq!(truncate("hello", 5), "hello");
}
#[test]
fn test_truncate_long() {
assert_eq!(truncate("hello world", 5), "hello");
}
#[test]
fn test_build_complexity_prompt() {
use bamboo_domain::session::tool_types::FunctionCall;
let tc = ToolCall {
id: "1".to_string(),
tool_type: "function".to_string(),
function: FunctionCall {
name: "Read".to_string(),
arguments: "{}".to_string(),
},
};
let prompt = build_complexity_prompt(&[tc], 3);
assert!(prompt.contains("Read"));
assert!(prompt.contains("3"));
assert!(prompt.contains("simple"));
assert!(prompt.contains("complex"));
}
}