1use anyhow::Result;
7use async_trait::async_trait;
8
9use super::{AgentBackend, AgentHandle, AgentRequest};
10
11#[cfg(feature = "direct-api")]
12use {
13 super::{AgentEvent, AgentResult, AgentStatus, ToolCallRecord},
14 tokio::sync::mpsc,
15 tokio_util::sync::CancellationToken,
16};
17
18#[cfg(feature = "direct-api")]
22pub struct DirectApiBackend {
23 max_tokens: u32,
24}
25
26#[cfg(feature = "direct-api")]
27impl DirectApiBackend {
28 pub fn new() -> Self {
29 Self { max_tokens: 16_000 }
30 }
31}
32
33#[cfg(feature = "direct-api")]
34#[async_trait]
35impl AgentBackend for DirectApiBackend {
36 async fn execute(&self, req: AgentRequest) -> Result<AgentHandle> {
37 use crate::commands::spawn::headless::events::{StreamEvent, StreamEventKind};
38 use crate::llm::agent;
39 use crate::llm::provider::AgentProvider;
40
41 let (event_tx, rx) = mpsc::channel(1000);
42 let cancel = CancellationToken::new();
43
44 let provider = if let Some(ref p) = req.provider {
45 AgentProvider::from_provider_str(p)?
46 } else {
47 AgentProvider::Anthropic
48 };
49
50 let model = req.model.clone();
51 let max_tokens = self.max_tokens;
52 let prompt = req.prompt.clone();
53 let working_dir = req.working_dir.clone();
54 let system_prompt = req.system_prompt.clone();
55
56 let (stream_tx, mut stream_rx) = mpsc::channel::<StreamEvent>(1000);
58 let cancel_clone = cancel.clone();
59
60 let stream_tx_err = stream_tx.clone();
62 tokio::spawn(async move {
63 if let Err(e) = agent::run_agent_loop(
64 &prompt,
65 system_prompt.as_deref(),
66 &working_dir,
67 model.as_deref(),
68 max_tokens,
69 stream_tx,
70 &provider,
71 )
72 .await
73 {
74 let _ = stream_tx_err.send(StreamEvent::error(&e.to_string())).await;
75 let _ = stream_tx_err.send(StreamEvent::complete(false)).await;
76 }
77 });
78
79 tokio::spawn(async move {
81 let mut text_parts = Vec::new();
82 let mut tool_calls: Vec<ToolCallRecord> = Vec::new();
83
84 loop {
85 tokio::select! {
86 _ = cancel_clone.cancelled() => {
87 let _ = event_tx.send(AgentEvent::Complete(AgentResult {
88 text: text_parts.join(""),
89 status: AgentStatus::Cancelled,
90 tool_calls,
91 usage: None,
92 })).await;
93 break;
94 }
95 event = stream_rx.recv() => {
96 match event {
97 Some(stream_event) => {
98 let agent_event = match &stream_event.kind {
99 StreamEventKind::TextDelta { text } => {
100 text_parts.push(text.clone());
101 AgentEvent::TextDelta(text.clone())
102 }
103 StreamEventKind::ToolStart { tool_name, tool_id, .. } => {
104 tool_calls.push(ToolCallRecord {
105 id: tool_id.clone(),
106 name: tool_name.clone(),
107 output: String::new(),
108 });
109 AgentEvent::ToolCallStart {
110 id: tool_id.clone(),
111 name: tool_name.clone(),
112 }
113 }
114 StreamEventKind::ToolResult { tool_id, success, .. } => {
115 if let Some(record) = tool_calls.iter_mut().find(|r| r.id == *tool_id) {
116 record.output = if *success { "ok".into() } else { "error".into() };
117 }
118 AgentEvent::ToolCallEnd {
119 id: tool_id.clone(),
120 output: if *success { "ok".into() } else { "error".into() },
121 }
122 }
123 StreamEventKind::Complete { success } => {
124 let status = if *success {
125 AgentStatus::Completed
126 } else {
127 AgentStatus::Failed("Agent reported failure".into())
128 };
129 let _ = event_tx.send(AgentEvent::Complete(AgentResult {
130 text: text_parts.join(""),
131 status,
132 tool_calls: tool_calls.clone(),
133 usage: None,
134 })).await;
135 break;
136 }
137 StreamEventKind::Error { message } => {
138 AgentEvent::Error(message.clone())
139 }
140 StreamEventKind::SessionAssigned { .. } => continue,
141 };
142 if event_tx.send(agent_event).await.is_err() {
143 break;
144 }
145 }
146 None => {
147 let _ = event_tx.send(AgentEvent::Complete(AgentResult {
148 text: text_parts.join(""),
149 status: AgentStatus::Completed,
150 tool_calls,
151 usage: None,
152 })).await;
153 break;
154 }
155 }
156 }
157 }
158 }
159 });
160
161 Ok(AgentHandle { events: rx, cancel })
162 }
163}
164
165#[cfg(not(feature = "direct-api"))]
167pub struct DirectApiBackend;
168
169#[cfg(not(feature = "direct-api"))]
170impl DirectApiBackend {
171 pub fn new() -> Self {
172 Self
173 }
174}
175
176#[cfg(not(feature = "direct-api"))]
177#[async_trait]
178impl AgentBackend for DirectApiBackend {
179 async fn execute(&self, _req: AgentRequest) -> Result<AgentHandle> {
180 anyhow::bail!("Direct API backend requires the 'direct-api' feature to be enabled")
181 }
182}