construct/providers/
gemini_cli.rs1use crate::providers::traits::{ChatRequest, ChatResponse, Provider, TokenUsage};
38use async_trait::async_trait;
39use std::path::PathBuf;
40use tokio::io::AsyncWriteExt;
41use tokio::process::Command;
42use tokio::time::{Duration, timeout};
43
44pub const GEMINI_CLI_PATH_ENV: &str = "GEMINI_CLI_PATH";
46
47const DEFAULT_GEMINI_CLI_BINARY: &str = "gemini";
49
50const DEFAULT_MODEL_MARKER: &str = "default";
52const GEMINI_CLI_REQUEST_TIMEOUT: Duration = Duration::from_secs(120);
54const MAX_GEMINI_CLI_STDERR_CHARS: usize = 512;
56const GEMINI_CLI_SUPPORTED_TEMPERATURES: [f64; 2] = [0.7, 1.0];
58const TEMP_EPSILON: f64 = 1e-9;
59
60pub struct GeminiCliProvider {
65 binary_path: PathBuf,
67}
68
69impl GeminiCliProvider {
70 pub fn new() -> Self {
75 let binary_path = std::env::var(GEMINI_CLI_PATH_ENV)
76 .ok()
77 .filter(|path| !path.trim().is_empty())
78 .map(PathBuf::from)
79 .unwrap_or_else(|| PathBuf::from(DEFAULT_GEMINI_CLI_BINARY));
80
81 Self { binary_path }
82 }
83
84 fn should_forward_model(model: &str) -> bool {
86 let trimmed = model.trim();
87 !trimmed.is_empty() && trimmed != DEFAULT_MODEL_MARKER
88 }
89
90 fn supports_temperature(temperature: f64) -> bool {
91 GEMINI_CLI_SUPPORTED_TEMPERATURES
92 .iter()
93 .any(|v| (temperature - v).abs() < TEMP_EPSILON)
94 }
95
96 fn validate_temperature(temperature: f64) -> anyhow::Result<()> {
97 if !temperature.is_finite() {
98 anyhow::bail!("Gemini CLI provider received non-finite temperature value");
99 }
100 if !Self::supports_temperature(temperature) {
101 anyhow::bail!(
102 "temperature unsupported by Gemini CLI: {temperature}. \
103 Supported values: 0.7 or 1.0"
104 );
105 }
106 Ok(())
107 }
108
109 fn redact_stderr(stderr: &[u8]) -> String {
110 let text = String::from_utf8_lossy(stderr);
111 let trimmed = text.trim();
112 if trimmed.is_empty() {
113 return String::new();
114 }
115 if trimmed.chars().count() <= MAX_GEMINI_CLI_STDERR_CHARS {
116 return trimmed.to_string();
117 }
118 let clipped: String = trimmed.chars().take(MAX_GEMINI_CLI_STDERR_CHARS).collect();
119 format!("{clipped}...")
120 }
121
122 async fn invoke_cli(&self, message: &str, model: &str) -> anyhow::Result<String> {
125 let mut cmd = Command::new(&self.binary_path);
126 cmd.arg("--print");
127
128 if Self::should_forward_model(model) {
129 cmd.arg("--model").arg(model);
130 }
131
132 cmd.arg("-");
134 cmd.kill_on_drop(true);
135 cmd.stdin(std::process::Stdio::piped());
136 cmd.stdout(std::process::Stdio::piped());
137 cmd.stderr(std::process::Stdio::piped());
138
139 let mut child = cmd.spawn().map_err(|err| {
140 anyhow::anyhow!(
141 "Failed to spawn Gemini CLI binary at {}: {err}. \
142 Ensure `gemini` is installed and in PATH, or set GEMINI_CLI_PATH.",
143 self.binary_path.display()
144 )
145 })?;
146
147 if let Some(mut stdin) = child.stdin.take() {
148 stdin.write_all(message.as_bytes()).await.map_err(|err| {
149 anyhow::anyhow!("Failed to write prompt to Gemini CLI stdin: {err}")
150 })?;
151 stdin.shutdown().await.map_err(|err| {
152 anyhow::anyhow!("Failed to finalize Gemini CLI stdin stream: {err}")
153 })?;
154 }
155
156 let output = timeout(GEMINI_CLI_REQUEST_TIMEOUT, child.wait_with_output())
157 .await
158 .map_err(|_| {
159 anyhow::anyhow!(
160 "Gemini CLI request timed out after {:?} (binary: {})",
161 GEMINI_CLI_REQUEST_TIMEOUT,
162 self.binary_path.display()
163 )
164 })?
165 .map_err(|err| anyhow::anyhow!("Gemini CLI process failed: {err}"))?;
166
167 if !output.status.success() {
168 let code = output.status.code().unwrap_or(-1);
169 let stderr_excerpt = Self::redact_stderr(&output.stderr);
170 let stderr_note = if stderr_excerpt.is_empty() {
171 String::new()
172 } else {
173 format!(" Stderr: {stderr_excerpt}")
174 };
175 anyhow::bail!(
176 "Gemini CLI exited with non-zero status {code}. \
177 Check that Gemini CLI is authenticated and the CLI is supported.{stderr_note}"
178 );
179 }
180
181 let text = String::from_utf8(output.stdout)
182 .map_err(|err| anyhow::anyhow!("Gemini CLI produced non-UTF-8 output: {err}"))?;
183
184 Ok(text.trim().to_string())
185 }
186}
187
188impl Default for GeminiCliProvider {
189 fn default() -> Self {
190 Self::new()
191 }
192}
193
194#[async_trait]
195impl Provider for GeminiCliProvider {
196 async fn chat_with_system(
197 &self,
198 system_prompt: Option<&str>,
199 message: &str,
200 model: &str,
201 temperature: f64,
202 ) -> anyhow::Result<String> {
203 Self::validate_temperature(temperature)?;
204
205 let full_message = match system_prompt {
206 Some(system) if !system.is_empty() => {
207 format!("{system}\n\n{message}")
208 }
209 _ => message.to_string(),
210 };
211
212 self.invoke_cli(&full_message, model).await
213 }
214
215 async fn chat(
216 &self,
217 request: ChatRequest<'_>,
218 model: &str,
219 temperature: f64,
220 ) -> anyhow::Result<ChatResponse> {
221 let text = self
222 .chat_with_history(request.messages, model, temperature)
223 .await?;
224
225 Ok(ChatResponse {
226 text: Some(text),
227 tool_calls: Vec::new(),
228 usage: Some(TokenUsage::default()),
229 reasoning_content: None,
230 })
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237 use std::sync::{Mutex, OnceLock};
238
239 fn env_lock() -> std::sync::MutexGuard<'static, ()> {
240 static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
241 LOCK.get_or_init(|| Mutex::new(()))
242 .lock()
243 .expect("env lock poisoned")
244 }
245
246 #[test]
247 fn new_uses_env_override() {
248 let _guard = env_lock();
249 let orig = std::env::var(GEMINI_CLI_PATH_ENV).ok();
250 unsafe { std::env::set_var(GEMINI_CLI_PATH_ENV, "/usr/local/bin/gemini") };
252 let provider = GeminiCliProvider::new();
253 assert_eq!(provider.binary_path, PathBuf::from("/usr/local/bin/gemini"));
254 match orig {
255 Some(v) => unsafe { std::env::set_var(GEMINI_CLI_PATH_ENV, v) },
257 None => unsafe { std::env::remove_var(GEMINI_CLI_PATH_ENV) },
259 }
260 }
261
262 #[test]
263 fn new_defaults_to_gemini() {
264 let _guard = env_lock();
265 let orig = std::env::var(GEMINI_CLI_PATH_ENV).ok();
266 unsafe { std::env::remove_var(GEMINI_CLI_PATH_ENV) };
268 let provider = GeminiCliProvider::new();
269 assert_eq!(provider.binary_path, PathBuf::from("gemini"));
270 if let Some(v) = orig {
271 unsafe { std::env::set_var(GEMINI_CLI_PATH_ENV, v) };
273 }
274 }
275
276 #[test]
277 fn new_ignores_blank_env_override() {
278 let _guard = env_lock();
279 let orig = std::env::var(GEMINI_CLI_PATH_ENV).ok();
280 unsafe { std::env::set_var(GEMINI_CLI_PATH_ENV, " ") };
282 let provider = GeminiCliProvider::new();
283 assert_eq!(provider.binary_path, PathBuf::from("gemini"));
284 match orig {
285 Some(v) => unsafe { std::env::set_var(GEMINI_CLI_PATH_ENV, v) },
287 None => unsafe { std::env::remove_var(GEMINI_CLI_PATH_ENV) },
289 }
290 }
291
292 #[test]
293 fn should_forward_model_standard() {
294 assert!(GeminiCliProvider::should_forward_model("gemini-2.5-pro"));
295 assert!(GeminiCliProvider::should_forward_model("gemini-2.5-flash"));
296 }
297
298 #[test]
299 fn should_not_forward_default_model() {
300 assert!(!GeminiCliProvider::should_forward_model(
301 DEFAULT_MODEL_MARKER
302 ));
303 assert!(!GeminiCliProvider::should_forward_model(""));
304 assert!(!GeminiCliProvider::should_forward_model(" "));
305 }
306
307 #[test]
308 fn validate_temperature_allows_defaults() {
309 assert!(GeminiCliProvider::validate_temperature(0.7).is_ok());
310 assert!(GeminiCliProvider::validate_temperature(1.0).is_ok());
311 }
312
313 #[test]
314 fn validate_temperature_rejects_custom_value() {
315 let err = GeminiCliProvider::validate_temperature(0.2).unwrap_err();
316 assert!(
317 err.to_string()
318 .contains("temperature unsupported by Gemini CLI")
319 );
320 }
321
322 #[tokio::test]
323 async fn invoke_missing_binary_returns_error() {
324 let provider = GeminiCliProvider {
325 binary_path: PathBuf::from("/nonexistent/path/to/gemini"),
326 };
327 let result = provider.invoke_cli("hello", "default").await;
328 assert!(result.is_err());
329 let msg = result.unwrap_err().to_string();
330 assert!(
331 msg.contains("Failed to spawn Gemini CLI binary"),
332 "unexpected error message: {msg}"
333 );
334 }
335}