use crate::agent_loop::{agent_loop, AgentLoopConfig};
use crate::context::ExecutionLimits;
use crate::provider::model::ModelConfig;
use crate::provider::StreamProvider;
use crate::shared_state::SharedState;
use crate::tools::shared_state_tool::SharedStateTool;
use crate::types::*;
use std::sync::Arc;
use tokio::sync::mpsc;
const DEFAULT_MAX_TURNS: usize = 10;
pub struct SubAgentTool {
tool_name: String,
tool_description: String,
system_prompt: String,
model: String,
api_key: String,
provider: Arc<dyn StreamProvider>,
tools: Vec<Arc<dyn AgentTool>>,
thinking_level: ThinkingLevel,
max_tokens: Option<u32>,
cache_config: CacheConfig,
tool_execution: ToolExecutionStrategy,
retry_config: crate::retry::RetryConfig,
max_turns: usize,
shared_state: Option<SharedState>,
turn_delay: Option<std::time::Duration>,
model_config: Option<ModelConfig>,
}
impl SubAgentTool {
pub fn new(name: impl Into<String>, provider: Arc<dyn StreamProvider>) -> Self {
let name = name.into();
Self {
tool_description: format!("Delegate a task to the '{}' sub-agent", name),
tool_name: name,
system_prompt: String::new(),
model: String::new(),
api_key: String::new(),
provider,
tools: Vec::new(),
thinking_level: ThinkingLevel::Off,
max_tokens: None,
cache_config: CacheConfig::default(),
tool_execution: ToolExecutionStrategy::default(),
retry_config: crate::retry::RetryConfig::default(),
max_turns: DEFAULT_MAX_TURNS,
shared_state: None,
turn_delay: None,
model_config: None,
}
}
pub fn with_description(mut self, desc: impl Into<String>) -> Self {
self.tool_description = desc.into();
self
}
pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt = prompt.into();
self
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self
}
pub fn with_api_key(mut self, key: impl Into<String>) -> Self {
self.api_key = key.into();
self
}
pub fn with_tools(mut self, tools: Vec<Arc<dyn AgentTool>>) -> Self {
self.tools = tools;
self
}
pub fn with_thinking(mut self, level: ThinkingLevel) -> Self {
self.thinking_level = level;
self
}
pub fn with_max_tokens(mut self, max: u32) -> Self {
self.max_tokens = Some(max);
self
}
pub fn with_cache_config(mut self, config: CacheConfig) -> Self {
self.cache_config = config;
self
}
pub fn with_tool_execution(mut self, strategy: ToolExecutionStrategy) -> Self {
self.tool_execution = strategy;
self
}
pub fn with_retry_config(mut self, config: crate::retry::RetryConfig) -> Self {
self.retry_config = config;
self
}
pub fn with_max_turns(mut self, max: usize) -> Self {
self.max_turns = max;
self
}
pub fn with_shared_state(mut self, state: SharedState) -> Self {
self.shared_state = Some(state);
self
}
pub fn with_turn_delay(mut self, delay: std::time::Duration) -> Self {
self.turn_delay = Some(delay);
self
}
pub fn with_model_config(mut self, config: ModelConfig) -> Self {
self.model_config = Some(config);
self
}
}
struct ArcToolWrapper(Arc<dyn AgentTool>);
#[async_trait::async_trait]
impl AgentTool for ArcToolWrapper {
fn name(&self) -> &str {
self.0.name()
}
fn label(&self) -> &str {
self.0.label()
}
fn description(&self) -> &str {
self.0.description()
}
fn parameters_schema(&self) -> serde_json::Value {
self.0.parameters_schema()
}
async fn execute(
&self,
params: serde_json::Value,
ctx: ToolContext,
) -> Result<ToolResult, ToolError> {
self.0.execute(params, ctx).await
}
}
#[async_trait::async_trait]
impl AgentTool for SubAgentTool {
fn name(&self) -> &str {
&self.tool_name
}
fn label(&self) -> &str {
&self.tool_name
}
fn description(&self) -> &str {
&self.tool_description
}
fn parameters_schema(&self) -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": {
"task": {
"type": "string",
"description": "The task to delegate to this sub-agent"
}
},
"required": ["task"]
})
}
async fn execute(
&self,
params: serde_json::Value,
ctx: ToolContext,
) -> Result<ToolResult, ToolError> {
let cancel = ctx.cancel;
let on_update = ctx.on_update;
let on_progress = ctx.on_progress;
let task = params
.get("task")
.and_then(|v| v.as_str())
.ok_or_else(|| ToolError::InvalidArgs("Missing required 'task' parameter".into()))?
.to_string();
let mut tools: Vec<Box<dyn AgentTool>> = self
.tools
.iter()
.map(|t| Box::new(ArcToolWrapper(Arc::clone(t))) as Box<dyn AgentTool>)
.collect();
let mut system_prompt = self.system_prompt.clone();
if let Some(ref state) = self.shared_state {
tools.push(Box::new(SharedStateTool::new(state.clone())));
let summary = state.summary().await;
system_prompt.push_str(&format!(
"\n\n## Shared State\nYou have access to a shared variable store via the `shared_state` tool.\nAvailable: {}",
summary
));
}
let mut context = AgentContext {
system_prompt,
messages: Vec::new(),
tools,
};
let config = AgentLoopConfig {
provider: self.provider.clone(),
model: self.model.clone(),
api_key: self.api_key.clone(),
thinking_level: self.thinking_level,
max_tokens: self.max_tokens,
temperature: None,
model_config: self.model_config.clone(),
convert_to_llm: None,
transform_context: None,
get_steering_messages: None,
get_follow_up_messages: None,
context_config: None,
compaction_strategy: None,
execution_limits: Some(ExecutionLimits {
max_turns: self.max_turns,
max_total_tokens: 1_000_000,
max_duration: std::time::Duration::from_secs(300),
}),
cache_config: self.cache_config.clone(),
tool_execution: self.tool_execution.clone(),
retry_config: self.retry_config.clone(),
before_turn: None,
after_turn: None,
on_error: None,
input_filters: vec![],
turn_delay: self.turn_delay,
};
let (tx, mut rx) = mpsc::unbounded_channel();
let forward_handle = if on_update.is_some() || on_progress.is_some() {
let tool_name = self.tool_name.clone();
Some(tokio::spawn(async move {
while let Some(event) = rx.recv().await {
if let AgentEvent::ProgressMessage { text, .. } = &event {
if let Some(ref cb) = on_progress {
cb(text.clone());
}
}
if let Some(ref on_update) = on_update {
let update_text = match &event {
AgentEvent::MessageUpdate {
delta: StreamDelta::Text { delta },
..
} => Some(delta.clone()),
AgentEvent::ToolExecutionStart { tool_name, .. } => {
Some(format!("[sub-agent calling tool: {}]", tool_name))
}
_ => None,
};
if let Some(text) = update_text {
on_update(ToolResult {
content: vec![Content::Text { text }],
details: serde_json::json!({ "sub_agent": tool_name }),
});
}
}
}
}))
} else {
None
};
let prompt = AgentMessage::Llm(Message::user(task));
let new_messages = agent_loop(vec![prompt], &mut context, &config, tx, cancel).await;
if let Some(handle) = forward_handle {
let _ = handle.await;
}
if let Some(error_msg) = extract_error(&new_messages) {
return Err(ToolError::Failed(format!(
"Sub-agent '{}' failed: {}",
self.tool_name, error_msg
)));
}
let result_text = extract_final_text(&new_messages);
let details = serde_json::json!({
"sub_agent": self.tool_name,
"turns": new_messages.len(),
});
Ok(ToolResult {
content: vec![Content::Text { text: result_text }],
details,
})
}
}
fn extract_error(messages: &[AgentMessage]) -> Option<String> {
for msg in messages.iter().rev() {
if let AgentMessage::Llm(Message::Assistant {
stop_reason,
error_message,
..
}) = msg
{
if *stop_reason == StopReason::Error {
return Some(
error_message
.clone()
.unwrap_or_else(|| "Unknown error".into()),
);
}
}
}
None
}
fn extract_final_text(messages: &[AgentMessage]) -> String {
for msg in messages.iter().rev() {
if let AgentMessage::Llm(Message::Assistant { content, .. }) = msg {
let texts: Vec<&str> = content
.iter()
.filter_map(|c| match c {
Content::Text { text } if !text.is_empty() => Some(text.as_str()),
_ => None,
})
.collect();
if !texts.is_empty() {
return texts.join("\n");
}
}
}
"(sub-agent produced no text output)".to_string()
}