1use async_trait::async_trait;
2use chrono::Utc;
3use serde_json::json;
4use tokio::sync::mpsc;
5use tracing::debug;
6use uuid::Uuid;
7
8use ciab_core::error::{CiabError, CiabResult};
9use ciab_core::traits::agent::AgentProvider;
10use ciab_core::types::agent::{
11 AgentCommand, AgentConfig, AgentHealth, PromptMode, SlashCommand, SlashCommandArg,
12 SlashCommandCategory,
13};
14use ciab_core::types::llm_provider::{AgentLlmCompatibility, LlmProviderKind};
15use ciab_core::types::session::Message;
16use ciab_core::types::stream::{StreamEvent, StreamEventType};
17
18pub struct GeminiProvider;
19
20#[async_trait]
21impl AgentProvider for GeminiProvider {
22 fn name(&self) -> &str {
23 "gemini"
24 }
25
26 fn base_image(&self) -> &str {
27 "ghcr.io/ciab/gemini-sandbox:latest"
28 }
29
30 fn install_commands(&self) -> Vec<String> {
31 vec!["npm install -g @google/gemini-cli".to_string()]
32 }
33
34 fn build_start_command(&self, config: &AgentConfig) -> AgentCommand {
35 let mut args = vec!["--output-format".to_string(), "stream-json".to_string()];
38
39 if let Some(ref model) = config.model {
40 args.push("--model".to_string());
41 args.push(model.clone());
42 }
43
44 if let Some(mode) = config.extra.get("permission_mode").and_then(|v| v.as_str()) {
46 match mode {
47 "auto_approve" | "unrestricted" => {
48 args.push("--yolo".to_string());
49 }
50 "approve_edits" => {
51 args.push("--approval-mode".to_string());
52 args.push("auto_edit".to_string());
53 }
54 _ => {}
57 }
58 }
59
60 if config
62 .extra
63 .get("sandbox")
64 .and_then(|v| v.as_bool())
65 .unwrap_or(false)
66 {
67 args.push("--sandbox".to_string());
68 }
69
70 if config
72 .extra
73 .get("debug")
74 .and_then(|v| v.as_bool())
75 .unwrap_or(false)
76 {
77 args.push("--debug".to_string());
78 }
79
80 if let Some(session_id) = config
82 .extra
83 .get("resume_session_id")
84 .and_then(|v| v.as_str())
85 {
86 args.push("--resume".to_string());
87 args.push(session_id.to_string());
88 }
89
90 if !config.allowed_tools.is_empty() {
92 args.push("--allowed-tools".to_string());
93 args.push(config.allowed_tools.join(","));
94 }
95
96 if let Some(extensions) = config.extra.get("extensions").and_then(|v| v.as_str()) {
98 args.push("--extensions".to_string());
99 args.push(extensions.to_string());
100 }
101
102 AgentCommand {
103 command: "gemini".to_string(),
104 args,
105 env: Default::default(),
106 workdir: None,
107 }
108 }
109
110 fn prompt_mode(&self) -> PromptMode {
111 PromptMode::CliArgument
112 }
113
114 fn required_env_vars(&self) -> Vec<String> {
115 vec!["GOOGLE_API_KEY".to_string()]
116 }
117
118 fn parse_output(&self, sandbox_id: &Uuid, raw: &str) -> Vec<StreamEvent> {
131 let mut events = Vec::new();
132
133 for line in raw.lines() {
134 let line = line.trim();
135 if line.is_empty() {
136 continue;
137 }
138
139 let obj: serde_json::Value = match serde_json::from_str(line) {
140 Ok(v) => v,
141 Err(_) => {
142 events.push(StreamEvent {
143 id: Uuid::new_v4().to_string(),
144 sandbox_id: *sandbox_id,
145 session_id: None,
146 event_type: StreamEventType::LogLine,
147 data: json!({ "line": line }),
148 timestamp: Utc::now(),
149 });
150 continue;
151 }
152 };
153
154 let event_type = obj.get("type").and_then(|t| t.as_str()).unwrap_or("");
155
156 match event_type {
157 "init" | "system" => {
158 events.push(StreamEvent {
159 id: Uuid::new_v4().to_string(),
160 sandbox_id: *sandbox_id,
161 session_id: None,
162 event_type: StreamEventType::Connected,
163 data: json!({
164 "session_id": obj.get("session_id"),
165 "model": obj.get("model"),
166 "cwd": obj.get("cwd"),
167 "tools": obj.get("tools"),
168 }),
169 timestamp: Utc::now(),
170 });
171 }
172
173 "message" | "assistant" => {
174 if let Some(text) = obj.get("text").and_then(|t| t.as_str()) {
177 events.push(StreamEvent {
178 id: Uuid::new_v4().to_string(),
179 sandbox_id: *sandbox_id,
180 session_id: None,
181 event_type: StreamEventType::TextDelta,
182 data: json!({ "text": text }),
183 timestamp: Utc::now(),
184 });
185 }
186 if let Some(content) = obj
188 .get("message")
189 .and_then(|m| m.get("content"))
190 .and_then(|c| c.as_array())
191 {
192 for block in content {
193 let block_type =
194 block.get("type").and_then(|t| t.as_str()).unwrap_or("");
195 match block_type {
196 "text" => {
197 if let Some(text) = block.get("text").and_then(|t| t.as_str()) {
198 events.push(StreamEvent {
199 id: Uuid::new_v4().to_string(),
200 sandbox_id: *sandbox_id,
201 session_id: None,
202 event_type: StreamEventType::TextDelta,
203 data: json!({ "text": text }),
204 timestamp: Utc::now(),
205 });
206 }
207 }
208 "tool_use" => {
209 events.push(StreamEvent {
210 id: Uuid::new_v4().to_string(),
211 sandbox_id: *sandbox_id,
212 session_id: None,
213 event_type: StreamEventType::ToolUseStart,
214 data: json!({
215 "id": block.get("id"),
216 "name": block.get("name"),
217 "input": block.get("input").cloned().unwrap_or(json!({})),
218 }),
219 timestamp: Utc::now(),
220 });
221 }
222 _ => {}
223 }
224 }
225 }
226 if let Some(text) = obj
228 .get("message")
229 .and_then(|m| m.get("content"))
230 .and_then(|c| c.as_str())
231 {
232 events.push(StreamEvent {
233 id: Uuid::new_v4().to_string(),
234 sandbox_id: *sandbox_id,
235 session_id: None,
236 event_type: StreamEventType::TextDelta,
237 data: json!({ "text": text }),
238 timestamp: Utc::now(),
239 });
240 }
241 }
242
243 "tool_use" => {
244 let name = obj
245 .get("name")
246 .or_else(|| obj.get("tool_name"))
247 .and_then(|n| n.as_str())
248 .unwrap_or("unknown");
249 events.push(StreamEvent {
250 id: Uuid::new_v4().to_string(),
251 sandbox_id: *sandbox_id,
252 session_id: None,
253 event_type: StreamEventType::ToolUseStart,
254 data: json!({
255 "id": obj.get("id").or_else(|| obj.get("tool_use_id")),
256 "name": name,
257 "input": obj.get("input").or_else(|| obj.get("args")).cloned().unwrap_or(json!({})),
258 }),
259 timestamp: Utc::now(),
260 });
261 }
262
263 "tool_result" => {
264 events.push(StreamEvent {
265 id: Uuid::new_v4().to_string(),
266 sandbox_id: *sandbox_id,
267 session_id: None,
268 event_type: StreamEventType::ToolResult,
269 data: json!({
270 "tool_use_id": obj.get("tool_use_id").or_else(|| obj.get("id")),
271 "content": obj.get("content").or_else(|| obj.get("output")).or_else(|| obj.get("result")),
272 "is_error": obj.get("is_error").and_then(|v| v.as_bool()).unwrap_or(false),
273 }),
274 timestamp: Utc::now(),
275 });
276 }
277
278 "error" => {
279 events.push(StreamEvent {
280 id: Uuid::new_v4().to_string(),
281 sandbox_id: *sandbox_id,
282 session_id: None,
283 event_type: StreamEventType::ResultError,
284 data: json!({
285 "error_type": "error",
286 "message": obj.get("message").or_else(|| obj.get("error")),
287 }),
288 timestamp: Utc::now(),
289 });
290
291 events.push(StreamEvent {
292 id: Uuid::new_v4().to_string(),
293 sandbox_id: *sandbox_id,
294 session_id: None,
295 event_type: StreamEventType::SessionCompleted,
296 data: json!({
297 "session_id": obj.get("session_id"),
298 "error": true,
299 }),
300 timestamp: Utc::now(),
301 });
302 }
303
304 "result" => {
305 let response_text = obj
307 .get("response")
308 .and_then(|r| r.as_str())
309 .or_else(|| obj.get("result").and_then(|r| r.as_str()));
310
311 if let Some(text) = response_text {
312 events.push(StreamEvent {
313 id: Uuid::new_v4().to_string(),
314 sandbox_id: *sandbox_id,
315 session_id: None,
316 event_type: StreamEventType::TextComplete,
317 data: json!({
318 "text": text,
319 "session_id": obj.get("session_id"),
320 "stats": obj.get("stats"),
321 }),
322 timestamp: Utc::now(),
323 });
324 }
325
326 events.push(StreamEvent {
328 id: Uuid::new_v4().to_string(),
329 sandbox_id: *sandbox_id,
330 session_id: None,
331 event_type: StreamEventType::SessionCompleted,
332 data: json!({
333 "session_id": obj.get("session_id"),
334 "stats": obj.get("stats"),
335 }),
336 timestamp: Utc::now(),
337 });
338 }
339
340 _ => {
341 events.push(StreamEvent {
342 id: Uuid::new_v4().to_string(),
343 sandbox_id: *sandbox_id,
344 session_id: None,
345 event_type: StreamEventType::LogLine,
346 data: obj,
347 timestamp: Utc::now(),
348 });
349 }
350 }
351 }
352
353 events
354 }
355
356 fn validate_config(&self, config: &AgentConfig) -> CiabResult<()> {
357 if config.provider != "gemini" {
358 return Err(CiabError::ConfigValidationError(format!(
359 "expected provider 'gemini', got '{}'",
360 config.provider
361 )));
362 }
363 Ok(())
364 }
365
366 async fn send_message(
367 &self,
368 sandbox_id: &Uuid,
369 session_id: &Uuid,
370 message: &Message,
371 tx: &mpsc::Sender<StreamEvent>,
372 ) -> CiabResult<()> {
373 debug!(
374 sandbox_id = %sandbox_id,
375 session_id = %session_id,
376 "stub: message would be sent via execd"
377 );
378
379 let event = StreamEvent {
380 id: Uuid::new_v4().to_string(),
381 sandbox_id: *sandbox_id,
382 session_id: Some(*session_id),
383 event_type: StreamEventType::TextDelta,
384 data: json!({
385 "text": format!(
386 "stub: message with {} content part(s) would be sent via execd",
387 message.content.len()
388 )
389 }),
390 timestamp: Utc::now(),
391 };
392
393 tx.send(event).await.map_err(|e| {
394 CiabError::AgentCommunicationError(format!("failed to send event: {}", e))
395 })?;
396
397 Ok(())
398 }
399
400 async fn interrupt(&self, _sandbox_id: &Uuid) -> CiabResult<()> {
401 Ok(())
402 }
403
404 async fn health_check(&self, _sandbox_id: &Uuid) -> CiabResult<AgentHealth> {
405 Ok(AgentHealth {
406 healthy: true,
407 status: "ok".into(),
408 uptime_secs: None,
409 })
410 }
411
412 fn slash_commands(&self) -> Vec<SlashCommand> {
413 vec![
414 SlashCommand {
415 name: "clear".into(),
416 description: "Clear conversation history".into(),
417 category: SlashCommandCategory::Session,
418 args: vec![],
419 provider_native: false,
420 },
421 SlashCommand {
422 name: "help".into(),
423 description: "Show available commands".into(),
424 category: SlashCommandCategory::Help,
425 args: vec![],
426 provider_native: false,
427 },
428 SlashCommand {
429 name: "model".into(),
430 description: "Switch model".into(),
431 category: SlashCommandCategory::Agent,
432 args: vec![SlashCommandArg {
433 name: "model".into(),
434 description: "Model name to switch to".into(),
435 required: false,
436 }],
437 provider_native: true,
438 },
439 SlashCommand {
440 name: "stats".into(),
441 description: "Show usage statistics".into(),
442 category: SlashCommandCategory::Session,
443 args: vec![],
444 provider_native: true,
445 },
446 ]
447 }
448
449 fn supported_llm_providers(&self) -> Vec<AgentLlmCompatibility> {
450 vec![AgentLlmCompatibility {
451 agent_provider: "gemini".to_string(),
452 llm_provider_kind: LlmProviderKind::Google,
453 env_var_mapping: [("GOOGLE_API_KEY".to_string(), "{api_key}".to_string())]
454 .into_iter()
455 .collect(),
456 supports_model_override: true,
457 notes: Some("Native provider".to_string()),
458 }]
459 }
460}