1pub(crate) mod data;
7
8use crate::attachments::validate_request_attachments;
9use crate::provider::LlmProvider;
10use crate::streaming::{StreamBox, StreamDelta, StreamErrorKind};
11use agent_sdk_foundation::llm::{ChatOutcome, ChatRequest, ChatResponse, ThinkingConfig};
12use anyhow::Result;
13use async_trait::async_trait;
14use data::{
15 ApiContent, ApiFunctionCallingConfig, ApiGenerateContentRequest, ApiGenerateContentResponse,
16 ApiGenerationConfig, ApiPart, ApiUsageMetadata, build_api_contents, build_content_blocks,
17 convert_tools_to_config, gemini_response_schema, map_finish_reason, map_thinking_config,
18};
19use reqwest::StatusCode;
20
21const API_BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta";
22
23const CONNECT_TIMEOUT_SECS: u64 = 30;
25const TCP_KEEPALIVE_SECS: u64 = 30;
27const CHAT_READ_TIMEOUT_SECS: u64 = 300;
31
32fn build_http_client() -> reqwest::Client {
35 reqwest::Client::builder()
36 .connect_timeout(std::time::Duration::from_secs(CONNECT_TIMEOUT_SECS))
37 .tcp_keepalive(std::time::Duration::from_secs(TCP_KEEPALIVE_SECS))
38 .build()
39 .unwrap_or_else(|error| {
40 log::warn!(
41 "failed to build Gemini HTTP client with timeouts ({error}); using default client"
42 );
43 reqwest::Client::new()
44 })
45}
46
47pub const MODEL_GEMINI_31_PRO: &str = "gemini-3.1-pro-preview";
49pub const MODEL_GEMINI_31_FLASH_LITE: &str = "gemini-3.1-flash-lite-preview";
50
51pub const MODEL_GEMINI_3_FLASH: &str = "gemini-3-flash-preview";
53
54pub const MODEL_GEMINI_3_PRO: &str = "gemini-3.0-pro";
56
57pub const MODEL_GEMINI_25_FLASH: &str = "gemini-2.5-flash";
59pub const MODEL_GEMINI_25_PRO: &str = "gemini-2.5-pro";
60
61pub const MODEL_GEMINI_2_FLASH: &str = "gemini-2.0-flash";
63pub const MODEL_GEMINI_2_FLASH_LITE: &str = "gemini-2.0-flash-lite";
64
65#[derive(Clone)]
67pub struct GeminiProvider {
68 client: reqwest::Client,
69 api_key: String,
70 model: String,
71 base_url: String,
72 thinking: Option<ThinkingConfig>,
73 use_header_auth: bool,
76 extra_headers: Vec<(String, String)>,
78}
79
80impl GeminiProvider {
81 pub const API_KEY_ENV: &'static str = "GEMINI_API_KEY";
83
84 #[must_use]
86 pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
87 Self {
88 client: build_http_client(),
89 api_key: api_key.into(),
90 model: model.into(),
91 base_url: API_BASE_URL.to_owned(),
92 thinking: None,
93 use_header_auth: true,
94 extra_headers: Vec::new(),
95 }
96 }
97
98 fn effective_max_tokens(&self, request: &ChatRequest) -> u32 {
105 if request.max_tokens_explicit {
106 request.max_tokens
107 } else {
108 self.default_max_tokens()
109 }
110 }
111
112 #[must_use]
120 pub fn from_env() -> Self {
121 Self::try_from_env().unwrap_or_else(|e| panic!("{e}"))
122 }
123
124 pub fn try_from_env() -> Result<Self> {
131 let api_key = std::env::var(Self::API_KEY_ENV).map_err(|_| {
132 anyhow::anyhow!("environment variable `{}` is not set", Self::API_KEY_ENV)
133 })?;
134 Ok(Self::flash(api_key))
135 }
136
137 #[must_use]
139 pub fn flash(api_key: impl Into<String>) -> Self {
140 Self::new(api_key, MODEL_GEMINI_3_FLASH)
141 }
142
143 #[must_use]
145 pub fn flash_lite_31(api_key: String) -> Self {
146 Self::new(api_key, MODEL_GEMINI_31_FLASH_LITE.to_owned())
147 }
148
149 #[must_use]
151 pub fn flash_lite(api_key: String) -> Self {
152 Self::new(api_key, MODEL_GEMINI_2_FLASH_LITE.to_owned())
153 }
154
155 #[must_use]
157 pub fn pro_31(api_key: String) -> Self {
158 Self::new(api_key, MODEL_GEMINI_31_PRO.to_owned())
159 }
160
161 #[must_use]
163 pub fn pro(api_key: String) -> Self {
164 Self::new(api_key, MODEL_GEMINI_31_PRO.to_owned())
165 }
166
167 #[must_use]
169 pub const fn with_thinking(mut self, thinking: ThinkingConfig) -> Self {
170 self.thinking = Some(thinking);
171 self
172 }
173
174 #[must_use]
176 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
177 self.base_url = base_url.into();
178 self
179 }
180
181 #[must_use]
184 pub const fn with_header_auth(mut self) -> Self {
185 self.use_header_auth = true;
186 self
187 }
188
189 #[must_use]
191 pub fn with_extra_headers(mut self, headers: Vec<(String, String)>) -> Self {
192 self.extra_headers = headers;
193 self
194 }
195
196 fn apply_auth(&self, builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
199 let builder = if self.api_key.is_empty() {
200 builder
201 } else if self.use_header_auth {
202 builder.header("x-goog-api-key", &self.api_key)
203 } else {
204 builder.query(&[("key", &self.api_key)])
205 };
206 self.extra_headers
207 .iter()
208 .fold(builder, |b, (k, v)| b.header(k.as_str(), v.as_str()))
209 }
210}
211
212#[async_trait]
213#[allow(clippy::too_many_lines)]
214impl LlmProvider for GeminiProvider {
215 async fn chat(&self, request: ChatRequest) -> Result<ChatOutcome> {
216 let thinking = match self.resolve_thinking_config(request.thinking.as_ref()) {
217 Ok(thinking) => thinking,
218 Err(error) => return Ok(ChatOutcome::InvalidRequest(error.to_string())),
219 };
220 if let Err(error) = validate_request_attachments(self.provider(), self.model(), &request) {
221 return Ok(ChatOutcome::InvalidRequest(error.to_string()));
222 }
223 let contents = build_api_contents(&request.messages);
224 let tools = request
225 .tools
226 .as_ref()
227 .map(|t| convert_tools_to_config(t.clone()));
228 let tool_config = request
229 .tool_choice
230 .as_ref()
231 .map(ApiFunctionCallingConfig::from_tool_choice);
232 let system_instruction = if request.system.is_empty() {
233 None
234 } else {
235 Some(ApiContent {
236 role: None,
237 parts: vec![ApiPart::Text {
238 text: request.system.clone(),
239 thought_signature: None,
240 }],
241 })
242 };
243
244 let thinking_config = thinking.as_ref().map(map_thinking_config);
245 let (response_mime_type, response_schema) =
246 request.response_format.as_ref().map_or((None, None), |rf| {
247 (
248 Some("application/json"),
249 Some(gemini_response_schema(&rf.schema)),
250 )
251 });
252
253 let max_tokens = self.effective_max_tokens(&request);
254 let api_request = ApiGenerateContentRequest {
255 contents: &contents,
256 system_instruction: system_instruction.as_ref(),
257 tools: tools.as_ref().map(std::slice::from_ref),
258 tool_config,
259 generation_config: Some(ApiGenerationConfig {
260 max_output_tokens: Some(max_tokens),
261 thinking_config,
262 response_mime_type,
263 response_schema,
264 }),
265 cached_content: request.cached_content.as_deref(),
266 };
267
268 log::debug!(
269 "Gemini LLM request model={} max_tokens={}",
270 self.model,
271 max_tokens
272 );
273
274 let builder = self
275 .client
276 .post(format!(
277 "{}/models/{}:generateContent",
278 self.base_url, self.model
279 ))
280 .header("Content-Type", "application/json")
281 .timeout(std::time::Duration::from_secs(CHAT_READ_TIMEOUT_SECS));
282 let response = self
283 .apply_auth(builder)
284 .json(&api_request)
285 .send()
286 .await
287 .map_err(|e| anyhow::anyhow!("request failed: {e}"))?;
288
289 let status = response.status();
290 let bytes = response
291 .bytes()
292 .await
293 .map_err(|e| anyhow::anyhow!("failed to read response body: {e}"))?;
294
295 log::debug!(
296 "Gemini LLM response status={} body_len={}",
297 status,
298 bytes.len()
299 );
300
301 if status == StatusCode::TOO_MANY_REQUESTS {
302 return Ok(ChatOutcome::RateLimited);
303 }
304
305 if status.is_server_error() {
306 let body = String::from_utf8_lossy(&bytes);
307 log::error!("Gemini server error status={status} body={body}");
308 return Ok(ChatOutcome::ServerError(body.into_owned()));
309 }
310
311 if status.is_client_error() {
312 let body = String::from_utf8_lossy(&bytes);
313 log::warn!("Gemini client error status={status} body={body}");
314 return Ok(ChatOutcome::InvalidRequest(body.into_owned()));
315 }
316
317 let api_response: ApiGenerateContentResponse = serde_json::from_slice(&bytes)
318 .map_err(|e| anyhow::anyhow!("failed to parse response: {e}"))?;
319
320 let candidate = api_response
321 .candidates
322 .into_iter()
323 .next()
324 .ok_or_else(|| anyhow::anyhow!("no candidates in response"))?;
325
326 let content = build_content_blocks(&candidate.content);
327
328 if content.is_empty() && !candidate.content.parts.is_empty() {
329 log::warn!(
330 "Gemini parts not converted to content blocks raw_parts={:?}",
331 candidate.content.parts
332 );
333 }
334
335 let has_tool_calls = content
336 .iter()
337 .any(|b| matches!(b, agent_sdk_foundation::llm::ContentBlock::ToolUse { .. }));
338
339 let stop_reason = candidate
340 .finish_reason
341 .as_ref()
342 .map(|r| map_finish_reason(r, has_tool_calls));
343
344 let usage = api_response
345 .usage_metadata
346 .unwrap_or(ApiUsageMetadata {
347 prompt: 0,
348 candidates: 0,
349 cached_content: 0,
350 })
351 .into_usage();
352
353 Ok(ChatOutcome::Success(ChatResponse {
354 id: String::new(),
355 content,
356 model: self.model.clone(),
357 stop_reason,
358 usage,
359 }))
360 }
361
362 fn chat_stream(&self, request: ChatRequest) -> StreamBox<'_> {
363 Box::pin(async_stream::stream! {
364 let thinking = match self.resolve_thinking_config(request.thinking.as_ref()) {
365 Ok(thinking) => thinking,
366 Err(error) => {
367 yield Ok(StreamDelta::Error {
368 message: error.to_string(),
369 kind: StreamErrorKind::InvalidRequest,
370 });
371 return;
372 }
373 };
374 if let Err(error) = validate_request_attachments(self.provider(), self.model(), &request) {
375 yield Ok(StreamDelta::Error {
376 message: error.to_string(),
377 kind: StreamErrorKind::InvalidRequest,
378 });
379 return;
380 }
381 let contents = build_api_contents(&request.messages);
382 let tools = request
383 .tools
384 .as_ref()
385 .map(|t| convert_tools_to_config(t.clone()));
386 let tool_config = request
387 .tool_choice
388 .as_ref()
389 .map(ApiFunctionCallingConfig::from_tool_choice);
390 let system_instruction = if request.system.is_empty() {
391 None
392 } else {
393 Some(ApiContent {
394 role: None,
395 parts: vec![ApiPart::Text {
396 text: request.system.clone(),
397 thought_signature: None,
398 }],
399 })
400 };
401
402 let thinking_config = thinking.as_ref().map(map_thinking_config);
403 let (response_mime_type, response_schema) = request
404 .response_format
405 .as_ref()
406 .map_or((None, None), |rf| {
407 (
408 Some("application/json"),
409 Some(gemini_response_schema(&rf.schema)),
410 )
411 });
412
413 let max_tokens = self.effective_max_tokens(&request);
414 let api_request = ApiGenerateContentRequest {
415 contents: &contents,
416 system_instruction: system_instruction.as_ref(),
417 tools: tools.as_ref().map(std::slice::from_ref),
418 tool_config,
419 generation_config: Some(ApiGenerationConfig {
420 max_output_tokens: Some(max_tokens),
421 thinking_config,
422 response_mime_type,
423 response_schema,
424 }),
425 cached_content: request.cached_content.as_deref(),
426 };
427
428 log::debug!(
429 "Gemini streaming LLM request model={} max_tokens={}",
430 self.model,
431 max_tokens
432 );
433
434 let stream_builder = self
435 .client
436 .post(format!(
437 "{}/models/{}:streamGenerateContent",
438 self.base_url, self.model
439 ))
440 .header("Content-Type", "application/json")
441 .query(&[("alt", "sse")]);
442 let response = match self
443 .apply_auth(stream_builder)
444 .json(&api_request)
445 .send()
446 .await
447 {
448 Ok(r) => r,
449 Err(e) => {
450 yield Err(anyhow::anyhow!("request failed: {e}"));
452 return;
453 }
454 };
455
456 let status = response.status();
457 if !status.is_success() {
458 let body = response.text().await.unwrap_or_default();
459 let kind = if status == StatusCode::TOO_MANY_REQUESTS {
460 StreamErrorKind::RateLimited
461 } else if status.is_server_error() {
462 StreamErrorKind::ServerError
463 } else {
464 StreamErrorKind::InvalidRequest
465 };
466 log::warn!("Gemini error status={status} body={body}");
467 yield Ok(StreamDelta::Error {
468 message: body,
469 kind,
470 });
471 return;
472 }
473
474 let mut inner = data::stream_gemini_response(response);
475 while let Some(item) = futures::StreamExt::next(&mut inner).await {
476 yield item;
477 }
478 })
479 }
480
481 fn model(&self) -> &str {
482 &self.model
483 }
484
485 fn provider(&self) -> &'static str {
486 "gemini"
487 }
488
489 fn configured_thinking(&self) -> Option<&ThinkingConfig> {
490 self.thinking.as_ref()
491 }
492}
493
494#[cfg(test)]
495mod tests {
496 use super::*;
497
498 #[test]
499 fn test_new_creates_provider_with_custom_model() {
500 let provider = GeminiProvider::new("test-api-key".to_string(), "custom-model".to_string());
501
502 assert_eq!(provider.model(), "custom-model");
503 assert_eq!(provider.provider(), "gemini");
504 }
505
506 #[test]
507 fn test_flash_factory_creates_flash_provider() {
508 let provider = GeminiProvider::flash("test-api-key".to_string());
509
510 assert_eq!(provider.model(), MODEL_GEMINI_3_FLASH);
511 assert_eq!(provider.provider(), "gemini");
512 }
513
514 #[test]
515 fn test_flash_lite_factory_creates_flash_lite_provider() {
516 let provider = GeminiProvider::flash_lite("test-api-key".to_string());
517
518 assert_eq!(provider.model(), MODEL_GEMINI_2_FLASH_LITE);
519 assert_eq!(provider.provider(), "gemini");
520 }
521
522 #[test]
523 fn test_flash_lite_31_factory_creates_flash_lite_provider() {
524 let provider = GeminiProvider::flash_lite_31("test-api-key".to_string());
525
526 assert_eq!(provider.model(), MODEL_GEMINI_31_FLASH_LITE);
527 assert_eq!(provider.provider(), "gemini");
528 }
529
530 #[test]
531 fn test_pro_factory_creates_pro_provider() {
532 let provider = GeminiProvider::pro("test-api-key".to_string());
533
534 assert_eq!(provider.model(), MODEL_GEMINI_31_PRO);
535 assert_eq!(provider.provider(), "gemini");
536 }
537
538 #[test]
539 fn test_pro_31_factory_creates_pro_provider() {
540 let provider = GeminiProvider::pro_31("test-api-key".to_string());
541
542 assert_eq!(provider.model(), MODEL_GEMINI_31_PRO);
543 assert_eq!(provider.provider(), "gemini");
544 }
545
546 #[test]
547 fn test_model_constants_have_expected_values() {
548 assert_eq!(MODEL_GEMINI_31_PRO, "gemini-3.1-pro-preview");
549 assert_eq!(MODEL_GEMINI_31_FLASH_LITE, "gemini-3.1-flash-lite-preview");
550 assert_eq!(MODEL_GEMINI_3_FLASH, "gemini-3-flash-preview");
551 assert_eq!(MODEL_GEMINI_3_PRO, "gemini-3.0-pro");
552 assert_eq!(MODEL_GEMINI_25_FLASH, "gemini-2.5-flash");
553 assert_eq!(MODEL_GEMINI_25_PRO, "gemini-2.5-pro");
554 assert_eq!(MODEL_GEMINI_2_FLASH, "gemini-2.0-flash");
555 assert_eq!(MODEL_GEMINI_2_FLASH_LITE, "gemini-2.0-flash-lite");
556 }
557
558 #[test]
559 fn test_gemini_20_models_reject_thinking() {
560 let provider = GeminiProvider::flash_lite("test-api-key".to_string());
561 let error = provider
562 .validate_thinking_config(Some(&ThinkingConfig::new(10_000)))
563 .unwrap_err();
564 assert!(error.to_string().contains("thinking is not supported"));
565 }
566
567 #[test]
568 fn test_default_uses_header_auth() {
569 let provider = GeminiProvider::new("test-key".to_string(), "model".to_string());
570 assert!(
571 provider.use_header_auth,
572 "Default should use header auth for security"
573 );
574 }
575
576 #[test]
577 fn test_provider_is_cloneable() {
578 let provider = GeminiProvider::new("test-api-key".to_string(), "test-model".to_string());
579 let cloned = provider.clone();
580
581 assert_eq!(provider.model(), cloned.model());
582 assert_eq!(provider.provider(), cloned.provider());
583 }
584
585 fn request_with_max_tokens(max_tokens: u32, explicit: bool) -> ChatRequest {
586 ChatRequest {
587 system: String::new(),
588 messages: vec![agent_sdk_foundation::llm::Message::user("hi")],
589 tools: None,
590 max_tokens,
591 max_tokens_explicit: explicit,
592 session_id: None,
593 cached_content: None,
594 thinking: None,
595 tool_choice: None,
596 response_format: None,
597 }
598 }
599
600 #[test]
601 fn test_effective_max_tokens_honors_explicit_budget() {
602 let provider = GeminiProvider::pro("test-api-key".to_string());
603 let request = request_with_max_tokens(123, true);
604 assert_eq!(provider.effective_max_tokens(&request), 123);
605 }
606
607 #[test]
608 fn test_effective_max_tokens_uses_default_when_implicit() {
609 let provider = GeminiProvider::pro("test-api-key".to_string());
612 let request = request_with_max_tokens(4096, false);
613 assert_eq!(
614 provider.effective_max_tokens(&request),
615 provider.default_max_tokens()
616 );
617 }
618}