Skip to main content

opencode/
client.rs

1use std::{
2    collections::BTreeMap,
3    path::PathBuf,
4    pin::Pin,
5    process::Stdio,
6    task::{Context, Poll},
7    time::{Duration, Instant},
8};
9
10use futures_core::Stream;
11use tokio::{
12    io::{AsyncBufReadExt, AsyncReadExt, BufReader},
13    sync::{mpsc, oneshot},
14};
15
16use crate::{
17    DynOpencodeRunJsonCompletion, DynOpencodeRunJsonEventStream, OpencodeError,
18    OpencodeRunCompletion, OpencodeRunJsonControlHandle, OpencodeRunJsonEvent,
19    OpencodeRunJsonHandle, OpencodeRunJsonParseError, OpencodeRunJsonParser, OpencodeRunRequest,
20    OpencodeTerminationHandle,
21};
22
23const STDERR_CAPTURE_MAX_BYTES: usize = 4096;
24const RUN_FAILED_MESSAGE: &str = "opencode run failed";
25
26#[derive(Clone, Copy, Debug, Eq, PartialEq)]
27enum SelectionMode {
28    Last,
29    Id,
30}
31
32#[derive(Clone, Debug)]
33pub struct OpencodeClient {
34    pub(crate) binary: PathBuf,
35    pub(crate) env: BTreeMap<String, String>,
36    pub(crate) timeout: Option<Duration>,
37}
38
39impl OpencodeClient {
40    pub fn builder() -> crate::OpencodeClientBuilder {
41        crate::OpencodeClientBuilder::default()
42    }
43
44    pub async fn run_json(
45        &self,
46        request: OpencodeRunRequest,
47    ) -> Result<OpencodeRunJsonHandle, OpencodeError> {
48        let (events, completion, _termination) = self.spawn_run_json(request).await?;
49        Ok(OpencodeRunJsonHandle { events, completion })
50    }
51
52    pub async fn run_json_control(
53        &self,
54        request: OpencodeRunRequest,
55    ) -> Result<OpencodeRunJsonControlHandle, OpencodeError> {
56        let (events, completion, termination) = self.spawn_run_json(request).await?;
57        Ok(OpencodeRunJsonControlHandle {
58            events,
59            completion,
60            termination,
61        })
62    }
63
64    async fn spawn_run_json(
65        &self,
66        request: OpencodeRunRequest,
67    ) -> Result<
68        (
69            DynOpencodeRunJsonEventStream,
70            DynOpencodeRunJsonCompletion,
71            OpencodeTerminationHandle,
72        ),
73        OpencodeError,
74    > {
75        let selection_mode = selection_mode(&request);
76        let argv = request.argv()?;
77        let mut command = tokio::process::Command::new(&self.binary);
78        command
79            .args(argv)
80            .stdin(Stdio::null())
81            .stdout(Stdio::piped())
82            .stderr(Stdio::piped());
83
84        for (key, value) in &self.env {
85            command.env(key, value);
86        }
87
88        let mut child = command.spawn().map_err(|source| {
89            if source.kind() == std::io::ErrorKind::NotFound {
90                OpencodeError::MissingBinary
91            } else {
92                OpencodeError::Spawn {
93                    binary: self.binary.clone(),
94                    source,
95                }
96            }
97        })?;
98
99        let stdout = child.stdout.take().ok_or(OpencodeError::MissingStdout)?;
100        let stderr_capture = child
101            .stderr
102            .take()
103            .map(|stderr| tokio::spawn(async move { capture_stderr(stderr).await }));
104        let timeout = self.timeout;
105        let termination = OpencodeTerminationHandle::new();
106        let termination_for_runner = termination.clone();
107
108        let (events_tx, events_rx) = mpsc::channel(32);
109        let (completion_tx, completion_rx) = oneshot::channel();
110
111        tokio::spawn(async move {
112            let result = run_opencode_child(
113                child,
114                stdout,
115                stderr_capture,
116                events_tx,
117                timeout,
118                termination_for_runner,
119                selection_mode,
120            )
121            .await;
122            let _ = completion_tx.send(result);
123        });
124
125        let events: DynOpencodeRunJsonEventStream =
126            Box::pin(OpencodeRunJsonEventChannelStream::new(events_rx));
127
128        let completion: DynOpencodeRunJsonCompletion = Box::pin(async move {
129            completion_rx
130                .await
131                .map_err(|_| OpencodeError::Join("run-json task dropped".to_string()))?
132        });
133
134        Ok((events, completion, termination))
135    }
136}
137
138struct OpencodeRunJsonEventChannelStream {
139    rx: mpsc::Receiver<Result<OpencodeRunJsonEvent, OpencodeRunJsonParseError>>,
140}
141
142impl OpencodeRunJsonEventChannelStream {
143    fn new(rx: mpsc::Receiver<Result<OpencodeRunJsonEvent, OpencodeRunJsonParseError>>) -> Self {
144        Self { rx }
145    }
146}
147
148impl Stream for OpencodeRunJsonEventChannelStream {
149    type Item = Result<OpencodeRunJsonEvent, OpencodeRunJsonParseError>;
150
151    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
152        self.get_mut().rx.poll_recv(cx)
153    }
154}
155
156async fn run_opencode_child(
157    mut child: tokio::process::Child,
158    stdout: tokio::process::ChildStdout,
159    stderr_capture: Option<tokio::task::JoinHandle<Result<Vec<u8>, std::io::Error>>>,
160    events_tx: mpsc::Sender<Result<OpencodeRunJsonEvent, OpencodeRunJsonParseError>>,
161    timeout: Option<Duration>,
162    termination: OpencodeTerminationHandle,
163    selection_mode: Option<SelectionMode>,
164) -> Result<OpencodeRunCompletion, OpencodeError> {
165    let mut reader = BufReader::new(stdout);
166    let mut parser = OpencodeRunJsonParser::new();
167    let mut line = String::new();
168    let mut events_open = true;
169    let mut final_text = String::new();
170    let mut saw_finish = false;
171    let mut termination_requested = false;
172    let deadline = timeout.map(|value| Instant::now() + value);
173
174    loop {
175        if let Some(deadline) = deadline {
176            if Instant::now() >= deadline {
177                match wait_for_child_exit(&mut child, timeout, Some(deadline)).await {
178                    Ok(_) => {
179                        let _ = consume_stderr_capture(stderr_capture).await;
180                        return Err(OpencodeError::Timeout {
181                            timeout: timeout.expect("deadline implies timeout"),
182                        });
183                    }
184                    Err(err) => return Err(err),
185                }
186            }
187        }
188
189        line.clear();
190        let read_result = if let Some(deadline) = deadline {
191            let remaining = deadline.saturating_duration_since(Instant::now());
192            tokio::select! {
193                _ = termination.requested() => {
194                    termination_requested = true;
195                    let _ = child.start_kill();
196                    break;
197                }
198                read = tokio::time::timeout(remaining, reader.read_line(&mut line)) => {
199                    match read {
200                        Ok(result) => result,
201                        Err(_) => {
202                            match wait_for_child_exit(&mut child, timeout, Some(deadline)).await {
203                                Ok(_) => {
204                                    let _ = consume_stderr_capture(stderr_capture).await;
205                                    return Err(OpencodeError::Timeout {
206                                        timeout: timeout.expect("deadline implies timeout"),
207                                    });
208                                }
209                                Err(err) => return Err(err),
210                            }
211                        }
212                    }
213                }
214            }
215        } else {
216            tokio::select! {
217                _ = termination.requested() => {
218                    termination_requested = true;
219                    let _ = child.start_kill();
220                    break;
221                }
222                read = reader.read_line(&mut line) => read,
223            }
224        };
225
226        let bytes = match read_result {
227            Ok(bytes) => bytes,
228            Err(err) => {
229                let _ = child.start_kill();
230                let _ = child.wait().await;
231                let _ = consume_stderr_capture(stderr_capture).await;
232                return Err(OpencodeError::StdoutRead(err));
233            }
234        };
235
236        if bytes == 0 {
237            break;
238        }
239
240        let parsed = parser.parse_line(line.trim_end_matches('\n'));
241        match parsed {
242            Ok(Some(event)) => {
243                if let OpencodeRunJsonEvent::Text { text, .. } = &event {
244                    final_text.push_str(text);
245                } else if matches!(event, OpencodeRunJsonEvent::StepFinish { .. }) {
246                    saw_finish = true;
247                }
248
249                if events_open && events_tx.send(Ok(event)).await.is_err() {
250                    events_open = false;
251                }
252            }
253            Ok(None) => {}
254            Err(error) => {
255                if events_open && events_tx.send(Err(error)).await.is_err() {
256                    events_open = false;
257                }
258            }
259        }
260    }
261
262    let status = match wait_for_child_exit(&mut child, timeout, deadline).await {
263        Ok(status) => status,
264        Err(err @ OpencodeError::Timeout { .. }) => {
265            let _ = consume_stderr_capture(stderr_capture).await;
266            return Err(err);
267        }
268        Err(err) => return Err(err),
269    };
270    let stderr = consume_stderr_capture(stderr_capture).await?;
271    if !status.success() {
272        if termination_requested {
273            drop(events_tx);
274            return Ok(OpencodeRunCompletion {
275                status,
276                final_text: None,
277            });
278        }
279        if let Some(message) = classify_selection_failure(&stderr, selection_mode) {
280            if events_open {
281                let _ = events_tx
282                    .send(Ok(OpencodeRunJsonEvent::TerminalError {
283                        message: message.clone(),
284                        raw: serde_json::Value::Null,
285                    }))
286                    .await;
287            }
288            drop(events_tx);
289            return Err(OpencodeError::SelectionFailed { message });
290        }
291        if events_open {
292            let _ = events_tx
293                .send(Ok(OpencodeRunJsonEvent::TerminalError {
294                    message: RUN_FAILED_MESSAGE.to_string(),
295                    raw: serde_json::Value::Null,
296                }))
297                .await;
298        }
299        drop(events_tx);
300        return Err(OpencodeError::RunFailed {
301            status,
302            message: RUN_FAILED_MESSAGE.to_string(),
303        });
304    }
305    drop(events_tx);
306
307    let final_text = (saw_finish && !final_text.is_empty()).then_some(final_text);
308
309    Ok(OpencodeRunCompletion { status, final_text })
310}
311
312async fn wait_for_child_exit(
313    child: &mut tokio::process::Child,
314    timeout: Option<Duration>,
315    deadline: Option<Instant>,
316) -> Result<std::process::ExitStatus, OpencodeError> {
317    match deadline {
318        None => child.wait().await.map_err(OpencodeError::Wait),
319        Some(deadline) => {
320            let remaining = deadline.saturating_duration_since(Instant::now());
321            if remaining.is_zero() {
322                let timeout = timeout.expect("deadline implies timeout");
323                let _ = child.start_kill();
324                match child.wait().await {
325                    Ok(_status) => Err(OpencodeError::Timeout { timeout }),
326                    Err(err) => Err(OpencodeError::Wait(err)),
327                }
328            } else {
329                match tokio::time::timeout(remaining, child.wait()).await {
330                    Ok(result) => result.map_err(OpencodeError::Wait),
331                    Err(_) => {
332                        let timeout = timeout.expect("deadline implies timeout");
333                        let _ = child.start_kill();
334                        match child.wait().await {
335                            Ok(_status) => Err(OpencodeError::Timeout { timeout }),
336                            Err(err) => Err(OpencodeError::Wait(err)),
337                        }
338                    }
339                }
340            }
341        }
342    }
343}
344
345fn selection_mode(request: &OpencodeRunRequest) -> Option<SelectionMode> {
346    if request.session_id().is_some() {
347        Some(SelectionMode::Id)
348    } else if request.continue_requested() {
349        Some(SelectionMode::Last)
350    } else {
351        None
352    }
353}
354
355async fn capture_stderr(
356    mut stderr: tokio::process::ChildStderr,
357) -> Result<Vec<u8>, std::io::Error> {
358    let mut captured = Vec::new();
359    let mut buffer = [0u8; 1024];
360
361    loop {
362        let read = stderr.read(&mut buffer).await?;
363        if read == 0 {
364            break;
365        }
366
367        if captured.len() < STDERR_CAPTURE_MAX_BYTES {
368            let remaining = STDERR_CAPTURE_MAX_BYTES - captured.len();
369            captured.extend_from_slice(&buffer[..read.min(remaining)]);
370        }
371    }
372
373    Ok(captured)
374}
375
376async fn consume_stderr_capture(
377    stderr_capture: Option<tokio::task::JoinHandle<Result<Vec<u8>, std::io::Error>>>,
378) -> Result<String, OpencodeError> {
379    let Some(stderr_capture) = stderr_capture else {
380        return Ok(String::new());
381    };
382
383    let captured = stderr_capture
384        .await
385        .map_err(|err| OpencodeError::Join(format!("stderr capture task failed: {err}")))?
386        .map_err(OpencodeError::StderrRead)?;
387
388    Ok(String::from_utf8_lossy(&captured).into_owned())
389}
390
391fn classify_selection_failure(
392    stderr: &str,
393    selection_mode: Option<SelectionMode>,
394) -> Option<String> {
395    let selection_mode = selection_mode?;
396    let stderr = stderr.to_ascii_lowercase();
397
398    let saw_not_found = (stderr.contains("not found")
399        && (stderr.contains("session")
400            || stderr.contains("thread")
401            || stderr.contains("conversation")))
402        || stderr.contains("no session")
403        || stderr.contains("no sessions")
404        || stderr.contains("unknown session")
405        || stderr.contains("no thread")
406        || stderr.contains("no threads")
407        || stderr.contains("unknown thread")
408        || stderr.contains("no conversation")
409        || stderr.contains("unknown conversation");
410
411    if !saw_not_found {
412        return None;
413    }
414
415    Some(match selection_mode {
416        SelectionMode::Last => "no session found".to_string(),
417        SelectionMode::Id => "session not found".to_string(),
418    })
419}