Skip to main content

aider/
client.rs

1use std::{
2    collections::BTreeMap,
3    path::PathBuf,
4    pin::Pin,
5    process::Stdio,
6    task::{Context, Poll},
7    time::{Duration, Instant},
8};
9
10use futures_core::Stream;
11use tokio::{
12    io::{AsyncBufReadExt, AsyncReadExt, BufReader},
13    sync::{mpsc, oneshot},
14};
15
16use crate::{
17    AiderCliError, AiderStreamJsonCompletion, AiderStreamJsonControlHandle, AiderStreamJsonError,
18    AiderStreamJsonEvent, AiderStreamJsonHandle, AiderStreamJsonResultPayload,
19    AiderStreamJsonRunRequest, AiderTerminationHandle, DynAiderStreamJsonCompletion,
20    DynAiderStreamJsonEventStream,
21};
22
23const STDERR_CAPTURE_MAX_BYTES: usize = 4096;
24const RUN_FAILED_MESSAGE: &str = "aider run failed";
25const INVALID_INPUT_MESSAGE: &str = "invalid input";
26const TURN_LIMIT_EXCEEDED_MESSAGE: &str = "turn limit exceeded";
27
28#[derive(Clone, Debug)]
29pub struct AiderCliClient {
30    pub(crate) binary: PathBuf,
31    pub(crate) env: BTreeMap<String, String>,
32    pub(crate) timeout: Option<Duration>,
33}
34
35impl AiderCliClient {
36    pub fn builder() -> crate::AiderCliClientBuilder {
37        crate::AiderCliClientBuilder::default()
38    }
39
40    pub async fn stream_json(
41        &self,
42        request: AiderStreamJsonRunRequest,
43    ) -> Result<AiderStreamJsonHandle, AiderCliError> {
44        let (events, completion, _termination) = self.spawn_stream_json(request).await?;
45        Ok(AiderStreamJsonHandle { events, completion })
46    }
47
48    pub async fn stream_json_control(
49        &self,
50        request: AiderStreamJsonRunRequest,
51    ) -> Result<AiderStreamJsonControlHandle, AiderCliError> {
52        let (events, completion, termination) = self.spawn_stream_json(request).await?;
53        Ok(AiderStreamJsonControlHandle {
54            events,
55            completion,
56            termination,
57        })
58    }
59
60    async fn spawn_stream_json(
61        &self,
62        request: AiderStreamJsonRunRequest,
63    ) -> Result<
64        (
65            DynAiderStreamJsonEventStream,
66            DynAiderStreamJsonCompletion,
67            AiderTerminationHandle,
68        ),
69        AiderCliError,
70    > {
71        let argv = request.argv()?;
72        let mut command = tokio::process::Command::new(&self.binary);
73        command
74            .args(argv)
75            .stdin(Stdio::null())
76            .stdout(Stdio::piped())
77            .stderr(Stdio::piped());
78
79        if let Some(working_dir) = request.working_directory() {
80            command.current_dir(working_dir);
81        }
82
83        for (key, value) in &self.env {
84            command.env(key, value);
85        }
86
87        let mut child = command.spawn().map_err(|source| {
88            if source.kind() == std::io::ErrorKind::NotFound {
89                AiderCliError::MissingBinary
90            } else {
91                AiderCliError::Spawn {
92                    binary: self.binary.clone(),
93                    source,
94                }
95            }
96        })?;
97
98        let stdout = child.stdout.take().ok_or(AiderCliError::MissingStdout)?;
99        let stderr_capture = child
100            .stderr
101            .take()
102            .map(|stderr| tokio::spawn(async move { capture_stderr(stderr).await }));
103        let timeout = self.timeout;
104        let termination = AiderTerminationHandle::new();
105        let termination_for_runner = termination.clone();
106
107        let (events_tx, events_rx) = mpsc::channel(32);
108        let (completion_tx, completion_rx) = oneshot::channel();
109
110        tokio::spawn(async move {
111            let result = run_aider_child(
112                child,
113                stdout,
114                stderr_capture,
115                events_tx,
116                timeout,
117                termination_for_runner,
118            )
119            .await;
120            let _ = completion_tx.send(result);
121        });
122
123        let events: DynAiderStreamJsonEventStream =
124            Box::pin(AiderStreamJsonEventChannelStream::new(events_rx));
125
126        let completion: DynAiderStreamJsonCompletion = Box::pin(async move {
127            completion_rx
128                .await
129                .map_err(|_| AiderCliError::Join("stream-json task dropped".to_string()))?
130        });
131
132        Ok((events, completion, termination))
133    }
134}
135
136struct AiderStreamJsonEventChannelStream {
137    rx: mpsc::Receiver<Result<AiderStreamJsonEvent, AiderStreamJsonError>>,
138}
139
140impl AiderStreamJsonEventChannelStream {
141    fn new(rx: mpsc::Receiver<Result<AiderStreamJsonEvent, AiderStreamJsonError>>) -> Self {
142        Self { rx }
143    }
144}
145
146impl Stream for AiderStreamJsonEventChannelStream {
147    type Item = Result<AiderStreamJsonEvent, AiderStreamJsonError>;
148
149    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
150        self.get_mut().rx.poll_recv(cx)
151    }
152}
153
154#[derive(Default)]
155struct CompletionAccumulator {
156    session_id: Option<String>,
157    model: Option<String>,
158    assistant_text: String,
159    raw_result: Option<Value>,
160}
161
162use serde_json::Value;
163
164impl CompletionAccumulator {
165    fn observe(&mut self, event: &AiderStreamJsonEvent) {
166        match event {
167            AiderStreamJsonEvent::Init {
168                session_id, model, ..
169            } => {
170                self.session_id = Some(session_id.clone());
171                self.model = Some(model.clone());
172            }
173            AiderStreamJsonEvent::Message {
174                role,
175                content,
176                delta,
177                ..
178            } if role == "assistant" => {
179                if *delta || self.assistant_text.is_empty() {
180                    self.assistant_text.push_str(content);
181                } else {
182                    self.assistant_text.push('\n');
183                    self.assistant_text.push_str(content);
184                }
185            }
186            AiderStreamJsonEvent::Result { payload } => {
187                self.raw_result = Some(payload.raw.clone());
188            }
189            _ => {}
190        }
191    }
192
193    fn final_text(&self) -> Option<String> {
194        (!self.assistant_text.is_empty()).then(|| self.assistant_text.clone())
195    }
196}
197
198async fn run_aider_child(
199    mut child: tokio::process::Child,
200    stdout: tokio::process::ChildStdout,
201    stderr_capture: Option<tokio::task::JoinHandle<Result<Vec<u8>, std::io::Error>>>,
202    events_tx: mpsc::Sender<Result<AiderStreamJsonEvent, AiderStreamJsonError>>,
203    timeout: Option<Duration>,
204    termination: AiderTerminationHandle,
205) -> Result<AiderStreamJsonCompletion, AiderCliError> {
206    let mut reader = BufReader::new(stdout);
207    let mut parser = crate::AiderStreamJsonParser::new();
208    let mut line = String::new();
209    let mut events_open = true;
210    let mut completion = CompletionAccumulator::default();
211    let mut last_result: Option<AiderStreamJsonResultPayload> = None;
212    let mut termination_requested = false;
213    let deadline = timeout.map(|value| Instant::now() + value);
214    let mut exit_status = None;
215
216    loop {
217        if let Some(deadline) = deadline {
218            if Instant::now() >= deadline {
219                match wait_for_child_exit(&mut child, timeout, Some(deadline)).await {
220                    Ok(ChildExit::Exited(status)) => {
221                        exit_status = Some(status);
222                        break;
223                    }
224                    Ok(ChildExit::TimedOut) => {
225                        let _ = consume_stderr_capture(stderr_capture).await;
226                        return Err(AiderCliError::Timeout {
227                            timeout: timeout.expect("deadline implies timeout"),
228                        });
229                    }
230                    Err(err) => return Err(err),
231                }
232            }
233        }
234
235        line.clear();
236        let read_result = if let Some(deadline) = deadline {
237            let remaining = deadline.saturating_duration_since(Instant::now());
238            tokio::select! {
239                _ = termination.requested() => {
240                    termination_requested = true;
241                    let _ = child.start_kill();
242                    break;
243                }
244                read = tokio::time::timeout(remaining, reader.read_line(&mut line)) => {
245                    match read {
246                        Ok(result) => result,
247                        Err(_) => {
248                            match wait_for_child_exit(&mut child, timeout, Some(deadline)).await {
249                                Ok(ChildExit::Exited(status)) => {
250                                    exit_status = Some(status);
251                                    break;
252                                }
253                                Ok(ChildExit::TimedOut) => {
254                                    let _ = consume_stderr_capture(stderr_capture).await;
255                                    return Err(AiderCliError::Timeout {
256                                        timeout: timeout.expect("deadline implies timeout"),
257                                    });
258                                }
259                                Err(err) => return Err(err),
260                            }
261                        }
262                    }
263                }
264            }
265        } else {
266            tokio::select! {
267                _ = termination.requested() => {
268                    termination_requested = true;
269                    let _ = child.start_kill();
270                    break;
271                }
272                read = reader.read_line(&mut line) => read,
273            }
274        };
275
276        let bytes = match read_result {
277            Ok(bytes) => bytes,
278            Err(err) => {
279                let _ = child.start_kill();
280                let _ = child.wait().await;
281                let _ = consume_stderr_capture(stderr_capture).await;
282                return Err(AiderCliError::StdoutRead(err));
283            }
284        };
285
286        if bytes == 0 {
287            break;
288        }
289
290        let parsed = parser.parse_line(line.trim_end_matches('\n'));
291        match parsed {
292            Ok(Some(event)) => {
293                completion.observe(&event);
294                if let AiderStreamJsonEvent::Result { payload } = &event {
295                    last_result = Some(payload.clone());
296                }
297                if events_open && events_tx.send(Ok(event)).await.is_err() {
298                    events_open = false;
299                }
300            }
301            Ok(None) => {}
302            Err(error) => {
303                if events_open && events_tx.send(Err(error)).await.is_err() {
304                    events_open = false;
305                }
306            }
307        }
308    }
309
310    let status = match exit_status {
311        Some(status) => status,
312        None => match wait_for_child_exit(&mut child, timeout, deadline).await {
313            Ok(ChildExit::Exited(status)) => status,
314            Ok(ChildExit::TimedOut) => {
315                let _ = consume_stderr_capture(stderr_capture).await;
316                return Err(AiderCliError::Timeout {
317                    timeout: timeout.expect("deadline implies timeout"),
318                });
319            }
320            Err(err) => return Err(err),
321        },
322    };
323
324    let _stderr = consume_stderr_capture(stderr_capture).await?;
325
326    if !status.success() {
327        if termination_requested {
328            drop(events_tx);
329            return Ok(AiderStreamJsonCompletion {
330                status,
331                final_text: None,
332                session_id: completion.session_id,
333                model: completion.model,
334                raw_result: completion.raw_result,
335            });
336        }
337
338        let exit_code = status.code();
339        let message = classify_run_failure(exit_code, last_result.as_ref());
340        if last_result.is_none() && events_open {
341            let _ = events_tx
342                .send(Ok(AiderStreamJsonEvent::Error {
343                    severity: "error".to_string(),
344                    message: message.clone(),
345                    raw: Value::Null,
346                }))
347                .await;
348        }
349        drop(events_tx);
350        return Err(AiderCliError::RunFailed {
351            status,
352            exit_code,
353            message,
354            result_error_type: last_result
355                .as_ref()
356                .and_then(|payload| payload.error_type.clone()),
357        });
358    }
359
360    drop(events_tx);
361    Ok(AiderStreamJsonCompletion {
362        status,
363        final_text: completion.final_text(),
364        session_id: completion.session_id,
365        model: completion.model,
366        raw_result: completion.raw_result,
367    })
368}
369
370#[derive(Debug, Clone, Copy)]
371enum ChildExit {
372    Exited(std::process::ExitStatus),
373    TimedOut,
374}
375
376async fn wait_for_child_exit(
377    child: &mut tokio::process::Child,
378    timeout: Option<Duration>,
379    deadline: Option<Instant>,
380) -> Result<ChildExit, AiderCliError> {
381    match deadline {
382        None => child
383            .wait()
384            .await
385            .map(ChildExit::Exited)
386            .map_err(AiderCliError::Wait),
387        Some(deadline) => {
388            let remaining = deadline.saturating_duration_since(Instant::now());
389            if remaining.is_zero() {
390                match child.try_wait().map_err(AiderCliError::Wait)? {
391                    Some(status) => Ok(ChildExit::Exited(status)),
392                    None => {
393                        timeout.expect("deadline implies timeout");
394                        let _ = child.start_kill();
395                        match child.wait().await {
396                            Ok(_status) => Ok(ChildExit::TimedOut),
397                            Err(err) => Err(AiderCliError::Wait(err)),
398                        }
399                    }
400                }
401            } else {
402                match tokio::time::timeout(remaining, child.wait()).await {
403                    Ok(result) => result.map(ChildExit::Exited).map_err(AiderCliError::Wait),
404                    Err(_) => match child.try_wait().map_err(AiderCliError::Wait)? {
405                        Some(status) => Ok(ChildExit::Exited(status)),
406                        None => {
407                            timeout.expect("deadline implies timeout");
408                            let _ = child.start_kill();
409                            match child.wait().await {
410                                Ok(_status) => Ok(ChildExit::TimedOut),
411                                Err(err) => Err(AiderCliError::Wait(err)),
412                            }
413                        }
414                    },
415                }
416            }
417        }
418    }
419}
420
421async fn capture_stderr(
422    mut stderr: tokio::process::ChildStderr,
423) -> Result<Vec<u8>, std::io::Error> {
424    let mut captured = Vec::new();
425    let mut buffer = [0u8; 1024];
426
427    loop {
428        let read = stderr.read(&mut buffer).await?;
429        if read == 0 {
430            break;
431        }
432
433        if captured.len() < STDERR_CAPTURE_MAX_BYTES {
434            let remaining = STDERR_CAPTURE_MAX_BYTES - captured.len();
435            captured.extend_from_slice(&buffer[..read.min(remaining)]);
436        }
437    }
438
439    Ok(captured)
440}
441
442async fn consume_stderr_capture(
443    stderr_capture: Option<tokio::task::JoinHandle<Result<Vec<u8>, std::io::Error>>>,
444) -> Result<String, AiderCliError> {
445    let Some(stderr_capture) = stderr_capture else {
446        return Ok(String::new());
447    };
448
449    let captured = stderr_capture
450        .await
451        .map_err(|err| AiderCliError::Join(format!("stderr capture task failed: {err}")))?
452        .map_err(AiderCliError::StderrRead)?;
453
454    Ok(String::from_utf8_lossy(&captured).into_owned())
455}
456
457fn classify_run_failure(
458    exit_code: Option<i32>,
459    result: Option<&AiderStreamJsonResultPayload>,
460) -> String {
461    match exit_code {
462        Some(42) => INVALID_INPUT_MESSAGE.to_string(),
463        Some(53) => TURN_LIMIT_EXCEEDED_MESSAGE.to_string(),
464        _ => result
465            .and_then(|payload| payload.error_message.clone())
466            .filter(|message| !message.trim().is_empty())
467            .unwrap_or_else(|| RUN_FAILED_MESSAGE.to_string()),
468    }
469}
470
471#[cfg(test)]
472mod tests {
473    use std::process::Stdio;
474
475    use super::{wait_for_child_exit, ChildExit};
476    use std::time::{Duration, Instant};
477
478    #[cfg(unix)]
479    #[tokio::test]
480    async fn wait_for_child_exit_returns_status_when_deadline_has_elapsed() {
481        let mut child = tokio::process::Command::new("sh")
482            .args(["-c", "exit 0"])
483            .stdout(Stdio::null())
484            .stderr(Stdio::null())
485            .spawn()
486            .expect("spawn child");
487        tokio::time::sleep(Duration::from_millis(50)).await;
488
489        let outcome = wait_for_child_exit(
490            &mut child,
491            Some(Duration::from_millis(1)),
492            Some(Instant::now()),
493        )
494        .await
495        .expect("wait helper succeeds");
496
497        match outcome {
498            ChildExit::Exited(status) => assert!(status.success()),
499            ChildExit::TimedOut => panic!("expected exited status, got timeout"),
500        }
501    }
502}