use crate::callback::EventCallback;
use crate::config::{CodexConfig, OutputSchema, OutputSchemaFile, ThreadOptions, TurnOptions};
use crate::discovery;
use crate::errors::{Error, Result};
use crate::hooks::{self, HookContext, HookDecision, HookMatcher};
use crate::permissions::{
ApprovalCallback, ApprovalContext, ApprovalResponse, PatchApprovalCallback,
PatchApprovalContext, PatchApprovalResponse,
};
use crate::transport::{CliTransport, Transport};
use crate::types::events::{StreamedTurn, ThreadEvent, Turn};
use crate::types::input::Input;
use serde_json::Value;
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
use tokio_stream::StreamExt;
struct TurnGuard {
flag: Arc<AtomicBool>,
active_transport: Arc<std::sync::Mutex<Option<Arc<dyn Transport>>>>,
}
impl Drop for TurnGuard {
fn drop(&mut self) {
self.flag.store(false, Ordering::Release);
*self
.active_transport
.lock()
.unwrap_or_else(|e| e.into_inner()) = None;
}
}
pub struct Codex {
config: CodexConfig,
cli_path: PathBuf,
}
impl Codex {
pub fn new(config: CodexConfig) -> Result<Self> {
let cli_path = match &config.cli_path {
Some(path) => path.clone(),
None => discovery::find_cli()?,
};
Ok(Self { config, cli_path })
}
pub fn start_thread(&self, options: ThreadOptions) -> Thread {
Thread::new(self.cli_path.clone(), self.config.clone(), options, None)
}
pub fn resume_thread(&self, thread_id: impl Into<String>, options: ThreadOptions) -> Thread {
Thread::new(
self.cli_path.clone(),
self.config.clone(),
options,
Some(thread_id.into()),
)
}
pub fn cli_path(&self) -> &std::path::Path {
&self.cli_path
}
pub async fn version(&self) -> Result<String> {
discovery::check_version(&self.cli_path, self.config.version_check_timeout).await
}
}
pub struct Thread {
cli_path: PathBuf,
config: CodexConfig,
options: ThreadOptions,
resume_id: Option<String>,
thread_id: Arc<std::sync::Mutex<Option<String>>>,
approval_callback: Option<ApprovalCallback>,
patch_approval_callback: Option<PatchApprovalCallback>,
event_callback: Option<EventCallback>,
hooks: Vec<HookMatcher>,
default_hook_timeout: Duration,
max_turns: Option<u32>,
max_budget_tokens: Option<u64>,
turn_in_progress: Arc<AtomicBool>,
active_transport: Arc<std::sync::Mutex<Option<Arc<dyn Transport>>>>,
transport_override: Option<Arc<dyn Transport>>,
}
impl Thread {
fn new(
cli_path: PathBuf,
config: CodexConfig,
mut options: ThreadOptions,
resume_id: Option<String>,
) -> Self {
let hooks = std::mem::take(&mut options.hooks);
let default_hook_timeout = options.default_hook_timeout;
let max_turns = options.max_turns;
let max_budget_tokens = options.max_budget_tokens;
Self {
cli_path,
config,
options,
resume_id,
thread_id: Arc::new(std::sync::Mutex::new(None)),
approval_callback: None,
patch_approval_callback: None,
event_callback: None,
hooks,
default_hook_timeout,
max_turns,
max_budget_tokens,
turn_in_progress: Arc::new(AtomicBool::new(false)),
active_transport: Arc::new(std::sync::Mutex::new(None)),
transport_override: None,
}
}
pub fn with_approval_callback(mut self, callback: ApprovalCallback) -> Self {
self.approval_callback = Some(callback);
self
}
pub fn with_patch_approval_callback(mut self, callback: PatchApprovalCallback) -> Self {
self.patch_approval_callback = Some(callback);
self
}
pub fn with_event_callback(mut self, callback: EventCallback) -> Self {
self.event_callback = Some(callback);
self
}
pub fn with_hooks(mut self, hooks: Vec<HookMatcher>) -> Self {
self.hooks = hooks;
self
}
pub fn with_transport(mut self, transport: Arc<dyn Transport>) -> Self {
self.transport_override = Some(transport);
self
}
pub fn id(&self) -> Option<String> {
self.thread_id
.lock()
.unwrap_or_else(|e| e.into_inner())
.clone()
.or_else(|| self.resume_id.clone())
}
pub async fn interrupt(&self) -> Result<()> {
let transport = self
.active_transport
.lock()
.unwrap_or_else(|e| e.into_inner())
.clone();
if let Some(t) = transport {
t.interrupt().await?;
}
Ok(())
}
pub async fn run(
&mut self,
input: impl Into<Input>,
turn_options: TurnOptions,
) -> Result<Turn> {
let mut streamed = self.run_streamed(input, turn_options).await?;
let mut events = Vec::new();
let mut final_response = String::new();
let mut usage = None;
while let Some(event) = streamed.next().await {
let event = event?;
match &event {
ThreadEvent::ItemCompleted {
item: crate::types::items::ThreadItem::AgentMessage { text, .. },
} => {
final_response = text.clone();
}
ThreadEvent::TurnCompleted { usage: u } => {
usage = Some(u.clone());
}
ThreadEvent::TurnFailed { error } => {
let msg = error.message.clone();
events.push(event);
return Err(Error::Other(msg));
}
ThreadEvent::Error { message } => {
let msg = message.clone();
events.push(event);
return Err(Error::Other(msg));
}
_ => {}
}
events.push(event);
}
Ok(Turn {
events,
final_response,
usage,
})
}
pub async fn run_streamed(
&mut self,
input: impl Into<Input>,
turn_options: TurnOptions,
) -> Result<StreamedTurn> {
if self
.turn_in_progress
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
.is_err()
{
return Err(Error::ConcurrentTurn);
}
let input = input.into();
let mut args = self.options.to_cli_args();
self.config.apply_overrides(&mut args);
let (schema_args, schema_guard, thread_schema_guard) = resolve_output_schema(
turn_options.output_schema.as_ref(),
&self.options.output_schema,
)?;
args.extend(schema_args);
if let Some(ref resume_id) = self.resume_id {
args.push("resume".into());
args.push(resume_id.clone());
}
let transport: Arc<dyn Transport> = match &self.transport_override {
Some(t) => Arc::clone(t),
None => Arc::new(CliTransport::new(
self.cli_path.clone(),
args,
self.config.to_env(),
self.config.stderr_callback.clone(),
turn_options.cancel.clone(),
self.config.close_timeout,
)),
};
*self
.active_transport
.lock()
.unwrap_or_else(|e| e.into_inner()) = Some(Arc::clone(&transport));
let active_transport_slot = Arc::clone(&self.active_transport);
let turn_guard = TurnGuard {
flag: self.turn_in_progress.clone(),
active_transport: active_transport_slot,
};
let connect_future = transport.connect();
match self.config.connect_timeout {
Some(timeout) => {
tokio::time::timeout(timeout, connect_future)
.await
.map_err(|_| Error::Timeout {
operation: "connect".into(),
})??;
}
None => connect_future.await?,
}
let prompt_text = match &input {
Input::Text(s) => s.clone(),
Input::Items(items) => serde_json::to_string(items)
.map_err(|e| Error::Config(format!("failed to serialize input: {e}")))?,
};
transport.write(&prompt_text).await?;
transport.end_input().await?;
let messages = transport.read_messages();
let approval_cb = self.approval_callback.clone();
let patch_approval_cb = self.patch_approval_callback.clone();
let event_cb = self.event_callback.clone();
let hooks = self.hooks.clone();
let default_hook_timeout = self.default_hook_timeout;
let max_turns = self.max_turns;
let max_budget_tokens = self.max_budget_tokens;
let transport_clone = transport.clone();
let thread_id_slot = self.thread_id.clone();
let stream = async_stream::stream! {
let _schema_guard = schema_guard;
let _thread_schema_guard = thread_schema_guard;
let _turn_guard = turn_guard;
let get_thread_id = || {
thread_id_slot
.lock()
.unwrap_or_else(|e| e.into_inner())
.clone()
};
let mut turn_count: u32 = 0;
let mut total_output_tokens: u64 = 0;
tokio::pin!(messages);
while let Some(result) = messages.next().await {
match result {
Ok(value) => {
let event = match serde_json::from_value::<ThreadEvent>(value.clone()) {
Ok(e) => e,
Err(e) => {
tracing::warn!("Skipping unrecognized event: {e} — raw: {value}");
continue;
}
};
if let ThreadEvent::ThreadStarted { ref thread_id } = event {
*thread_id_slot
.lock()
.unwrap_or_else(|e| e.into_inner()) = Some(thread_id.clone());
}
if let ThreadEvent::ApprovalRequest(ref req) = event {
let outcome = if let Some(ref cb) = approval_cb {
let ctx = ApprovalContext {
request: req.clone(),
thread_id: get_thread_id(),
};
cb(ctx).await
} else {
crate::permissions::ApprovalDecision::Denied.into()
};
let response = ApprovalResponse::new(req.id.clone(), outcome.decision);
if let Err(e) = write_response(&response, &*transport_clone).await {
yield Err(e);
break;
}
}
if let ThreadEvent::PatchApprovalRequest(ref req) = event {
let outcome = if let Some(ref cb) = patch_approval_cb {
let ctx = PatchApprovalContext {
request: req.clone(),
thread_id: get_thread_id(),
};
cb(ctx).await
} else {
crate::permissions::ApprovalDecision::Denied.into()
};
let response = PatchApprovalResponse::new(req.id.clone(), outcome.decision);
if let Err(e) = write_response(&response, &*transport_clone).await {
yield Err(e);
break;
}
}
let event = if !hooks.is_empty() {
let hook_ctx = HookContext {
thread_id: get_thread_id(),
turn_count,
};
match hooks::dispatch_hook(&event, &hooks, &hook_ctx, default_hook_timeout).await {
Some(output) => match output.decision {
HookDecision::Allow => event,
HookDecision::Block => continue,
HookDecision::Modify => {
output.replacement_event.unwrap_or(event)
}
HookDecision::Abort => {
tracing::info!("Hook aborted stream: {:?}", output.reason);
break;
}
},
None => event,
}
} else {
event
};
if let ThreadEvent::TurnCompleted { ref usage } = event {
turn_count += 1;
total_output_tokens += usage.output_tokens;
let event = match crate::callback::apply_callback(event, event_cb.as_ref()) {
Some(e) => e,
None => continue,
};
yield Ok(event);
if let Some(limit) = max_turns {
if turn_count >= limit {
tracing::info!("max_turns reached ({turn_count}/{limit}), closing stream");
break;
}
}
if let Some(budget) = max_budget_tokens {
if total_output_tokens >= budget {
tracing::info!(
"max_budget_tokens reached ({total_output_tokens}/{budget}), closing stream"
);
break;
}
}
continue;
}
let event = match crate::callback::apply_callback(event, event_cb.as_ref()) {
Some(e) => e,
None => continue,
};
yield Ok(event);
}
Err(e) => {
let is_fatal = !matches!(&e, Error::Json(_));
yield Err(e);
if is_fatal {
break;
}
}
}
}
match transport_clone.close().await {
Ok(Some(code)) if code != 0 => {
yield Err(Error::ProcessExited {
code,
stderr: transport_clone.collected_stderr(),
});
}
Err(e) => {
yield Err(e);
}
_ => {}
}
};
Ok(StreamedTurn::new(stream))
}
}
async fn write_response<R: serde::Serialize>(
response: &R,
transport: &dyn crate::transport::Transport,
) -> Result<()> {
let json = serde_json::to_string(response).map_err(Error::Json)?;
transport.write(&json).await
}
fn resolve_output_schema(
turn_schema: Option<&Value>,
thread_schema: &Option<OutputSchema>,
) -> Result<(Vec<String>, OutputSchemaFile, Option<OutputSchemaFile>)> {
let turn_guard = OutputSchemaFile::new(turn_schema)?;
if let Some(path) = turn_guard.path() {
let args = vec!["--output-schema".into(), path.display().to_string()];
return Ok((args, turn_guard, None));
}
match thread_schema {
Some(OutputSchema::File(path)) => {
let args = vec!["--output-schema".into(), path.display().to_string()];
Ok((args, turn_guard, None))
}
Some(OutputSchema::Inline(value)) => {
let thread_guard = OutputSchemaFile::new(Some(value))?;
let args = thread_guard
.path()
.map(|p| vec!["--output-schema".into(), p.display().to_string()])
.unwrap_or_default();
Ok((args, turn_guard, Some(thread_guard)))
}
None => Ok((vec![], turn_guard, None)),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::testing::builders;
use crate::testing::mock_transport::MockTransport;
use tokio_stream::StreamExt;
fn make_thread_with_mock(mock: Arc<MockTransport>) -> Thread {
let mut thread = Thread::new(
std::path::PathBuf::from("/nonexistent/codex"),
CodexConfig::default(),
ThreadOptions::default(),
None,
);
thread.transport_override = Some(mock as Arc<dyn Transport>);
thread
}
#[tokio::test]
async fn test_transport_override_basic_turn() {
let mock = Arc::new(MockTransport::new());
mock.enqueue_session("thread-1");
mock.enqueue_turn_complete("Hello from mock!");
let mut thread = make_thread_with_mock(Arc::clone(&mock));
let turn = thread
.run("say hello", TurnOptions::default())
.await
.unwrap();
assert_eq!(turn.final_response, "Hello from mock!");
assert!(turn.usage.is_some());
assert_eq!(thread.id(), Some("thread-1".to_string()));
}
#[tokio::test]
async fn test_turn_guard_resets_on_drop() {
let mock = Arc::new(MockTransport::new());
mock.enqueue_session("thread-2");
mock.enqueue_turn_complete("first");
let mut thread = make_thread_with_mock(Arc::clone(&mock));
thread.run("first", TurnOptions::default()).await.unwrap();
let mock2 = Arc::new(MockTransport::new());
mock2.enqueue_session("thread-2");
mock2.enqueue_turn_complete("second");
thread.transport_override = Some(mock2 as Arc<dyn Transport>);
let result = thread.run("second", TurnOptions::default()).await;
assert!(
result.is_ok(),
"Second turn should succeed after first completes"
);
}
#[tokio::test]
async fn test_turn_guard_resets_on_stream_drop() {
let mock = Arc::new(MockTransport::new());
mock.enqueue_session("thread-3");
mock.enqueue_turn_complete("data");
let mut thread = make_thread_with_mock(Arc::clone(&mock));
{
let _stream = thread
.run_streamed("prompt", TurnOptions::default())
.await
.unwrap();
}
assert!(
!thread.turn_in_progress.load(Ordering::Acquire),
"turn_in_progress should be false after stream drop"
);
let mock2 = Arc::new(MockTransport::new());
mock2.enqueue_session("thread-3");
mock2.enqueue_turn_complete("ok");
thread.transport_override = Some(mock2 as Arc<dyn Transport>);
let result = thread.run("next", TurnOptions::default()).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_approval_with_mock_transport() {
use crate::permissions::{ApprovalCallback, ApprovalDecision};
let mock = Arc::new(MockTransport::new());
mock.enqueue_session("thread-4");
mock.enqueue_event(builders::approval_request("ap-1", "ls"));
mock.enqueue_turn_complete("done");
let mut thread = make_thread_with_mock(Arc::clone(&mock));
let callback: ApprovalCallback =
Arc::new(|_ctx| Box::pin(async { ApprovalDecision::Approved.into() }));
thread.approval_callback = Some(callback);
let turn = thread.run("do it", TurnOptions::default()).await.unwrap();
let written = mock.written_lines();
assert!(!written.is_empty(), "approval response should be written");
assert!(
written.iter().any(|s| s.contains("ap-1")),
"approval id should appear in response"
);
assert_eq!(turn.final_response, "done");
}
#[tokio::test]
async fn test_run_returns_error_on_turn_failed() {
let mock = Arc::new(MockTransport::new());
mock.enqueue_session("thread-err-1");
mock.enqueue_event(builders::turn_failed("model overloaded"));
let mut thread = make_thread_with_mock(Arc::clone(&mock));
let result = thread.run("prompt", TurnOptions::default()).await;
assert!(result.is_err(), "run() should return Err on turn.failed");
let err = result.unwrap_err();
assert!(
err.to_string().contains("model overloaded"),
"error should contain the failure message, got: {err}"
);
}
#[tokio::test]
async fn test_run_returns_error_on_error_event() {
let mock = Arc::new(MockTransport::new());
mock.enqueue_session("thread-err-2");
mock.enqueue_event(builders::error("something broke"));
let mut thread = make_thread_with_mock(Arc::clone(&mock));
let result = thread.run("prompt", TurnOptions::default()).await;
assert!(result.is_err(), "run() should return Err on error event");
let err = result.unwrap_err();
assert!(
err.to_string().contains("something broke"),
"error should contain the message, got: {err}"
);
}
#[tokio::test]
async fn test_nonzero_exit_code_surfaces() {
let mock = Arc::new(MockTransport::new());
mock.enqueue_session("thread-exit");
mock.enqueue_turn_complete("partial");
mock.set_exit_code(1);
let mut thread = make_thread_with_mock(Arc::clone(&mock));
let mut streamed = thread
.run_streamed("prompt", TurnOptions::default())
.await
.unwrap();
let mut saw_exit_error = false;
while let Some(event) = streamed.next().await {
if let Err(crate::Error::ProcessExited { code, .. }) = &event {
if *code == 1 {
saw_exit_error = true;
}
}
}
assert!(
saw_exit_error,
"stream should yield ProcessExited error for non-zero exit code"
);
}
#[tokio::test]
async fn test_read_messages_already_consumed() {
let mock = MockTransport::new();
mock.enqueue_event(serde_json::json!({"type": "turn.started"}));
mock.connect().await.unwrap();
let mut first = mock.read_messages();
let _ = first.next().await;
let mut second = mock.read_messages();
let result = second.next().await;
assert!(result.is_some());
let err = result.unwrap();
assert!(matches!(err, Err(crate::Error::TransportClosed)));
}
#[tokio::test]
async fn test_max_turns_enforced() {
let mock = Arc::new(MockTransport::new());
mock.enqueue_session("thread-budget");
mock.enqueue_turn_complete("response-1");
mock.enqueue_event(builders::turn_started());
mock.enqueue_event(builders::agent_message_completed("msg-2", "response-2"));
mock.enqueue_event(builders::turn_completed(50, 0, 25));
mock.enqueue_event(builders::turn_started());
mock.enqueue_event(builders::agent_message_completed("msg-3", "response-3"));
mock.enqueue_event(builders::turn_completed(50, 0, 25));
let mut thread = Thread::new(
std::path::PathBuf::from("/nonexistent/codex"),
CodexConfig::default(),
ThreadOptions::builder().max_turns(2u32).build(),
None,
);
thread.transport_override = Some(mock as Arc<dyn Transport>);
let mut streamed = thread
.run_streamed("prompt", TurnOptions::default())
.await
.unwrap();
let mut turn_completions = 0;
while let Some(event) = streamed.next().await {
if let Ok(ThreadEvent::TurnCompleted { .. }) = event {
turn_completions += 1;
}
}
assert_eq!(turn_completions, 2, "stream should close after max_turns=2");
}
#[tokio::test]
async fn test_max_budget_tokens_enforced() {
let mock = Arc::new(MockTransport::new());
mock.enqueue_session("thread-budget-tok");
mock.enqueue_event(builders::agent_message_completed("msg-1", "response"));
mock.enqueue_event(builders::turn_completed(100, 0, 500));
mock.enqueue_event(builders::turn_started());
mock.enqueue_event(builders::agent_message_completed("msg-2", "response-2"));
mock.enqueue_event(builders::turn_completed(100, 0, 600));
mock.enqueue_event(builders::turn_started());
mock.enqueue_event(builders::agent_message_completed("msg-3", "response-3"));
mock.enqueue_event(builders::turn_completed(100, 0, 100));
let mut thread = Thread::new(
std::path::PathBuf::from("/nonexistent/codex"),
CodexConfig::default(),
ThreadOptions::builder().max_budget_tokens(1000u64).build(),
None,
);
thread.transport_override = Some(mock as Arc<dyn Transport>);
let mut streamed = thread
.run_streamed("prompt", TurnOptions::default())
.await
.unwrap();
let mut turn_completions = 0;
while let Some(event) = streamed.next().await {
if let Ok(ThreadEvent::TurnCompleted { .. }) = event {
turn_completions += 1;
}
}
assert_eq!(
turn_completions, 2,
"stream should close after exceeding budget on turn 2"
);
}
#[tokio::test]
async fn test_hook_blocks_event() {
use crate::hooks::{HookDecision, HookEvent, HookMatcher, HookOutput};
let mock = Arc::new(MockTransport::new());
mock.enqueue_session("thread-hook");
mock.enqueue_event(builders::command_started("cmd-1", "rm -rf /"));
mock.enqueue_turn_complete("done");
let hook = HookMatcher {
event: HookEvent::CommandStarted,
command_filter: Some("rm".into()),
callback: Arc::new(|_input, _ctx| {
Box::pin(async {
HookOutput {
decision: HookDecision::Block,
reason: Some("blocked rm".into()),
replacement_event: None,
}
})
}),
timeout: None,
on_timeout: Default::default(),
};
let mut thread = Thread::new(
std::path::PathBuf::from("/nonexistent/codex"),
CodexConfig::default(),
ThreadOptions::builder().hooks(vec![hook]).build(),
None,
);
thread.transport_override = Some(mock as Arc<dyn Transport>);
let mut streamed = thread
.run_streamed("prompt", TurnOptions::default())
.await
.unwrap();
let mut saw_command_started = false;
while let Some(event) = streamed.next().await {
if let Ok(ThreadEvent::ItemStarted {
item: crate::types::items::ThreadItem::CommandExecution { .. },
}) = event
{
saw_command_started = true;
}
}
assert!(
!saw_command_started,
"command started event should be blocked by hook"
);
}
#[tokio::test]
async fn test_hooks_persist_across_turns() {
use crate::hooks::{HookEvent, HookMatcher, HookOutput};
use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering};
let call_count = Arc::new(AtomicUsize::new(0));
let call_count_clone = Arc::clone(&call_count);
let hook = HookMatcher {
event: HookEvent::TurnCompleted,
command_filter: None,
callback: Arc::new(move |_input, _ctx| {
let c = Arc::clone(&call_count_clone);
Box::pin(async move {
c.fetch_add(1, AtomicOrdering::Relaxed);
HookOutput::default()
})
}),
timeout: None,
on_timeout: crate::hooks::HookTimeoutBehavior::FailOpen,
};
let mut thread = Thread::new(
std::path::PathBuf::from("/nonexistent/codex"),
CodexConfig::default(),
ThreadOptions::builder().hooks(vec![hook]).build(),
None,
);
let mock1 = Arc::new(MockTransport::new());
mock1.enqueue_session("thread-persist-hooks");
mock1.enqueue_turn_complete("first");
thread.transport_override = Some(mock1 as Arc<dyn Transport>);
thread.run("first", TurnOptions::default()).await.unwrap();
let mock2 = Arc::new(MockTransport::new());
mock2.enqueue_session("thread-persist-hooks");
mock2.enqueue_turn_complete("second");
thread.transport_override = Some(mock2 as Arc<dyn Transport>);
thread.run("second", TurnOptions::default()).await.unwrap();
assert_eq!(
call_count.load(AtomicOrdering::Relaxed),
2,
"hook should fire on both turns, not just the first"
);
}
#[tokio::test]
async fn test_thread_interrupt_delegates_to_transport() {
use tokio::sync::Barrier;
let mock = Arc::new(MockTransport::new());
mock.enqueue_session("thread-interrupt-1");
mock.enqueue_turn_complete("done");
let mock_for_assert = Arc::clone(&mock);
let barrier = Arc::new(Barrier::new(2));
let barrier2 = Arc::clone(&barrier);
let mut thread = make_thread_with_mock(Arc::clone(&mock));
let mut streamed = thread
.run_streamed("prompt", TurnOptions::default())
.await
.unwrap();
let thread_ref = &thread;
thread_ref.interrupt().await.unwrap();
while let Some(_) = streamed.next().await {}
assert!(
mock_for_assert.interrupt_called(),
"interrupt() should have been delegated to the mock transport"
);
let _ = barrier2; }
#[tokio::test]
async fn test_thread_interrupt_noop_when_idle() {
let thread = Thread::new(
std::path::PathBuf::from("/nonexistent/codex"),
CodexConfig::default(),
ThreadOptions::default(),
None,
);
let result = thread.interrupt().await;
assert!(
result.is_ok(),
"interrupt with no active turn should return Ok"
);
}
#[tokio::test]
async fn test_active_transport_cleared_after_turn() {
let mock = Arc::new(MockTransport::new());
mock.enqueue_session("thread-clear");
mock.enqueue_turn_complete("done");
let mut thread = make_thread_with_mock(Arc::clone(&mock));
thread.run("prompt", TurnOptions::default()).await.unwrap();
let result = thread.interrupt().await;
assert!(result.is_ok());
assert!(
!mock.interrupt_called(),
"interrupt_called should be false — slot was cleared after turn completed"
);
}
#[tokio::test]
async fn test_hook_aborts_stream() {
use crate::hooks::{HookDecision, HookEvent, HookMatcher, HookOutput};
let mock = Arc::new(MockTransport::new());
mock.enqueue_session("thread-abort");
mock.enqueue_event(builders::command_started("cmd-1", "dangerous"));
mock.enqueue_turn_complete("should not see this");
let hook = HookMatcher {
event: HookEvent::CommandStarted,
command_filter: None,
callback: Arc::new(|_input, _ctx| {
Box::pin(async {
HookOutput {
decision: HookDecision::Abort,
reason: Some("abort!".into()),
replacement_event: None,
}
})
}),
timeout: None,
on_timeout: Default::default(),
};
let mut thread = Thread::new(
std::path::PathBuf::from("/nonexistent/codex"),
CodexConfig::default(),
ThreadOptions::builder().hooks(vec![hook]).build(),
None,
);
thread.transport_override = Some(mock as Arc<dyn Transport>);
let mut streamed = thread
.run_streamed("prompt", TurnOptions::default())
.await
.unwrap();
let mut events = vec![];
while let Some(event) = streamed.next().await {
if let Ok(ref e) = event {
events.push(e.clone());
}
}
assert!(
!events
.iter()
.any(|e| matches!(e, ThreadEvent::TurnCompleted { .. })),
"TurnCompleted should not appear — stream was aborted"
);
}
}