1use crate::callback::EventCallback;
2use crate::config::{CodexConfig, OutputSchema, OutputSchemaFile, ThreadOptions, TurnOptions};
3use crate::discovery;
4use crate::errors::{Error, Result};
5use crate::hooks::{self, HookContext, HookDecision, HookMatcher};
6use crate::permissions::{
7 ApprovalCallback, ApprovalContext, ApprovalResponse, PatchApprovalCallback,
8 PatchApprovalContext, PatchApprovalResponse,
9};
10use crate::transport::{CliTransport, Transport};
11use crate::types::events::{StreamedTurn, ThreadEvent, Turn};
12use crate::types::input::Input;
13
14use serde_json::Value;
15use std::path::PathBuf;
16use std::sync::Arc;
17use std::sync::atomic::{AtomicBool, Ordering};
18use std::time::Duration;
19use tokio_stream::StreamExt;
20
21struct TurnGuard {
27 flag: Arc<AtomicBool>,
28 active_transport: Arc<std::sync::Mutex<Option<Arc<dyn Transport>>>>,
29}
30
31impl Drop for TurnGuard {
32 fn drop(&mut self) {
33 self.flag.store(false, Ordering::Release);
34 *self
35 .active_transport
36 .lock()
37 .unwrap_or_else(|e| e.into_inner()) = None;
38 }
39}
40
41pub struct Codex {
47 config: CodexConfig,
48 cli_path: PathBuf,
49}
50
51impl Codex {
52 pub fn new(config: CodexConfig) -> Result<Self> {
56 let cli_path = match &config.cli_path {
57 Some(path) => path.clone(),
58 None => discovery::find_cli()?,
59 };
60 Ok(Self { config, cli_path })
61 }
62
63 pub fn start_thread(&self, options: ThreadOptions) -> Thread {
65 Thread::new(self.cli_path.clone(), self.config.clone(), options, None)
66 }
67
68 pub fn resume_thread(&self, thread_id: impl Into<String>, options: ThreadOptions) -> Thread {
70 Thread::new(
71 self.cli_path.clone(),
72 self.config.clone(),
73 options,
74 Some(thread_id.into()),
75 )
76 }
77
78 pub fn cli_path(&self) -> &std::path::Path {
80 &self.cli_path
81 }
82
83 pub async fn version(&self) -> Result<String> {
85 discovery::check_version(&self.cli_path, self.config.version_check_timeout).await
86 }
87}
88
89pub struct Thread {
95 cli_path: PathBuf,
96 config: CodexConfig,
97 options: ThreadOptions,
98 resume_id: Option<String>,
99 thread_id: Arc<std::sync::Mutex<Option<String>>>,
100 approval_callback: Option<ApprovalCallback>,
101 patch_approval_callback: Option<PatchApprovalCallback>,
102 event_callback: Option<EventCallback>,
103 hooks: Vec<HookMatcher>,
104 default_hook_timeout: Duration,
105 max_turns: Option<u32>,
106 max_budget_tokens: Option<u64>,
107 turn_in_progress: Arc<AtomicBool>,
108 active_transport: Arc<std::sync::Mutex<Option<Arc<dyn Transport>>>>,
109 transport_override: Option<Arc<dyn Transport>>,
110}
111
112impl Thread {
113 fn new(
114 cli_path: PathBuf,
115 config: CodexConfig,
116 mut options: ThreadOptions,
117 resume_id: Option<String>,
118 ) -> Self {
119 let hooks = std::mem::take(&mut options.hooks);
121 let default_hook_timeout = options.default_hook_timeout;
122 let max_turns = options.max_turns;
123 let max_budget_tokens = options.max_budget_tokens;
124
125 Self {
126 cli_path,
127 config,
128 options,
129 resume_id,
130 thread_id: Arc::new(std::sync::Mutex::new(None)),
131 approval_callback: None,
132 patch_approval_callback: None,
133 event_callback: None,
134 hooks,
135 default_hook_timeout,
136 max_turns,
137 max_budget_tokens,
138 turn_in_progress: Arc::new(AtomicBool::new(false)),
139 active_transport: Arc::new(std::sync::Mutex::new(None)),
140 transport_override: None,
141 }
142 }
143
144 pub fn with_approval_callback(mut self, callback: ApprovalCallback) -> Self {
146 self.approval_callback = Some(callback);
147 self
148 }
149
150 pub fn with_patch_approval_callback(mut self, callback: PatchApprovalCallback) -> Self {
152 self.patch_approval_callback = Some(callback);
153 self
154 }
155
156 pub fn with_event_callback(mut self, callback: EventCallback) -> Self {
158 self.event_callback = Some(callback);
159 self
160 }
161
162 pub fn with_hooks(mut self, hooks: Vec<HookMatcher>) -> Self {
164 self.hooks = hooks;
165 self
166 }
167
168 pub fn with_transport(mut self, transport: Arc<dyn Transport>) -> Self {
174 self.transport_override = Some(transport);
175 self
176 }
177
178 pub fn id(&self) -> Option<String> {
180 self.thread_id
181 .lock()
182 .unwrap_or_else(|e| e.into_inner())
183 .clone()
184 .or_else(|| self.resume_id.clone())
185 }
186
187 pub async fn interrupt(&self) -> Result<()> {
192 let transport = self
193 .active_transport
194 .lock()
195 .unwrap_or_else(|e| e.into_inner())
196 .clone();
197 if let Some(t) = transport {
198 t.interrupt().await?;
199 }
200 Ok(())
201 }
202
203 pub async fn run(
205 &mut self,
206 input: impl Into<Input>,
207 turn_options: TurnOptions,
208 ) -> Result<Turn> {
209 let mut streamed = self.run_streamed(input, turn_options).await?;
210 let mut events = Vec::new();
211 let mut final_response = String::new();
212 let mut usage = None;
213
214 while let Some(event) = streamed.next().await {
215 let event = event?;
216 match &event {
217 ThreadEvent::ItemCompleted {
218 item: crate::types::items::ThreadItem::AgentMessage { text, .. },
219 } => {
220 final_response = text.clone();
221 }
222 ThreadEvent::TurnCompleted { usage: u } => {
223 usage = Some(u.clone());
224 }
225 ThreadEvent::TurnFailed { error } => {
226 let msg = error.message.clone();
227 events.push(event);
228 return Err(Error::Other(msg));
229 }
230 ThreadEvent::Error { message } => {
231 let msg = message.clone();
232 events.push(event);
233 return Err(Error::Other(msg));
234 }
235 _ => {}
236 }
237 events.push(event);
238 }
239
240 Ok(Turn {
241 events,
242 final_response,
243 usage,
244 })
245 }
246
247 pub async fn run_streamed(
249 &mut self,
250 input: impl Into<Input>,
251 turn_options: TurnOptions,
252 ) -> Result<StreamedTurn> {
253 if self
255 .turn_in_progress
256 .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
257 .is_err()
258 {
259 return Err(Error::ConcurrentTurn);
260 }
261
262 let input = input.into();
263
264 let mut args = self.options.to_cli_args();
266 self.config.apply_overrides(&mut args);
267
268 let (schema_args, schema_guard, thread_schema_guard) = resolve_output_schema(
270 turn_options.output_schema.as_ref(),
271 &self.options.output_schema,
272 )?;
273 args.extend(schema_args);
274
275 if let Some(ref resume_id) = self.resume_id {
277 args.push("resume".into());
278 args.push(resume_id.clone());
279 }
280
281 let transport: Arc<dyn Transport> = match &self.transport_override {
283 Some(t) => Arc::clone(t),
284 None => Arc::new(CliTransport::new(
285 self.cli_path.clone(),
286 args,
287 self.config.to_env(),
288 self.config.stderr_callback.clone(),
289 turn_options.cancel.clone(),
290 self.config.close_timeout,
291 )),
292 };
293
294 *self
296 .active_transport
297 .lock()
298 .unwrap_or_else(|e| e.into_inner()) = Some(Arc::clone(&transport));
299 let active_transport_slot = Arc::clone(&self.active_transport);
300 let turn_guard = TurnGuard {
301 flag: self.turn_in_progress.clone(),
302 active_transport: active_transport_slot,
303 };
304
305 let connect_future = transport.connect();
307 match self.config.connect_timeout {
308 Some(timeout) => {
309 tokio::time::timeout(timeout, connect_future)
310 .await
311 .map_err(|_| Error::Timeout {
312 operation: "connect".into(),
313 })??;
314 }
315 None => connect_future.await?,
316 }
317
318 let prompt_text = match &input {
320 Input::Text(s) => s.clone(),
321 Input::Items(items) => serde_json::to_string(items)
322 .map_err(|e| Error::Config(format!("failed to serialize input: {e}")))?,
323 };
324
325 transport.write(&prompt_text).await?;
326 transport.end_input().await?;
327
328 let messages = transport.read_messages();
330 let approval_cb = self.approval_callback.clone();
331 let patch_approval_cb = self.patch_approval_callback.clone();
332 let event_cb = self.event_callback.clone();
333 let hooks = self.hooks.clone();
334 let default_hook_timeout = self.default_hook_timeout;
335 let max_turns = self.max_turns;
336 let max_budget_tokens = self.max_budget_tokens;
337 let transport_clone = transport.clone();
338 let thread_id_slot = self.thread_id.clone();
339
340 let stream = async_stream::stream! {
341 let _schema_guard = schema_guard;
343 let _thread_schema_guard = thread_schema_guard;
344 let _turn_guard = turn_guard;
346
347 let get_thread_id = || {
349 thread_id_slot
350 .lock()
351 .unwrap_or_else(|e| e.into_inner())
352 .clone()
353 };
354
355 let mut turn_count: u32 = 0;
357 let mut total_output_tokens: u64 = 0;
358
359 tokio::pin!(messages);
360
361 while let Some(result) = messages.next().await {
362 match result {
363 Ok(value) => {
364 let event = match serde_json::from_value::<ThreadEvent>(value.clone()) {
365 Ok(e) => e,
366 Err(e) => {
367 tracing::warn!("Skipping unrecognized event: {e} — raw: {value}");
368 continue;
369 }
370 };
371
372 if let ThreadEvent::ThreadStarted { ref thread_id } = event {
374 *thread_id_slot
375 .lock()
376 .unwrap_or_else(|e| e.into_inner()) = Some(thread_id.clone());
377 }
378
379 if let ThreadEvent::ApprovalRequest(ref req) = event {
381 let outcome = if let Some(ref cb) = approval_cb {
382 let ctx = ApprovalContext {
383 request: req.clone(),
384 thread_id: get_thread_id(),
385 };
386 cb(ctx).await
387 } else {
388 crate::permissions::ApprovalDecision::Denied.into()
389 };
390 let response = ApprovalResponse::new(req.id.clone(), outcome.decision);
391 if let Err(e) = write_response(&response, &*transport_clone).await {
392 yield Err(e);
393 break;
394 }
395 }
396
397 if let ThreadEvent::PatchApprovalRequest(ref req) = event {
399 let outcome = if let Some(ref cb) = patch_approval_cb {
400 let ctx = PatchApprovalContext {
401 request: req.clone(),
402 thread_id: get_thread_id(),
403 };
404 cb(ctx).await
405 } else {
406 crate::permissions::ApprovalDecision::Denied.into()
407 };
408 let response = PatchApprovalResponse::new(req.id.clone(), outcome.decision);
409 if let Err(e) = write_response(&response, &*transport_clone).await {
410 yield Err(e);
411 break;
412 }
413 }
414
415 let event = if !hooks.is_empty() {
417 let hook_ctx = HookContext {
418 thread_id: get_thread_id(),
419 turn_count,
420 };
421
422 match hooks::dispatch_hook(&event, &hooks, &hook_ctx, default_hook_timeout).await {
423 Some(output) => match output.decision {
424 HookDecision::Allow => event,
425 HookDecision::Block => continue,
426 HookDecision::Modify => {
427 output.replacement_event.unwrap_or(event)
428 }
429 HookDecision::Abort => {
430 tracing::info!("Hook aborted stream: {:?}", output.reason);
431 break;
432 }
433 },
434 None => event,
435 }
436 } else {
437 event
438 };
439
440 if let ThreadEvent::TurnCompleted { ref usage } = event {
442 turn_count += 1;
443 total_output_tokens += usage.output_tokens;
444
445 let event = match crate::callback::apply_callback(event, event_cb.as_ref()) {
447 Some(e) => e,
448 None => continue,
449 };
450 yield Ok(event);
451
452 if let Some(limit) = max_turns {
453 if turn_count >= limit {
454 tracing::info!("max_turns reached ({turn_count}/{limit}), closing stream");
455 break;
456 }
457 }
458 if let Some(budget) = max_budget_tokens {
459 if total_output_tokens >= budget {
460 tracing::info!(
461 "max_budget_tokens reached ({total_output_tokens}/{budget}), closing stream"
462 );
463 break;
464 }
465 }
466 continue;
467 }
468
469 let event = match crate::callback::apply_callback(event, event_cb.as_ref()) {
471 Some(e) => e,
472 None => continue,
473 };
474 yield Ok(event);
475 }
476 Err(e) => {
477 let is_fatal = !matches!(&e, Error::Json(_));
478 yield Err(e);
479 if is_fatal {
480 break;
481 }
482 }
483 }
484 }
485
486 match transport_clone.close().await {
487 Ok(Some(code)) if code != 0 => {
488 yield Err(Error::ProcessExited {
489 code,
490 stderr: transport_clone.collected_stderr(),
491 });
492 }
493 Err(e) => {
494 yield Err(e);
495 }
496 _ => {}
497 }
498 };
499
500 Ok(StreamedTurn::new(stream))
501 }
502}
503
504async fn write_response<R: serde::Serialize>(
508 response: &R,
509 transport: &dyn crate::transport::Transport,
510) -> Result<()> {
511 let json = serde_json::to_string(response).map_err(Error::Json)?;
512 transport.write(&json).await
513}
514
515fn resolve_output_schema(
522 turn_schema: Option<&Value>,
523 thread_schema: &Option<OutputSchema>,
524) -> Result<(Vec<String>, OutputSchemaFile, Option<OutputSchemaFile>)> {
525 let turn_guard = OutputSchemaFile::new(turn_schema)?;
526
527 if let Some(path) = turn_guard.path() {
528 let args = vec!["--output-schema".into(), path.display().to_string()];
530 return Ok((args, turn_guard, None));
531 }
532
533 match thread_schema {
535 Some(OutputSchema::File(path)) => {
536 let args = vec!["--output-schema".into(), path.display().to_string()];
537 Ok((args, turn_guard, None))
538 }
539 Some(OutputSchema::Inline(value)) => {
540 let thread_guard = OutputSchemaFile::new(Some(value))?;
541 let args = thread_guard
542 .path()
543 .map(|p| vec!["--output-schema".into(), p.display().to_string()])
544 .unwrap_or_default();
545 Ok((args, turn_guard, Some(thread_guard)))
546 }
547 None => Ok((vec![], turn_guard, None)),
548 }
549}
550
551#[cfg(test)]
554mod tests {
555 use super::*;
556 use crate::testing::builders;
557 use crate::testing::mock_transport::MockTransport;
558 use tokio_stream::StreamExt;
559
560 fn make_thread_with_mock(mock: Arc<MockTransport>) -> Thread {
561 let mut thread = Thread::new(
562 std::path::PathBuf::from("/nonexistent/codex"),
563 CodexConfig::default(),
564 ThreadOptions::default(),
565 None,
566 );
567 thread.transport_override = Some(mock as Arc<dyn Transport>);
568 thread
569 }
570
571 #[tokio::test]
572 async fn test_transport_override_basic_turn() {
573 let mock = Arc::new(MockTransport::new());
574 mock.enqueue_session("thread-1");
575 mock.enqueue_turn_complete("Hello from mock!");
576
577 let mut thread = make_thread_with_mock(Arc::clone(&mock));
578 let turn = thread
579 .run("say hello", TurnOptions::default())
580 .await
581 .unwrap();
582
583 assert_eq!(turn.final_response, "Hello from mock!");
584 assert!(turn.usage.is_some());
585 assert_eq!(thread.id(), Some("thread-1".to_string()));
587 }
588
589 #[tokio::test]
590 async fn test_turn_guard_resets_on_drop() {
591 let mock = Arc::new(MockTransport::new());
592 mock.enqueue_session("thread-2");
593 mock.enqueue_turn_complete("first");
594
595 let mut thread = make_thread_with_mock(Arc::clone(&mock));
596
597 thread.run("first", TurnOptions::default()).await.unwrap();
599
600 let mock2 = Arc::new(MockTransport::new());
603 mock2.enqueue_session("thread-2");
604 mock2.enqueue_turn_complete("second");
605 thread.transport_override = Some(mock2 as Arc<dyn Transport>);
606
607 let result = thread.run("second", TurnOptions::default()).await;
608 assert!(
609 result.is_ok(),
610 "Second turn should succeed after first completes"
611 );
612 }
613
614 #[tokio::test]
615 async fn test_turn_guard_resets_on_stream_drop() {
616 let mock = Arc::new(MockTransport::new());
617 mock.enqueue_session("thread-3");
618 mock.enqueue_turn_complete("data");
619
620 let mut thread = make_thread_with_mock(Arc::clone(&mock));
621
622 {
624 let _stream = thread
625 .run_streamed("prompt", TurnOptions::default())
626 .await
627 .unwrap();
628 }
630
631 assert!(
633 !thread.turn_in_progress.load(Ordering::Acquire),
634 "turn_in_progress should be false after stream drop"
635 );
636
637 let mock2 = Arc::new(MockTransport::new());
639 mock2.enqueue_session("thread-3");
640 mock2.enqueue_turn_complete("ok");
641 thread.transport_override = Some(mock2 as Arc<dyn Transport>);
642
643 let result = thread.run("next", TurnOptions::default()).await;
644 assert!(result.is_ok());
645 }
646
647 #[tokio::test]
648 async fn test_approval_with_mock_transport() {
649 use crate::permissions::{ApprovalCallback, ApprovalDecision};
650
651 let mock = Arc::new(MockTransport::new());
652 mock.enqueue_session("thread-4");
653 mock.enqueue_event(builders::approval_request("ap-1", "ls"));
654 mock.enqueue_turn_complete("done");
655
656 let mut thread = make_thread_with_mock(Arc::clone(&mock));
657
658 let callback: ApprovalCallback =
659 Arc::new(|_ctx| Box::pin(async { ApprovalDecision::Approved.into() }));
660 thread.approval_callback = Some(callback);
661
662 let turn = thread.run("do it", TurnOptions::default()).await.unwrap();
663
664 let written = mock.written_lines();
666 assert!(!written.is_empty(), "approval response should be written");
667 assert!(
669 written.iter().any(|s| s.contains("ap-1")),
670 "approval id should appear in response"
671 );
672 assert_eq!(turn.final_response, "done");
673 }
674
675 #[tokio::test]
676 async fn test_run_returns_error_on_turn_failed() {
677 let mock = Arc::new(MockTransport::new());
678 mock.enqueue_session("thread-err-1");
679 mock.enqueue_event(builders::turn_failed("model overloaded"));
680
681 let mut thread = make_thread_with_mock(Arc::clone(&mock));
682 let result = thread.run("prompt", TurnOptions::default()).await;
683
684 assert!(result.is_err(), "run() should return Err on turn.failed");
685 let err = result.unwrap_err();
686 assert!(
687 err.to_string().contains("model overloaded"),
688 "error should contain the failure message, got: {err}"
689 );
690 }
691
692 #[tokio::test]
693 async fn test_run_returns_error_on_error_event() {
694 let mock = Arc::new(MockTransport::new());
695 mock.enqueue_session("thread-err-2");
696 mock.enqueue_event(builders::error("something broke"));
697
698 let mut thread = make_thread_with_mock(Arc::clone(&mock));
699 let result = thread.run("prompt", TurnOptions::default()).await;
700
701 assert!(result.is_err(), "run() should return Err on error event");
702 let err = result.unwrap_err();
703 assert!(
704 err.to_string().contains("something broke"),
705 "error should contain the message, got: {err}"
706 );
707 }
708
709 #[tokio::test]
710 async fn test_nonzero_exit_code_surfaces() {
711 let mock = Arc::new(MockTransport::new());
712 mock.enqueue_session("thread-exit");
713 mock.enqueue_turn_complete("partial");
714 mock.set_exit_code(1);
715
716 let mut thread = make_thread_with_mock(Arc::clone(&mock));
717 let mut streamed = thread
718 .run_streamed("prompt", TurnOptions::default())
719 .await
720 .unwrap();
721
722 let mut saw_exit_error = false;
723 while let Some(event) = streamed.next().await {
724 if let Err(crate::Error::ProcessExited { code, .. }) = &event {
725 if *code == 1 {
726 saw_exit_error = true;
727 }
728 }
729 }
730 assert!(
731 saw_exit_error,
732 "stream should yield ProcessExited error for non-zero exit code"
733 );
734 }
735
736 #[tokio::test]
737 async fn test_read_messages_already_consumed() {
738 let mock = MockTransport::new();
739 mock.enqueue_event(serde_json::json!({"type": "turn.started"}));
740 mock.connect().await.unwrap();
741
742 let mut first = mock.read_messages();
744 let _ = first.next().await;
745
746 let mut second = mock.read_messages();
748 let result = second.next().await;
749 assert!(result.is_some());
750 let err = result.unwrap();
751 assert!(matches!(err, Err(crate::Error::TransportClosed)));
752 }
753
754 #[tokio::test]
755 async fn test_max_turns_enforced() {
756 let mock = Arc::new(MockTransport::new());
757 mock.enqueue_session("thread-budget");
758 mock.enqueue_turn_complete("response-1");
760 mock.enqueue_event(builders::turn_started());
762 mock.enqueue_event(builders::agent_message_completed("msg-2", "response-2"));
763 mock.enqueue_event(builders::turn_completed(50, 0, 25));
764 mock.enqueue_event(builders::turn_started());
766 mock.enqueue_event(builders::agent_message_completed("msg-3", "response-3"));
767 mock.enqueue_event(builders::turn_completed(50, 0, 25));
768
769 let mut thread = Thread::new(
770 std::path::PathBuf::from("/nonexistent/codex"),
771 CodexConfig::default(),
772 ThreadOptions::builder().max_turns(2u32).build(),
773 None,
774 );
775 thread.transport_override = Some(mock as Arc<dyn Transport>);
776
777 let mut streamed = thread
778 .run_streamed("prompt", TurnOptions::default())
779 .await
780 .unwrap();
781
782 let mut turn_completions = 0;
783 while let Some(event) = streamed.next().await {
784 if let Ok(ThreadEvent::TurnCompleted { .. }) = event {
785 turn_completions += 1;
786 }
787 }
788
789 assert_eq!(turn_completions, 2, "stream should close after max_turns=2");
790 }
791
792 #[tokio::test]
793 async fn test_max_budget_tokens_enforced() {
794 let mock = Arc::new(MockTransport::new());
795 mock.enqueue_session("thread-budget-tok");
796 mock.enqueue_event(builders::agent_message_completed("msg-1", "response"));
798 mock.enqueue_event(builders::turn_completed(100, 0, 500));
799 mock.enqueue_event(builders::turn_started());
801 mock.enqueue_event(builders::agent_message_completed("msg-2", "response-2"));
802 mock.enqueue_event(builders::turn_completed(100, 0, 600));
803 mock.enqueue_event(builders::turn_started());
805 mock.enqueue_event(builders::agent_message_completed("msg-3", "response-3"));
806 mock.enqueue_event(builders::turn_completed(100, 0, 100));
807
808 let mut thread = Thread::new(
809 std::path::PathBuf::from("/nonexistent/codex"),
810 CodexConfig::default(),
811 ThreadOptions::builder().max_budget_tokens(1000u64).build(),
812 None,
813 );
814 thread.transport_override = Some(mock as Arc<dyn Transport>);
815
816 let mut streamed = thread
817 .run_streamed("prompt", TurnOptions::default())
818 .await
819 .unwrap();
820
821 let mut turn_completions = 0;
822 while let Some(event) = streamed.next().await {
823 if let Ok(ThreadEvent::TurnCompleted { .. }) = event {
824 turn_completions += 1;
825 }
826 }
827
828 assert_eq!(
829 turn_completions, 2,
830 "stream should close after exceeding budget on turn 2"
831 );
832 }
833
834 #[tokio::test]
835 async fn test_hook_blocks_event() {
836 use crate::hooks::{HookDecision, HookEvent, HookMatcher, HookOutput};
837
838 let mock = Arc::new(MockTransport::new());
839 mock.enqueue_session("thread-hook");
840 mock.enqueue_event(builders::command_started("cmd-1", "rm -rf /"));
841 mock.enqueue_turn_complete("done");
842
843 let hook = HookMatcher {
844 event: HookEvent::CommandStarted,
845 command_filter: Some("rm".into()),
846 callback: Arc::new(|_input, _ctx| {
847 Box::pin(async {
848 HookOutput {
849 decision: HookDecision::Block,
850 reason: Some("blocked rm".into()),
851 replacement_event: None,
852 }
853 })
854 }),
855 timeout: None,
856 on_timeout: Default::default(),
857 };
858
859 let mut thread = Thread::new(
860 std::path::PathBuf::from("/nonexistent/codex"),
861 CodexConfig::default(),
862 ThreadOptions::builder().hooks(vec![hook]).build(),
863 None,
864 );
865 thread.transport_override = Some(mock as Arc<dyn Transport>);
866
867 let mut streamed = thread
868 .run_streamed("prompt", TurnOptions::default())
869 .await
870 .unwrap();
871
872 let mut saw_command_started = false;
873 while let Some(event) = streamed.next().await {
874 if let Ok(ThreadEvent::ItemStarted {
875 item: crate::types::items::ThreadItem::CommandExecution { .. },
876 }) = event
877 {
878 saw_command_started = true;
879 }
880 }
881
882 assert!(
883 !saw_command_started,
884 "command started event should be blocked by hook"
885 );
886 }
887
888 #[tokio::test]
889 async fn test_hooks_persist_across_turns() {
890 use crate::hooks::{HookEvent, HookMatcher, HookOutput};
891 use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering};
892
893 let call_count = Arc::new(AtomicUsize::new(0));
894 let call_count_clone = Arc::clone(&call_count);
895
896 let hook = HookMatcher {
897 event: HookEvent::TurnCompleted,
898 command_filter: None,
899 callback: Arc::new(move |_input, _ctx| {
900 let c = Arc::clone(&call_count_clone);
901 Box::pin(async move {
902 c.fetch_add(1, AtomicOrdering::Relaxed);
903 HookOutput::default()
904 })
905 }),
906 timeout: None,
907 on_timeout: crate::hooks::HookTimeoutBehavior::FailOpen,
908 };
909
910 let mut thread = Thread::new(
911 std::path::PathBuf::from("/nonexistent/codex"),
912 CodexConfig::default(),
913 ThreadOptions::builder().hooks(vec![hook]).build(),
914 None,
915 );
916
917 let mock1 = Arc::new(MockTransport::new());
919 mock1.enqueue_session("thread-persist-hooks");
920 mock1.enqueue_turn_complete("first");
921 thread.transport_override = Some(mock1 as Arc<dyn Transport>);
922 thread.run("first", TurnOptions::default()).await.unwrap();
923
924 let mock2 = Arc::new(MockTransport::new());
926 mock2.enqueue_session("thread-persist-hooks");
927 mock2.enqueue_turn_complete("second");
928 thread.transport_override = Some(mock2 as Arc<dyn Transport>);
929 thread.run("second", TurnOptions::default()).await.unwrap();
930
931 assert_eq!(
932 call_count.load(AtomicOrdering::Relaxed),
933 2,
934 "hook should fire on both turns, not just the first"
935 );
936 }
937
938 #[tokio::test]
939 async fn test_thread_interrupt_delegates_to_transport() {
940 use tokio::sync::Barrier;
941
942 let mock = Arc::new(MockTransport::new());
943 mock.enqueue_session("thread-interrupt-1");
944
945 mock.enqueue_turn_complete("done");
949
950 let mock_for_assert = Arc::clone(&mock);
951 let barrier = Arc::new(Barrier::new(2));
952 let barrier2 = Arc::clone(&barrier);
953
954 let mut thread = make_thread_with_mock(Arc::clone(&mock));
955 let mut streamed = thread
956 .run_streamed("prompt", TurnOptions::default())
957 .await
958 .unwrap();
959
960 let thread_ref = &thread;
962 thread_ref.interrupt().await.unwrap();
964
965 while let Some(_) = streamed.next().await {}
967
968 assert!(
969 mock_for_assert.interrupt_called(),
970 "interrupt() should have been delegated to the mock transport"
971 );
972 let _ = barrier2; }
974
975 #[tokio::test]
976 async fn test_thread_interrupt_noop_when_idle() {
977 let thread = Thread::new(
978 std::path::PathBuf::from("/nonexistent/codex"),
979 CodexConfig::default(),
980 ThreadOptions::default(),
981 None,
982 );
983 let result = thread.interrupt().await;
985 assert!(
986 result.is_ok(),
987 "interrupt with no active turn should return Ok"
988 );
989 }
990
991 #[tokio::test]
992 async fn test_active_transport_cleared_after_turn() {
993 let mock = Arc::new(MockTransport::new());
994 mock.enqueue_session("thread-clear");
995 mock.enqueue_turn_complete("done");
996
997 let mut thread = make_thread_with_mock(Arc::clone(&mock));
998
999 thread.run("prompt", TurnOptions::default()).await.unwrap();
1001
1002 let result = thread.interrupt().await;
1005 assert!(result.is_ok());
1006 assert!(
1007 !mock.interrupt_called(),
1008 "interrupt_called should be false — slot was cleared after turn completed"
1009 );
1010 }
1011
1012 #[tokio::test]
1013 async fn test_hook_aborts_stream() {
1014 use crate::hooks::{HookDecision, HookEvent, HookMatcher, HookOutput};
1015
1016 let mock = Arc::new(MockTransport::new());
1017 mock.enqueue_session("thread-abort");
1018 mock.enqueue_event(builders::command_started("cmd-1", "dangerous"));
1019 mock.enqueue_turn_complete("should not see this");
1021
1022 let hook = HookMatcher {
1023 event: HookEvent::CommandStarted,
1024 command_filter: None,
1025 callback: Arc::new(|_input, _ctx| {
1026 Box::pin(async {
1027 HookOutput {
1028 decision: HookDecision::Abort,
1029 reason: Some("abort!".into()),
1030 replacement_event: None,
1031 }
1032 })
1033 }),
1034 timeout: None,
1035 on_timeout: Default::default(),
1036 };
1037
1038 let mut thread = Thread::new(
1039 std::path::PathBuf::from("/nonexistent/codex"),
1040 CodexConfig::default(),
1041 ThreadOptions::builder().hooks(vec![hook]).build(),
1042 None,
1043 );
1044 thread.transport_override = Some(mock as Arc<dyn Transport>);
1045
1046 let mut streamed = thread
1047 .run_streamed("prompt", TurnOptions::default())
1048 .await
1049 .unwrap();
1050
1051 let mut events = vec![];
1052 while let Some(event) = streamed.next().await {
1053 if let Ok(ref e) = event {
1054 events.push(e.clone());
1055 }
1056 }
1057
1058 assert!(
1060 !events
1061 .iter()
1062 .any(|e| matches!(e, ThreadEvent::TurnCompleted { .. })),
1063 "TurnCompleted should not appear — stream was aborted"
1064 );
1065 }
1066}