1use std::{
2 collections::BTreeMap,
3 path::PathBuf,
4 pin::Pin,
5 process::Stdio,
6 task::{Context, Poll},
7 time::{Duration, Instant},
8};
9
10use futures_core::Stream;
11use tokio::{
12 io::{AsyncBufReadExt, AsyncReadExt, BufReader},
13 sync::{mpsc, oneshot},
14};
15
16use crate::{
17 AiderCliError, AiderStreamJsonCompletion, AiderStreamJsonControlHandle, AiderStreamJsonError,
18 AiderStreamJsonEvent, AiderStreamJsonHandle, AiderStreamJsonResultPayload,
19 AiderStreamJsonRunRequest, AiderTerminationHandle, DynAiderStreamJsonCompletion,
20 DynAiderStreamJsonEventStream,
21};
22
23const STDERR_CAPTURE_MAX_BYTES: usize = 4096;
24const RUN_FAILED_MESSAGE: &str = "aider run failed";
25const INVALID_INPUT_MESSAGE: &str = "invalid input";
26const TURN_LIMIT_EXCEEDED_MESSAGE: &str = "turn limit exceeded";
27
28#[derive(Clone, Debug)]
29pub struct AiderCliClient {
30 pub(crate) binary: PathBuf,
31 pub(crate) env: BTreeMap<String, String>,
32 pub(crate) timeout: Option<Duration>,
33}
34
35impl AiderCliClient {
36 pub fn builder() -> crate::AiderCliClientBuilder {
37 crate::AiderCliClientBuilder::default()
38 }
39
40 pub async fn stream_json(
41 &self,
42 request: AiderStreamJsonRunRequest,
43 ) -> Result<AiderStreamJsonHandle, AiderCliError> {
44 let (events, completion, _termination) = self.spawn_stream_json(request).await?;
45 Ok(AiderStreamJsonHandle { events, completion })
46 }
47
48 pub async fn stream_json_control(
49 &self,
50 request: AiderStreamJsonRunRequest,
51 ) -> Result<AiderStreamJsonControlHandle, AiderCliError> {
52 let (events, completion, termination) = self.spawn_stream_json(request).await?;
53 Ok(AiderStreamJsonControlHandle {
54 events,
55 completion,
56 termination,
57 })
58 }
59
60 async fn spawn_stream_json(
61 &self,
62 request: AiderStreamJsonRunRequest,
63 ) -> Result<
64 (
65 DynAiderStreamJsonEventStream,
66 DynAiderStreamJsonCompletion,
67 AiderTerminationHandle,
68 ),
69 AiderCliError,
70 > {
71 let argv = request.argv()?;
72 let mut command = tokio::process::Command::new(&self.binary);
73 command
74 .args(argv)
75 .stdin(Stdio::null())
76 .stdout(Stdio::piped())
77 .stderr(Stdio::piped());
78
79 if let Some(working_dir) = request.working_directory() {
80 command.current_dir(working_dir);
81 }
82
83 for (key, value) in &self.env {
84 command.env(key, value);
85 }
86
87 let mut child = command.spawn().map_err(|source| {
88 if source.kind() == std::io::ErrorKind::NotFound {
89 AiderCliError::MissingBinary
90 } else {
91 AiderCliError::Spawn {
92 binary: self.binary.clone(),
93 source,
94 }
95 }
96 })?;
97
98 let stdout = child.stdout.take().ok_or(AiderCliError::MissingStdout)?;
99 let stderr_capture = child
100 .stderr
101 .take()
102 .map(|stderr| tokio::spawn(async move { capture_stderr(stderr).await }));
103 let timeout = self.timeout;
104 let termination = AiderTerminationHandle::new();
105 let termination_for_runner = termination.clone();
106
107 let (events_tx, events_rx) = mpsc::channel(32);
108 let (completion_tx, completion_rx) = oneshot::channel();
109
110 tokio::spawn(async move {
111 let result = run_aider_child(
112 child,
113 stdout,
114 stderr_capture,
115 events_tx,
116 timeout,
117 termination_for_runner,
118 )
119 .await;
120 let _ = completion_tx.send(result);
121 });
122
123 let events: DynAiderStreamJsonEventStream =
124 Box::pin(AiderStreamJsonEventChannelStream::new(events_rx));
125
126 let completion: DynAiderStreamJsonCompletion = Box::pin(async move {
127 completion_rx
128 .await
129 .map_err(|_| AiderCliError::Join("stream-json task dropped".to_string()))?
130 });
131
132 Ok((events, completion, termination))
133 }
134}
135
136struct AiderStreamJsonEventChannelStream {
137 rx: mpsc::Receiver<Result<AiderStreamJsonEvent, AiderStreamJsonError>>,
138}
139
140impl AiderStreamJsonEventChannelStream {
141 fn new(rx: mpsc::Receiver<Result<AiderStreamJsonEvent, AiderStreamJsonError>>) -> Self {
142 Self { rx }
143 }
144}
145
146impl Stream for AiderStreamJsonEventChannelStream {
147 type Item = Result<AiderStreamJsonEvent, AiderStreamJsonError>;
148
149 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
150 self.get_mut().rx.poll_recv(cx)
151 }
152}
153
154#[derive(Default)]
155struct CompletionAccumulator {
156 session_id: Option<String>,
157 model: Option<String>,
158 assistant_text: String,
159 raw_result: Option<Value>,
160}
161
162use serde_json::Value;
163
164impl CompletionAccumulator {
165 fn observe(&mut self, event: &AiderStreamJsonEvent) {
166 match event {
167 AiderStreamJsonEvent::Init {
168 session_id, model, ..
169 } => {
170 self.session_id = Some(session_id.clone());
171 self.model = Some(model.clone());
172 }
173 AiderStreamJsonEvent::Message {
174 role,
175 content,
176 delta,
177 ..
178 } if role == "assistant" => {
179 if *delta || self.assistant_text.is_empty() {
180 self.assistant_text.push_str(content);
181 } else {
182 self.assistant_text.push('\n');
183 self.assistant_text.push_str(content);
184 }
185 }
186 AiderStreamJsonEvent::Result { payload } => {
187 self.raw_result = Some(payload.raw.clone());
188 }
189 _ => {}
190 }
191 }
192
193 fn final_text(&self) -> Option<String> {
194 (!self.assistant_text.is_empty()).then(|| self.assistant_text.clone())
195 }
196}
197
198async fn run_aider_child(
199 mut child: tokio::process::Child,
200 stdout: tokio::process::ChildStdout,
201 stderr_capture: Option<tokio::task::JoinHandle<Result<Vec<u8>, std::io::Error>>>,
202 events_tx: mpsc::Sender<Result<AiderStreamJsonEvent, AiderStreamJsonError>>,
203 timeout: Option<Duration>,
204 termination: AiderTerminationHandle,
205) -> Result<AiderStreamJsonCompletion, AiderCliError> {
206 let mut reader = BufReader::new(stdout);
207 let mut parser = crate::AiderStreamJsonParser::new();
208 let mut line = String::new();
209 let mut events_open = true;
210 let mut completion = CompletionAccumulator::default();
211 let mut last_result: Option<AiderStreamJsonResultPayload> = None;
212 let mut termination_requested = false;
213 let deadline = timeout.map(|value| Instant::now() + value);
214 let mut exit_status = None;
215
216 loop {
217 if let Some(deadline) = deadline {
218 if Instant::now() >= deadline {
219 match wait_for_child_exit(&mut child, timeout, Some(deadline)).await {
220 Ok(ChildExit::Exited(status)) => {
221 exit_status = Some(status);
222 break;
223 }
224 Ok(ChildExit::TimedOut) => {
225 let _ = consume_stderr_capture(stderr_capture).await;
226 return Err(AiderCliError::Timeout {
227 timeout: timeout.expect("deadline implies timeout"),
228 });
229 }
230 Err(err) => return Err(err),
231 }
232 }
233 }
234
235 line.clear();
236 let read_result = if let Some(deadline) = deadline {
237 let remaining = deadline.saturating_duration_since(Instant::now());
238 tokio::select! {
239 _ = termination.requested() => {
240 termination_requested = true;
241 let _ = child.start_kill();
242 break;
243 }
244 read = tokio::time::timeout(remaining, reader.read_line(&mut line)) => {
245 match read {
246 Ok(result) => result,
247 Err(_) => {
248 match wait_for_child_exit(&mut child, timeout, Some(deadline)).await {
249 Ok(ChildExit::Exited(status)) => {
250 exit_status = Some(status);
251 break;
252 }
253 Ok(ChildExit::TimedOut) => {
254 let _ = consume_stderr_capture(stderr_capture).await;
255 return Err(AiderCliError::Timeout {
256 timeout: timeout.expect("deadline implies timeout"),
257 });
258 }
259 Err(err) => return Err(err),
260 }
261 }
262 }
263 }
264 }
265 } else {
266 tokio::select! {
267 _ = termination.requested() => {
268 termination_requested = true;
269 let _ = child.start_kill();
270 break;
271 }
272 read = reader.read_line(&mut line) => read,
273 }
274 };
275
276 let bytes = match read_result {
277 Ok(bytes) => bytes,
278 Err(err) => {
279 let _ = child.start_kill();
280 let _ = child.wait().await;
281 let _ = consume_stderr_capture(stderr_capture).await;
282 return Err(AiderCliError::StdoutRead(err));
283 }
284 };
285
286 if bytes == 0 {
287 break;
288 }
289
290 let parsed = parser.parse_line(line.trim_end_matches('\n'));
291 match parsed {
292 Ok(Some(event)) => {
293 completion.observe(&event);
294 if let AiderStreamJsonEvent::Result { payload } = &event {
295 last_result = Some(payload.clone());
296 }
297 if events_open && events_tx.send(Ok(event)).await.is_err() {
298 events_open = false;
299 }
300 }
301 Ok(None) => {}
302 Err(error) => {
303 if events_open && events_tx.send(Err(error)).await.is_err() {
304 events_open = false;
305 }
306 }
307 }
308 }
309
310 let status = match exit_status {
311 Some(status) => status,
312 None => match wait_for_child_exit(&mut child, timeout, deadline).await {
313 Ok(ChildExit::Exited(status)) => status,
314 Ok(ChildExit::TimedOut) => {
315 let _ = consume_stderr_capture(stderr_capture).await;
316 return Err(AiderCliError::Timeout {
317 timeout: timeout.expect("deadline implies timeout"),
318 });
319 }
320 Err(err) => return Err(err),
321 },
322 };
323
324 let _stderr = consume_stderr_capture(stderr_capture).await?;
325
326 if !status.success() {
327 if termination_requested {
328 drop(events_tx);
329 return Ok(AiderStreamJsonCompletion {
330 status,
331 final_text: None,
332 session_id: completion.session_id,
333 model: completion.model,
334 raw_result: completion.raw_result,
335 });
336 }
337
338 let exit_code = status.code();
339 let message = classify_run_failure(exit_code, last_result.as_ref());
340 if last_result.is_none() && events_open {
341 let _ = events_tx
342 .send(Ok(AiderStreamJsonEvent::Error {
343 severity: "error".to_string(),
344 message: message.clone(),
345 raw: Value::Null,
346 }))
347 .await;
348 }
349 drop(events_tx);
350 return Err(AiderCliError::RunFailed {
351 status,
352 exit_code,
353 message,
354 result_error_type: last_result
355 .as_ref()
356 .and_then(|payload| payload.error_type.clone()),
357 });
358 }
359
360 drop(events_tx);
361 Ok(AiderStreamJsonCompletion {
362 status,
363 final_text: completion.final_text(),
364 session_id: completion.session_id,
365 model: completion.model,
366 raw_result: completion.raw_result,
367 })
368}
369
370#[derive(Debug, Clone, Copy)]
371enum ChildExit {
372 Exited(std::process::ExitStatus),
373 TimedOut,
374}
375
376async fn wait_for_child_exit(
377 child: &mut tokio::process::Child,
378 timeout: Option<Duration>,
379 deadline: Option<Instant>,
380) -> Result<ChildExit, AiderCliError> {
381 match deadline {
382 None => child
383 .wait()
384 .await
385 .map(ChildExit::Exited)
386 .map_err(AiderCliError::Wait),
387 Some(deadline) => {
388 let remaining = deadline.saturating_duration_since(Instant::now());
389 if remaining.is_zero() {
390 match child.try_wait().map_err(AiderCliError::Wait)? {
391 Some(status) => Ok(ChildExit::Exited(status)),
392 None => {
393 timeout.expect("deadline implies timeout");
394 let _ = child.start_kill();
395 match child.wait().await {
396 Ok(_status) => Ok(ChildExit::TimedOut),
397 Err(err) => Err(AiderCliError::Wait(err)),
398 }
399 }
400 }
401 } else {
402 match tokio::time::timeout(remaining, child.wait()).await {
403 Ok(result) => result.map(ChildExit::Exited).map_err(AiderCliError::Wait),
404 Err(_) => match child.try_wait().map_err(AiderCliError::Wait)? {
405 Some(status) => Ok(ChildExit::Exited(status)),
406 None => {
407 timeout.expect("deadline implies timeout");
408 let _ = child.start_kill();
409 match child.wait().await {
410 Ok(_status) => Ok(ChildExit::TimedOut),
411 Err(err) => Err(AiderCliError::Wait(err)),
412 }
413 }
414 },
415 }
416 }
417 }
418 }
419}
420
421async fn capture_stderr(
422 mut stderr: tokio::process::ChildStderr,
423) -> Result<Vec<u8>, std::io::Error> {
424 let mut captured = Vec::new();
425 let mut buffer = [0u8; 1024];
426
427 loop {
428 let read = stderr.read(&mut buffer).await?;
429 if read == 0 {
430 break;
431 }
432
433 if captured.len() < STDERR_CAPTURE_MAX_BYTES {
434 let remaining = STDERR_CAPTURE_MAX_BYTES - captured.len();
435 captured.extend_from_slice(&buffer[..read.min(remaining)]);
436 }
437 }
438
439 Ok(captured)
440}
441
442async fn consume_stderr_capture(
443 stderr_capture: Option<tokio::task::JoinHandle<Result<Vec<u8>, std::io::Error>>>,
444) -> Result<String, AiderCliError> {
445 let Some(stderr_capture) = stderr_capture else {
446 return Ok(String::new());
447 };
448
449 let captured = stderr_capture
450 .await
451 .map_err(|err| AiderCliError::Join(format!("stderr capture task failed: {err}")))?
452 .map_err(AiderCliError::StderrRead)?;
453
454 Ok(String::from_utf8_lossy(&captured).into_owned())
455}
456
457fn classify_run_failure(
458 exit_code: Option<i32>,
459 result: Option<&AiderStreamJsonResultPayload>,
460) -> String {
461 match exit_code {
462 Some(42) => INVALID_INPUT_MESSAGE.to_string(),
463 Some(53) => TURN_LIMIT_EXCEEDED_MESSAGE.to_string(),
464 _ => result
465 .and_then(|payload| payload.error_message.clone())
466 .filter(|message| !message.trim().is_empty())
467 .unwrap_or_else(|| RUN_FAILED_MESSAGE.to_string()),
468 }
469}
470
471#[cfg(test)]
472mod tests {
473 use std::process::Stdio;
474
475 use super::{wait_for_child_exit, ChildExit};
476 use std::time::{Duration, Instant};
477
478 #[cfg(unix)]
479 #[tokio::test]
480 async fn wait_for_child_exit_returns_status_when_deadline_has_elapsed() {
481 let mut child = tokio::process::Command::new("sh")
482 .args(["-c", "exit 0"])
483 .stdout(Stdio::null())
484 .stderr(Stdio::null())
485 .spawn()
486 .expect("spawn child");
487 tokio::time::sleep(Duration::from_millis(50)).await;
488
489 let outcome = wait_for_child_exit(
490 &mut child,
491 Some(Duration::from_millis(1)),
492 Some(Instant::now()),
493 )
494 .await
495 .expect("wait helper succeeds");
496
497 match outcome {
498 ChildExit::Exited(status) => assert!(status.success()),
499 ChildExit::TimedOut => panic!("expected exited status, got timeout"),
500 }
501 }
502}