use crate::types::*;
use crate::engine::{QueryEngine, QueryEngineConfig};
use crate::error::AgentError;
use crate::env::EnvConfig;
use crate::tools::bash::BashTool;
use crate::tools::read::FileReadTool as ReadTool;
use crate::tools::write::FileWriteTool as WriteTool;
use crate::tools::glob::GlobTool;
use crate::tools::grep::GrepTool;
use crate::tools::edit::FileEditTool;
fn register_all_tool_executors(engine: &mut QueryEngine) {
type BoxFuture<T> = std::pin::Pin<Box<dyn std::future::Future<Output = T> + Send>>;
let bash_executor = move |input: serde_json::Value, ctx: &ToolContext| -> BoxFuture<Result<ToolResult, AgentError>> {
let tool_clone = BashTool::new();
let cwd = ctx.cwd.clone();
Box::pin(async move {
let ctx2 = ToolContext { cwd, abort_signal: None };
tool_clone.execute(input, &ctx2).await
})
};
engine.register_tool("Bash".to_string(), bash_executor);
let read_executor = move |input: serde_json::Value, ctx: &ToolContext| -> BoxFuture<Result<ToolResult, AgentError>> {
let tool_clone = ReadTool::new();
let cwd = ctx.cwd.clone();
Box::pin(async move {
let ctx2 = ToolContext { cwd, abort_signal: None };
tool_clone.execute(input, &ctx2).await
})
};
engine.register_tool("FileRead".to_string(), read_executor);
let write_executor = move |input: serde_json::Value, ctx: &ToolContext| -> BoxFuture<Result<ToolResult, AgentError>> {
let tool_clone = WriteTool::new();
let cwd = ctx.cwd.clone();
Box::pin(async move {
let ctx2 = ToolContext { cwd, abort_signal: None };
tool_clone.execute(input, &ctx2).await
})
};
engine.register_tool("FileWrite".to_string(), write_executor);
let glob_executor = move |input: serde_json::Value, ctx: &ToolContext| -> BoxFuture<Result<ToolResult, AgentError>> {
let tool_clone = GlobTool::new();
let cwd = ctx.cwd.clone();
Box::pin(async move {
let ctx2 = ToolContext { cwd, abort_signal: None };
tool_clone.execute(input, &ctx2).await
})
};
engine.register_tool("Glob".to_string(), glob_executor);
let grep_executor = move |input: serde_json::Value, ctx: &ToolContext| -> BoxFuture<Result<ToolResult, AgentError>> {
let tool_clone = GrepTool::new();
let cwd = ctx.cwd.clone();
Box::pin(async move {
let ctx2 = ToolContext { cwd, abort_signal: None };
tool_clone.execute(input, &ctx2).await
})
};
engine.register_tool("Grep".to_string(), grep_executor);
let edit_executor = move |input: serde_json::Value, ctx: &ToolContext| -> BoxFuture<Result<ToolResult, AgentError>> {
let tool_clone = FileEditTool::new();
let cwd = ctx.cwd.clone();
Box::pin(async move {
let ctx2 = ToolContext { cwd, abort_signal: None };
tool_clone.execute(input, &ctx2).await
})
};
engine.register_tool("FileEdit".to_string(), edit_executor);
use crate::tools::skill::SkillTool;
use crate::tools::skill::register_skills_from_dir;
use std::path::Path;
register_skills_from_dir(Path::new("examples/skills"));
let skill_executor = move |input: serde_json::Value, ctx: &ToolContext| -> BoxFuture<Result<ToolResult, AgentError>> {
let tool_clone = SkillTool::new();
let cwd = ctx.cwd.clone();
Box::pin(async move {
let ctx2 = ToolContext { cwd, abort_signal: None };
tool_clone.execute(input, &ctx2).await
})
};
engine.register_tool("Skill".to_string(), skill_executor);
let stub_executor = |input: serde_json::Value, _ctx: &ToolContext| -> BoxFuture<Result<ToolResult, AgentError>> {
let tool_name = input.get("name")
.and_then(|n| n.as_str())
.unwrap_or("unknown")
.to_string();
Box::pin(async move {
Ok(ToolResult {
result_type: "text".to_string(),
tool_use_id: tool_name.clone(),
content: format!("Tool '{}' is not fully implemented yet", tool_name),
is_error: Some(false),
})
})
};
for tool_name in &["TaskCreate", "TaskList", "TaskUpdate", "TaskGet", "TeamCreate", "TeamDelete", "SendMessage", "EnterWorktree", "ExitWorktree", "EnterPlanMode", "ExitPlanMode", "AskUserQuestion", "ToolSearch", "CronCreate", "CronDelete", "CronList", "Config", "TodoWrite", "NotebookEdit", "WebFetch", "WebSearch", "Agent"] {
engine.register_tool(tool_name.to_string(), stub_executor);
}
}
pub struct Agent {
config: AgentOptions,
model: String,
api_key: Option<String>,
base_url: Option<String>,
tool_pool: Vec<ToolDefinition>,
messages: Vec<Message>,
session_id: String,
}
impl From<AgentOptions> for Agent {
fn from(options: AgentOptions) -> Self {
Agent::create(options)
}
}
impl Agent {
pub fn new(model: &str, max_turns: u32) -> Self {
Self::create(AgentOptions {
model: Some(model.to_string()),
max_turns: Some(max_turns),
..Default::default()
})
}
pub fn create(options: AgentOptions) -> Self {
let env_config = EnvConfig::load();
let model = env_config.model.clone()
.or_else(|| options.model.clone())
.unwrap_or_else(|| "claude-sonnet-4-6".to_string());
let api_key = env_config.auth_token.clone()
.or_else(|| options.api_key.clone());
let base_url = env_config.base_url.clone()
.or_else(|| options.base_url.clone());
let session_id = uuid::Uuid::new_v4().to_string();
Self {
config: options.clone(),
model,
api_key,
base_url,
tool_pool: options.tools.clone(),
messages: vec![],
session_id,
}
}
pub fn get_model(&self) -> &str {
&self.model
}
pub fn get_session_id(&self) -> &str {
&self.session_id
}
pub fn get_messages(&self) -> &[Message] {
&self.messages
}
pub fn get_tools(&self) -> &[ToolDefinition] {
&self.tool_pool
}
pub fn set_system_prompt(&mut self, prompt: &str) {
self.config.system_prompt = Some(prompt.to_string());
}
pub fn set_cwd(&mut self, cwd: &str) {
self.config.cwd = Some(cwd.to_string());
}
pub async fn execute_tool(&mut self, name: &str, input: serde_json::Value) -> Result<ToolResult, AgentError> {
let cwd = self.config.cwd.clone().unwrap_or_else(|| std::env::current_dir().map(|p| p.to_string_lossy().to_string()).unwrap_or_else(|_| ".".to_string()));
let model = self.model.clone();
let api_key = self.api_key.clone();
let base_url = self.base_url.clone();
let mut engine = QueryEngine::new(QueryEngineConfig {
cwd: cwd.clone(),
model: model.clone(),
api_key: api_key.clone(),
base_url: base_url.clone(),
tools: vec![],
system_prompt: None,
max_turns: 10,
max_budget_usd: None,
max_tokens: 16384,
can_use_tool: None,
});
register_all_tool_executors(&mut engine);
let agent_tool_executor = move |input: serde_json::Value, _ctx: &ToolContext| -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<ToolResult, AgentError>> + Send>> {
let cwd = cwd.clone();
let api_key = api_key.clone();
let base_url = base_url.clone();
let model = model.clone();
Box::pin(async move {
let description = input["description"].as_str().unwrap_or("subagent");
let subagent_prompt = input["prompt"].as_str().unwrap_or("");
let subagent_model = input["model"].as_str().map(|s| s.to_string()).unwrap_or_else(|| model.clone());
let max_turns = input["max_turns"].as_u64().unwrap_or(10) as u32;
let mut sub_engine = QueryEngine::new(QueryEngineConfig {
cwd,
model: subagent_model.to_string(),
api_key,
base_url,
tools: vec![],
system_prompt: None,
max_turns,
max_budget_usd: None,
max_tokens: 16384,
can_use_tool: None,
});
match sub_engine.submit_message(subagent_prompt).await {
Ok(result_text) => {
Ok(ToolResult {
result_type: "text".to_string(),
tool_use_id: "agent_tool".to_string(),
content: format!("[Subagent: {}]\n\n{}", description, result_text),
is_error: Some(false),
})
}
Err(e) => {
Ok(ToolResult {
result_type: "text".to_string(),
tool_use_id: "agent_tool".to_string(),
content: format!("[Subagent: {}] Error: {}", description, e),
is_error: Some(true),
})
}
}
})
};
engine.register_tool("Agent".to_string(), agent_tool_executor);
engine.execute_tool(name, input).await
}
pub async fn prompt(&mut self, prompt: &str) -> Result<QueryResult, AgentError> {
self.query(prompt).await
}
pub async fn query(&mut self, prompt: &str) -> Result<QueryResult, AgentError> {
use crate::memory::load_memory_prompt;
use crate::ai_md::load_ai_md;
use crate::tools::get_all_base_tools;
let cwd = self.config.cwd.clone().unwrap_or_else(|| std::env::current_dir().map(|p| p.to_string_lossy().to_string()).unwrap_or_else(|_| ".".to_string()));
let cwd_path = std::path::Path::new(&cwd);
let model = self.model.clone();
let api_key = self.api_key.clone();
let base_url = self.base_url.clone();
let ai_md_prompt = load_ai_md(cwd_path).ok().flatten();
let memory_prompt = load_memory_prompt();
let system_prompt = match (&ai_md_prompt, &memory_prompt, &self.config.system_prompt) {
(Some(ai_md), Some(mem), Some(custom)) => Some(format!("{}\n\n{}\n\n{}", ai_md, mem, custom)),
(Some(ai_md), Some(mem), None) => Some(format!("{}\n\n{}", ai_md, mem)),
(Some(ai_md), None, Some(custom)) => Some(format!("{}\n\n{}", ai_md, custom)),
(Some(ai_md), None, None) => Some(ai_md.clone()),
(None, Some(mem), Some(custom)) => Some(format!("{}\n\n{}", mem, custom)),
(None, Some(mem), None) => Some(mem.clone()),
(None, None, Some(custom)) => Some(custom.clone()),
(None, None, None) => None,
};
let tools = if self.tool_pool.is_empty() {
get_all_base_tools()
} else {
self.tool_pool.clone()
};
let mut engine = QueryEngine::new(QueryEngineConfig {
cwd: cwd.clone(),
model: model.clone(),
api_key: api_key.clone(),
base_url: base_url.clone(),
tools,
system_prompt,
max_turns: self.config.max_turns.unwrap_or(10),
max_budget_usd: self.config.max_budget_usd,
max_tokens: self.config.max_tokens.unwrap_or(16384),
can_use_tool: None,
});
let agent_tool_executor = move |input: serde_json::Value, _ctx: &ToolContext| -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<ToolResult, AgentError>> + Send>> {
let cwd = cwd.clone();
let api_key = api_key.clone();
let base_url = base_url.clone();
let model = model.clone();
Box::pin(async move {
let description = input["description"].as_str().unwrap_or("subagent");
let subagent_prompt = input["prompt"].as_str().unwrap_or("");
let subagent_model = input["model"].as_str().map(|s| s.to_string()).unwrap_or_else(|| model.clone());
let max_turns = input["max_turns"].as_u64().unwrap_or(10) as u32;
let mut sub_engine = QueryEngine::new(QueryEngineConfig {
cwd,
model: subagent_model.to_string(),
api_key,
base_url,
tools: vec![],
system_prompt: None,
max_turns,
max_budget_usd: None,
max_tokens: 16384,
can_use_tool: None,
});
match sub_engine.submit_message(subagent_prompt).await {
Ok(result_text) => {
Ok(ToolResult {
result_type: "text".to_string(),
tool_use_id: "agent_tool".to_string(),
content: format!("[Subagent: {}]\n\n{}", description, result_text),
is_error: Some(false),
})
}
Err(e) => {
Ok(ToolResult {
result_type: "text".to_string(),
tool_use_id: "agent_tool".to_string(),
content: format!("[Subagent: {}] Error: {}", description, e),
is_error: Some(true),
})
}
}
})
};
register_all_tool_executors(&mut engine);
engine.register_tool("Agent".to_string(), agent_tool_executor);
engine.set_messages(self.messages.clone());
let start = std::time::Instant::now();
let response_text = engine.submit_message(prompt).await?;
let messages = engine.get_messages();
let usage = TokenUsage {
input_tokens: 100, output_tokens: 50,
cache_creation_input_tokens: None,
cache_read_input_tokens: None,
};
self.messages = messages;
Ok(QueryResult {
text: response_text,
usage,
num_turns: engine.get_turn_count(),
duration_ms: start.elapsed().as_millis() as u64,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_create_agent() {
let agent = Agent::create(AgentOptions {
model: Some("claude-sonnet-4-6".to_string()),
..Default::default()
});
assert!(!agent.get_model().is_empty());
}
}