Skip to main content

codex_cli_sdk/
client.rs

1use crate::callback::EventCallback;
2use crate::config::{CodexConfig, OutputSchema, OutputSchemaFile, ThreadOptions, TurnOptions};
3use crate::discovery;
4use crate::errors::{Error, Result};
5use crate::hooks::{self, HookContext, HookDecision, HookMatcher};
6use crate::permissions::{
7    ApprovalCallback, ApprovalContext, ApprovalResponse, PatchApprovalCallback,
8    PatchApprovalContext, PatchApprovalResponse,
9};
10use crate::transport::{CliTransport, Transport};
11use crate::types::events::{StreamedTurn, ThreadEvent, Turn};
12use crate::types::input::Input;
13
14use serde_json::Value;
15use std::path::PathBuf;
16use std::sync::Arc;
17use std::sync::atomic::{AtomicBool, Ordering};
18use std::time::Duration;
19use tokio_stream::StreamExt;
20
21// ── TurnGuard ──────────────────────────────────────────────────
22
23/// Resets `turn_in_progress` to `false` and clears the active transport slot
24/// when dropped, ensuring cleanup happens on any exit path (normal completion,
25/// early stream drop, or cancellation).
26struct TurnGuard {
27    flag: Arc<AtomicBool>,
28    active_transport: Arc<std::sync::Mutex<Option<Arc<dyn Transport>>>>,
29}
30
31impl Drop for TurnGuard {
32    fn drop(&mut self) {
33        self.flag.store(false, Ordering::Release);
34        *self
35            .active_transport
36            .lock()
37            .unwrap_or_else(|e| e.into_inner()) = None;
38    }
39}
40
41// ── Codex (factory) ────────────────────────────────────────────
42
43/// Entry point for the Codex SDK.
44///
45/// Creates `Thread` instances for running prompts against the Codex CLI.
46pub struct Codex {
47    config: CodexConfig,
48    cli_path: PathBuf,
49}
50
51impl Codex {
52    /// Create a new Codex instance.
53    ///
54    /// Discovers the CLI binary (or uses `config.cli_path` if set).
55    pub fn new(config: CodexConfig) -> Result<Self> {
56        let cli_path = match &config.cli_path {
57            Some(path) => path.clone(),
58            None => discovery::find_cli()?,
59        };
60        Ok(Self { config, cli_path })
61    }
62
63    /// Start a new thread (conversation).
64    pub fn start_thread(&self, options: ThreadOptions) -> Thread {
65        Thread::new(self.cli_path.clone(), self.config.clone(), options, None)
66    }
67
68    /// Resume a previous thread by ID.
69    pub fn resume_thread(&self, thread_id: impl Into<String>, options: ThreadOptions) -> Thread {
70        Thread::new(
71            self.cli_path.clone(),
72            self.config.clone(),
73            options,
74            Some(thread_id.into()),
75        )
76    }
77
78    /// Get the resolved CLI path.
79    pub fn cli_path(&self) -> &std::path::Path {
80        &self.cli_path
81    }
82
83    /// Check the CLI version.
84    pub async fn version(&self) -> Result<String> {
85        discovery::check_version(&self.cli_path, self.config.version_check_timeout).await
86    }
87}
88
89// ── Thread (session) ───────────────────────────────────────────
90
91/// A conversation thread with the Codex agent.
92///
93/// Each call to `run()` or `run_streamed()` spawns a `codex exec` subprocess.
94pub struct Thread {
95    cli_path: PathBuf,
96    config: CodexConfig,
97    options: ThreadOptions,
98    resume_id: Option<String>,
99    thread_id: Arc<std::sync::Mutex<Option<String>>>,
100    approval_callback: Option<ApprovalCallback>,
101    patch_approval_callback: Option<PatchApprovalCallback>,
102    event_callback: Option<EventCallback>,
103    hooks: Vec<HookMatcher>,
104    default_hook_timeout: Duration,
105    max_turns: Option<u32>,
106    max_budget_tokens: Option<u64>,
107    turn_in_progress: Arc<AtomicBool>,
108    active_transport: Arc<std::sync::Mutex<Option<Arc<dyn Transport>>>>,
109    transport_override: Option<Arc<dyn Transport>>,
110}
111
112impl Thread {
113    fn new(
114        cli_path: PathBuf,
115        config: CodexConfig,
116        mut options: ThreadOptions,
117        resume_id: Option<String>,
118    ) -> Self {
119        // Extract hook/budget fields from options (they don't go to CLI).
120        let hooks = std::mem::take(&mut options.hooks);
121        let default_hook_timeout = options.default_hook_timeout;
122        let max_turns = options.max_turns;
123        let max_budget_tokens = options.max_budget_tokens;
124
125        Self {
126            cli_path,
127            config,
128            options,
129            resume_id,
130            thread_id: Arc::new(std::sync::Mutex::new(None)),
131            approval_callback: None,
132            patch_approval_callback: None,
133            event_callback: None,
134            hooks,
135            default_hook_timeout,
136            max_turns,
137            max_budget_tokens,
138            turn_in_progress: Arc::new(AtomicBool::new(false)),
139            active_transport: Arc::new(std::sync::Mutex::new(None)),
140            transport_override: None,
141        }
142    }
143
144    /// Set an approval callback for handling permission requests.
145    pub fn with_approval_callback(mut self, callback: ApprovalCallback) -> Self {
146        self.approval_callback = Some(callback);
147        self
148    }
149
150    /// Set a patch approval callback for handling file-patch approval requests.
151    pub fn with_patch_approval_callback(mut self, callback: PatchApprovalCallback) -> Self {
152        self.patch_approval_callback = Some(callback);
153        self
154    }
155
156    /// Set an event callback for observing, transforming, or filtering events.
157    pub fn with_event_callback(mut self, callback: EventCallback) -> Self {
158        self.event_callback = Some(callback);
159        self
160    }
161
162    /// Add hooks to this thread.
163    pub fn with_hooks(mut self, hooks: Vec<HookMatcher>) -> Self {
164        self.hooks = hooks;
165        self
166    }
167
168    /// Override the transport used for this thread (useful for testing).
169    ///
170    /// When set, the provided transport is used instead of spawning a real
171    /// `codex` subprocess. The same `connect / write / end_input / read_messages
172    /// / close` call sequence is used regardless.
173    pub fn with_transport(mut self, transport: Arc<dyn Transport>) -> Self {
174        self.transport_override = Some(transport);
175        self
176    }
177
178    /// Get the thread ID (populated after first successful turn).
179    pub fn id(&self) -> Option<String> {
180        self.thread_id
181            .lock()
182            .unwrap_or_else(|e| e.into_inner())
183            .clone()
184            .or_else(|| self.resume_id.clone())
185    }
186
187    /// Interrupt the currently running turn by sending SIGINT to the CLI subprocess.
188    ///
189    /// Equivalent to pressing Ctrl-C. The CLI has `close_timeout` seconds to clean
190    /// up before being force-killed. Returns `Ok(())` if no turn is in progress.
191    pub async fn interrupt(&self) -> Result<()> {
192        let transport = self
193            .active_transport
194            .lock()
195            .unwrap_or_else(|e| e.into_inner())
196            .clone();
197        if let Some(t) = transport {
198            t.interrupt().await?;
199        }
200        Ok(())
201    }
202
203    /// Run a prompt and collect all events into a `Turn`.
204    pub async fn run(
205        &mut self,
206        input: impl Into<Input>,
207        turn_options: TurnOptions,
208    ) -> Result<Turn> {
209        let mut streamed = self.run_streamed(input, turn_options).await?;
210        let mut events = Vec::new();
211        let mut final_response = String::new();
212        let mut usage = None;
213
214        while let Some(event) = streamed.next().await {
215            let event = event?;
216            match &event {
217                ThreadEvent::ItemCompleted {
218                    item: crate::types::items::ThreadItem::AgentMessage { text, .. },
219                } => {
220                    final_response = text.clone();
221                }
222                ThreadEvent::TurnCompleted { usage: u } => {
223                    usage = Some(u.clone());
224                }
225                ThreadEvent::TurnFailed { error } => {
226                    let msg = error.message.clone();
227                    events.push(event);
228                    return Err(Error::Other(msg));
229                }
230                ThreadEvent::Error { message } => {
231                    let msg = message.clone();
232                    events.push(event);
233                    return Err(Error::Other(msg));
234                }
235                _ => {}
236            }
237            events.push(event);
238        }
239
240        Ok(Turn {
241            events,
242            final_response,
243            usage,
244        })
245    }
246
247    /// Run a prompt and return a streaming `StreamedTurn`.
248    pub async fn run_streamed(
249        &mut self,
250        input: impl Into<Input>,
251        turn_options: TurnOptions,
252    ) -> Result<StreamedTurn> {
253        // Guard against concurrent turns.
254        if self
255            .turn_in_progress
256            .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
257            .is_err()
258        {
259            return Err(Error::ConcurrentTurn);
260        }
261
262        let input = input.into();
263
264        // Build CLI args
265        let mut args = self.options.to_cli_args();
266        self.config.apply_overrides(&mut args);
267
268        // Resolve output schema (turn-level takes priority over thread-level).
269        let (schema_args, schema_guard, thread_schema_guard) = resolve_output_schema(
270            turn_options.output_schema.as_ref(),
271            &self.options.output_schema,
272        )?;
273        args.extend(schema_args);
274
275        // Handle resume
276        if let Some(ref resume_id) = self.resume_id {
277            args.push("resume".into());
278            args.push(resume_id.clone());
279        }
280
281        // Create transport (use override if set, otherwise spawn CLI subprocess)
282        let transport: Arc<dyn Transport> = match &self.transport_override {
283            Some(t) => Arc::clone(t),
284            None => Arc::new(CliTransport::new(
285                self.cli_path.clone(),
286                args,
287                self.config.to_env(),
288                self.config.stderr_callback.clone(),
289                turn_options.cancel.clone(),
290                self.config.close_timeout,
291            )),
292        };
293
294        // Register the active transport so Thread::interrupt() can reach it.
295        *self
296            .active_transport
297            .lock()
298            .unwrap_or_else(|e| e.into_inner()) = Some(Arc::clone(&transport));
299        let active_transport_slot = Arc::clone(&self.active_transport);
300        let turn_guard = TurnGuard {
301            flag: self.turn_in_progress.clone(),
302            active_transport: active_transport_slot,
303        };
304
305        // Connect (spawn process)
306        let connect_future = transport.connect();
307        match self.config.connect_timeout {
308            Some(timeout) => {
309                tokio::time::timeout(timeout, connect_future)
310                    .await
311                    .map_err(|_| Error::Timeout {
312                        operation: "connect".into(),
313                    })??;
314            }
315            None => connect_future.await?,
316        }
317
318        // Write prompt to stdin
319        let prompt_text = match &input {
320            Input::Text(s) => s.clone(),
321            Input::Items(items) => serde_json::to_string(items)
322                .map_err(|e| Error::Config(format!("failed to serialize input: {e}")))?,
323        };
324
325        transport.write(&prompt_text).await?;
326        transport.end_input().await?;
327
328        // Build event stream
329        let messages = transport.read_messages();
330        let approval_cb = self.approval_callback.clone();
331        let patch_approval_cb = self.patch_approval_callback.clone();
332        let event_cb = self.event_callback.clone();
333        let hooks = self.hooks.clone();
334        let default_hook_timeout = self.default_hook_timeout;
335        let max_turns = self.max_turns;
336        let max_budget_tokens = self.max_budget_tokens;
337        let transport_clone = transport.clone();
338        let thread_id_slot = self.thread_id.clone();
339
340        let stream = async_stream::stream! {
341            // Keep schema guards alive for the duration of the stream.
342            let _schema_guard = schema_guard;
343            let _thread_schema_guard = thread_schema_guard;
344            // Resets turn_in_progress on drop, even if stream is dropped early.
345            let _turn_guard = turn_guard;
346
347            // Helper: read the current thread ID without holding the lock.
348            let get_thread_id = || {
349                thread_id_slot
350                    .lock()
351                    .unwrap_or_else(|e| e.into_inner())
352                    .clone()
353            };
354
355            // Budget tracking state.
356            let mut turn_count: u32 = 0;
357            let mut total_output_tokens: u64 = 0;
358
359            tokio::pin!(messages);
360
361            while let Some(result) = messages.next().await {
362                match result {
363                    Ok(value) => {
364                        let event = match serde_json::from_value::<ThreadEvent>(value.clone()) {
365                            Ok(e) => e,
366                            Err(e) => {
367                                tracing::warn!("Skipping unrecognized event: {e} — raw: {value}");
368                                continue;
369                            }
370                        };
371
372                        // Capture thread ID from ThreadStarted.
373                        if let ThreadEvent::ThreadStarted { ref thread_id } = event {
374                            *thread_id_slot
375                                .lock()
376                                .unwrap_or_else(|e| e.into_inner()) = Some(thread_id.clone());
377                        }
378
379                        // Handle exec approval requests.
380                        if let ThreadEvent::ApprovalRequest(ref req) = event {
381                            let outcome = if let Some(ref cb) = approval_cb {
382                                let ctx = ApprovalContext {
383                                    request: req.clone(),
384                                    thread_id: get_thread_id(),
385                                };
386                                cb(ctx).await
387                            } else {
388                                crate::permissions::ApprovalDecision::Denied.into()
389                            };
390                            let response = ApprovalResponse::new(req.id.clone(), outcome.decision);
391                            if let Err(e) = write_response(&response, &*transport_clone).await {
392                                yield Err(e);
393                                break;
394                            }
395                        }
396
397                        // Handle patch approval requests.
398                        if let ThreadEvent::PatchApprovalRequest(ref req) = event {
399                            let outcome = if let Some(ref cb) = patch_approval_cb {
400                                let ctx = PatchApprovalContext {
401                                    request: req.clone(),
402                                    thread_id: get_thread_id(),
403                                };
404                                cb(ctx).await
405                            } else {
406                                crate::permissions::ApprovalDecision::Denied.into()
407                            };
408                            let response = PatchApprovalResponse::new(req.id.clone(), outcome.decision);
409                            if let Err(e) = write_response(&response, &*transport_clone).await {
410                                yield Err(e);
411                                break;
412                            }
413                        }
414
415                        // ── Hook dispatch ──────────────────────────────────
416                        let event = if !hooks.is_empty() {
417                            let hook_ctx = HookContext {
418                                thread_id: get_thread_id(),
419                                turn_count,
420                            };
421
422                            match hooks::dispatch_hook(&event, &hooks, &hook_ctx, default_hook_timeout).await {
423                                Some(output) => match output.decision {
424                                    HookDecision::Allow => event,
425                                    HookDecision::Block => continue,
426                                    HookDecision::Modify => {
427                                        output.replacement_event.unwrap_or(event)
428                                    }
429                                    HookDecision::Abort => {
430                                        tracing::info!("Hook aborted stream: {:?}", output.reason);
431                                        break;
432                                    }
433                                },
434                                None => event,
435                            }
436                        } else {
437                            event
438                        };
439
440                        // ── Budget enforcement ─────────────────────────────
441                        if let ThreadEvent::TurnCompleted { ref usage } = event {
442                            turn_count += 1;
443                            total_output_tokens += usage.output_tokens;
444
445                            // Yield the TurnCompleted event first, then check limits.
446                            let event = match crate::callback::apply_callback(event, event_cb.as_ref()) {
447                                Some(e) => e,
448                                None => continue,
449                            };
450                            yield Ok(event);
451
452                            if let Some(limit) = max_turns {
453                                if turn_count >= limit {
454                                    tracing::info!("max_turns reached ({turn_count}/{limit}), closing stream");
455                                    break;
456                                }
457                            }
458                            if let Some(budget) = max_budget_tokens {
459                                if total_output_tokens >= budget {
460                                    tracing::info!(
461                                        "max_budget_tokens reached ({total_output_tokens}/{budget}), closing stream"
462                                    );
463                                    break;
464                                }
465                            }
466                            continue;
467                        }
468
469                        // ── Event callback ─────────────────────────────────
470                        let event = match crate::callback::apply_callback(event, event_cb.as_ref()) {
471                            Some(e) => e,
472                            None => continue,
473                        };
474                        yield Ok(event);
475                    }
476                    Err(e) => {
477                        let is_fatal = !matches!(&e, Error::Json(_));
478                        yield Err(e);
479                        if is_fatal {
480                            break;
481                        }
482                    }
483                }
484            }
485
486            match transport_clone.close().await {
487                Ok(Some(code)) if code != 0 => {
488                    yield Err(Error::ProcessExited {
489                        code,
490                        stderr: transport_clone.collected_stderr(),
491                    });
492                }
493                Err(e) => {
494                    yield Err(e);
495                }
496                _ => {}
497            }
498        };
499
500        Ok(StreamedTurn::new(stream))
501    }
502}
503
504// ── Private helpers ────────────────────────────────────────────
505
506/// Serialize `response` to JSON and write it to `transport`.
507async fn write_response<R: serde::Serialize>(
508    response: &R,
509    transport: &dyn crate::transport::Transport,
510) -> Result<()> {
511    let json = serde_json::to_string(response).map_err(Error::Json)?;
512    transport.write(&json).await
513}
514
515// ── Public helpers ─────────────────────────────────────────────
516
517/// Resolve output schema args and temp-file guards.
518///
519/// Turn-level schema takes priority over the thread-level schema.
520/// Returns `(additional_cli_args, turn_guard, thread_guard)`.
521fn resolve_output_schema(
522    turn_schema: Option<&Value>,
523    thread_schema: &Option<OutputSchema>,
524) -> Result<(Vec<String>, OutputSchemaFile, Option<OutputSchemaFile>)> {
525    let turn_guard = OutputSchemaFile::new(turn_schema)?;
526
527    if let Some(path) = turn_guard.path() {
528        // Turn-level inline schema wins — no thread-level guard needed.
529        let args = vec!["--output-schema".into(), path.display().to_string()];
530        return Ok((args, turn_guard, None));
531    }
532
533    // No turn-level schema; fall back to thread-level.
534    match thread_schema {
535        Some(OutputSchema::File(path)) => {
536            let args = vec!["--output-schema".into(), path.display().to_string()];
537            Ok((args, turn_guard, None))
538        }
539        Some(OutputSchema::Inline(value)) => {
540            let thread_guard = OutputSchemaFile::new(Some(value))?;
541            let args = thread_guard
542                .path()
543                .map(|p| vec!["--output-schema".into(), p.display().to_string()])
544                .unwrap_or_default();
545            Ok((args, turn_guard, Some(thread_guard)))
546        }
547        None => Ok((vec![], turn_guard, None)),
548    }
549}
550
551// ── Tests ───────────────────────────────────────────────────────
552
553#[cfg(test)]
554mod tests {
555    use super::*;
556    use crate::testing::builders;
557    use crate::testing::mock_transport::MockTransport;
558    use tokio_stream::StreamExt;
559
560    fn make_thread_with_mock(mock: Arc<MockTransport>) -> Thread {
561        let mut thread = Thread::new(
562            std::path::PathBuf::from("/nonexistent/codex"),
563            CodexConfig::default(),
564            ThreadOptions::default(),
565            None,
566        );
567        thread.transport_override = Some(mock as Arc<dyn Transport>);
568        thread
569    }
570
571    #[tokio::test]
572    async fn test_transport_override_basic_turn() {
573        let mock = Arc::new(MockTransport::new());
574        mock.enqueue_session("thread-1");
575        mock.enqueue_turn_complete("Hello from mock!");
576
577        let mut thread = make_thread_with_mock(Arc::clone(&mock));
578        let turn = thread
579            .run("say hello", TurnOptions::default())
580            .await
581            .unwrap();
582
583        assert_eq!(turn.final_response, "Hello from mock!");
584        assert!(turn.usage.is_some());
585        // Thread ID was captured from ThreadStarted event
586        assert_eq!(thread.id(), Some("thread-1".to_string()));
587    }
588
589    #[tokio::test]
590    async fn test_turn_guard_resets_on_drop() {
591        let mock = Arc::new(MockTransport::new());
592        mock.enqueue_session("thread-2");
593        mock.enqueue_turn_complete("first");
594
595        let mut thread = make_thread_with_mock(Arc::clone(&mock));
596
597        // First turn completes normally
598        thread.run("first", TurnOptions::default()).await.unwrap();
599
600        // After completion, turn_in_progress should be false so we can run again.
601        // Need a new mock since the old one's receiver is consumed.
602        let mock2 = Arc::new(MockTransport::new());
603        mock2.enqueue_session("thread-2");
604        mock2.enqueue_turn_complete("second");
605        thread.transport_override = Some(mock2 as Arc<dyn Transport>);
606
607        let result = thread.run("second", TurnOptions::default()).await;
608        assert!(
609            result.is_ok(),
610            "Second turn should succeed after first completes"
611        );
612    }
613
614    #[tokio::test]
615    async fn test_turn_guard_resets_on_stream_drop() {
616        let mock = Arc::new(MockTransport::new());
617        mock.enqueue_session("thread-3");
618        mock.enqueue_turn_complete("data");
619
620        let mut thread = make_thread_with_mock(Arc::clone(&mock));
621
622        // Start streaming but drop it immediately
623        {
624            let _stream = thread
625                .run_streamed("prompt", TurnOptions::default())
626                .await
627                .unwrap();
628            // _stream dropped here — TurnGuard must reset turn_in_progress
629        }
630
631        // Spin briefly to allow the async drop to propagate (TurnGuard is sync drop, so immediate)
632        assert!(
633            !thread.turn_in_progress.load(Ordering::Acquire),
634            "turn_in_progress should be false after stream drop"
635        );
636
637        // Should be able to run a new turn
638        let mock2 = Arc::new(MockTransport::new());
639        mock2.enqueue_session("thread-3");
640        mock2.enqueue_turn_complete("ok");
641        thread.transport_override = Some(mock2 as Arc<dyn Transport>);
642
643        let result = thread.run("next", TurnOptions::default()).await;
644        assert!(result.is_ok());
645    }
646
647    #[tokio::test]
648    async fn test_approval_with_mock_transport() {
649        use crate::permissions::{ApprovalCallback, ApprovalDecision};
650
651        let mock = Arc::new(MockTransport::new());
652        mock.enqueue_session("thread-4");
653        mock.enqueue_event(builders::approval_request("ap-1", "ls"));
654        mock.enqueue_turn_complete("done");
655
656        let mut thread = make_thread_with_mock(Arc::clone(&mock));
657
658        let callback: ApprovalCallback =
659            Arc::new(|_ctx| Box::pin(async { ApprovalDecision::Approved.into() }));
660        thread.approval_callback = Some(callback);
661
662        let turn = thread.run("do it", TurnOptions::default()).await.unwrap();
663
664        // The approval response should have been written to the mock
665        let written = mock.written_lines();
666        assert!(!written.is_empty(), "approval response should be written");
667        // Written items: prompt + approval response
668        assert!(
669            written.iter().any(|s| s.contains("ap-1")),
670            "approval id should appear in response"
671        );
672        assert_eq!(turn.final_response, "done");
673    }
674
675    #[tokio::test]
676    async fn test_run_returns_error_on_turn_failed() {
677        let mock = Arc::new(MockTransport::new());
678        mock.enqueue_session("thread-err-1");
679        mock.enqueue_event(builders::turn_failed("model overloaded"));
680
681        let mut thread = make_thread_with_mock(Arc::clone(&mock));
682        let result = thread.run("prompt", TurnOptions::default()).await;
683
684        assert!(result.is_err(), "run() should return Err on turn.failed");
685        let err = result.unwrap_err();
686        assert!(
687            err.to_string().contains("model overloaded"),
688            "error should contain the failure message, got: {err}"
689        );
690    }
691
692    #[tokio::test]
693    async fn test_run_returns_error_on_error_event() {
694        let mock = Arc::new(MockTransport::new());
695        mock.enqueue_session("thread-err-2");
696        mock.enqueue_event(builders::error("something broke"));
697
698        let mut thread = make_thread_with_mock(Arc::clone(&mock));
699        let result = thread.run("prompt", TurnOptions::default()).await;
700
701        assert!(result.is_err(), "run() should return Err on error event");
702        let err = result.unwrap_err();
703        assert!(
704            err.to_string().contains("something broke"),
705            "error should contain the message, got: {err}"
706        );
707    }
708
709    #[tokio::test]
710    async fn test_nonzero_exit_code_surfaces() {
711        let mock = Arc::new(MockTransport::new());
712        mock.enqueue_session("thread-exit");
713        mock.enqueue_turn_complete("partial");
714        mock.set_exit_code(1);
715
716        let mut thread = make_thread_with_mock(Arc::clone(&mock));
717        let mut streamed = thread
718            .run_streamed("prompt", TurnOptions::default())
719            .await
720            .unwrap();
721
722        let mut saw_exit_error = false;
723        while let Some(event) = streamed.next().await {
724            if let Err(crate::Error::ProcessExited { code, .. }) = &event {
725                if *code == 1 {
726                    saw_exit_error = true;
727                }
728            }
729        }
730        assert!(
731            saw_exit_error,
732            "stream should yield ProcessExited error for non-zero exit code"
733        );
734    }
735
736    #[tokio::test]
737    async fn test_read_messages_already_consumed() {
738        let mock = MockTransport::new();
739        mock.enqueue_event(serde_json::json!({"type": "turn.started"}));
740        mock.connect().await.unwrap();
741
742        // First call takes the receiver
743        let mut first = mock.read_messages();
744        let _ = first.next().await;
745
746        // Second call should yield TransportClosed error
747        let mut second = mock.read_messages();
748        let result = second.next().await;
749        assert!(result.is_some());
750        let err = result.unwrap();
751        assert!(matches!(err, Err(crate::Error::TransportClosed)));
752    }
753
754    #[tokio::test]
755    async fn test_max_turns_enforced() {
756        let mock = Arc::new(MockTransport::new());
757        mock.enqueue_session("thread-budget");
758        // Enqueue 3 turn completions
759        mock.enqueue_turn_complete("response-1");
760        // For the second turn, manually enqueue turn bookends + message
761        mock.enqueue_event(builders::turn_started());
762        mock.enqueue_event(builders::agent_message_completed("msg-2", "response-2"));
763        mock.enqueue_event(builders::turn_completed(50, 0, 25));
764        // Third turn (should not be reached)
765        mock.enqueue_event(builders::turn_started());
766        mock.enqueue_event(builders::agent_message_completed("msg-3", "response-3"));
767        mock.enqueue_event(builders::turn_completed(50, 0, 25));
768
769        let mut thread = Thread::new(
770            std::path::PathBuf::from("/nonexistent/codex"),
771            CodexConfig::default(),
772            ThreadOptions::builder().max_turns(2u32).build(),
773            None,
774        );
775        thread.transport_override = Some(mock as Arc<dyn Transport>);
776
777        let mut streamed = thread
778            .run_streamed("prompt", TurnOptions::default())
779            .await
780            .unwrap();
781
782        let mut turn_completions = 0;
783        while let Some(event) = streamed.next().await {
784            if let Ok(ThreadEvent::TurnCompleted { .. }) = event {
785                turn_completions += 1;
786            }
787        }
788
789        assert_eq!(turn_completions, 2, "stream should close after max_turns=2");
790    }
791
792    #[tokio::test]
793    async fn test_max_budget_tokens_enforced() {
794        let mock = Arc::new(MockTransport::new());
795        mock.enqueue_session("thread-budget-tok");
796        // First turn: 500 output tokens
797        mock.enqueue_event(builders::agent_message_completed("msg-1", "response"));
798        mock.enqueue_event(builders::turn_completed(100, 0, 500));
799        // Second turn: 600 output tokens (total 1100, exceeds budget of 1000)
800        mock.enqueue_event(builders::turn_started());
801        mock.enqueue_event(builders::agent_message_completed("msg-2", "response-2"));
802        mock.enqueue_event(builders::turn_completed(100, 0, 600));
803        // Third turn (should not be reached)
804        mock.enqueue_event(builders::turn_started());
805        mock.enqueue_event(builders::agent_message_completed("msg-3", "response-3"));
806        mock.enqueue_event(builders::turn_completed(100, 0, 100));
807
808        let mut thread = Thread::new(
809            std::path::PathBuf::from("/nonexistent/codex"),
810            CodexConfig::default(),
811            ThreadOptions::builder().max_budget_tokens(1000u64).build(),
812            None,
813        );
814        thread.transport_override = Some(mock as Arc<dyn Transport>);
815
816        let mut streamed = thread
817            .run_streamed("prompt", TurnOptions::default())
818            .await
819            .unwrap();
820
821        let mut turn_completions = 0;
822        while let Some(event) = streamed.next().await {
823            if let Ok(ThreadEvent::TurnCompleted { .. }) = event {
824                turn_completions += 1;
825            }
826        }
827
828        assert_eq!(
829            turn_completions, 2,
830            "stream should close after exceeding budget on turn 2"
831        );
832    }
833
834    #[tokio::test]
835    async fn test_hook_blocks_event() {
836        use crate::hooks::{HookDecision, HookEvent, HookMatcher, HookOutput};
837
838        let mock = Arc::new(MockTransport::new());
839        mock.enqueue_session("thread-hook");
840        mock.enqueue_event(builders::command_started("cmd-1", "rm -rf /"));
841        mock.enqueue_turn_complete("done");
842
843        let hook = HookMatcher {
844            event: HookEvent::CommandStarted,
845            command_filter: Some("rm".into()),
846            callback: Arc::new(|_input, _ctx| {
847                Box::pin(async {
848                    HookOutput {
849                        decision: HookDecision::Block,
850                        reason: Some("blocked rm".into()),
851                        replacement_event: None,
852                    }
853                })
854            }),
855            timeout: None,
856            on_timeout: Default::default(),
857        };
858
859        let mut thread = Thread::new(
860            std::path::PathBuf::from("/nonexistent/codex"),
861            CodexConfig::default(),
862            ThreadOptions::builder().hooks(vec![hook]).build(),
863            None,
864        );
865        thread.transport_override = Some(mock as Arc<dyn Transport>);
866
867        let mut streamed = thread
868            .run_streamed("prompt", TurnOptions::default())
869            .await
870            .unwrap();
871
872        let mut saw_command_started = false;
873        while let Some(event) = streamed.next().await {
874            if let Ok(ThreadEvent::ItemStarted {
875                item: crate::types::items::ThreadItem::CommandExecution { .. },
876            }) = event
877            {
878                saw_command_started = true;
879            }
880        }
881
882        assert!(
883            !saw_command_started,
884            "command started event should be blocked by hook"
885        );
886    }
887
888    #[tokio::test]
889    async fn test_hooks_persist_across_turns() {
890        use crate::hooks::{HookEvent, HookMatcher, HookOutput};
891        use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering};
892
893        let call_count = Arc::new(AtomicUsize::new(0));
894        let call_count_clone = Arc::clone(&call_count);
895
896        let hook = HookMatcher {
897            event: HookEvent::TurnCompleted,
898            command_filter: None,
899            callback: Arc::new(move |_input, _ctx| {
900                let c = Arc::clone(&call_count_clone);
901                Box::pin(async move {
902                    c.fetch_add(1, AtomicOrdering::Relaxed);
903                    HookOutput::default()
904                })
905            }),
906            timeout: None,
907            on_timeout: crate::hooks::HookTimeoutBehavior::FailOpen,
908        };
909
910        let mut thread = Thread::new(
911            std::path::PathBuf::from("/nonexistent/codex"),
912            CodexConfig::default(),
913            ThreadOptions::builder().hooks(vec![hook]).build(),
914            None,
915        );
916
917        // First turn
918        let mock1 = Arc::new(MockTransport::new());
919        mock1.enqueue_session("thread-persist-hooks");
920        mock1.enqueue_turn_complete("first");
921        thread.transport_override = Some(mock1 as Arc<dyn Transport>);
922        thread.run("first", TurnOptions::default()).await.unwrap();
923
924        // Second turn — hooks must still fire
925        let mock2 = Arc::new(MockTransport::new());
926        mock2.enqueue_session("thread-persist-hooks");
927        mock2.enqueue_turn_complete("second");
928        thread.transport_override = Some(mock2 as Arc<dyn Transport>);
929        thread.run("second", TurnOptions::default()).await.unwrap();
930
931        assert_eq!(
932            call_count.load(AtomicOrdering::Relaxed),
933            2,
934            "hook should fire on both turns, not just the first"
935        );
936    }
937
938    #[tokio::test]
939    async fn test_thread_interrupt_delegates_to_transport() {
940        use tokio::sync::Barrier;
941
942        let mock = Arc::new(MockTransport::new());
943        mock.enqueue_session("thread-interrupt-1");
944
945        // Enqueue a slow turn: turn.started is already in from enqueue_session,
946        // but don't enqueue turn.completed — stream stays open long enough to interrupt.
947        // We just need the stream alive; the mock will close naturally.
948        mock.enqueue_turn_complete("done");
949
950        let mock_for_assert = Arc::clone(&mock);
951        let barrier = Arc::new(Barrier::new(2));
952        let barrier2 = Arc::clone(&barrier);
953
954        let mut thread = make_thread_with_mock(Arc::clone(&mock));
955        let mut streamed = thread
956            .run_streamed("prompt", TurnOptions::default())
957            .await
958            .unwrap();
959
960        // Spawn a task: wait for barrier then interrupt
961        let thread_ref = &thread;
962        // We call interrupt right away — the transport is registered before connect
963        thread_ref.interrupt().await.unwrap();
964
965        // Drain the stream
966        while let Some(_) = streamed.next().await {}
967
968        assert!(
969            mock_for_assert.interrupt_called(),
970            "interrupt() should have been delegated to the mock transport"
971        );
972        let _ = barrier2; // silence unused warning
973    }
974
975    #[tokio::test]
976    async fn test_thread_interrupt_noop_when_idle() {
977        let thread = Thread::new(
978            std::path::PathBuf::from("/nonexistent/codex"),
979            CodexConfig::default(),
980            ThreadOptions::default(),
981            None,
982        );
983        // No active turn — interrupt should be a no-op returning Ok(())
984        let result = thread.interrupt().await;
985        assert!(
986            result.is_ok(),
987            "interrupt with no active turn should return Ok"
988        );
989    }
990
991    #[tokio::test]
992    async fn test_active_transport_cleared_after_turn() {
993        let mock = Arc::new(MockTransport::new());
994        mock.enqueue_session("thread-clear");
995        mock.enqueue_turn_complete("done");
996
997        let mut thread = make_thread_with_mock(Arc::clone(&mock));
998
999        // Run a complete turn
1000        thread.run("prompt", TurnOptions::default()).await.unwrap();
1001
1002        // After the turn, active_transport slot should be cleared by TurnGuard drop.
1003        // Calling interrupt() should return Ok(()) without calling into mock.
1004        let result = thread.interrupt().await;
1005        assert!(result.is_ok());
1006        assert!(
1007            !mock.interrupt_called(),
1008            "interrupt_called should be false — slot was cleared after turn completed"
1009        );
1010    }
1011
1012    #[tokio::test]
1013    async fn test_hook_aborts_stream() {
1014        use crate::hooks::{HookDecision, HookEvent, HookMatcher, HookOutput};
1015
1016        let mock = Arc::new(MockTransport::new());
1017        mock.enqueue_session("thread-abort");
1018        mock.enqueue_event(builders::command_started("cmd-1", "dangerous"));
1019        // These events should never be reached
1020        mock.enqueue_turn_complete("should not see this");
1021
1022        let hook = HookMatcher {
1023            event: HookEvent::CommandStarted,
1024            command_filter: None,
1025            callback: Arc::new(|_input, _ctx| {
1026                Box::pin(async {
1027                    HookOutput {
1028                        decision: HookDecision::Abort,
1029                        reason: Some("abort!".into()),
1030                        replacement_event: None,
1031                    }
1032                })
1033            }),
1034            timeout: None,
1035            on_timeout: Default::default(),
1036        };
1037
1038        let mut thread = Thread::new(
1039            std::path::PathBuf::from("/nonexistent/codex"),
1040            CodexConfig::default(),
1041            ThreadOptions::builder().hooks(vec![hook]).build(),
1042            None,
1043        );
1044        thread.transport_override = Some(mock as Arc<dyn Transport>);
1045
1046        let mut streamed = thread
1047            .run_streamed("prompt", TurnOptions::default())
1048            .await
1049            .unwrap();
1050
1051        let mut events = vec![];
1052        while let Some(event) = streamed.next().await {
1053            if let Ok(ref e) = event {
1054                events.push(e.clone());
1055            }
1056        }
1057
1058        // Should have ThreadStarted + TurnStarted, then abort before anything else
1059        assert!(
1060            !events
1061                .iter()
1062                .any(|e| matches!(e, ThreadEvent::TurnCompleted { .. })),
1063            "TurnCompleted should not appear — stream was aborted"
1064        );
1065    }
1066}