Skip to main content

zag_agent/providers/
gemini.rs

1// provider-updated: 2026-04-05
2use crate::agent::{Agent, ModelSize};
3use crate::output::AgentOutput;
4use crate::providers::common::CommonAgentState;
5use crate::session_log::{
6    BackfilledSession, HistoricalLogAdapter, LiveLogAdapter, LiveLogContext, LogCompleteness,
7    LogEventKind, LogSourceKind, SessionLogMetadata, SessionLogWriter,
8};
9use anyhow::Result;
10use async_trait::async_trait;
11use log::info;
12use std::collections::HashSet;
13use tokio::fs;
14use tokio::process::Command;
15
16/// Return the Gemini tmp directory: `~/.gemini/tmp/`.
17pub fn tmp_dir() -> Option<std::path::PathBuf> {
18    dirs::home_dir().map(|h| h.join(".gemini/tmp"))
19}
20
21pub const DEFAULT_MODEL: &str = "auto";
22
23pub const AVAILABLE_MODELS: &[&str] = &[
24    "auto",
25    "gemini-3.1-pro-preview",
26    "gemini-3.1-flash-lite-preview",
27    "gemini-3-pro-preview",
28    "gemini-3-flash-preview",
29    "gemini-2.5-pro",
30    "gemini-2.5-flash",
31    "gemini-2.5-flash-lite",
32];
33
34pub struct Gemini {
35    pub common: CommonAgentState,
36}
37
38pub struct GeminiLiveLogAdapter {
39    ctx: LiveLogContext,
40    session_path: Option<std::path::PathBuf>,
41    emitted_message_ids: std::collections::HashSet<String>,
42}
43
44pub struct GeminiHistoricalLogAdapter;
45
46impl Gemini {
47    pub fn new() -> Self {
48        Self {
49            common: CommonAgentState::new(DEFAULT_MODEL),
50        }
51    }
52
53    async fn write_system_file(&self) -> Result<()> {
54        let base = self.common.get_base_path();
55        log::debug!("Writing Gemini system file to {}", base.display());
56        let gemini_dir = base.join(".gemini");
57        fs::create_dir_all(&gemini_dir).await?;
58        fs::write(gemini_dir.join("system.md"), &self.common.system_prompt).await?;
59        Ok(())
60    }
61
62    /// Build the argument list for a run/exec invocation.
63    fn build_run_args(&self, interactive: bool, prompt: Option<&str>) -> Vec<String> {
64        let mut args = Vec::new();
65
66        if self.common.skip_permissions {
67            args.extend(["--approval-mode", "yolo"].map(String::from));
68        }
69
70        if !self.common.model.is_empty() && self.common.model != "auto" {
71            args.extend(["--model".to_string(), self.common.model.clone()]);
72        }
73
74        for dir in &self.common.add_dirs {
75            args.extend(["--include-directories".to_string(), dir.clone()]);
76        }
77
78        if !interactive && let Some(ref format) = self.common.output_format {
79            args.extend(["--output-format".to_string(), format.clone()]);
80        }
81
82        // Note: Gemini CLI does not support --max-turns as a CLI flag.
83        // Max turns is configured via settings.json (maxSessionTurns).
84        // The value is stored but not passed as an argument.
85
86        if let Some(p) = prompt {
87            args.push(p.to_string());
88        }
89
90        args
91    }
92
93    /// Create a `Command` either directly or wrapped in sandbox.
94    fn make_command(&self, agent_args: Vec<String>) -> Command {
95        self.common.make_command("gemini", agent_args)
96    }
97
98    async fn execute(
99        &self,
100        interactive: bool,
101        prompt: Option<&str>,
102    ) -> Result<Option<AgentOutput>> {
103        if !self.common.system_prompt.is_empty() {
104            log::debug!(
105                "Gemini system prompt (written to system.md): {}",
106                self.common.system_prompt
107            );
108            self.write_system_file().await?;
109        }
110
111        let agent_args = self.build_run_args(interactive, prompt);
112        log::debug!("Gemini command: gemini {}", agent_args.join(" "));
113        if let Some(p) = prompt {
114            log::debug!("Gemini user prompt: {}", p);
115        }
116        let mut cmd = self.make_command(agent_args);
117
118        if !self.common.system_prompt.is_empty() {
119            cmd.env("GEMINI_SYSTEM_MD", "true");
120        }
121
122        if interactive {
123            CommonAgentState::run_interactive_command(&mut cmd, "Gemini").await?;
124            Ok(None)
125        } else {
126            self.common
127                .run_non_interactive_simple(&mut cmd, "Gemini")
128                .await
129        }
130    }
131}
132
133#[cfg(test)]
134#[path = "gemini_tests.rs"]
135mod tests;
136
137impl Default for Gemini {
138    fn default() -> Self {
139        Self::new()
140    }
141}
142
143impl GeminiLiveLogAdapter {
144    pub fn new(ctx: LiveLogContext) -> Self {
145        Self {
146            ctx,
147            session_path: None,
148            emitted_message_ids: HashSet::new(),
149        }
150    }
151
152    fn discover_session_path(&self) -> Option<std::path::PathBuf> {
153        let gemini_tmp = tmp_dir()?;
154        let mut best: Option<(std::time::SystemTime, std::path::PathBuf)> = None;
155        let projects = std::fs::read_dir(gemini_tmp).ok()?;
156        for project in projects.flatten() {
157            let chats = project.path().join("chats");
158            let files = std::fs::read_dir(chats).ok()?;
159            for file in files.flatten() {
160                let path = file.path();
161                let metadata = file.metadata().ok()?;
162                let modified = metadata.modified().ok()?;
163                let started_at = std::time::SystemTime::UNIX_EPOCH
164                    + std::time::Duration::from_secs(self.ctx.started_at.timestamp().max(0) as u64);
165                if modified < started_at {
166                    continue;
167                }
168                if best
169                    .as_ref()
170                    .map(|(current, _)| modified > *current)
171                    .unwrap_or(true)
172                {
173                    best = Some((modified, path));
174                }
175            }
176        }
177        best.map(|(_, path)| path)
178    }
179}
180
181#[async_trait]
182impl LiveLogAdapter for GeminiLiveLogAdapter {
183    async fn poll(&mut self, writer: &SessionLogWriter) -> Result<()> {
184        if self.session_path.is_none() {
185            self.session_path = self.discover_session_path();
186            if let Some(path) = &self.session_path {
187                writer.add_source_path(path.to_string_lossy().to_string())?;
188            }
189        }
190        let Some(path) = self.session_path.as_ref() else {
191            return Ok(());
192        };
193        let content = match std::fs::read_to_string(path) {
194            Ok(content) => content,
195            Err(_) => return Ok(()),
196        };
197        let json: serde_json::Value = match serde_json::from_str(&content) {
198            Ok(json) => json,
199            Err(_) => {
200                writer.emit(
201                    LogSourceKind::ProviderFile,
202                    LogEventKind::ParseWarning {
203                        message: "Failed to parse Gemini chat file".to_string(),
204                        raw: None,
205                    },
206                )?;
207                return Ok(());
208            }
209        };
210        if let Some(session_id) = json.get("sessionId").and_then(|value| value.as_str()) {
211            writer.set_provider_session_id(Some(session_id.to_string()))?;
212        }
213        if let Some(messages) = json.get("messages").and_then(|value| value.as_array()) {
214            for message in messages {
215                let message_id = message
216                    .get("id")
217                    .and_then(|value| value.as_str())
218                    .unwrap_or_default()
219                    .to_string();
220                if message_id.is_empty() || !self.emitted_message_ids.insert(message_id.clone()) {
221                    continue;
222                }
223                match message.get("type").and_then(|value| value.as_str()) {
224                    Some("user") => writer.emit(
225                        LogSourceKind::ProviderFile,
226                        LogEventKind::UserMessage {
227                            role: "user".to_string(),
228                            content: message
229                                .get("content")
230                                .and_then(|value| value.as_str())
231                                .unwrap_or_default()
232                                .to_string(),
233                            message_id: Some(message_id.clone()),
234                        },
235                    )?,
236                    Some("gemini") => {
237                        writer.emit(
238                            LogSourceKind::ProviderFile,
239                            LogEventKind::AssistantMessage {
240                                content: message
241                                    .get("content")
242                                    .and_then(|value| value.as_str())
243                                    .unwrap_or_default()
244                                    .to_string(),
245                                message_id: Some(message_id.clone()),
246                            },
247                        )?;
248                        if let Some(thoughts) =
249                            message.get("thoughts").and_then(|value| value.as_array())
250                        {
251                            for thought in thoughts {
252                                writer.emit(
253                                    LogSourceKind::ProviderFile,
254                                    LogEventKind::Reasoning {
255                                        content: thought
256                                            .get("description")
257                                            .and_then(|value| value.as_str())
258                                            .unwrap_or_default()
259                                            .to_string(),
260                                        message_id: Some(message_id.clone()),
261                                    },
262                                )?;
263                            }
264                        }
265                        writer.emit(
266                            LogSourceKind::ProviderFile,
267                            LogEventKind::ProviderStatus {
268                                message: "Gemini message metadata".to_string(),
269                                data: Some(serde_json::json!({
270                                    "tokens": message.get("tokens"),
271                                    "model": message.get("model"),
272                                })),
273                            },
274                        )?;
275                    }
276                    _ => {}
277                }
278            }
279        }
280
281        Ok(())
282    }
283}
284
285impl HistoricalLogAdapter for GeminiHistoricalLogAdapter {
286    fn backfill(&self, _root: Option<&str>) -> Result<Vec<BackfilledSession>> {
287        let mut sessions = Vec::new();
288        let Some(gemini_tmp) = tmp_dir() else {
289            return Ok(sessions);
290        };
291        let projects = match std::fs::read_dir(gemini_tmp) {
292            Ok(projects) => projects,
293            Err(_) => return Ok(sessions),
294        };
295        for project in projects.flatten() {
296            let chats = project.path().join("chats");
297            let files = match std::fs::read_dir(chats) {
298                Ok(files) => files,
299                Err(_) => continue,
300            };
301            for file in files.flatten() {
302                let path = file.path();
303                info!("Scanning Gemini history: {}", path.display());
304                let content = match std::fs::read_to_string(&path) {
305                    Ok(content) => content,
306                    Err(_) => continue,
307                };
308                let json: serde_json::Value = match serde_json::from_str(&content) {
309                    Ok(json) => json,
310                    Err(_) => continue,
311                };
312                let Some(session_id) = json.get("sessionId").and_then(|value| value.as_str())
313                else {
314                    continue;
315                };
316                let mut events = Vec::new();
317                if let Some(messages) = json.get("messages").and_then(|value| value.as_array()) {
318                    for message in messages {
319                        let message_id = message
320                            .get("id")
321                            .and_then(|value| value.as_str())
322                            .map(str::to_string);
323                        match message.get("type").and_then(|value| value.as_str()) {
324                            Some("user") => events.push((
325                                LogSourceKind::Backfill,
326                                LogEventKind::UserMessage {
327                                    role: "user".to_string(),
328                                    content: message
329                                        .get("content")
330                                        .and_then(|value| value.as_str())
331                                        .unwrap_or_default()
332                                        .to_string(),
333                                    message_id: message_id.clone(),
334                                },
335                            )),
336                            Some("gemini") => {
337                                events.push((
338                                    LogSourceKind::Backfill,
339                                    LogEventKind::AssistantMessage {
340                                        content: message
341                                            .get("content")
342                                            .and_then(|value| value.as_str())
343                                            .unwrap_or_default()
344                                            .to_string(),
345                                        message_id: message_id.clone(),
346                                    },
347                                ));
348                                if let Some(thoughts) =
349                                    message.get("thoughts").and_then(|value| value.as_array())
350                                {
351                                    for thought in thoughts {
352                                        events.push((
353                                            LogSourceKind::Backfill,
354                                            LogEventKind::Reasoning {
355                                                content: thought
356                                                    .get("description")
357                                                    .and_then(|value| value.as_str())
358                                                    .unwrap_or_default()
359                                                    .to_string(),
360                                                message_id: message_id.clone(),
361                                            },
362                                        ));
363                                    }
364                                }
365                            }
366                            _ => {}
367                        }
368                    }
369                }
370                sessions.push(BackfilledSession {
371                    metadata: SessionLogMetadata {
372                        provider: "gemini".to_string(),
373                        wrapper_session_id: session_id.to_string(),
374                        provider_session_id: Some(session_id.to_string()),
375                        workspace_path: None,
376                        command: "backfill".to_string(),
377                        model: None,
378                        resumed: false,
379                        backfilled: true,
380                    },
381                    completeness: LogCompleteness::Full,
382                    source_paths: vec![path.to_string_lossy().to_string()],
383                    events,
384                });
385            }
386        }
387        Ok(sessions)
388    }
389}
390
391#[async_trait]
392impl Agent for Gemini {
393    fn name(&self) -> &str {
394        "gemini"
395    }
396
397    fn default_model() -> &'static str {
398        DEFAULT_MODEL
399    }
400
401    fn model_for_size(size: ModelSize) -> &'static str {
402        match size {
403            ModelSize::Small => "gemini-3.1-flash-lite-preview",
404            ModelSize::Medium => "gemini-2.5-flash",
405            ModelSize::Large => "gemini-3.1-pro-preview",
406        }
407    }
408
409    fn available_models() -> &'static [&'static str] {
410        AVAILABLE_MODELS
411    }
412
413    crate::providers::common::impl_common_agent_setters!();
414
415    fn set_skip_permissions(&mut self, skip: bool) {
416        self.common.skip_permissions = skip;
417    }
418
419    crate::providers::common::impl_as_any!();
420
421    async fn run(&self, prompt: Option<&str>) -> Result<Option<AgentOutput>> {
422        self.execute(false, prompt).await
423    }
424
425    async fn run_interactive(&self, prompt: Option<&str>) -> Result<()> {
426        self.execute(true, prompt).await?;
427        Ok(())
428    }
429
430    async fn run_resume(&self, session_id: Option<&str>, _last: bool) -> Result<()> {
431        let mut args = Vec::new();
432
433        if let Some(id) = session_id {
434            args.extend(["--resume".to_string(), id.to_string()]);
435        } else {
436            args.extend(["--resume".to_string(), "latest".to_string()]);
437        }
438
439        if self.common.skip_permissions {
440            args.extend(["--approval-mode", "yolo"].map(String::from));
441        }
442
443        if !self.common.model.is_empty() && self.common.model != "auto" {
444            args.extend(["--model".to_string(), self.common.model.clone()]);
445        }
446
447        for dir in &self.common.add_dirs {
448            args.extend(["--include-directories".to_string(), dir.clone()]);
449        }
450
451        let mut cmd = self.make_command(args);
452        CommonAgentState::run_interactive_command(&mut cmd, "Gemini").await
453    }
454
455    async fn cleanup(&self) -> Result<()> {
456        log::debug!("Cleaning up Gemini agent resources");
457        let base = self.common.get_base_path();
458        let gemini_dir = base.join(".gemini");
459        let system_file = gemini_dir.join("system.md");
460
461        if system_file.exists() {
462            fs::remove_file(&system_file).await?;
463        }
464
465        if gemini_dir.exists()
466            && fs::read_dir(&gemini_dir)
467                .await?
468                .next_entry()
469                .await?
470                .is_none()
471        {
472            fs::remove_dir(&gemini_dir).await?;
473        }
474
475        Ok(())
476    }
477}