Skip to main content

ciab_agent_codex/
lib.rs

1use async_trait::async_trait;
2use chrono::Utc;
3use serde_json::json;
4use tokio::sync::mpsc;
5use tracing::debug;
6use uuid::Uuid;
7
8use ciab_core::error::{CiabError, CiabResult};
9use ciab_core::traits::agent::AgentProvider;
10use ciab_core::types::agent::{
11    AgentCommand, AgentConfig, AgentHealth, PromptMode, SlashCommand, SlashCommandArg,
12    SlashCommandCategory,
13};
14use ciab_core::types::llm_provider::{AgentLlmCompatibility, LlmProviderKind};
15use ciab_core::types::session::Message;
16use ciab_core::types::stream::{StreamEvent, StreamEventType};
17
18pub struct CodexProvider;
19
20#[async_trait]
21impl AgentProvider for CodexProvider {
22    fn name(&self) -> &str {
23        "codex"
24    }
25
26    fn base_image(&self) -> &str {
27        "ghcr.io/ciab/codex-sandbox:latest"
28    }
29
30    fn install_commands(&self) -> Vec<String> {
31        vec!["npm install -g @openai/codex".to_string()]
32    }
33
34    fn build_start_command(&self, config: &AgentConfig) -> AgentCommand {
35        // Codex CLI: `codex --quiet --full-auto "prompt"`
36        // The prompt is appended as a positional arg by the session handler (PromptMode::CliArgument).
37        let mut args = vec!["--quiet".to_string()];
38
39        if let Some(ref model) = config.model {
40            args.push("--model".to_string());
41            args.push(model.clone());
42        }
43
44        // Approval mode mapping.
45        if let Some(mode) = config.extra.get("permission_mode").and_then(|v| v.as_str()) {
46            match mode {
47                "auto_approve" | "unrestricted" => {
48                    args.push("--full-auto".to_string());
49                }
50                "approve_edits" => {
51                    args.push("--auto-edit".to_string());
52                }
53                // approve_all → suggest mode (default)
54                // plan_only → suggest mode (closest equivalent)
55                _ => {}
56            }
57        }
58
59        let mut env: std::collections::HashMap<String, String> = Default::default();
60
61        // LLM provider override
62        if let Some(base_url) = config.extra.get("llm_base_url").and_then(|v| v.as_str()) {
63            env.insert("OPENAI_BASE_URL".to_string(), base_url.to_string());
64        }
65        if let Some(api_key) = config.extra.get("llm_api_key").and_then(|v| v.as_str()) {
66            env.insert("OPENAI_API_KEY".to_string(), api_key.to_string());
67        }
68
69        AgentCommand {
70            command: "codex".to_string(),
71            args,
72            env,
73            workdir: None,
74        }
75    }
76
77    fn prompt_mode(&self) -> PromptMode {
78        PromptMode::CliArgument
79    }
80
81    fn required_env_vars(&self) -> Vec<String> {
82        vec!["OPENAI_API_KEY".to_string()]
83    }
84
85    /// Parse Codex CLI output.
86    ///
87    /// Codex outputs a mix of plain text and JSON. In `--quiet` mode it outputs
88    /// the agent's response. We parse each line looking for JSON events, falling
89    /// back to plain text lines as TextDelta.
90    fn parse_output(&self, sandbox_id: &Uuid, raw: &str) -> Vec<StreamEvent> {
91        let mut events = Vec::new();
92
93        for line in raw.lines() {
94            let line = line.trim();
95            if line.is_empty() {
96                continue;
97            }
98
99            let obj: serde_json::Value = match serde_json::from_str(line) {
100                Ok(v) => v,
101                Err(_) => {
102                    // Plain text output from Codex — treat as text delta.
103                    events.push(StreamEvent {
104                        id: Uuid::new_v4().to_string(),
105                        sandbox_id: *sandbox_id,
106                        session_id: None,
107                        event_type: StreamEventType::TextDelta,
108                        data: json!({ "text": format!("{}\n", line) }),
109                        timestamp: Utc::now(),
110                    });
111                    continue;
112                }
113            };
114
115            let event_type = obj.get("type").and_then(|t| t.as_str()).unwrap_or("");
116
117            match event_type {
118                "system" | "init" => {
119                    events.push(StreamEvent {
120                        id: Uuid::new_v4().to_string(),
121                        sandbox_id: *sandbox_id,
122                        session_id: None,
123                        event_type: StreamEventType::Connected,
124                        data: json!({
125                            "session_id": obj.get("session_id"),
126                            "model": obj.get("model"),
127                        }),
128                        timestamp: Utc::now(),
129                    });
130                }
131
132                "message" | "assistant" | "text" => {
133                    let text = obj
134                        .get("text")
135                        .or_else(|| obj.get("content"))
136                        .or_else(|| obj.get("message"))
137                        .and_then(|t| t.as_str())
138                        .unwrap_or("");
139                    if !text.is_empty() {
140                        events.push(StreamEvent {
141                            id: Uuid::new_v4().to_string(),
142                            sandbox_id: *sandbox_id,
143                            session_id: None,
144                            event_type: StreamEventType::TextDelta,
145                            data: json!({ "text": text }),
146                            timestamp: Utc::now(),
147                        });
148                    }
149                }
150
151                "tool_use" | "tool_call" => {
152                    let name = obj
153                        .get("name")
154                        .or_else(|| obj.get("tool_name"))
155                        .and_then(|n| n.as_str())
156                        .unwrap_or("unknown");
157                    events.push(StreamEvent {
158                        id: Uuid::new_v4().to_string(),
159                        sandbox_id: *sandbox_id,
160                        session_id: None,
161                        event_type: StreamEventType::ToolUseStart,
162                        data: json!({
163                            "id": obj.get("id").or_else(|| obj.get("tool_use_id")),
164                            "name": name,
165                            "input": obj.get("input").or_else(|| obj.get("args")).cloned().unwrap_or(json!({})),
166                        }),
167                        timestamp: Utc::now(),
168                    });
169                }
170
171                "tool_result" => {
172                    events.push(StreamEvent {
173                        id: Uuid::new_v4().to_string(),
174                        sandbox_id: *sandbox_id,
175                        session_id: None,
176                        event_type: StreamEventType::ToolResult,
177                        data: json!({
178                            "tool_use_id": obj.get("tool_use_id").or_else(|| obj.get("id")),
179                            "content": obj.get("content").or_else(|| obj.get("output")),
180                            "is_error": obj.get("is_error").and_then(|v| v.as_bool()).unwrap_or(false),
181                        }),
182                        timestamp: Utc::now(),
183                    });
184                }
185
186                "result" => {
187                    if let Some(text) = obj
188                        .get("result")
189                        .or_else(|| obj.get("response"))
190                        .and_then(|r| r.as_str())
191                    {
192                        events.push(StreamEvent {
193                            id: Uuid::new_v4().to_string(),
194                            sandbox_id: *sandbox_id,
195                            session_id: None,
196                            event_type: StreamEventType::TextComplete,
197                            data: json!({
198                                "text": text,
199                                "session_id": obj.get("session_id"),
200                            }),
201                            timestamp: Utc::now(),
202                        });
203                    }
204
205                    events.push(StreamEvent {
206                        id: Uuid::new_v4().to_string(),
207                        sandbox_id: *sandbox_id,
208                        session_id: None,
209                        event_type: StreamEventType::SessionCompleted,
210                        data: json!({
211                            "session_id": obj.get("session_id"),
212                        }),
213                        timestamp: Utc::now(),
214                    });
215                }
216
217                _ => {
218                    events.push(StreamEvent {
219                        id: Uuid::new_v4().to_string(),
220                        sandbox_id: *sandbox_id,
221                        session_id: None,
222                        event_type: StreamEventType::LogLine,
223                        data: obj,
224                        timestamp: Utc::now(),
225                    });
226                }
227            }
228        }
229
230        events
231    }
232
233    fn validate_config(&self, config: &AgentConfig) -> CiabResult<()> {
234        if config.provider != "codex" {
235            return Err(CiabError::ConfigValidationError(format!(
236                "expected provider 'codex', got '{}'",
237                config.provider
238            )));
239        }
240        Ok(())
241    }
242
243    async fn send_message(
244        &self,
245        sandbox_id: &Uuid,
246        session_id: &Uuid,
247        message: &Message,
248        tx: &mpsc::Sender<StreamEvent>,
249    ) -> CiabResult<()> {
250        debug!(
251            sandbox_id = %sandbox_id,
252            session_id = %session_id,
253            "stub: message would be sent via execd"
254        );
255
256        let event = StreamEvent {
257            id: Uuid::new_v4().to_string(),
258            sandbox_id: *sandbox_id,
259            session_id: Some(*session_id),
260            event_type: StreamEventType::TextDelta,
261            data: json!({
262                "text": format!(
263                    "stub: message with {} content part(s) would be sent via execd",
264                    message.content.len()
265                )
266            }),
267            timestamp: Utc::now(),
268        };
269
270        tx.send(event).await.map_err(|e| {
271            CiabError::AgentCommunicationError(format!("failed to send event: {}", e))
272        })?;
273
274        Ok(())
275    }
276
277    async fn interrupt(&self, _sandbox_id: &Uuid) -> CiabResult<()> {
278        Ok(())
279    }
280
281    async fn health_check(&self, _sandbox_id: &Uuid) -> CiabResult<AgentHealth> {
282        Ok(AgentHealth {
283            healthy: true,
284            status: "ok".into(),
285            uptime_secs: None,
286        })
287    }
288
289    fn slash_commands(&self) -> Vec<SlashCommand> {
290        vec![
291            SlashCommand {
292                name: "clear".into(),
293                description: "Clear conversation history".into(),
294                category: SlashCommandCategory::Session,
295                args: vec![],
296                provider_native: false,
297            },
298            SlashCommand {
299                name: "help".into(),
300                description: "Show available commands".into(),
301                category: SlashCommandCategory::Help,
302                args: vec![],
303                provider_native: false,
304            },
305            SlashCommand {
306                name: "model".into(),
307                description: "Switch model".into(),
308                category: SlashCommandCategory::Agent,
309                args: vec![SlashCommandArg {
310                    name: "model".into(),
311                    description: "Model name to switch to".into(),
312                    required: false,
313                }],
314                provider_native: true,
315            },
316            SlashCommand {
317                name: "approval-mode".into(),
318                description: "Set approval mode (suggest, auto-edit, full-auto)".into(),
319                category: SlashCommandCategory::Agent,
320                args: vec![SlashCommandArg {
321                    name: "mode".into(),
322                    description: "Approval mode".into(),
323                    required: false,
324                }],
325                provider_native: true,
326            },
327        ]
328    }
329
330    fn supported_llm_providers(&self) -> Vec<AgentLlmCompatibility> {
331        vec![
332            AgentLlmCompatibility {
333                agent_provider: "codex".to_string(),
334                llm_provider_kind: LlmProviderKind::OpenAi,
335                env_var_mapping: [("OPENAI_API_KEY".to_string(), "{api_key}".to_string())]
336                    .into_iter()
337                    .collect(),
338                supports_model_override: true,
339                notes: Some("Native provider".to_string()),
340            },
341            AgentLlmCompatibility {
342                agent_provider: "codex".to_string(),
343                llm_provider_kind: LlmProviderKind::OpenRouter,
344                env_var_mapping: [
345                    (
346                        "OPENAI_BASE_URL".to_string(),
347                        "https://openrouter.ai/api/v1".to_string(),
348                    ),
349                    ("OPENAI_API_KEY".to_string(), "{api_key}".to_string()),
350                ]
351                .into_iter()
352                .collect(),
353                supports_model_override: true,
354                notes: Some("Via OPENAI_BASE_URL override".to_string()),
355            },
356            // Ollama: Codex uses the OpenAI-compatible endpoint via OPENAI_BASE_URL.
357            AgentLlmCompatibility {
358                agent_provider: "codex".to_string(),
359                llm_provider_kind: LlmProviderKind::Ollama,
360                env_var_mapping: [
361                    ("OPENAI_BASE_URL".to_string(), "{base_url}/v1".to_string()),
362                    ("OPENAI_API_KEY".to_string(), "ollama".to_string()),
363                ]
364                .into_iter()
365                .collect(),
366                supports_model_override: true,
367                notes: Some("Via OPENAI_BASE_URL → Ollama OpenAI-compatible endpoint".to_string()),
368            },
369        ]
370    }
371}