Skip to main content

aster/providers/
gemini_cli.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use serde_json::json;
4use std::ffi::OsString;
5use std::path::PathBuf;
6use std::process::Stdio;
7use tokio::io::{AsyncBufReadExt, BufReader};
8use tokio::process::Command;
9
10use super::base::{Provider, ProviderMetadata, ProviderUsage, Usage};
11use super::errors::ProviderError;
12use super::utils::{filter_extensions_from_system_prompt, RequestLog};
13use crate::config::base::GeminiCliCommand;
14use crate::config::search_path::SearchPaths;
15use crate::config::Config;
16use crate::conversation::message::{Message, MessageContent};
17use crate::model::ModelConfig;
18use crate::providers::base::ConfigKey;
19use crate::subprocess::configure_command_no_window;
20use rmcp::model::Role;
21use rmcp::model::Tool;
22
23pub const GEMINI_CLI_DEFAULT_MODEL: &str = "gemini-2.5-pro";
24pub const GEMINI_CLI_KNOWN_MODELS: &[&str] = &[
25    "gemini-2.5-pro",
26    "gemini-2.5-flash",
27    "gemini-2.5-flash-lite",
28];
29
30pub const GEMINI_CLI_DOC_URL: &str = "https://ai.google.dev/gemini-api/docs";
31
32#[derive(Debug, serde::Serialize)]
33pub struct GeminiCliProvider {
34    command: PathBuf,
35    model: ModelConfig,
36    #[serde(skip)]
37    name: String,
38}
39
40impl GeminiCliProvider {
41    pub async fn from_env(model: ModelConfig) -> Result<Self> {
42        let config = Config::global();
43        let command: OsString = config.get_gemini_cli_command().unwrap_or_default().into();
44        let resolved_command = SearchPaths::builder().with_npm().resolve(command)?;
45
46        Ok(Self {
47            command: resolved_command,
48            model,
49            name: Self::metadata().name,
50        })
51    }
52
53    /// Execute gemini CLI command with simple text prompt
54    async fn execute_command(
55        &self,
56        system: &str,
57        messages: &[Message],
58        _tools: &[Tool],
59    ) -> Result<Vec<String>, ProviderError> {
60        // Create a simple prompt combining system + conversation
61        let mut full_prompt = String::new();
62
63        let filtered_system = filter_extensions_from_system_prompt(system);
64        full_prompt.push_str(&filtered_system);
65        full_prompt.push_str("\n\n");
66
67        // Add conversation history
68        for message in messages.iter().filter(|m| m.is_agent_visible()) {
69            let role_prefix = match message.role {
70                Role::User => "Human: ",
71                Role::Assistant => "Assistant: ",
72            };
73            full_prompt.push_str(role_prefix);
74
75            for content in &message.content {
76                if let MessageContent::Text(text_content) = content {
77                    full_prompt.push_str(&text_content.text);
78                    full_prompt.push('\n');
79                }
80            }
81            full_prompt.push('\n');
82        }
83
84        full_prompt.push_str("Assistant: ");
85
86        if std::env::var("ASTER_GEMINI_CLI_DEBUG").is_ok() {
87            println!("=== GEMINI CLI PROVIDER DEBUG ===");
88            println!("Command: {:?}", self.command);
89            println!("Full prompt: {}", full_prompt);
90            println!("================================");
91        }
92
93        let mut cmd = Command::new(&self.command);
94        configure_command_no_window(&mut cmd);
95
96        if let Ok(path) = SearchPaths::builder().with_npm().path() {
97            cmd.env("PATH", path);
98        }
99
100        // Only pass model parameter if it's in the known models list
101        if GEMINI_CLI_KNOWN_MODELS.contains(&self.model.model_name.as_str()) {
102            cmd.arg("-m").arg(&self.model.model_name);
103        }
104
105        if cfg!(windows) {
106            let sanitized_prompt = full_prompt.replace("\r\n", "\\n").replace('\n', "\\n");
107
108            cmd.arg("-p").arg(&sanitized_prompt).arg("--yolo");
109        } else {
110            cmd.arg("-p").arg(&full_prompt).arg("--yolo");
111        }
112
113        cmd.stdout(Stdio::piped()).stderr(Stdio::piped());
114
115        let mut child = cmd.spawn().map_err(|e| {
116            ProviderError::RequestFailed(format!(
117                "Failed to spawn Gemini CLI command '{:?}': {}. \
118                Make sure the Gemini CLI is installed and available in the configured search paths.",
119                self.command, e
120            ))
121        })?;
122
123        let stdout = child
124            .stdout
125            .take()
126            .ok_or_else(|| ProviderError::RequestFailed("Failed to capture stdout".to_string()))?;
127
128        let mut reader = BufReader::new(stdout);
129        let mut lines = Vec::new();
130        let mut line = String::new();
131
132        loop {
133            line.clear();
134            match reader.read_line(&mut line).await {
135                Ok(0) => break, // EOF
136                Ok(_) => {
137                    let trimmed = line.trim();
138                    if !trimmed.is_empty() && !trimmed.starts_with("Loaded cached credentials") {
139                        lines.push(trimmed.to_string());
140                    }
141                }
142                Err(e) => {
143                    return Err(ProviderError::RequestFailed(format!(
144                        "Failed to read output: {}",
145                        e
146                    )));
147                }
148            }
149        }
150
151        let exit_status = child.wait().await.map_err(|e| {
152            ProviderError::RequestFailed(format!("Failed to wait for command: {}", e))
153        })?;
154
155        if !exit_status.success() {
156            return Err(ProviderError::RequestFailed(format!(
157                "Command failed with exit code: {:?}",
158                exit_status.code()
159            )));
160        }
161
162        tracing::debug!(
163            "Gemini CLI executed successfully, got {} lines",
164            lines.len()
165        );
166
167        Ok(lines)
168    }
169
170    /// Parse simple text response
171    fn parse_response(&self, lines: &[String]) -> Result<(Message, Usage), ProviderError> {
172        // Join all lines into a single response
173        let response_text = lines.join("\n");
174
175        if response_text.trim().is_empty() {
176            return Err(ProviderError::RequestFailed(
177                "Empty response from gemini command".to_string(),
178            ));
179        }
180
181        let message = Message::new(
182            Role::Assistant,
183            chrono::Utc::now().timestamp(),
184            vec![MessageContent::text(response_text)],
185        );
186
187        let usage = Usage::default(); // No usage info available for gemini CLI
188
189        Ok((message, usage))
190    }
191
192    /// Generate a simple session description without calling subprocess
193    fn generate_simple_session_description(
194        &self,
195        messages: &[Message],
196    ) -> Result<(Message, ProviderUsage), ProviderError> {
197        // Extract the first user message text
198        let description = messages
199            .iter()
200            .find(|m| m.role == Role::User)
201            .and_then(|m| {
202                m.content.iter().find_map(|c| match c {
203                    MessageContent::Text(text_content) => Some(&text_content.text),
204                    _ => None,
205                })
206            })
207            .map(|text| {
208                // Take first few words, limit to 4 words
209                text.split_whitespace()
210                    .take(4)
211                    .collect::<Vec<_>>()
212                    .join(" ")
213            })
214            .unwrap_or_else(|| "Simple task".to_string());
215
216        if std::env::var("ASTER_GEMINI_CLI_DEBUG").is_ok() {
217            println!("=== GEMINI CLI PROVIDER DEBUG ===");
218            println!("Generated simple session description: {}", description);
219            println!("Skipped subprocess call for session description");
220            println!("================================");
221        }
222
223        let message = Message::new(
224            Role::Assistant,
225            chrono::Utc::now().timestamp(),
226            vec![MessageContent::text(description.clone())],
227        );
228
229        let usage = Usage::default();
230
231        Ok((
232            message,
233            ProviderUsage::new(self.model.model_name.clone(), usage),
234        ))
235    }
236}
237
238#[async_trait]
239impl Provider for GeminiCliProvider {
240    fn metadata() -> ProviderMetadata {
241        ProviderMetadata::new(
242            "gemini-cli",
243            "Gemini CLI",
244            "Execute Gemini models via gemini CLI tool",
245            GEMINI_CLI_DEFAULT_MODEL,
246            GEMINI_CLI_KNOWN_MODELS.to_vec(),
247            GEMINI_CLI_DOC_URL,
248            vec![ConfigKey::from_value_type::<GeminiCliCommand>(true, false)],
249        )
250    }
251
252    fn get_name(&self) -> &str {
253        &self.name
254    }
255
256    fn get_model_config(&self) -> ModelConfig {
257        // Return the model config with appropriate context limit for Gemini models
258        self.model.clone()
259    }
260
261    #[tracing::instrument(
262        skip(self, _model_config, system, messages, tools),
263        fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
264    )]
265    async fn complete_with_model(
266        &self,
267        _model_config: &ModelConfig,
268        system: &str,
269        messages: &[Message],
270        tools: &[Tool],
271    ) -> Result<(Message, ProviderUsage), ProviderError> {
272        // Check if this is a session description request (short system prompt asking for 4 words or less)
273        if system.contains("four words or less") || system.contains("4 words or less") {
274            return self.generate_simple_session_description(messages);
275        }
276
277        // Create a dummy payload for debug tracing
278        let payload = json!({
279            "command": self.command,
280            "model": self.model.model_name,
281            "system": system,
282            "messages": messages.len()
283        });
284
285        let mut log = RequestLog::start(&self.model, &payload).map_err(|e| {
286            ProviderError::RequestFailed(format!("Failed to start request log: {}", e))
287        })?;
288
289        let lines = self.execute_command(system, messages, tools).await?;
290
291        let (message, usage) = self.parse_response(&lines)?;
292
293        let response = json!({
294            "lines": lines.len(),
295            "usage": usage
296        });
297
298        log.write(&response, Some(&usage)).map_err(|e| {
299            ProviderError::RequestFailed(format!("Failed to write request log: {}", e))
300        })?;
301
302        Ok((
303            message,
304            ProviderUsage::new(self.model.model_name.clone(), usage),
305        ))
306    }
307}