Skip to main content

scud/backend/
direct.rs

1//! Direct API backend.
2//!
3//! Wraps the existing `llm::agent::run_agent_loop()` behind the
4//! [`AgentBackend`] trait for in-process LLM execution.
5
6use anyhow::Result;
7use async_trait::async_trait;
8
9use super::{AgentBackend, AgentHandle, AgentRequest};
10
11#[cfg(feature = "direct-api")]
12use {
13    tokio::sync::mpsc,
14    tokio_util::sync::CancellationToken,
15    super::{AgentEvent, AgentResult, AgentStatus, ToolCallRecord},
16};
17
18/// Backend that calls LLM APIs directly (in-process).
19///
20/// Only available with the `direct-api` feature.
21#[cfg(feature = "direct-api")]
22pub struct DirectApiBackend {
23    max_tokens: u32,
24}
25
26#[cfg(feature = "direct-api")]
27impl DirectApiBackend {
28    pub fn new() -> Self {
29        Self {
30            max_tokens: 16_000,
31        }
32    }
33}
34
35#[cfg(feature = "direct-api")]
36#[async_trait]
37impl AgentBackend for DirectApiBackend {
38    async fn execute(&self, req: AgentRequest) -> Result<AgentHandle> {
39        use crate::commands::spawn::headless::events::{StreamEvent, StreamEventKind};
40        use crate::llm::agent;
41        use crate::llm::provider::AgentProvider;
42
43        let (event_tx, rx) = mpsc::channel(1000);
44        let cancel = CancellationToken::new();
45
46        let provider = if let Some(ref p) = req.provider {
47            AgentProvider::from_provider_str(p)?
48        } else {
49            AgentProvider::Anthropic
50        };
51
52        let model = req.model.clone();
53        let max_tokens = self.max_tokens;
54        let prompt = req.prompt.clone();
55        let working_dir = req.working_dir.clone();
56        let system_prompt = req.system_prompt.clone();
57
58        // Bridge: run_agent_loop emits StreamEvent, we convert to AgentEvent
59        let (stream_tx, mut stream_rx) = mpsc::channel::<StreamEvent>(1000);
60        let cancel_clone = cancel.clone();
61
62        // Spawn the agent loop
63        let stream_tx_err = stream_tx.clone();
64        tokio::spawn(async move {
65            if let Err(e) = agent::run_agent_loop(
66                &prompt,
67                system_prompt.as_deref(),
68                &working_dir,
69                model.as_deref(),
70                max_tokens,
71                stream_tx,
72                &provider,
73            )
74            .await
75            {
76                let _ = stream_tx_err
77                    .send(StreamEvent::error(&e.to_string()))
78                    .await;
79                let _ = stream_tx_err.send(StreamEvent::complete(false)).await;
80            }
81        });
82
83        // Bridge task: StreamEvent -> AgentEvent
84        tokio::spawn(async move {
85            let mut text_parts = Vec::new();
86            let mut tool_calls: Vec<ToolCallRecord> = Vec::new();
87
88            loop {
89                tokio::select! {
90                    _ = cancel_clone.cancelled() => {
91                        let _ = event_tx.send(AgentEvent::Complete(AgentResult {
92                            text: text_parts.join(""),
93                            status: AgentStatus::Cancelled,
94                            tool_calls,
95                            usage: None,
96                        })).await;
97                        break;
98                    }
99                    event = stream_rx.recv() => {
100                        match event {
101                            Some(stream_event) => {
102                                let agent_event = match &stream_event.kind {
103                                    StreamEventKind::TextDelta { text } => {
104                                        text_parts.push(text.clone());
105                                        AgentEvent::TextDelta(text.clone())
106                                    }
107                                    StreamEventKind::ToolStart { tool_name, tool_id, .. } => {
108                                        tool_calls.push(ToolCallRecord {
109                                            id: tool_id.clone(),
110                                            name: tool_name.clone(),
111                                            output: String::new(),
112                                        });
113                                        AgentEvent::ToolCallStart {
114                                            id: tool_id.clone(),
115                                            name: tool_name.clone(),
116                                        }
117                                    }
118                                    StreamEventKind::ToolResult { tool_id, success, .. } => {
119                                        if let Some(record) = tool_calls.iter_mut().find(|r| r.id == *tool_id) {
120                                            record.output = if *success { "ok".into() } else { "error".into() };
121                                        }
122                                        AgentEvent::ToolCallEnd {
123                                            id: tool_id.clone(),
124                                            output: if *success { "ok".into() } else { "error".into() },
125                                        }
126                                    }
127                                    StreamEventKind::Complete { success } => {
128                                        let status = if *success {
129                                            AgentStatus::Completed
130                                        } else {
131                                            AgentStatus::Failed("Agent reported failure".into())
132                                        };
133                                        let _ = event_tx.send(AgentEvent::Complete(AgentResult {
134                                            text: text_parts.join(""),
135                                            status,
136                                            tool_calls: tool_calls.clone(),
137                                            usage: None,
138                                        })).await;
139                                        break;
140                                    }
141                                    StreamEventKind::Error { message } => {
142                                        AgentEvent::Error(message.clone())
143                                    }
144                                    StreamEventKind::SessionAssigned { .. } => continue,
145                                };
146                                if event_tx.send(agent_event).await.is_err() {
147                                    break;
148                                }
149                            }
150                            None => {
151                                let _ = event_tx.send(AgentEvent::Complete(AgentResult {
152                                    text: text_parts.join(""),
153                                    status: AgentStatus::Completed,
154                                    tool_calls,
155                                    usage: None,
156                                })).await;
157                                break;
158                            }
159                        }
160                    }
161                }
162            }
163        });
164
165        Ok(AgentHandle { events: rx, cancel })
166    }
167}
168
169// Stub when direct-api feature is not enabled
170#[cfg(not(feature = "direct-api"))]
171pub struct DirectApiBackend;
172
173#[cfg(not(feature = "direct-api"))]
174impl DirectApiBackend {
175    pub fn new() -> Self {
176        Self
177    }
178}
179
180#[cfg(not(feature = "direct-api"))]
181#[async_trait]
182impl AgentBackend for DirectApiBackend {
183    async fn execute(&self, _req: AgentRequest) -> Result<AgentHandle> {
184        anyhow::bail!("Direct API backend requires the 'direct-api' feature to be enabled")
185    }
186}