1use std::process::Stdio;
37
38use aonyx_core::{
39 AonyxError, ChatChunk, ChatRequest, ChatStream, LlmProvider, Message, Result, Role,
40};
41use async_stream::try_stream;
42use async_trait::async_trait;
43use serde::Deserialize;
44use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
45use tokio::process::Command;
46
47pub const CLAUDE_DEFAULT_BIN: &str = "claude";
49
50#[derive(Clone)]
52pub struct ClaudeCodeProvider {
53 binary: String,
54 extra_args: Vec<String>,
55}
56
57impl ClaudeCodeProvider {
58 pub fn new() -> Self {
60 Self {
61 binary: CLAUDE_DEFAULT_BIN.to_string(),
62 extra_args: Vec::new(),
63 }
64 }
65
66 pub fn with_binary(mut self, binary: impl Into<String>) -> Self {
68 self.binary = binary.into();
69 self
70 }
71
72 pub fn with_extra_args(mut self, args: Vec<String>) -> Self {
75 self.extra_args = args;
76 self
77 }
78
79 pub fn binary(&self) -> &str {
81 &self.binary
82 }
83}
84
85impl Default for ClaudeCodeProvider {
86 fn default() -> Self {
87 Self::new()
88 }
89}
90
91#[async_trait]
92impl LlmProvider for ClaudeCodeProvider {
93 fn name(&self) -> &str {
94 "claude-code"
95 }
96
97 async fn chat_stream(&self, req: ChatRequest) -> Result<ChatStream> {
98 let prompt = render_conversation(&req.messages);
99
100 let mut cmd = Command::new(&self.binary);
101 cmd.arg("-p")
102 .arg("--output-format")
103 .arg("stream-json")
104 .arg("--verbose");
105 if !req.model.is_empty() {
106 cmd.arg("--model").arg(&req.model);
107 }
108 for arg in &self.extra_args {
109 cmd.arg(arg);
110 }
111 cmd.stdin(Stdio::piped())
112 .stdout(Stdio::piped())
113 .stderr(Stdio::piped())
114 .kill_on_drop(true);
115
116 let mut child = cmd.spawn().map_err(|e| {
117 AonyxError::Provider(format!(
118 "claude-code spawn: {e}; is '{}' installed and on PATH?",
119 self.binary
120 ))
121 })?;
122
123 if let Some(mut stdin) = child.stdin.take() {
124 stdin
125 .write_all(prompt.as_bytes())
126 .await
127 .map_err(|e| AonyxError::Provider(format!("claude-code stdin: {e}")))?;
128 stdin
129 .shutdown()
130 .await
131 .map_err(|e| AonyxError::Provider(format!("claude-code stdin close: {e}")))?;
132 }
133
134 let stdout = child
135 .stdout
136 .take()
137 .ok_or_else(|| AonyxError::Provider("claude-code: no stdout pipe".into()))?;
138 let mut reader = BufReader::new(stdout).lines();
139
140 let chunk_stream = try_stream! {
141 let mut last_text = String::new();
142 let mut emitted_finish = false;
143 loop {
144 match reader.next_line().await {
145 Ok(Some(line)) => {
146 if line.trim().is_empty() {
147 continue;
148 }
149 if let Some(chunk) = parse_event_line(&line, &mut last_text) {
150 if chunk.finished {
151 emitted_finish = true;
152 }
153 yield chunk;
154 }
155 }
156 Ok(None) => break,
157 Err(e) => {
158 Err(AonyxError::Provider(format!("claude-code read: {e}")))?;
159 }
160 }
161 }
162
163 match child.wait().await {
164 Ok(status) if !status.success() => {
165 Err(AonyxError::Provider(format!(
166 "claude-code exit {}",
167 status.code().unwrap_or(-1)
168 )))?;
169 }
170 Err(e) => {
171 Err(AonyxError::Provider(format!("claude-code wait: {e}")))?;
172 }
173 Ok(_) => {}
174 }
175
176 if !emitted_finish {
177 yield ChatChunk {
178 delta_text: String::new(),
179 tool_call: None,
180 finished: true,
181 };
182 }
183 };
184
185 Ok(Box::pin(chunk_stream))
186 }
187}
188
189fn render_conversation(messages: &[Message]) -> String {
190 let mut out = String::new();
191 for m in messages {
192 let tag = match m.role {
193 Role::System => "[system]",
194 Role::User => "[user]",
195 Role::Assistant => "[assistant]",
196 Role::Tool => "[tool result]",
197 };
198 out.push_str(tag);
199 out.push('\n');
200 out.push_str(&m.content);
201 out.push_str("\n\n");
202 }
203 out
204}
205
206#[derive(Deserialize)]
207#[serde(tag = "type")]
208enum ClaudeEvent {
209 #[serde(rename = "assistant")]
210 Assistant { message: ClaudeMessage },
211 #[serde(rename = "result")]
215 Result(serde::de::IgnoredAny),
216 #[serde(other)]
217 Other,
218}
219
220#[derive(Deserialize)]
221struct ClaudeMessage {
222 #[serde(default)]
223 content: Vec<ClaudeContent>,
224}
225
226#[derive(Deserialize)]
227#[serde(tag = "type")]
228enum ClaudeContent {
229 #[serde(rename = "text")]
230 Text { text: String },
231 #[serde(other)]
232 Other,
233}
234
235fn extract_text(message: ClaudeMessage) -> String {
236 let mut out = String::new();
237 for c in message.content {
238 if let ClaudeContent::Text { text } = c {
239 out.push_str(&text);
240 }
241 }
242 out
243}
244
245pub(crate) fn parse_event_line(line: &str, last_text: &mut String) -> Option<ChatChunk> {
247 let event: ClaudeEvent = serde_json::from_str(line).ok()?;
248 match event {
249 ClaudeEvent::Assistant { message } => {
250 let full = extract_text(message);
251 if full.is_empty() {
252 return None;
253 }
254 if full.starts_with(last_text.as_str()) && full.len() > last_text.len() {
256 let delta = full[last_text.len()..].to_string();
257 *last_text = full;
258 Some(ChatChunk {
259 delta_text: delta,
260 tool_call: None,
261 finished: false,
262 })
263 } else if full == *last_text {
264 None
265 } else {
266 *last_text = full.clone();
268 Some(ChatChunk {
269 delta_text: full,
270 tool_call: None,
271 finished: false,
272 })
273 }
274 }
275 ClaudeEvent::Result(_) => Some(ChatChunk {
276 delta_text: String::new(),
277 tool_call: None,
278 finished: true,
279 }),
280 ClaudeEvent::Other => None,
281 }
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287 use aonyx_core::Message;
288
289 #[test]
290 fn provider_name_is_claude_code() {
291 let p = ClaudeCodeProvider::new();
292 assert_eq!(p.name(), "claude-code");
293 assert_eq!(p.binary(), CLAUDE_DEFAULT_BIN);
294 }
295
296 #[test]
297 fn with_binary_overrides_default() {
298 let p = ClaudeCodeProvider::new().with_binary("/opt/claude");
299 assert_eq!(p.binary(), "/opt/claude");
300 }
301
302 #[test]
303 fn render_conversation_tags_every_role() {
304 let msgs = vec![
305 Message::new(Role::System, "be brief"),
306 Message::new(Role::User, "hi"),
307 Message::new(Role::Assistant, "hello"),
308 Message::new(Role::Tool, "tool said x"),
309 ];
310 let s = render_conversation(&msgs);
311 assert!(s.contains("[system]"));
312 assert!(s.contains("be brief"));
313 assert!(s.contains("[user]"));
314 assert!(s.contains("hi"));
315 assert!(s.contains("[assistant]"));
316 assert!(s.contains("hello"));
317 assert!(s.contains("[tool result]"));
318 assert!(s.contains("tool said x"));
319 }
320
321 #[test]
322 fn parses_assistant_text_event() {
323 let mut last = String::new();
324 let line = r#"{"type":"assistant","message":{"content":[{"type":"text","text":"Hello"}]}}"#;
325 let got = parse_event_line(line, &mut last).expect("parsed");
326 assert_eq!(got.delta_text, "Hello");
327 assert!(!got.finished);
328 assert_eq!(last, "Hello");
329 }
330
331 #[test]
332 fn emits_delta_when_assistant_message_grows() {
333 let mut last = String::from("Hello");
334 let line =
335 r#"{"type":"assistant","message":{"content":[{"type":"text","text":"Hello world"}]}}"#;
336 let got = parse_event_line(line, &mut last).expect("parsed");
337 assert_eq!(got.delta_text, " world");
338 assert_eq!(last, "Hello world");
339 }
340
341 #[test]
342 fn duplicate_assistant_message_is_ignored() {
343 let mut last = String::from("Hello");
344 let line = r#"{"type":"assistant","message":{"content":[{"type":"text","text":"Hello"}]}}"#;
345 assert!(parse_event_line(line, &mut last).is_none());
346 }
347
348 #[test]
349 fn replaced_assistant_message_emits_full_text() {
350 let mut last = String::from("draft answer");
351 let line =
352 r#"{"type":"assistant","message":{"content":[{"type":"text","text":"final reply"}]}}"#;
353 let got = parse_event_line(line, &mut last).expect("parsed");
354 assert_eq!(got.delta_text, "final reply");
355 assert_eq!(last, "final reply");
356 }
357
358 #[test]
359 fn result_event_marks_finished() {
360 let mut last = String::new();
361 let line = r#"{"type":"result","subtype":"success","result":"done","cost_usd":0.001,"duration_ms":1234,"num_turns":1,"session_id":"abc","is_error":false}"#;
362 let got = parse_event_line(line, &mut last).expect("parsed");
363 assert!(got.finished);
364 assert!(got.delta_text.is_empty());
365 }
366
367 #[test]
368 fn ignores_system_init_event() {
369 let mut last = String::new();
370 let line = r#"{"type":"system","subtype":"init","session_id":"abc"}"#;
371 assert!(parse_event_line(line, &mut last).is_none());
372 }
373
374 #[test]
375 fn ignores_non_text_content_blocks() {
376 let mut last = String::new();
377 let line = r#"{"type":"assistant","message":{"content":[{"type":"tool_use","id":"x","name":"Read","input":{}}]}}"#;
378 assert!(parse_event_line(line, &mut last).is_none());
379 }
380
381 #[test]
382 fn malformed_json_is_silently_skipped() {
383 let mut last = String::new();
384 assert!(parse_event_line("not json", &mut last).is_none());
385 }
386}