Skip to main content

codex_cli_sdk/
config.rs

1use crate::hooks::HookMatcher;
2use crate::mcp::McpServers;
3use serde_json::Value;
4use std::path::PathBuf;
5use std::time::Duration;
6use typed_builder::TypedBuilder;
7
8// ── Crate-level config ─────────────────────────────────────────
9
10/// Top-level configuration for the Codex SDK.
11#[derive(Clone, TypedBuilder)]
12pub struct CodexConfig {
13    /// Path to the `codex` CLI binary. If None, auto-detected via discovery.
14    #[builder(default, setter(strip_option, into))]
15    pub cli_path: Option<PathBuf>,
16
17    /// Environment variables to pass to the Codex CLI process.
18    #[builder(default)]
19    pub env: std::collections::HashMap<String, String>,
20
21    /// TOML config overrides.
22    #[builder(default)]
23    pub config_overrides: ConfigOverrides,
24
25    /// Config profile name (passed via `--profile`).
26    #[builder(default, setter(strip_option, into))]
27    pub profile: Option<String>,
28
29    /// Timeout for CLI process spawn + session init.
30    #[builder(default_code = "Some(Duration::from_secs(30))")]
31    pub connect_timeout: Option<Duration>,
32
33    /// Timeout for graceful shutdown.
34    #[builder(default_code = "Some(Duration::from_secs(10))")]
35    pub close_timeout: Option<Duration>,
36
37    /// Timeout for `codex --version` check.
38    #[builder(default_code = "Some(Duration::from_secs(5))")]
39    pub version_check_timeout: Option<Duration>,
40
41    /// Stderr callback — invoked with each line of stderr from the CLI.
42    #[builder(default, setter(strip_option))]
43    pub stderr_callback: Option<StderrCallback>,
44}
45
46pub type StderrCallback = std::sync::Arc<dyn Fn(&str) + Send + Sync>;
47
48impl std::fmt::Debug for CodexConfig {
49    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50        f.debug_struct("CodexConfig")
51            .field("cli_path", &self.cli_path)
52            .field("env", &self.env)
53            .field("config_overrides", &self.config_overrides)
54            .field("profile", &self.profile)
55            .field("connect_timeout", &self.connect_timeout)
56            .field("close_timeout", &self.close_timeout)
57            .field("version_check_timeout", &self.version_check_timeout)
58            .field(
59                "stderr_callback",
60                &self.stderr_callback.as_ref().map(|_| "..."),
61            )
62            .finish()
63    }
64}
65
66impl Default for CodexConfig {
67    fn default() -> Self {
68        Self::builder().build()
69    }
70}
71
72// ── Per-thread options ─────────────────────────────────────────
73
74/// Options for a single thread (conversation).
75#[derive(Clone, TypedBuilder)]
76pub struct ThreadOptions {
77    /// Working directory for the thread.
78    #[builder(default, setter(strip_option, into))]
79    pub working_directory: Option<PathBuf>,
80
81    /// Model to use (e.g., "gpt-5-codex", "o4-mini").
82    #[builder(default, setter(strip_option, into))]
83    pub model: Option<String>,
84
85    /// Sandbox policy.
86    #[builder(default)]
87    pub sandbox: SandboxPolicy,
88
89    /// Approval policy.
90    #[builder(default)]
91    pub approval: ApprovalPolicy,
92
93    /// Additional writable directories (passed via `--add-dir`).
94    #[builder(default)]
95    pub additional_directories: Vec<PathBuf>,
96
97    /// Skip git repository check.
98    #[builder(default)]
99    pub skip_git_repo_check: bool,
100
101    /// Reasoning effort level.
102    #[builder(default, setter(strip_option))]
103    pub reasoning_effort: Option<ReasoningEffort>,
104
105    /// Enable network access in sandbox.
106    #[builder(default, setter(strip_option))]
107    pub network_access: Option<bool>,
108
109    /// Web search mode.
110    #[builder(default, setter(strip_option))]
111    pub web_search: Option<WebSearchMode>,
112
113    /// JSON Schema for structured output.
114    #[builder(default, setter(strip_option))]
115    pub output_schema: Option<OutputSchema>,
116
117    /// Ephemeral mode — don't persist session to disk.
118    #[builder(default)]
119    pub ephemeral: bool,
120
121    /// Image file paths to include with the prompt.
122    #[builder(default)]
123    pub images: Vec<PathBuf>,
124
125    /// Use local/OSS provider (lmstudio, ollama).
126    #[builder(default, setter(strip_option, into))]
127    pub local_provider: Option<String>,
128
129    // ── Feature-parity fields (Gap 1, 2, 4, 5) ───────────────
130    /// System prompt override — passed to CLI via `-c system_prompt="..."`.
131    #[builder(default, setter(strip_option, into))]
132    pub system_prompt: Option<String>,
133
134    /// SDK-enforced maximum number of turns. The stream closes after this many
135    /// `TurnCompleted` events.
136    #[builder(default, setter(strip_option))]
137    pub max_turns: Option<u32>,
138
139    /// SDK-enforced token budget. The stream closes when cumulative
140    /// `Usage.output_tokens` exceeds this value.
141    #[builder(default, setter(strip_option))]
142    pub max_budget_tokens: Option<u64>,
143
144    /// MCP server configurations. Serialized to CLI config overrides.
145    #[builder(default)]
146    pub mcp_servers: McpServers,
147
148    /// Hook matchers — evaluated in order on each stream event.
149    #[builder(default)]
150    pub hooks: Vec<HookMatcher>,
151
152    /// Default timeout for hook callbacks (applied when `HookMatcher.timeout` is `None`).
153    #[builder(default_code = "Duration::from_secs(30)")]
154    pub default_hook_timeout: Duration,
155}
156
157impl std::fmt::Debug for ThreadOptions {
158    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
159        f.debug_struct("ThreadOptions")
160            .field("working_directory", &self.working_directory)
161            .field("model", &self.model)
162            .field("sandbox", &self.sandbox)
163            .field("approval", &self.approval)
164            .field("additional_directories", &self.additional_directories)
165            .field("skip_git_repo_check", &self.skip_git_repo_check)
166            .field("reasoning_effort", &self.reasoning_effort)
167            .field("network_access", &self.network_access)
168            .field("web_search", &self.web_search)
169            .field("output_schema", &self.output_schema)
170            .field("ephemeral", &self.ephemeral)
171            .field("images", &self.images)
172            .field("local_provider", &self.local_provider)
173            .field("system_prompt", &self.system_prompt)
174            .field("max_turns", &self.max_turns)
175            .field("max_budget_tokens", &self.max_budget_tokens)
176            .field("mcp_servers", &self.mcp_servers)
177            .field("hooks", &self.hooks)
178            .field("default_hook_timeout", &self.default_hook_timeout)
179            .finish()
180    }
181}
182
183impl Default for ThreadOptions {
184    fn default() -> Self {
185        Self::builder().build()
186    }
187}
188
189// ── Enums ──────────────────────────────────────────────────────
190
191/// Sandbox isolation level.
192#[derive(Debug, Clone, Default)]
193pub enum SandboxPolicy {
194    /// Read-only access (most restrictive).
195    Restricted,
196    /// Workspace directory is writable (default).
197    #[default]
198    WorkspaceWrite,
199    /// Full filesystem access (DANGEROUS).
200    DangerFullAccess,
201}
202
203/// When to ask for user approval of tool calls.
204#[derive(Debug, Clone, Default)]
205pub enum ApprovalPolicy {
206    /// Auto-approve everything (default for exec mode).
207    #[default]
208    Never,
209    /// Model decides when to ask.
210    OnRequest,
211    /// Only auto-approve known-safe read-only commands.
212    UnlessTrusted,
213}
214
215/// Reasoning effort level.
216#[derive(Debug, Clone)]
217pub enum ReasoningEffort {
218    Minimal,
219    Low,
220    Medium,
221    High,
222    XHigh,
223}
224
225/// Web search configuration.
226#[derive(Debug, Clone)]
227pub enum WebSearchMode {
228    Disabled,
229    Cached,
230    Live,
231}
232
233/// Output schema — either an inline JSON object or a file path.
234#[derive(Debug, Clone)]
235pub enum OutputSchema {
236    /// Inline JSON schema (written to temp file automatically).
237    Inline(Value),
238    /// Path to an existing schema file.
239    File(PathBuf),
240}
241
242/// Config overrides — either flat key-value pairs or nested JSON.
243#[derive(Debug, Clone, Default)]
244pub enum ConfigOverrides {
245    #[default]
246    None,
247    /// Flat key-value pairs: `[("model", "o4-mini")]`
248    Flat(Vec<(String, String)>),
249    /// Nested JSON object — recursively flattened to dot-notation.
250    Json(Value),
251}
252
253impl ConfigOverrides {
254    /// Flatten to CLI `-c key=value` pairs.
255    pub fn to_cli_pairs(&self) -> Vec<(String, String)> {
256        match self {
257            Self::None => vec![],
258            Self::Flat(pairs) => pairs.clone(),
259            Self::Json(value) => {
260                let mut result = vec![];
261                flatten_json("", value, &mut result);
262                result
263            }
264        }
265    }
266}
267
268/// Recursively flatten nested JSON to dot-notation key-value pairs.
269fn flatten_json(prefix: &str, value: &Value, out: &mut Vec<(String, String)>) {
270    match value {
271        Value::Object(map) => {
272            for (key, val) in map {
273                let full_key = if prefix.is_empty() {
274                    key.clone()
275                } else {
276                    format!("{prefix}.{key}")
277                };
278                flatten_json(&full_key, val, out);
279            }
280        }
281        Value::Array(arr) => {
282            let formatted: Vec<String> = arr
283                .iter()
284                .map(|v| match v {
285                    // serde_json::to_string produces a properly escaped JSON string.
286                    Value::String(s) => {
287                        serde_json::to_string(s).expect("infallible: String serialization")
288                    }
289                    other => other.to_string(),
290                })
291                .collect();
292            out.push((prefix.to_string(), format!("[{}]", formatted.join(", "))));
293        }
294        // The Codex CLI `-c` flag expects strings wrapped in double quotes
295        // (e.g. `key="value"`), but bare booleans and numbers (e.g. `key=true`).
296        // Use serde_json to produce a properly escaped JSON string literal.
297        Value::String(s) => out.push((
298            prefix.to_string(),
299            serde_json::to_string(s).expect("infallible: String serialization"),
300        )),
301        Value::Number(n) => out.push((prefix.to_string(), n.to_string())),
302        Value::Bool(b) => out.push((prefix.to_string(), b.to_string())),
303        Value::Null => {}
304    }
305}
306
307// ── Per-turn options ───────────────────────────────────────────
308
309/// Options for a single turn (run/run_streamed call).
310#[derive(Debug, Default)]
311pub struct TurnOptions {
312    /// JSON Schema for structured output (overrides `ThreadOptions.output_schema`).
313    pub output_schema: Option<Value>,
314    /// Cancellation token — abort the turn mid-stream.
315    pub cancel: Option<tokio_util::sync::CancellationToken>,
316}
317
318// ── Output schema temp file (RAII guard) ───────────────────────
319
320/// Manages a temp file for inline JSON schemas.
321pub(crate) struct OutputSchemaFile {
322    _temp_dir: Option<tempfile::TempDir>,
323    schema_path: Option<PathBuf>,
324}
325
326impl OutputSchemaFile {
327    pub fn new(schema: Option<&Value>) -> crate::Result<Self> {
328        match schema {
329            None => Ok(Self {
330                _temp_dir: None,
331                schema_path: None,
332            }),
333            Some(value) => {
334                if !value.is_object() {
335                    return Err(crate::Error::Config(
336                        "output schema must be a JSON object".into(),
337                    ));
338                }
339                let temp_dir = tempfile::Builder::new()
340                    .prefix("codex-output-schema-")
341                    .tempdir()
342                    .map_err(|e| crate::Error::Config(format!("failed to create temp dir: {e}")))?;
343                let schema_path = temp_dir.path().join("schema.json");
344                let bytes = serde_json::to_vec(value).map_err(|e| {
345                    crate::Error::Config(format!("failed to serialize schema: {e}"))
346                })?;
347                std::fs::write(&schema_path, bytes)
348                    .map_err(|e| crate::Error::Config(format!("failed to write schema: {e}")))?;
349                Ok(Self {
350                    schema_path: Some(schema_path),
351                    _temp_dir: Some(temp_dir),
352                })
353            }
354        }
355    }
356
357    pub fn path(&self) -> Option<&std::path::Path> {
358        self.schema_path.as_deref()
359    }
360}
361
362// ── CLI argument generation ────────────────────────────────────
363
364impl ThreadOptions {
365    /// Convert thread options to CLI arguments for `codex exec`.
366    pub fn to_cli_args(&self) -> Vec<String> {
367        let mut args = vec!["exec".to_string(), "--json".to_string()];
368
369        if let Some(ref model) = self.model {
370            args.extend(["--model".into(), model.clone()]);
371        }
372
373        match &self.sandbox {
374            SandboxPolicy::Restricted => {
375                args.extend(["--sandbox".into(), "restricted".into()]);
376            }
377            SandboxPolicy::WorkspaceWrite => {
378                args.extend(["--sandbox".into(), "workspace-write".into()]);
379            }
380            SandboxPolicy::DangerFullAccess => {
381                args.extend(["--sandbox".into(), "danger-full-access".into()]);
382            }
383        }
384
385        if let Some(ref cwd) = self.working_directory {
386            args.extend(["--cd".into(), cwd.display().to_string()]);
387        }
388
389        for dir in &self.additional_directories {
390            args.extend(["--add-dir".into(), dir.display().to_string()]);
391        }
392
393        if self.skip_git_repo_check {
394            args.push("--skip-git-repo-check".into());
395        }
396
397        if self.ephemeral {
398            args.push("--ephemeral".into());
399        }
400
401        for img in &self.images {
402            args.extend(["--image".into(), img.display().to_string()]);
403        }
404
405        if let Some(ref provider) = self.local_provider {
406            args.extend(["--local-provider".into(), provider.clone()]);
407        }
408
409        match &self.approval {
410            ApprovalPolicy::Never => {}
411            ApprovalPolicy::OnRequest => {
412                args.extend(["-c".into(), "approval_policy=on-request".into()]);
413            }
414            ApprovalPolicy::UnlessTrusted => {
415                args.extend(["-c".into(), "approval_policy=untrusted".into()]);
416            }
417        }
418
419        if let Some(ref effort) = self.reasoning_effort {
420            let val = match effort {
421                ReasoningEffort::Minimal => "minimal",
422                ReasoningEffort::Low => "low",
423                ReasoningEffort::Medium => "medium",
424                ReasoningEffort::High => "high",
425                ReasoningEffort::XHigh => "xhigh",
426            };
427            args.extend(["-c".into(), format!("model_reasoning_effort={val}")]);
428        }
429
430        if let Some(network) = self.network_access {
431            args.extend([
432                "-c".into(),
433                format!("sandbox_workspace_write.network_access={network}"),
434            ]);
435        }
436
437        if let Some(ref ws) = self.web_search {
438            let val = match ws {
439                WebSearchMode::Disabled => "disabled",
440                WebSearchMode::Cached => "cached",
441                WebSearchMode::Live => "live",
442            };
443            args.extend(["-c".into(), format!("web_search={val}")]);
444        }
445
446        // System prompt → -c system_prompt="..."
447        // serde_json::to_string produces a properly escaped JSON string literal,
448        // handling embedded quotes, backslashes, newlines, etc.
449        if let Some(ref prompt) = self.system_prompt {
450            let escaped = serde_json::to_string(prompt).expect("infallible: String serialization");
451            args.extend(["-c".into(), format!("system_prompt={escaped}")]);
452        }
453
454        // MCP servers → -c mcp_servers=<json>
455        // Note: max_turns and max_budget_tokens are SDK-enforced, not CLI args.
456        if !self.mcp_servers.is_empty() {
457            if let Ok(json) = serde_json::to_string(&self.mcp_servers) {
458                args.extend(["-c".into(), format!("mcp_servers={json}")]);
459            }
460        }
461
462        args
463    }
464}
465
466impl CodexConfig {
467    /// Merge crate-level config overrides into CLI args.
468    pub fn apply_overrides(&self, args: &mut Vec<String>) {
469        if let Some(ref profile) = self.profile {
470            args.extend(["--profile".into(), profile.clone()]);
471        }
472        for (key, val) in self.config_overrides.to_cli_pairs() {
473            args.extend(["-c".into(), format!("{key}={val}")]);
474        }
475    }
476
477    /// Build environment variables for the subprocess.
478    pub fn to_env(&self) -> std::collections::HashMap<String, String> {
479        let mut env = self.env.clone();
480        env.entry("CODEX_INTERNAL_ORIGINATOR_OVERRIDE".into())
481            .or_insert_with(|| "codex_cli_sdk_rs".into());
482        env.entry("CI".into()).or_insert_with(|| "true".into());
483        env.entry("TERM".into()).or_insert_with(|| "xterm".into());
484        env
485    }
486}
487
488#[cfg(test)]
489mod tests {
490    use super::*;
491
492    #[test]
493    fn default_thread_options_cli_args() {
494        let args = ThreadOptions::default().to_cli_args();
495        assert_eq!(args[0], "exec");
496        assert_eq!(args[1], "--json");
497        assert!(args.contains(&"--sandbox".to_string()));
498        assert!(args.contains(&"workspace-write".to_string()));
499    }
500
501    #[test]
502    fn full_thread_options_cli_args() {
503        let opts = ThreadOptions::builder()
504            .model("o4-mini")
505            .sandbox(SandboxPolicy::DangerFullAccess)
506            .ephemeral(true)
507            .skip_git_repo_check(true)
508            .reasoning_effort(ReasoningEffort::High)
509            .network_access(true)
510            .web_search(WebSearchMode::Live)
511            .build();
512
513        let args = opts.to_cli_args();
514        assert!(args.contains(&"--model".to_string()));
515        assert!(args.contains(&"o4-mini".to_string()));
516        assert!(args.contains(&"danger-full-access".to_string()));
517        assert!(args.contains(&"--ephemeral".to_string()));
518        assert!(args.contains(&"--skip-git-repo-check".to_string()));
519    }
520
521    #[test]
522    fn flatten_json_nested() {
523        let value = serde_json::json!({
524            "sandbox_workspace_write": {
525                "network_access": true
526            }
527        });
528        let overrides = ConfigOverrides::Json(value);
529        let pairs = overrides.to_cli_pairs();
530        assert_eq!(pairs.len(), 1);
531        assert_eq!(pairs[0].0, "sandbox_workspace_write.network_access");
532        assert_eq!(pairs[0].1, "true");
533    }
534
535    #[test]
536    fn config_to_env_sets_defaults() {
537        let config = CodexConfig::default();
538        let env = config.to_env();
539        assert_eq!(env.get("CI").unwrap(), "true");
540        assert_eq!(env.get("TERM").unwrap(), "xterm");
541        assert!(env.contains_key("CODEX_INTERNAL_ORIGINATOR_OVERRIDE"));
542    }
543
544    #[test]
545    fn output_schema_file_creates_temp() {
546        let schema = serde_json::json!({"type": "object", "properties": {}});
547        let guard = OutputSchemaFile::new(Some(&schema)).unwrap();
548        assert!(guard.path().is_some());
549        assert!(guard.path().unwrap().exists());
550    }
551
552    #[test]
553    fn output_schema_file_rejects_non_object() {
554        let schema = serde_json::json!("not an object");
555        let result = OutputSchemaFile::new(Some(&schema));
556        assert!(result.is_err());
557    }
558
559    #[test]
560    fn system_prompt_cli_arg() {
561        let opts = ThreadOptions::builder()
562            .system_prompt("You are a helpful assistant")
563            .build();
564        let args = opts.to_cli_args();
565        assert!(args.contains(&"-c".to_string()));
566        assert!(
567            args.iter()
568                .any(|a| a.contains("system_prompt=") && a.contains("You are a helpful assistant"))
569        );
570    }
571
572    #[test]
573    fn system_prompt_with_special_chars_is_escaped() {
574        let opts = ThreadOptions::builder()
575            .system_prompt(r#"Say "hello" and use \n newlines"#)
576            .build();
577        let args = opts.to_cli_args();
578        let arg = args
579            .iter()
580            .find(|a| a.starts_with("system_prompt="))
581            .expect("system_prompt arg missing");
582        // Value must be a valid JSON string (parseable, no raw unescaped quotes).
583        let json_value = arg.strip_prefix("system_prompt=").unwrap();
584        let parsed: String = serde_json::from_str(json_value)
585            .expect("system_prompt value should be valid JSON string");
586        assert!(parsed.contains('"'));
587        assert!(parsed.contains('\\'));
588    }
589
590    #[test]
591    fn flatten_json_escapes_string_values() {
592        let value = serde_json::json!({ "key": "val\"ue with \"quotes\" and \\backslash" });
593        let overrides = ConfigOverrides::Json(value);
594        let pairs = overrides.to_cli_pairs();
595        assert_eq!(pairs.len(), 1);
596        // The value must be a valid JSON string literal.
597        let parsed: String = serde_json::from_str(&pairs[0].1)
598            .expect("flattened string value should be valid JSON string");
599        assert!(parsed.contains('"'));
600    }
601
602    #[test]
603    fn mcp_servers_cli_arg() {
604        use crate::mcp::McpServerConfig;
605
606        let mut servers = crate::mcp::McpServers::new();
607        servers.insert(
608            "fs".into(),
609            McpServerConfig::new("npx").with_args(["-y", "fs-server"]),
610        );
611
612        let opts = ThreadOptions::builder().mcp_servers(servers).build();
613        let args = opts.to_cli_args();
614        assert!(args.iter().any(|a| a.starts_with("mcp_servers=")));
615    }
616
617    #[test]
618    fn max_turns_not_in_cli_args() {
619        let opts = ThreadOptions::builder().max_turns(5).build();
620        let args = opts.to_cli_args();
621        // max_turns is SDK-enforced, not a CLI arg
622        assert!(!args.iter().any(|a| a.contains("max_turns")));
623    }
624
625    #[test]
626    fn max_budget_tokens_not_in_cli_args() {
627        let opts = ThreadOptions::builder().max_budget_tokens(10000).build();
628        let args = opts.to_cli_args();
629        assert!(!args.iter().any(|a| a.contains("max_budget")));
630    }
631
632    #[test]
633    fn default_hook_timeout_is_30s() {
634        let opts = ThreadOptions::default();
635        assert_eq!(
636            opts.default_hook_timeout,
637            std::time::Duration::from_secs(30)
638        );
639    }
640}