use crate::hooks::HookMatcher;
use crate::mcp::McpServers;
use serde_json::Value;
use std::path::PathBuf;
use std::time::Duration;
use typed_builder::TypedBuilder;
#[derive(Clone, TypedBuilder)]
pub struct CodexConfig {
#[builder(default, setter(strip_option, into))]
pub cli_path: Option<PathBuf>,
#[builder(default)]
pub env: std::collections::HashMap<String, String>,
#[builder(default)]
pub config_overrides: ConfigOverrides,
#[builder(default, setter(strip_option, into))]
pub profile: Option<String>,
#[builder(default_code = "Some(Duration::from_secs(30))")]
pub connect_timeout: Option<Duration>,
#[builder(default_code = "Some(Duration::from_secs(10))")]
pub close_timeout: Option<Duration>,
#[builder(default_code = "Some(Duration::from_secs(5))")]
pub version_check_timeout: Option<Duration>,
#[builder(default, setter(strip_option))]
pub stderr_callback: Option<StderrCallback>,
}
pub type StderrCallback = std::sync::Arc<dyn Fn(&str) + Send + Sync>;
impl std::fmt::Debug for CodexConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CodexConfig")
.field("cli_path", &self.cli_path)
.field("env", &self.env)
.field("config_overrides", &self.config_overrides)
.field("profile", &self.profile)
.field("connect_timeout", &self.connect_timeout)
.field("close_timeout", &self.close_timeout)
.field("version_check_timeout", &self.version_check_timeout)
.field(
"stderr_callback",
&self.stderr_callback.as_ref().map(|_| "..."),
)
.finish()
}
}
impl Default for CodexConfig {
fn default() -> Self {
Self::builder().build()
}
}
#[derive(Clone, TypedBuilder)]
pub struct ThreadOptions {
#[builder(default, setter(strip_option, into))]
pub working_directory: Option<PathBuf>,
#[builder(default, setter(strip_option, into))]
pub model: Option<String>,
#[builder(default)]
pub sandbox: SandboxPolicy,
#[builder(default)]
pub approval: ApprovalPolicy,
#[builder(default)]
pub additional_directories: Vec<PathBuf>,
#[builder(default)]
pub skip_git_repo_check: bool,
#[builder(default, setter(strip_option))]
pub reasoning_effort: Option<ReasoningEffort>,
#[builder(default, setter(strip_option))]
pub network_access: Option<bool>,
#[builder(default, setter(strip_option))]
pub web_search: Option<WebSearchMode>,
#[builder(default, setter(strip_option))]
pub output_schema: Option<OutputSchema>,
#[builder(default)]
pub ephemeral: bool,
#[builder(default)]
pub images: Vec<PathBuf>,
#[builder(default, setter(strip_option, into))]
pub local_provider: Option<String>,
#[builder(default, setter(strip_option, into))]
pub system_prompt: Option<String>,
#[builder(default, setter(strip_option))]
pub max_turns: Option<u32>,
#[builder(default, setter(strip_option))]
pub max_budget_tokens: Option<u64>,
#[builder(default)]
pub mcp_servers: McpServers,
#[builder(default)]
pub hooks: Vec<HookMatcher>,
#[builder(default_code = "Duration::from_secs(30)")]
pub default_hook_timeout: Duration,
}
impl std::fmt::Debug for ThreadOptions {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ThreadOptions")
.field("working_directory", &self.working_directory)
.field("model", &self.model)
.field("sandbox", &self.sandbox)
.field("approval", &self.approval)
.field("additional_directories", &self.additional_directories)
.field("skip_git_repo_check", &self.skip_git_repo_check)
.field("reasoning_effort", &self.reasoning_effort)
.field("network_access", &self.network_access)
.field("web_search", &self.web_search)
.field("output_schema", &self.output_schema)
.field("ephemeral", &self.ephemeral)
.field("images", &self.images)
.field("local_provider", &self.local_provider)
.field("system_prompt", &self.system_prompt)
.field("max_turns", &self.max_turns)
.field("max_budget_tokens", &self.max_budget_tokens)
.field("mcp_servers", &self.mcp_servers)
.field("hooks", &self.hooks)
.field("default_hook_timeout", &self.default_hook_timeout)
.finish()
}
}
impl Default for ThreadOptions {
fn default() -> Self {
Self::builder().build()
}
}
#[derive(Debug, Clone, Default)]
pub enum SandboxPolicy {
Restricted,
#[default]
WorkspaceWrite,
DangerFullAccess,
}
#[derive(Debug, Clone, Default)]
pub enum ApprovalPolicy {
#[default]
Never,
OnRequest,
UnlessTrusted,
}
#[derive(Debug, Clone)]
pub enum ReasoningEffort {
Minimal,
Low,
Medium,
High,
XHigh,
}
#[derive(Debug, Clone)]
pub enum WebSearchMode {
Disabled,
Cached,
Live,
}
#[derive(Debug, Clone)]
pub enum OutputSchema {
Inline(Value),
File(PathBuf),
}
#[derive(Debug, Clone, Default)]
pub enum ConfigOverrides {
#[default]
None,
Flat(Vec<(String, String)>),
Json(Value),
}
impl ConfigOverrides {
pub fn to_cli_pairs(&self) -> Vec<(String, String)> {
match self {
Self::None => vec![],
Self::Flat(pairs) => pairs.clone(),
Self::Json(value) => {
let mut result = vec![];
flatten_json("", value, &mut result);
result
}
}
}
}
fn flatten_json(prefix: &str, value: &Value, out: &mut Vec<(String, String)>) {
match value {
Value::Object(map) => {
for (key, val) in map {
let full_key = if prefix.is_empty() {
key.clone()
} else {
format!("{prefix}.{key}")
};
flatten_json(&full_key, val, out);
}
}
Value::Array(arr) => {
let formatted: Vec<String> = arr
.iter()
.map(|v| match v {
Value::String(s) => {
serde_json::to_string(s).expect("infallible: String serialization")
}
other => other.to_string(),
})
.collect();
out.push((prefix.to_string(), format!("[{}]", formatted.join(", "))));
}
Value::String(s) => out.push((
prefix.to_string(),
serde_json::to_string(s).expect("infallible: String serialization"),
)),
Value::Number(n) => out.push((prefix.to_string(), n.to_string())),
Value::Bool(b) => out.push((prefix.to_string(), b.to_string())),
Value::Null => {}
}
}
#[derive(Debug, Default)]
pub struct TurnOptions {
pub output_schema: Option<Value>,
pub cancel: Option<tokio_util::sync::CancellationToken>,
}
pub(crate) struct OutputSchemaFile {
_temp_dir: Option<tempfile::TempDir>,
schema_path: Option<PathBuf>,
}
impl OutputSchemaFile {
pub fn new(schema: Option<&Value>) -> crate::Result<Self> {
match schema {
None => Ok(Self {
_temp_dir: None,
schema_path: None,
}),
Some(value) => {
if !value.is_object() {
return Err(crate::Error::Config(
"output schema must be a JSON object".into(),
));
}
let temp_dir = tempfile::Builder::new()
.prefix("codex-output-schema-")
.tempdir()
.map_err(|e| crate::Error::Config(format!("failed to create temp dir: {e}")))?;
let schema_path = temp_dir.path().join("schema.json");
let bytes = serde_json::to_vec(value).map_err(|e| {
crate::Error::Config(format!("failed to serialize schema: {e}"))
})?;
std::fs::write(&schema_path, bytes)
.map_err(|e| crate::Error::Config(format!("failed to write schema: {e}")))?;
Ok(Self {
schema_path: Some(schema_path),
_temp_dir: Some(temp_dir),
})
}
}
}
pub fn path(&self) -> Option<&std::path::Path> {
self.schema_path.as_deref()
}
}
impl ThreadOptions {
pub fn to_cli_args(&self) -> Vec<String> {
let mut args = vec!["exec".to_string(), "--json".to_string()];
if let Some(ref model) = self.model {
args.extend(["--model".into(), model.clone()]);
}
match &self.sandbox {
SandboxPolicy::Restricted => {
args.extend(["--sandbox".into(), "restricted".into()]);
}
SandboxPolicy::WorkspaceWrite => {
args.extend(["--sandbox".into(), "workspace-write".into()]);
}
SandboxPolicy::DangerFullAccess => {
args.extend(["--sandbox".into(), "danger-full-access".into()]);
}
}
if let Some(ref cwd) = self.working_directory {
args.extend(["--cd".into(), cwd.display().to_string()]);
}
for dir in &self.additional_directories {
args.extend(["--add-dir".into(), dir.display().to_string()]);
}
if self.skip_git_repo_check {
args.push("--skip-git-repo-check".into());
}
if self.ephemeral {
args.push("--ephemeral".into());
}
for img in &self.images {
args.extend(["--image".into(), img.display().to_string()]);
}
if let Some(ref provider) = self.local_provider {
args.extend(["--local-provider".into(), provider.clone()]);
}
match &self.approval {
ApprovalPolicy::Never => {}
ApprovalPolicy::OnRequest => {
args.extend(["-c".into(), "approval_policy=on-request".into()]);
}
ApprovalPolicy::UnlessTrusted => {
args.extend(["-c".into(), "approval_policy=untrusted".into()]);
}
}
if let Some(ref effort) = self.reasoning_effort {
let val = match effort {
ReasoningEffort::Minimal => "minimal",
ReasoningEffort::Low => "low",
ReasoningEffort::Medium => "medium",
ReasoningEffort::High => "high",
ReasoningEffort::XHigh => "xhigh",
};
args.extend(["-c".into(), format!("model_reasoning_effort={val}")]);
}
if let Some(network) = self.network_access {
args.extend([
"-c".into(),
format!("sandbox_workspace_write.network_access={network}"),
]);
}
if let Some(ref ws) = self.web_search {
let val = match ws {
WebSearchMode::Disabled => "disabled",
WebSearchMode::Cached => "cached",
WebSearchMode::Live => "live",
};
args.extend(["-c".into(), format!("web_search={val}")]);
}
if let Some(ref prompt) = self.system_prompt {
let escaped = serde_json::to_string(prompt).expect("infallible: String serialization");
args.extend(["-c".into(), format!("system_prompt={escaped}")]);
}
if !self.mcp_servers.is_empty() {
if let Ok(json) = serde_json::to_string(&self.mcp_servers) {
args.extend(["-c".into(), format!("mcp_servers={json}")]);
}
}
args
}
}
impl CodexConfig {
pub fn apply_overrides(&self, args: &mut Vec<String>) {
if let Some(ref profile) = self.profile {
args.extend(["--profile".into(), profile.clone()]);
}
for (key, val) in self.config_overrides.to_cli_pairs() {
args.extend(["-c".into(), format!("{key}={val}")]);
}
}
pub fn to_env(&self) -> std::collections::HashMap<String, String> {
let mut env = self.env.clone();
env.entry("CODEX_INTERNAL_ORIGINATOR_OVERRIDE".into())
.or_insert_with(|| "codex_cli_sdk_rs".into());
env.entry("CI".into()).or_insert_with(|| "true".into());
env.entry("TERM".into()).or_insert_with(|| "xterm".into());
env
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_thread_options_cli_args() {
let args = ThreadOptions::default().to_cli_args();
assert_eq!(args[0], "exec");
assert_eq!(args[1], "--json");
assert!(args.contains(&"--sandbox".to_string()));
assert!(args.contains(&"workspace-write".to_string()));
}
#[test]
fn full_thread_options_cli_args() {
let opts = ThreadOptions::builder()
.model("o4-mini")
.sandbox(SandboxPolicy::DangerFullAccess)
.ephemeral(true)
.skip_git_repo_check(true)
.reasoning_effort(ReasoningEffort::High)
.network_access(true)
.web_search(WebSearchMode::Live)
.build();
let args = opts.to_cli_args();
assert!(args.contains(&"--model".to_string()));
assert!(args.contains(&"o4-mini".to_string()));
assert!(args.contains(&"danger-full-access".to_string()));
assert!(args.contains(&"--ephemeral".to_string()));
assert!(args.contains(&"--skip-git-repo-check".to_string()));
}
#[test]
fn flatten_json_nested() {
let value = serde_json::json!({
"sandbox_workspace_write": {
"network_access": true
}
});
let overrides = ConfigOverrides::Json(value);
let pairs = overrides.to_cli_pairs();
assert_eq!(pairs.len(), 1);
assert_eq!(pairs[0].0, "sandbox_workspace_write.network_access");
assert_eq!(pairs[0].1, "true");
}
#[test]
fn config_to_env_sets_defaults() {
let config = CodexConfig::default();
let env = config.to_env();
assert_eq!(env.get("CI").unwrap(), "true");
assert_eq!(env.get("TERM").unwrap(), "xterm");
assert!(env.contains_key("CODEX_INTERNAL_ORIGINATOR_OVERRIDE"));
}
#[test]
fn output_schema_file_creates_temp() {
let schema = serde_json::json!({"type": "object", "properties": {}});
let guard = OutputSchemaFile::new(Some(&schema)).unwrap();
assert!(guard.path().is_some());
assert!(guard.path().unwrap().exists());
}
#[test]
fn output_schema_file_rejects_non_object() {
let schema = serde_json::json!("not an object");
let result = OutputSchemaFile::new(Some(&schema));
assert!(result.is_err());
}
#[test]
fn system_prompt_cli_arg() {
let opts = ThreadOptions::builder()
.system_prompt("You are a helpful assistant")
.build();
let args = opts.to_cli_args();
assert!(args.contains(&"-c".to_string()));
assert!(
args.iter()
.any(|a| a.contains("system_prompt=") && a.contains("You are a helpful assistant"))
);
}
#[test]
fn system_prompt_with_special_chars_is_escaped() {
let opts = ThreadOptions::builder()
.system_prompt(r#"Say "hello" and use \n newlines"#)
.build();
let args = opts.to_cli_args();
let arg = args
.iter()
.find(|a| a.starts_with("system_prompt="))
.expect("system_prompt arg missing");
let json_value = arg.strip_prefix("system_prompt=").unwrap();
let parsed: String = serde_json::from_str(json_value)
.expect("system_prompt value should be valid JSON string");
assert!(parsed.contains('"'));
assert!(parsed.contains('\\'));
}
#[test]
fn flatten_json_escapes_string_values() {
let value = serde_json::json!({ "key": "val\"ue with \"quotes\" and \\backslash" });
let overrides = ConfigOverrides::Json(value);
let pairs = overrides.to_cli_pairs();
assert_eq!(pairs.len(), 1);
let parsed: String = serde_json::from_str(&pairs[0].1)
.expect("flattened string value should be valid JSON string");
assert!(parsed.contains('"'));
}
#[test]
fn mcp_servers_cli_arg() {
use crate::mcp::McpServerConfig;
let mut servers = crate::mcp::McpServers::new();
servers.insert(
"fs".into(),
McpServerConfig::new("npx").with_args(["-y", "fs-server"]),
);
let opts = ThreadOptions::builder().mcp_servers(servers).build();
let args = opts.to_cli_args();
assert!(args.iter().any(|a| a.starts_with("mcp_servers=")));
}
#[test]
fn max_turns_not_in_cli_args() {
let opts = ThreadOptions::builder().max_turns(5).build();
let args = opts.to_cli_args();
assert!(!args.iter().any(|a| a.contains("max_turns")));
}
#[test]
fn max_budget_tokens_not_in_cli_args() {
let opts = ThreadOptions::builder().max_budget_tokens(10000).build();
let args = opts.to_cli_args();
assert!(!args.iter().any(|a| a.contains("max_budget")));
}
#[test]
fn default_hook_timeout_is_30s() {
let opts = ThreadOptions::default();
assert_eq!(
opts.default_hook_timeout,
std::time::Duration::from_secs(30)
);
}
}