pub mod cli;
pub mod direct;
pub mod simulated;
use anyhow::Result;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use std::time::Duration;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use crate::commands::spawn::terminal::Harness;
#[async_trait]
pub trait AgentBackend: Send + Sync {
async fn execute(&self, request: AgentRequest) -> Result<AgentHandle>;
}
#[derive(Debug, Clone)]
pub struct AgentRequest {
pub prompt: String,
pub system_prompt: Option<String>,
pub working_dir: PathBuf,
pub model: Option<String>,
pub provider: Option<String>,
pub max_turns: Option<usize>,
pub timeout: Option<Duration>,
pub reasoning_effort: Option<String>,
}
impl Default for AgentRequest {
fn default() -> Self {
Self {
prompt: String::new(),
system_prompt: None,
working_dir: std::env::current_dir().unwrap_or_default(),
model: None,
provider: None,
max_turns: None,
timeout: None,
reasoning_effort: None,
}
}
}
pub struct AgentHandle {
pub events: mpsc::Receiver<AgentEvent>,
pub cancel: CancellationToken,
}
impl AgentHandle {
pub async fn result(mut self) -> Result<AgentResult> {
let mut text_parts = Vec::new();
let mut tool_calls = Vec::new();
let mut status = AgentStatus::Completed;
let usage = None;
while let Some(event) = self.events.recv().await {
match event {
AgentEvent::TextDelta(delta) => text_parts.push(delta),
AgentEvent::TextComplete(text) => {
text_parts.clear();
text_parts.push(text);
}
AgentEvent::ToolCallStart { id, name } => {
tool_calls.push(ToolCallRecord {
id,
name,
output: String::new(),
});
}
AgentEvent::ToolCallEnd { id, output } => {
if let Some(record) = tool_calls.iter_mut().find(|r| r.id == id) {
record.output = output;
}
}
AgentEvent::Complete(result) => return Ok(result),
AgentEvent::Error(msg) => {
status = AgentStatus::Failed(msg);
break;
}
AgentEvent::ThinkingDelta(_) => {}
}
}
Ok(AgentResult {
text: text_parts.join(""),
status,
tool_calls,
usage,
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentResult {
pub text: String,
pub status: AgentStatus,
pub tool_calls: Vec<ToolCallRecord>,
pub usage: Option<TokenUsage>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum AgentStatus {
Completed,
Failed(String),
Cancelled,
Timeout,
}
impl AgentStatus {
pub fn is_success(&self) -> bool {
matches!(self, AgentStatus::Completed)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallRecord {
pub id: String,
pub name: String,
pub output: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenUsage {
pub input_tokens: u64,
pub output_tokens: u64,
}
#[derive(Debug, Clone)]
pub enum AgentEvent {
TextDelta(String),
TextComplete(String),
ToolCallStart { id: String, name: String },
ToolCallEnd { id: String, output: String },
ThinkingDelta(String),
Error(String),
Complete(AgentResult),
}
pub fn create_backend(harness: &Harness) -> Result<Box<dyn AgentBackend>> {
match harness {
#[cfg(feature = "direct-api")]
Harness::DirectApi => Ok(Box::new(direct::DirectApiBackend::new())),
_ => Ok(Box::new(cli::CliBackend::new(harness.clone())?)),
}
}
pub fn create_simulated_backend() -> Box<dyn AgentBackend> {
Box::new(simulated::SimulatedBackend)
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_agent_request_default() {
let req = AgentRequest::default();
assert!(req.prompt.is_empty());
assert!(req.model.is_none());
assert!(req.timeout.is_none());
}
#[tokio::test]
async fn test_agent_status_is_success() {
assert!(AgentStatus::Completed.is_success());
assert!(!AgentStatus::Failed("err".into()).is_success());
assert!(!AgentStatus::Cancelled.is_success());
assert!(!AgentStatus::Timeout.is_success());
}
#[tokio::test]
async fn test_simulated_backend() {
let backend = create_simulated_backend();
let req = AgentRequest {
prompt: "Hello world".into(),
..Default::default()
};
let handle = backend.execute(req).await.unwrap();
let result = handle.result().await.unwrap();
assert!(result.status.is_success());
assert!(result.text.contains("Simulated"));
}
}