Skip to main content

aonyx_llm/
claude_code.rs

1//! Claude Code provider — uses an installed `claude` binary as the backend.
2//!
3//! Lets users with a Claude subscription (Pro / Max / Team) or an
4//! `ANTHROPIC_API_KEY` already wired into Claude Code drive Aonyx Agent
5//! without configuring a second key in `~/.aonyx/config.toml`.
6//!
7//! ## How it works
8//!
9//! Each `chat_stream` call spawns:
10//!
11//! ```text
12//! claude -p --output-format stream-json --verbose [--model <model>] [extra_args...]
13//! ```
14//!
15//! The full conversation is written to the child's stdin as a single
16//! plain-text transcript (system / user / assistant / tool messages, each
17//! prefixed by a role tag). Claude Code emits one JSON object per line on
18//! stdout. We forward `assistant` text content as [`ChatChunk`] deltas and
19//! emit a terminal chunk on the `result` event.
20//!
21//! ## Behaviour notes
22//!
23//! - **Auth**: handled entirely by Claude Code. Aonyx never sees the user's
24//!   credentials.
25//! - **Streaming**: Claude Code may emit the full assistant message at every
26//!   update (a partial-replace pattern) instead of pure deltas. We track the
27//!   last surface text and forward the suffix when it grows, falling back to
28//!   the full text otherwise.
29//! - **Tool calls**: the V1 implementation forwards text only. Native Claude
30//!   Code tool invocations (Read, Bash, …) happen inside the child process
31//!   and never become Aonyx `ToolCall`s.
32//! - **Prerequisites**: the `claude` binary must be installed and on `PATH`.
33//!   A typical install: `npm install -g @anthropic-ai/claude-code` or
34//!   download from <https://claude.ai/install>.
35
36use std::process::Stdio;
37
38use aonyx_core::{
39    AonyxError, ChatChunk, ChatRequest, ChatStream, LlmProvider, Message, Result, Role,
40};
41use async_stream::try_stream;
42use async_trait::async_trait;
43use serde::Deserialize;
44use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
45use tokio::process::Command;
46
47/// Default binary name. Resolved via `PATH`.
48pub const CLAUDE_DEFAULT_BIN: &str = "claude";
49
50/// Claude Code provider.
51#[derive(Clone)]
52pub struct ClaudeCodeProvider {
53    binary: String,
54    extra_args: Vec<String>,
55}
56
57impl ClaudeCodeProvider {
58    /// Build a provider that runs `claude` from `PATH`.
59    pub fn new() -> Self {
60        Self {
61            binary: CLAUDE_DEFAULT_BIN.to_string(),
62            extra_args: Vec::new(),
63        }
64    }
65
66    /// Override the binary path (e.g. `"C:/Users/x/.claude/local/claude.exe"`).
67    pub fn with_binary(mut self, binary: impl Into<String>) -> Self {
68        self.binary = binary.into();
69        self
70    }
71
72    /// Append extra arguments forwarded to every `claude` invocation
73    /// (e.g. `["--max-turns", "5"]`).
74    pub fn with_extra_args(mut self, args: Vec<String>) -> Self {
75        self.extra_args = args;
76        self
77    }
78
79    /// Inspect the configured binary path.
80    pub fn binary(&self) -> &str {
81        &self.binary
82    }
83}
84
85impl Default for ClaudeCodeProvider {
86    fn default() -> Self {
87        Self::new()
88    }
89}
90
91#[async_trait]
92impl LlmProvider for ClaudeCodeProvider {
93    fn name(&self) -> &str {
94        "claude-code"
95    }
96
97    async fn chat_stream(&self, req: ChatRequest) -> Result<ChatStream> {
98        let prompt = render_conversation(&req.messages);
99
100        let mut cmd = Command::new(&self.binary);
101        cmd.arg("-p")
102            .arg("--output-format")
103            .arg("stream-json")
104            .arg("--verbose");
105        if !req.model.is_empty() {
106            cmd.arg("--model").arg(&req.model);
107        }
108        for arg in &self.extra_args {
109            cmd.arg(arg);
110        }
111        cmd.stdin(Stdio::piped())
112            .stdout(Stdio::piped())
113            .stderr(Stdio::piped())
114            .kill_on_drop(true);
115
116        let mut child = cmd.spawn().map_err(|e| {
117            AonyxError::Provider(format!(
118                "claude-code spawn: {e}; is '{}' installed and on PATH?",
119                self.binary
120            ))
121        })?;
122
123        if let Some(mut stdin) = child.stdin.take() {
124            stdin
125                .write_all(prompt.as_bytes())
126                .await
127                .map_err(|e| AonyxError::Provider(format!("claude-code stdin: {e}")))?;
128            stdin
129                .shutdown()
130                .await
131                .map_err(|e| AonyxError::Provider(format!("claude-code stdin close: {e}")))?;
132        }
133
134        let stdout = child
135            .stdout
136            .take()
137            .ok_or_else(|| AonyxError::Provider("claude-code: no stdout pipe".into()))?;
138        let mut reader = BufReader::new(stdout).lines();
139
140        let chunk_stream = try_stream! {
141            let mut last_text = String::new();
142            let mut emitted_finish = false;
143            loop {
144                match reader.next_line().await {
145                    Ok(Some(line)) => {
146                        if line.trim().is_empty() {
147                            continue;
148                        }
149                        if let Some(chunk) = parse_event_line(&line, &mut last_text) {
150                            if chunk.finished {
151                                emitted_finish = true;
152                            }
153                            yield chunk;
154                        }
155                    }
156                    Ok(None) => break,
157                    Err(e) => {
158                        Err(AonyxError::Provider(format!("claude-code read: {e}")))?;
159                    }
160                }
161            }
162
163            match child.wait().await {
164                Ok(status) if !status.success() => {
165                    Err(AonyxError::Provider(format!(
166                        "claude-code exit {}",
167                        status.code().unwrap_or(-1)
168                    )))?;
169                }
170                Err(e) => {
171                    Err(AonyxError::Provider(format!("claude-code wait: {e}")))?;
172                }
173                Ok(_) => {}
174            }
175
176            if !emitted_finish {
177                yield ChatChunk {
178                    delta_text: String::new(),
179                    tool_call: None,
180                    finished: true,
181                };
182            }
183        };
184
185        Ok(Box::pin(chunk_stream))
186    }
187}
188
189fn render_conversation(messages: &[Message]) -> String {
190    let mut out = String::new();
191    for m in messages {
192        let tag = match m.role {
193            Role::System => "[system]",
194            Role::User => "[user]",
195            Role::Assistant => "[assistant]",
196            Role::Tool => "[tool result]",
197        };
198        out.push_str(tag);
199        out.push('\n');
200        out.push_str(&m.content);
201        out.push_str("\n\n");
202    }
203    out
204}
205
206#[derive(Deserialize)]
207#[serde(tag = "type")]
208enum ClaudeEvent {
209    #[serde(rename = "assistant")]
210    Assistant { message: ClaudeMessage },
211    /// `result` marks end-of-turn; we only care about the tag, every payload
212    /// field (subtype, result, cost_usd, duration_ms, …) is intentionally
213    /// dropped via [`serde::de::IgnoredAny`].
214    #[serde(rename = "result")]
215    Result(serde::de::IgnoredAny),
216    #[serde(other)]
217    Other,
218}
219
220#[derive(Deserialize)]
221struct ClaudeMessage {
222    #[serde(default)]
223    content: Vec<ClaudeContent>,
224}
225
226#[derive(Deserialize)]
227#[serde(tag = "type")]
228enum ClaudeContent {
229    #[serde(rename = "text")]
230    Text { text: String },
231    #[serde(other)]
232    Other,
233}
234
235fn extract_text(message: ClaudeMessage) -> String {
236    let mut out = String::new();
237    for c in message.content {
238        if let ClaudeContent::Text { text } = c {
239            out.push_str(&text);
240        }
241    }
242    out
243}
244
245/// Parse one stream-json line, updating `last_text` for delta tracking.
246pub(crate) fn parse_event_line(line: &str, last_text: &mut String) -> Option<ChatChunk> {
247    let event: ClaudeEvent = serde_json::from_str(line).ok()?;
248    match event {
249        ClaudeEvent::Assistant { message } => {
250            let full = extract_text(message);
251            if full.is_empty() {
252                return None;
253            }
254            // Partial-replace pattern: forward only the new suffix.
255            if full.starts_with(last_text.as_str()) && full.len() > last_text.len() {
256                let delta = full[last_text.len()..].to_string();
257                *last_text = full;
258                Some(ChatChunk {
259                    delta_text: delta,
260                    tool_call: None,
261                    finished: false,
262                })
263            } else if full == *last_text {
264                None
265            } else {
266                // Pure-delta stream: forward as-is and reset the surface.
267                *last_text = full.clone();
268                Some(ChatChunk {
269                    delta_text: full,
270                    tool_call: None,
271                    finished: false,
272                })
273            }
274        }
275        ClaudeEvent::Result(_) => Some(ChatChunk {
276            delta_text: String::new(),
277            tool_call: None,
278            finished: true,
279        }),
280        ClaudeEvent::Other => None,
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287    use aonyx_core::Message;
288
289    #[test]
290    fn provider_name_is_claude_code() {
291        let p = ClaudeCodeProvider::new();
292        assert_eq!(p.name(), "claude-code");
293        assert_eq!(p.binary(), CLAUDE_DEFAULT_BIN);
294    }
295
296    #[test]
297    fn with_binary_overrides_default() {
298        let p = ClaudeCodeProvider::new().with_binary("/opt/claude");
299        assert_eq!(p.binary(), "/opt/claude");
300    }
301
302    #[test]
303    fn render_conversation_tags_every_role() {
304        let msgs = vec![
305            Message::new(Role::System, "be brief"),
306            Message::new(Role::User, "hi"),
307            Message::new(Role::Assistant, "hello"),
308            Message::new(Role::Tool, "tool said x"),
309        ];
310        let s = render_conversation(&msgs);
311        assert!(s.contains("[system]"));
312        assert!(s.contains("be brief"));
313        assert!(s.contains("[user]"));
314        assert!(s.contains("hi"));
315        assert!(s.contains("[assistant]"));
316        assert!(s.contains("hello"));
317        assert!(s.contains("[tool result]"));
318        assert!(s.contains("tool said x"));
319    }
320
321    #[test]
322    fn parses_assistant_text_event() {
323        let mut last = String::new();
324        let line = r#"{"type":"assistant","message":{"content":[{"type":"text","text":"Hello"}]}}"#;
325        let got = parse_event_line(line, &mut last).expect("parsed");
326        assert_eq!(got.delta_text, "Hello");
327        assert!(!got.finished);
328        assert_eq!(last, "Hello");
329    }
330
331    #[test]
332    fn emits_delta_when_assistant_message_grows() {
333        let mut last = String::from("Hello");
334        let line =
335            r#"{"type":"assistant","message":{"content":[{"type":"text","text":"Hello world"}]}}"#;
336        let got = parse_event_line(line, &mut last).expect("parsed");
337        assert_eq!(got.delta_text, " world");
338        assert_eq!(last, "Hello world");
339    }
340
341    #[test]
342    fn duplicate_assistant_message_is_ignored() {
343        let mut last = String::from("Hello");
344        let line = r#"{"type":"assistant","message":{"content":[{"type":"text","text":"Hello"}]}}"#;
345        assert!(parse_event_line(line, &mut last).is_none());
346    }
347
348    #[test]
349    fn replaced_assistant_message_emits_full_text() {
350        let mut last = String::from("draft answer");
351        let line =
352            r#"{"type":"assistant","message":{"content":[{"type":"text","text":"final reply"}]}}"#;
353        let got = parse_event_line(line, &mut last).expect("parsed");
354        assert_eq!(got.delta_text, "final reply");
355        assert_eq!(last, "final reply");
356    }
357
358    #[test]
359    fn result_event_marks_finished() {
360        let mut last = String::new();
361        let line = r#"{"type":"result","subtype":"success","result":"done","cost_usd":0.001,"duration_ms":1234,"num_turns":1,"session_id":"abc","is_error":false}"#;
362        let got = parse_event_line(line, &mut last).expect("parsed");
363        assert!(got.finished);
364        assert!(got.delta_text.is_empty());
365    }
366
367    #[test]
368    fn ignores_system_init_event() {
369        let mut last = String::new();
370        let line = r#"{"type":"system","subtype":"init","session_id":"abc"}"#;
371        assert!(parse_event_line(line, &mut last).is_none());
372    }
373
374    #[test]
375    fn ignores_non_text_content_blocks() {
376        let mut last = String::new();
377        let line = r#"{"type":"assistant","message":{"content":[{"type":"tool_use","id":"x","name":"Read","input":{}}]}}"#;
378        assert!(parse_event_line(line, &mut last).is_none());
379    }
380
381    #[test]
382    fn malformed_json_is_silently_skipped() {
383        let mut last = String::new();
384        assert!(parse_event_line("not json", &mut last).is_none());
385    }
386}