codex-cli-sdk 0.0.1

Rust SDK for the OpenAI Codex CLI
Documentation
use crate::errors::{Error, Result};
use async_trait::async_trait;
use futures_core::Stream;
use serde_json::Value;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use tokio::sync::Mutex;

// ── Trait ──────────────────────────────────────────────────────

#[async_trait]
pub trait Transport: Send + Sync {
    /// Spawn the CLI process and start the background reader.
    async fn connect(&self) -> Result<()>;

    /// Write a JSON line to the CLI's stdin.
    async fn write(&self, data: &str) -> Result<()>;

    /// Get a stream of parsed JSONL messages from stdout.
    /// This takes ownership of the internal receiver — can only be called once.
    fn read_messages(&self) -> Pin<Box<dyn Stream<Item = Result<Value>> + Send>>;

    /// Close stdin (signals EOF to the CLI).
    async fn end_input(&self) -> Result<()>;

    /// Send interrupt signal (SIGINT on Unix, CTRL_BREAK on Windows).
    async fn interrupt(&self) -> Result<()>;

    /// Whether the transport is connected and the process is running.
    fn is_ready(&self) -> bool;

    /// Graceful shutdown: close stdin, wait for process exit, return exit code.
    async fn close(&self) -> Result<Option<i32>>;

    /// Return any stderr captured since the last call (or since connect).
    ///
    /// This is populated by `CliTransport` from the subprocess stderr stream.
    /// The default implementation returns an empty string (e.g. for `MockTransport`).
    fn collected_stderr(&self) -> String {
        String::new()
    }
}

// ── CLI Transport ──────────────────────────────────────────────

pub struct CliTransport {
    cli_args: Vec<String>,
    cli_path: std::path::PathBuf,
    env: std::collections::HashMap<String, String>,
    process: std::sync::Mutex<Option<tokio::process::Child>>,
    stdin: Mutex<Option<tokio::process::ChildStdin>>,
    message_rx: Mutex<Option<tokio::sync::mpsc::Receiver<Result<Value>>>>,
    reader_handle: Mutex<Option<tokio::task::JoinHandle<()>>>,
    stderr_callback: Option<crate::config::StderrCallback>,
    /// Accumulates all stderr lines from the subprocess.  Shared with the
    /// background stderr-reader task so it is available at `close()` time.
    stderr_buf: Arc<std::sync::Mutex<String>>,
    cancel: Option<tokio_util::sync::CancellationToken>,
    close_timeout: Option<std::time::Duration>,
    ready: AtomicBool,
}

impl CliTransport {
    pub fn new(
        cli_path: std::path::PathBuf,
        cli_args: Vec<String>,
        env: std::collections::HashMap<String, String>,
        stderr_callback: Option<crate::config::StderrCallback>,
        cancel: Option<tokio_util::sync::CancellationToken>,
        close_timeout: Option<std::time::Duration>,
    ) -> Self {
        Self {
            cli_args,
            cli_path,
            env,
            process: std::sync::Mutex::new(None),
            stdin: Mutex::new(None),
            message_rx: Mutex::new(None),
            reader_handle: Mutex::new(None),
            stderr_callback,
            stderr_buf: Arc::new(std::sync::Mutex::new(String::new())),
            cancel,
            close_timeout,
            ready: AtomicBool::new(false),
        }
    }
}

fn send_interrupt_signal(pid: u32) {
    #[cfg(unix)]
    {
        use nix::sys::signal::{Signal, kill};
        use nix::unistd::Pid;
        let _ = kill(Pid::from_raw(pid as i32), Signal::SIGINT);
    }
    #[cfg(windows)]
    {
        unsafe {
            windows_sys::Win32::System::Console::GenerateConsoleCtrlEvent(
                windows_sys::Win32::System::Console::CTRL_BREAK_EVENT,
                pid,
            );
        }
    }
}

