Skip to main content

ciab_agent_gemini/
lib.rs

1use async_trait::async_trait;
2use chrono::Utc;
3use serde_json::json;
4use tokio::sync::mpsc;
5use tracing::debug;
6use uuid::Uuid;
7
8use ciab_core::error::{CiabError, CiabResult};
9use ciab_core::traits::agent::AgentProvider;
10use ciab_core::types::agent::{
11    AgentCommand, AgentConfig, AgentHealth, PromptMode, SlashCommand, SlashCommandArg,
12    SlashCommandCategory,
13};
14use ciab_core::types::llm_provider::{AgentLlmCompatibility, LlmProviderKind};
15use ciab_core::types::session::Message;
16use ciab_core::types::stream::{StreamEvent, StreamEventType};
17
18pub struct GeminiProvider;
19
20#[async_trait]
21impl AgentProvider for GeminiProvider {
22    fn name(&self) -> &str {
23        "gemini"
24    }
25
26    fn base_image(&self) -> &str {
27        "ghcr.io/ciab/gemini-sandbox:latest"
28    }
29
30    fn install_commands(&self) -> Vec<String> {
31        vec!["npm install -g @google/gemini-cli".to_string()]
32    }
33
34    fn build_start_command(&self, config: &AgentConfig) -> AgentCommand {
35        // Gemini CLI: `gemini --output-format stream-json [flags] "prompt"`
36        // The prompt is appended as a positional arg by the session handler (PromptMode::CliArgument).
37        let mut args = vec!["--output-format".to_string(), "stream-json".to_string()];
38
39        if let Some(ref model) = config.model {
40            args.push("--model".to_string());
41            args.push(model.clone());
42        }
43
44        // Permission mode mapping.
45        if let Some(mode) = config.extra.get("permission_mode").and_then(|v| v.as_str()) {
46            match mode {
47                "auto_approve" | "unrestricted" => {
48                    args.push("--yolo".to_string());
49                }
50                "approve_edits" => {
51                    args.push("--approval-mode".to_string());
52                    args.push("auto_edit".to_string());
53                }
54                // approve_all → default (gemini prompts for approval)
55                // plan_only → not directly supported, we pass default
56                _ => {}
57            }
58        }
59
60        // Sandbox mode.
61        if config
62            .extra
63            .get("sandbox")
64            .and_then(|v| v.as_bool())
65            .unwrap_or(false)
66        {
67            args.push("--sandbox".to_string());
68        }
69
70        // Debug mode.
71        if config
72            .extra
73            .get("debug")
74            .and_then(|v| v.as_bool())
75            .unwrap_or(false)
76        {
77            args.push("--debug".to_string());
78        }
79
80        // Resume previous session.
81        if let Some(session_id) = config
82            .extra
83            .get("resume_session_id")
84            .and_then(|v| v.as_str())
85        {
86            args.push("--resume".to_string());
87            args.push(session_id.to_string());
88        }
89
90        // Allowed tools.
91        if !config.allowed_tools.is_empty() {
92            args.push("--allowed-tools".to_string());
93            args.push(config.allowed_tools.join(","));
94        }
95
96        // Extensions.
97        if let Some(extensions) = config.extra.get("extensions").and_then(|v| v.as_str()) {
98            args.push("--extensions".to_string());
99            args.push(extensions.to_string());
100        }
101
102        AgentCommand {
103            command: "gemini".to_string(),
104            args,
105            env: Default::default(),
106            workdir: None,
107        }
108    }
109
110    fn prompt_mode(&self) -> PromptMode {
111        PromptMode::CliArgument
112    }
113
114    fn required_env_vars(&self) -> Vec<String> {
115        vec!["GOOGLE_API_KEY".to_string()]
116    }
117
118    /// Parse Gemini CLI `--output-format stream-json` NDJSON output.
119    ///
120    /// Gemini emits NDJSON with these event types:
121    /// - `init` — session initialization with model, cwd
122    /// - `message` — assistant text content (streamed)
123    /// - `tool_use` — tool invocation started
124    /// - `tool_result` — tool invocation completed with result
125    /// - `error` — error event
126    /// - `result` — final result with response text, stats, session_id
127    ///
128    /// Also handles the alternate format where events use `type: "system"` etc.
129    /// (Gemini CLI format has evolved across versions).
130    fn parse_output(&self, sandbox_id: &Uuid, raw: &str) -> Vec<StreamEvent> {
131        let mut events = Vec::new();
132
133        for line in raw.lines() {
134            let line = line.trim();
135            if line.is_empty() {
136                continue;
137            }
138
139            let obj: serde_json::Value = match serde_json::from_str(line) {
140                Ok(v) => v,
141                Err(_) => {
142                    events.push(StreamEvent {
143                        id: Uuid::new_v4().to_string(),
144                        sandbox_id: *sandbox_id,
145                        session_id: None,
146                        event_type: StreamEventType::LogLine,
147                        data: json!({ "line": line }),
148                        timestamp: Utc::now(),
149                    });
150                    continue;
151                }
152            };
153
154            let event_type = obj.get("type").and_then(|t| t.as_str()).unwrap_or("");
155
156            match event_type {
157                "init" | "system" => {
158                    events.push(StreamEvent {
159                        id: Uuid::new_v4().to_string(),
160                        sandbox_id: *sandbox_id,
161                        session_id: None,
162                        event_type: StreamEventType::Connected,
163                        data: json!({
164                            "session_id": obj.get("session_id"),
165                            "model": obj.get("model"),
166                            "cwd": obj.get("cwd"),
167                            "tools": obj.get("tools"),
168                        }),
169                        timestamp: Utc::now(),
170                    });
171                }
172
173                "message" | "assistant" => {
174                    // Gemini streams assistant text via "message" events.
175                    // Can be a simple text field or content blocks.
176                    if let Some(text) = obj.get("text").and_then(|t| t.as_str()) {
177                        events.push(StreamEvent {
178                            id: Uuid::new_v4().to_string(),
179                            sandbox_id: *sandbox_id,
180                            session_id: None,
181                            event_type: StreamEventType::TextDelta,
182                            data: json!({ "text": text }),
183                            timestamp: Utc::now(),
184                        });
185                    }
186                    // Handle content block array format.
187                    if let Some(content) = obj
188                        .get("message")
189                        .and_then(|m| m.get("content"))
190                        .and_then(|c| c.as_array())
191                    {
192                        for block in content {
193                            let block_type =
194                                block.get("type").and_then(|t| t.as_str()).unwrap_or("");
195                            match block_type {
196                                "text" => {
197                                    if let Some(text) = block.get("text").and_then(|t| t.as_str()) {
198                                        events.push(StreamEvent {
199                                            id: Uuid::new_v4().to_string(),
200                                            sandbox_id: *sandbox_id,
201                                            session_id: None,
202                                            event_type: StreamEventType::TextDelta,
203                                            data: json!({ "text": text }),
204                                            timestamp: Utc::now(),
205                                        });
206                                    }
207                                }
208                                "tool_use" => {
209                                    events.push(StreamEvent {
210                                        id: Uuid::new_v4().to_string(),
211                                        sandbox_id: *sandbox_id,
212                                        session_id: None,
213                                        event_type: StreamEventType::ToolUseStart,
214                                        data: json!({
215                                            "id": block.get("id"),
216                                            "name": block.get("name"),
217                                            "input": block.get("input").cloned().unwrap_or(json!({})),
218                                        }),
219                                        timestamp: Utc::now(),
220                                    });
221                                }
222                                _ => {}
223                            }
224                        }
225                    }
226                    // Handle plain content string.
227                    if let Some(text) = obj
228                        .get("message")
229                        .and_then(|m| m.get("content"))
230                        .and_then(|c| c.as_str())
231                    {
232                        events.push(StreamEvent {
233                            id: Uuid::new_v4().to_string(),
234                            sandbox_id: *sandbox_id,
235                            session_id: None,
236                            event_type: StreamEventType::TextDelta,
237                            data: json!({ "text": text }),
238                            timestamp: Utc::now(),
239                        });
240                    }
241                }
242
243                "tool_use" => {
244                    let name = obj
245                        .get("name")
246                        .or_else(|| obj.get("tool_name"))
247                        .and_then(|n| n.as_str())
248                        .unwrap_or("unknown");
249                    events.push(StreamEvent {
250                        id: Uuid::new_v4().to_string(),
251                        sandbox_id: *sandbox_id,
252                        session_id: None,
253                        event_type: StreamEventType::ToolUseStart,
254                        data: json!({
255                            "id": obj.get("id").or_else(|| obj.get("tool_use_id")),
256                            "name": name,
257                            "input": obj.get("input").or_else(|| obj.get("args")).cloned().unwrap_or(json!({})),
258                        }),
259                        timestamp: Utc::now(),
260                    });
261                }
262
263                "tool_result" => {
264                    events.push(StreamEvent {
265                        id: Uuid::new_v4().to_string(),
266                        sandbox_id: *sandbox_id,
267                        session_id: None,
268                        event_type: StreamEventType::ToolResult,
269                        data: json!({
270                            "tool_use_id": obj.get("tool_use_id").or_else(|| obj.get("id")),
271                            "content": obj.get("content").or_else(|| obj.get("output")).or_else(|| obj.get("result")),
272                            "is_error": obj.get("is_error").and_then(|v| v.as_bool()).unwrap_or(false),
273                        }),
274                        timestamp: Utc::now(),
275                    });
276                }
277
278                "error" => {
279                    events.push(StreamEvent {
280                        id: Uuid::new_v4().to_string(),
281                        sandbox_id: *sandbox_id,
282                        session_id: None,
283                        event_type: StreamEventType::ResultError,
284                        data: json!({
285                            "error_type": "error",
286                            "message": obj.get("message").or_else(|| obj.get("error")),
287                        }),
288                        timestamp: Utc::now(),
289                    });
290
291                    events.push(StreamEvent {
292                        id: Uuid::new_v4().to_string(),
293                        sandbox_id: *sandbox_id,
294                        session_id: None,
295                        event_type: StreamEventType::SessionCompleted,
296                        data: json!({
297                            "session_id": obj.get("session_id"),
298                            "error": true,
299                        }),
300                        timestamp: Utc::now(),
301                    });
302                }
303
304                "result" => {
305                    // Extract response text.
306                    let response_text = obj
307                        .get("response")
308                        .and_then(|r| r.as_str())
309                        .or_else(|| obj.get("result").and_then(|r| r.as_str()));
310
311                    if let Some(text) = response_text {
312                        events.push(StreamEvent {
313                            id: Uuid::new_v4().to_string(),
314                            sandbox_id: *sandbox_id,
315                            session_id: None,
316                            event_type: StreamEventType::TextComplete,
317                            data: json!({
318                                "text": text,
319                                "session_id": obj.get("session_id"),
320                                "stats": obj.get("stats"),
321                            }),
322                            timestamp: Utc::now(),
323                        });
324                    }
325
326                    // Signal session completion.
327                    events.push(StreamEvent {
328                        id: Uuid::new_v4().to_string(),
329                        sandbox_id: *sandbox_id,
330                        session_id: None,
331                        event_type: StreamEventType::SessionCompleted,
332                        data: json!({
333                            "session_id": obj.get("session_id"),
334                            "stats": obj.get("stats"),
335                        }),
336                        timestamp: Utc::now(),
337                    });
338                }
339
340                _ => {
341                    events.push(StreamEvent {
342                        id: Uuid::new_v4().to_string(),
343                        sandbox_id: *sandbox_id,
344                        session_id: None,
345                        event_type: StreamEventType::LogLine,
346                        data: obj,
347                        timestamp: Utc::now(),
348                    });
349                }
350            }
351        }
352
353        events
354    }
355
356    fn validate_config(&self, config: &AgentConfig) -> CiabResult<()> {
357        if config.provider != "gemini" {
358            return Err(CiabError::ConfigValidationError(format!(
359                "expected provider 'gemini', got '{}'",
360                config.provider
361            )));
362        }
363        Ok(())
364    }
365
366    async fn send_message(
367        &self,
368        sandbox_id: &Uuid,
369        session_id: &Uuid,
370        message: &Message,
371        tx: &mpsc::Sender<StreamEvent>,
372    ) -> CiabResult<()> {
373        debug!(
374            sandbox_id = %sandbox_id,
375            session_id = %session_id,
376            "stub: message would be sent via execd"
377        );
378
379        let event = StreamEvent {
380            id: Uuid::new_v4().to_string(),
381            sandbox_id: *sandbox_id,
382            session_id: Some(*session_id),
383            event_type: StreamEventType::TextDelta,
384            data: json!({
385                "text": format!(
386                    "stub: message with {} content part(s) would be sent via execd",
387                    message.content.len()
388                )
389            }),
390            timestamp: Utc::now(),
391        };
392
393        tx.send(event).await.map_err(|e| {
394            CiabError::AgentCommunicationError(format!("failed to send event: {}", e))
395        })?;
396
397        Ok(())
398    }
399
400    async fn interrupt(&self, _sandbox_id: &Uuid) -> CiabResult<()> {
401        Ok(())
402    }
403
404    async fn health_check(&self, _sandbox_id: &Uuid) -> CiabResult<AgentHealth> {
405        Ok(AgentHealth {
406            healthy: true,
407            status: "ok".into(),
408            uptime_secs: None,
409        })
410    }
411
412    fn slash_commands(&self) -> Vec<SlashCommand> {
413        vec![
414            SlashCommand {
415                name: "clear".into(),
416                description: "Clear conversation history".into(),
417                category: SlashCommandCategory::Session,
418                args: vec![],
419                provider_native: false,
420            },
421            SlashCommand {
422                name: "help".into(),
423                description: "Show available commands".into(),
424                category: SlashCommandCategory::Help,
425                args: vec![],
426                provider_native: false,
427            },
428            SlashCommand {
429                name: "model".into(),
430                description: "Switch model".into(),
431                category: SlashCommandCategory::Agent,
432                args: vec![SlashCommandArg {
433                    name: "model".into(),
434                    description: "Model name to switch to".into(),
435                    required: false,
436                }],
437                provider_native: true,
438            },
439            SlashCommand {
440                name: "stats".into(),
441                description: "Show usage statistics".into(),
442                category: SlashCommandCategory::Session,
443                args: vec![],
444                provider_native: true,
445            },
446        ]
447    }
448
449    fn supported_llm_providers(&self) -> Vec<AgentLlmCompatibility> {
450        vec![AgentLlmCompatibility {
451            agent_provider: "gemini".to_string(),
452            llm_provider_kind: LlmProviderKind::Google,
453            env_var_mapping: [("GOOGLE_API_KEY".to_string(), "{api_key}".to_string())]
454                .into_iter()
455                .collect(),
456            supports_model_override: true,
457            notes: Some("Native provider".to_string()),
458        }]
459    }
460}