#[async_trait]
impl Transport for CliTransport {
    async fn connect(&self) -> Result<()> {
        if self.ready.load(Ordering::Acquire) {
            return Err(Error::AlreadyConnected);
        }

        let mut cmd = tokio::process::Command::new(&self.cli_path);
        cmd.args(&self.cli_args)
            .stdin(std::process::Stdio::piped())
            .stdout(std::process::Stdio::piped())
            .stderr(std::process::Stdio::piped())
            .envs(&self.env);

        #[cfg(windows)]
        cmd.creation_flags(windows_sys::Win32::System::Threading::CREATE_NEW_PROCESS_GROUP);

        let mut child = cmd.spawn().map_err(Error::SpawnFailed)?;
        let child_pid: Option<u32> = child.id();

        let stdout = child.stdout.take().ok_or(Error::NotConnected)?;
        let stdin = child.stdin.take().ok_or(Error::NotConnected)?;
        let stderr = child.stderr.take();

        let (tx, rx) = tokio::sync::mpsc::channel::<Result<Value>>(256);

        let cancel_token = self.cancel.clone();
        let reader_handle = tokio::spawn(async move {
            use tokio::io::{AsyncBufReadExt, BufReader};
            let reader = BufReader::new(stdout);
            let mut lines = reader.lines();

            loop {
                let line = if let Some(ref token) = cancel_token {
                    tokio::select! {
                        _ = token.cancelled() => {
                            tracing::debug!("Reader cancelled via CancellationToken — sending interrupt");
                            if let Some(pid) = child_pid {
                                send_interrupt_signal(pid);
                            }
                            break;
                        }
                        result = lines.next_line() => result,
                    }
                } else {
                    lines.next_line().await
                };

                match line {
                    Ok(Some(line)) => {
                        let line = line.trim().to_string();
                        if line.is_empty() {
                            continue;
                        }
                        match serde_json::from_str::<Value>(&line) {
                            Ok(value) => {
                                if tx.send(Ok(value)).await.is_err() {
                                    break;
                                }
                            }
                            Err(e) => {
                                tracing::warn!("JSONL parse error: {e} — line: {line}");
                                let _ = tx.send(Err(crate::errors::Error::Json(e))).await;
                                // Don't break — continue reading next line
                            }
                        }
                    }
                    Ok(None) => break,
                    Err(e) => {
                        tracing::error!("stdout read error: {e}");
                        let _ = tx.send(Err(crate::errors::Error::ReadFailed(e))).await;
                        break;
                    }
                }
            }
        });

        // Always spawn a stderr reader: buffers all lines for `collected_stderr()`
        // and optionally forwards each line to the user-supplied callback.
        if let Some(stderr) = stderr {
            let buf = Arc::clone(&self.stderr_buf);
            let cb = self.stderr_callback.as_ref().map(Arc::clone);
            tokio::spawn(async move {
                use tokio::io::{AsyncBufReadExt, BufReader};
                let reader = BufReader::new(stderr);
                let mut lines = reader.lines();
                while let Ok(Some(line)) = lines.next_line().await {
                    if let Some(ref callback) = cb {
                        callback(&line);
                    }
                    let mut guard = buf.lock().unwrap_or_else(|e| e.into_inner());
                    if !guard.is_empty() {
                        guard.push('\n');
                    }
                    guard.push_str(&line);
                }
            });
        }

        *self.process.lock().unwrap_or_else(|e| e.into_inner()) = Some(child);
        *self.stdin.lock().await = Some(stdin);
        *self.message_rx.lock().await = Some(rx);
        *self.reader_handle.lock().await = Some(reader_handle);
        self.ready.store(true, Ordering::Release);

        Ok(())
    }

    async fn write(&self, data: &str) -> Result<()> {
        use tokio::io::AsyncWriteExt;
        let mut guard = self.stdin.lock().await;
        let stdin = guard.as_mut().ok_or(Error::NotConnected)?;
        stdin
            .write_all(data.as_bytes())
            .await
            .map_err(Error::WriteFailed)?;
        stdin.write_all(b"\n").await.map_err(Error::WriteFailed)?;
        stdin.flush().await.map_err(Error::WriteFailed)?;
        Ok(())
    }

    fn read_messages(&self) -> Pin<Box<dyn Stream<Item = Result<Value>> + Send>> {
        match self.message_rx.try_lock() {
            Ok(mut guard) => match guard.take() {
                Some(rx) => Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)),
                None => Box::pin(tokio_stream::iter(std::iter::once(Err(
                    crate::errors::Error::TransportClosed,
                )))),
            },
            Err(_) => Box::pin(tokio_stream::iter(std::iter::once(Err(
                crate::errors::Error::TransportClosed,
            )))),
        }
    }

    async fn end_input(&self) -> Result<()> {
        let mut guard = self.stdin.lock().await;
        *guard = None;
        Ok(())
    }

    async fn interrupt(&self) -> Result<()> {
        if let Some(pid) = self
            .process
            .lock()
            .unwrap_or_else(|e| e.into_inner())
            .as_ref()
            .and_then(|c| c.id())
        {
            send_interrupt_signal(pid);
        }
        Ok(())
    }

    fn is_ready(&self) -> bool {
        self.ready.load(Ordering::Acquire)
    }

    async fn close(&self) -> Result<Option<i32>> {
        self.end_input().await?;

        if let Some(handle) = self.reader_handle.lock().await.take() {
            let _ = handle.await;
        }

        // Take the child out of the mutex to avoid holding the guard across await
        let mut child = self
            .process
            .lock()
            .unwrap_or_else(|e| e.into_inner())
            .take();

        let timeout = self
            .close_timeout
            .unwrap_or(std::time::Duration::from_secs(5));
        let exit_code = if let Some(ref mut child) = child {
            match tokio::time::timeout(timeout, child.wait()).await {
                Ok(Ok(status)) => status.code(),
                _ => {
                    let _ = child.start_kill();
                    None
                }
            }
        } else {
            None
        };

        self.ready.store(false, Ordering::Release);
        Ok(exit_code)
    }

    fn collected_stderr(&self) -> String {
        self.stderr_buf
            .lock()
            .unwrap_or_else(|e| e.into_inner())
            .clone()
    }
}