1use crate::auth::AuthService;
8use crate::providers::traits::{ChatMessage, Provider, TokenUsage};
9use async_trait::async_trait;
10use base64::Engine;
11use directories::UserDirs;
12use reqwest::Client;
13use serde::{Deserialize, Serialize};
14use std::path::PathBuf;
15use std::sync::Arc;
16
17pub struct GeminiProvider {
19 auth: Option<GeminiAuth>,
20 oauth_project: Arc<tokio::sync::Mutex<Option<String>>>,
21 oauth_cred_paths: Vec<PathBuf>,
22 oauth_index: Arc<tokio::sync::Mutex<usize>>,
23 auth_service: Option<AuthService>,
25 auth_profile_override: Option<String>,
27}
28
29struct OAuthTokenState {
31 access_token: String,
32 refresh_token: Option<String>,
33 client_id: Option<String>,
34 client_secret: Option<String>,
35 expiry_millis: Option<i64>,
37}
38
39enum GeminiAuth {
42 ExplicitKey(String),
44 EnvGeminiKey(String),
46 EnvGoogleKey(String),
48 OAuthToken(Arc<tokio::sync::Mutex<OAuthTokenState>>),
51 ManagedOAuth,
54}
55
56impl GeminiAuth {
57 fn is_api_key(&self) -> bool {
59 matches!(
60 self,
61 GeminiAuth::ExplicitKey(_) | GeminiAuth::EnvGeminiKey(_) | GeminiAuth::EnvGoogleKey(_)
62 )
63 }
64
65 fn is_oauth(&self) -> bool {
67 matches!(self, GeminiAuth::OAuthToken(_) | GeminiAuth::ManagedOAuth)
68 }
69
70 fn api_key_credential(&self) -> &str {
72 match self {
73 GeminiAuth::ExplicitKey(s)
74 | GeminiAuth::EnvGeminiKey(s)
75 | GeminiAuth::EnvGoogleKey(s) => s,
76 GeminiAuth::OAuthToken(_) | GeminiAuth::ManagedOAuth => "",
77 }
78 }
79}
80
81#[derive(Debug, Serialize, Clone)]
86struct GenerateContentRequest {
87 contents: Vec<Content>,
88 #[serde(rename = "systemInstruction", skip_serializing_if = "Option::is_none")]
89 system_instruction: Option<Content>,
90 #[serde(rename = "generationConfig")]
91 generation_config: GenerationConfig,
92}
93
94#[derive(Debug, Serialize)]
111struct InternalGenerateContentEnvelope {
112 model: String,
113 #[serde(skip_serializing_if = "Option::is_none")]
114 project: Option<String>,
115 #[serde(skip_serializing_if = "Option::is_none")]
116 user_prompt_id: Option<String>,
117 request: InternalGenerateContentRequest,
118}
119
120#[derive(Debug, Serialize)]
122struct InternalGenerateContentRequest {
123 contents: Vec<Content>,
124 #[serde(rename = "systemInstruction", skip_serializing_if = "Option::is_none")]
125 system_instruction: Option<Content>,
126 #[serde(rename = "generationConfig", skip_serializing_if = "Option::is_none")]
127 generation_config: Option<GenerationConfig>,
128}
129
130#[derive(Debug, Serialize, Clone)]
131struct Content {
132 #[serde(skip_serializing_if = "Option::is_none")]
133 role: Option<String>,
134 parts: Vec<Part>,
135}
136
137#[derive(Debug, Serialize, Clone)]
138#[serde(untagged)]
139enum Part {
140 Text { text: String },
141 Inline { inline_data: InlineData },
142}
143
144impl Part {
145 fn text(s: impl Into<String>) -> Self {
146 Part::Text { text: s.into() }
147 }
148}
149
150#[derive(Debug, Serialize, Clone)]
151struct InlineData {
152 mime_type: String,
153 data: String,
154}
155
156fn build_parts(content: &str) -> Vec<Part> {
161 let (text, image_refs) = crate::multimodal::parse_image_markers(content);
162 let mut parts = Vec::new();
163 let trimmed = text.trim();
164 if !trimmed.is_empty() {
165 parts.push(Part::text(trimmed));
166 }
167 for uri in &image_refs {
168 if let Some(rest) = uri.strip_prefix("data:") {
169 if let Some(semi_pos) = rest.find(';') {
170 let mime = &rest[..semi_pos];
171 if let Some(b64) = rest[semi_pos + 1..].strip_prefix("base64,") {
172 parts.push(Part::Inline {
173 inline_data: InlineData {
174 mime_type: mime.to_string(),
175 data: b64.to_string(),
176 },
177 });
178 }
179 }
180 }
181 }
182 if parts.is_empty() {
183 parts.push(Part::text(content));
184 }
185 parts
186}
187
188#[derive(Debug, Serialize, Clone)]
189struct GenerationConfig {
190 temperature: f64,
191 #[serde(rename = "maxOutputTokens")]
192 max_output_tokens: u32,
193}
194
195#[derive(Debug, Deserialize)]
196struct GenerateContentResponse {
197 candidates: Option<Vec<Candidate>>,
198 error: Option<ApiError>,
199 #[serde(default)]
200 response: Option<Box<GenerateContentResponse>>,
201 #[serde(default, rename = "usageMetadata")]
202 usage_metadata: Option<GeminiUsageMetadata>,
203}
204
205#[derive(Debug, Deserialize)]
206struct GeminiUsageMetadata {
207 #[serde(default, rename = "promptTokenCount")]
208 prompt_token_count: Option<u64>,
209 #[serde(default, rename = "candidatesTokenCount")]
210 candidates_token_count: Option<u64>,
211}
212
213#[derive(Debug, Deserialize)]
216struct InternalGenerateContentResponse {
217 response: GenerateContentResponse,
218}
219
220#[derive(Debug, Deserialize)]
221struct Candidate {
222 #[serde(default)]
223 content: Option<CandidateContent>,
224}
225
226#[derive(Debug, Deserialize)]
227struct CandidateContent {
228 parts: Vec<ResponsePart>,
229}
230
231#[derive(Debug, Deserialize)]
232struct ResponsePart {
233 #[serde(default)]
234 text: Option<String>,
235 #[serde(default)]
237 thought: bool,
238}
239
240impl CandidateContent {
241 fn effective_text(self) -> Option<String> {
251 let mut answer_parts: Vec<String> = Vec::new();
252 let mut first_thinking: Option<String> = None;
253
254 for part in self.parts {
255 if let Some(text) = part.text {
256 if text.is_empty() {
257 continue;
258 }
259 if !part.thought {
260 answer_parts.push(text);
261 } else if first_thinking.is_none() {
262 first_thinking = Some(text);
263 }
264 }
265 }
266
267 if answer_parts.is_empty() {
268 first_thinking
269 } else {
270 Some(answer_parts.join(""))
271 }
272 }
273}
274
275#[derive(Debug, Deserialize)]
276struct ApiError {
277 message: String,
278}
279
280impl GenerateContentResponse {
281 fn into_effective_response(self) -> Self {
283 match self {
284 Self {
285 response: Some(inner),
286 ..
287 } => *inner,
288 other => other,
289 }
290 }
291}
292
293#[derive(Debug, Deserialize)]
299struct GeminiCliOAuthCreds {
300 access_token: Option<String>,
301 #[serde(alias = "idToken")]
302 id_token: Option<String>,
303 refresh_token: Option<String>,
304 #[serde(alias = "clientId")]
305 client_id: Option<String>,
306 #[serde(alias = "clientSecret")]
307 client_secret: Option<String>,
308 #[serde(alias = "expiryDate")]
310 expiry_date: Option<i64>,
311 expiry: Option<String>,
313}
314
315const GOOGLE_TOKEN_ENDPOINT: &str = "https://oauth2.googleapis.com/token";
321
322const CLOUDCODE_PA_ENDPOINT: &str = "https://cloudcode-pa.googleapis.com/v1internal";
325
326const LOAD_CODE_ASSIST_ENDPOINT: &str =
328 "https://cloudcode-pa.googleapis.com/v1internal:loadCodeAssist";
329
330const PUBLIC_API_ENDPOINT: &str = "https://generativelanguage.googleapis.com/v1beta";
332
333struct RefreshedToken {
339 access_token: String,
340 expiry_millis: Option<i64>,
342}
343
344fn refresh_gemini_cli_token(
350 refresh_token: &str,
351 client_id: Option<&str>,
352 client_secret: Option<&str>,
353) -> anyhow::Result<RefreshedToken> {
354 let client = reqwest::blocking::Client::builder()
355 .timeout(std::time::Duration::from_secs(15))
356 .connect_timeout(std::time::Duration::from_secs(5))
357 .build()
358 .unwrap_or_else(|_| reqwest::blocking::Client::new());
359
360 let form = build_oauth_refresh_form(refresh_token, client_id, client_secret);
361
362 let response = client
363 .post(GOOGLE_TOKEN_ENDPOINT)
364 .header("Content-Type", "application/x-www-form-urlencoded")
365 .header("Accept", "application/json")
366 .form(&form)
367 .send()
368 .map_err(|error| anyhow::anyhow!("Gemini CLI OAuth refresh request failed: {error}"))?;
369
370 let status = response.status();
371 let body = response
372 .text()
373 .unwrap_or_else(|_| "<failed to read response body>".to_string());
374
375 if !status.is_success() {
376 anyhow::bail!("Gemini CLI OAuth refresh failed (HTTP {status}): {body}");
377 }
378
379 #[derive(Deserialize)]
380 struct TokenResponse {
381 access_token: Option<String>,
382 expires_in: Option<i64>,
383 }
384
385 let parsed: TokenResponse = serde_json::from_str(&body)
386 .map_err(|_| anyhow::anyhow!("Gemini CLI OAuth refresh response is not valid JSON"))?;
387
388 let access_token = parsed
389 .access_token
390 .filter(|t| !t.trim().is_empty())
391 .ok_or_else(|| anyhow::anyhow!("Gemini CLI OAuth refresh response missing access_token"))?;
392
393 let expiry_millis = parsed.expires_in.and_then(|secs| {
394 let now_millis = std::time::SystemTime::now()
395 .duration_since(std::time::UNIX_EPOCH)
396 .ok()
397 .and_then(|d| i64::try_from(d.as_millis()).ok())?;
398 now_millis.checked_add(secs.checked_mul(1000)?)
399 });
400
401 Ok(RefreshedToken {
402 access_token,
403 expiry_millis,
404 })
405}
406
407fn build_oauth_refresh_form(
408 refresh_token: &str,
409 client_id: Option<&str>,
410 client_secret: Option<&str>,
411) -> Vec<(&'static str, String)> {
412 let mut form = vec![
413 ("grant_type", "refresh_token".to_string()),
414 ("refresh_token", refresh_token.to_string()),
415 ];
416 if let Some(id) = client_id.and_then(GeminiProvider::normalize_non_empty) {
417 form.push(("client_id", id));
418 }
419 if let Some(secret) = client_secret.and_then(GeminiProvider::normalize_non_empty) {
420 form.push(("client_secret", secret));
421 }
422 form
423}
424
425fn extract_client_id_from_id_token(id_token: &str) -> Option<String> {
426 let payload = id_token.split('.').nth(1)?;
427 let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
428 .decode(payload)
429 .or_else(|_| base64::engine::general_purpose::URL_SAFE.decode(payload))
430 .ok()?;
431
432 #[derive(Deserialize)]
433 struct IdTokenClaims {
434 aud: Option<String>,
435 azp: Option<String>,
436 }
437
438 let claims: IdTokenClaims = serde_json::from_slice(&decoded).ok()?;
439 claims
440 .aud
441 .as_deref()
442 .and_then(GeminiProvider::normalize_non_empty)
443 .or_else(|| {
444 claims
445 .azp
446 .as_deref()
447 .and_then(GeminiProvider::normalize_non_empty)
448 })
449}
450
451async fn refresh_gemini_cli_token_async(
453 refresh_token: &str,
454 client_id: Option<&str>,
455 client_secret: Option<&str>,
456) -> anyhow::Result<RefreshedToken> {
457 let refresh_token = refresh_token.to_string();
458 let client_id = client_id.map(str::to_string);
459 let client_secret = client_secret.map(str::to_string);
460 tokio::task::spawn_blocking(move || {
461 refresh_gemini_cli_token(
462 &refresh_token,
463 client_id.as_deref(),
464 client_secret.as_deref(),
465 )
466 })
467 .await
468 .map_err(|e| anyhow::anyhow!("Token refresh task panicked: {e}"))?
469}
470
471impl GeminiProvider {
472 pub fn new(api_key: Option<&str>) -> Self {
480 let oauth_cred_paths = Self::discover_oauth_cred_paths();
481 let resolved_auth = api_key
482 .and_then(Self::normalize_non_empty)
483 .map(GeminiAuth::ExplicitKey)
484 .or_else(|| Self::load_non_empty_env("GEMINI_API_KEY").map(GeminiAuth::EnvGeminiKey))
485 .or_else(|| Self::load_non_empty_env("GOOGLE_API_KEY").map(GeminiAuth::EnvGoogleKey))
486 .or_else(|| {
487 Self::try_load_gemini_cli_token(oauth_cred_paths.first())
488 .map(|state| GeminiAuth::OAuthToken(Arc::new(tokio::sync::Mutex::new(state))))
489 });
490
491 Self {
492 auth: resolved_auth,
493 oauth_project: Arc::new(tokio::sync::Mutex::new(None)),
494 oauth_cred_paths,
495 oauth_index: Arc::new(tokio::sync::Mutex::new(0)),
496 auth_service: None,
497 auth_profile_override: None,
498 }
499 }
500
501 pub fn new_with_auth(
510 api_key: Option<&str>,
511 auth_service: AuthService,
512 profile_override: Option<String>,
513 ) -> Self {
514 let oauth_cred_paths = Self::discover_oauth_cred_paths();
515
516 let resolved_auth = api_key
518 .and_then(Self::normalize_non_empty)
519 .map(GeminiAuth::ExplicitKey)
520 .or_else(|| Self::load_non_empty_env("GEMINI_API_KEY").map(GeminiAuth::EnvGeminiKey))
521 .or_else(|| Self::load_non_empty_env("GOOGLE_API_KEY").map(GeminiAuth::EnvGoogleKey));
522
523 let (auth, use_managed) = if resolved_auth.is_some() {
526 (resolved_auth, false)
527 } else {
528 let has_managed = std::thread::scope(|s| {
531 s.spawn(|| {
532 let rt = tokio::runtime::Builder::new_current_thread()
533 .enable_all()
534 .build()
535 .ok()?;
536 rt.block_on(async {
537 auth_service
538 .get_gemini_profile(profile_override.as_deref())
539 .await
540 .ok()
541 .flatten()
542 })
543 })
544 .join()
545 .ok()
546 .flatten()
547 .is_some()
548 });
549
550 if has_managed {
551 (Some(GeminiAuth::ManagedOAuth), true)
552 } else {
553 let cli_auth = Self::try_load_gemini_cli_token(oauth_cred_paths.first())
555 .map(|state| GeminiAuth::OAuthToken(Arc::new(tokio::sync::Mutex::new(state))));
556 (cli_auth, false)
557 }
558 };
559
560 Self {
561 auth,
562 oauth_project: Arc::new(tokio::sync::Mutex::new(None)),
563 oauth_cred_paths,
564 oauth_index: Arc::new(tokio::sync::Mutex::new(0)),
565 auth_service: if use_managed {
566 Some(auth_service)
567 } else {
568 None
569 },
570 auth_profile_override: profile_override,
571 }
572 }
573
574 fn normalize_non_empty(value: &str) -> Option<String> {
575 let trimmed = value.trim();
576 if trimmed.is_empty() {
577 None
578 } else {
579 Some(trimmed.to_string())
580 }
581 }
582
583 fn load_non_empty_env(name: &str) -> Option<String> {
584 std::env::var(name)
585 .ok()
586 .and_then(|value| Self::normalize_non_empty(&value))
587 }
588
589 fn load_gemini_cli_creds(creds_path: &PathBuf) -> Option<GeminiCliOAuthCreds> {
590 if !creds_path.exists() {
591 return None;
592 }
593 let content = std::fs::read_to_string(creds_path).ok()?;
594 serde_json::from_str(&content).ok()
595 }
596
597 fn discover_oauth_cred_paths() -> Vec<PathBuf> {
602 let home = match UserDirs::new() {
603 Some(u) => u.home_dir().to_path_buf(),
604 None => return Vec::new(),
605 };
606
607 let mut paths = Vec::new();
608
609 let primary = home.join(".gemini").join("oauth_creds.json");
610 if primary.exists() {
611 paths.push(primary);
612 }
613
614 if let Ok(entries) = std::fs::read_dir(&home) {
615 let mut extras: Vec<PathBuf> = entries
616 .filter_map(|e| e.ok())
617 .filter_map(|e| {
618 let name = e.file_name().to_string_lossy().to_string();
619 if name.starts_with(".gemini-") && name.ends_with("-home") {
620 let path = e.path().join(".gemini").join("oauth_creds.json");
621 if path.exists() {
622 return Some(path);
623 }
624 }
625 None
626 })
627 .collect();
628 extras.sort();
629 paths.extend(extras);
630 }
631
632 paths
633 }
634
635 fn try_load_gemini_cli_token(path: Option<&PathBuf>) -> Option<OAuthTokenState> {
640 let creds = Self::load_gemini_cli_creds(path?)?;
641
642 let expiry_millis = creds.expiry_date.or_else(|| {
644 creds.expiry.as_deref().and_then(|expiry| {
645 chrono::DateTime::parse_from_rfc3339(expiry)
646 .ok()
647 .map(|dt| dt.timestamp_millis())
648 })
649 });
650
651 let access_token = creds
652 .access_token
653 .and_then(|token| Self::normalize_non_empty(&token))?;
654
655 let id_token_client_id = creds
656 .id_token
657 .as_deref()
658 .and_then(extract_client_id_from_id_token);
659
660 let client_id = Self::load_non_empty_env("GEMINI_OAUTH_CLIENT_ID")
661 .or_else(|| {
662 creds
663 .client_id
664 .as_deref()
665 .and_then(Self::normalize_non_empty)
666 })
667 .or(id_token_client_id);
668 let client_secret = Self::load_non_empty_env("GEMINI_OAUTH_CLIENT_SECRET").or_else(|| {
669 creds
670 .client_secret
671 .as_deref()
672 .and_then(Self::normalize_non_empty)
673 });
674
675 Some(OAuthTokenState {
676 access_token,
677 refresh_token: creds.refresh_token,
678 client_id,
679 client_secret,
680 expiry_millis,
681 })
682 }
683
684 fn gemini_cli_dir() -> Option<PathBuf> {
686 UserDirs::new().map(|u| u.home_dir().join(".gemini"))
687 }
688
689 pub fn has_cli_credentials() -> bool {
691 Self::discover_oauth_cred_paths().iter().any(|path| {
692 Self::load_gemini_cli_creds(path)
693 .and_then(|creds| {
694 creds
695 .access_token
696 .as_deref()
697 .and_then(Self::normalize_non_empty)
698 })
699 .is_some()
700 })
701 }
702
703 pub fn has_any_auth() -> bool {
705 Self::load_non_empty_env("GEMINI_API_KEY").is_some()
706 || Self::load_non_empty_env("GOOGLE_API_KEY").is_some()
707 || Self::has_cli_credentials()
708 }
709
710 pub fn auth_source(&self) -> &'static str {
713 match self.auth.as_ref() {
714 Some(GeminiAuth::ExplicitKey(_)) => "config",
715 Some(GeminiAuth::EnvGeminiKey(_)) => "GEMINI_API_KEY env var",
716 Some(GeminiAuth::EnvGoogleKey(_)) => "GOOGLE_API_KEY env var",
717 Some(GeminiAuth::OAuthToken(_)) => "Gemini CLI OAuth",
718 Some(GeminiAuth::ManagedOAuth) => "auth-profiles",
719 None => "none",
720 }
721 }
722
723 async fn get_valid_oauth_token(
726 state: &Arc<tokio::sync::Mutex<OAuthTokenState>>,
727 ) -> anyhow::Result<String> {
728 let mut guard = state.lock().await;
729
730 let now_millis = std::time::SystemTime::now()
731 .duration_since(std::time::UNIX_EPOCH)
732 .ok()
733 .and_then(|d| i64::try_from(d.as_millis()).ok())
734 .unwrap_or(i64::MAX);
735
736 let needs_refresh = guard
738 .expiry_millis
739 .map_or(true, |exp| exp <= now_millis.saturating_add(60_000));
740
741 if needs_refresh {
742 if let Some(ref refresh_token) = guard.refresh_token {
743 let refreshed = refresh_gemini_cli_token_async(
744 refresh_token,
745 guard.client_id.as_deref(),
746 guard.client_secret.as_deref(),
747 )
748 .await?;
749 tracing::info!("Gemini CLI OAuth token refreshed successfully (runtime)");
750 guard.access_token = refreshed.access_token;
751 guard.expiry_millis = refreshed.expiry_millis;
752 } else {
753 anyhow::bail!(
754 "Gemini CLI OAuth token expired and no refresh_token available — re-run `gemini` to authenticate"
755 );
756 }
757 }
758
759 Ok(guard.access_token.clone())
760 }
761
762 async fn rotate_oauth_credential(
765 &self,
766 state: &Arc<tokio::sync::Mutex<OAuthTokenState>>,
767 ) -> bool {
768 if self.oauth_cred_paths.len() <= 1 {
769 return false;
770 }
771
772 let mut idx = self.oauth_index.lock().await;
773 let start = *idx;
774
775 loop {
776 let next = (*idx + 1) % self.oauth_cred_paths.len();
777 *idx = next;
778
779 if next == start {
780 return false;
781 }
782
783 if let Some(next_state) =
784 Self::try_load_gemini_cli_token(self.oauth_cred_paths.get(next))
785 {
786 {
787 let mut guard = state.lock().await;
788 *guard = next_state;
789 }
790 {
791 let mut cached_project = self.oauth_project.lock().await;
792 *cached_project = None;
793 }
794 tracing::warn!(
795 "Gemini OAuth: rotated credential to {}",
796 self.oauth_cred_paths[next].display()
797 );
798 return true;
799 }
800 }
801 }
802
803 fn format_model_name(model: &str) -> String {
804 if model.starts_with("models/") {
805 model.to_string()
806 } else {
807 format!("models/{model}")
808 }
809 }
810
811 fn format_internal_model_name(model: &str) -> String {
812 model.strip_prefix("models/").unwrap_or(model).to_string()
813 }
814
815 fn build_generate_content_url(model: &str, auth: &GeminiAuth) -> String {
825 match auth {
826 GeminiAuth::OAuthToken(_) | GeminiAuth::ManagedOAuth => {
827 format!("{CLOUDCODE_PA_ENDPOINT}:generateContent")
830 }
831 _ => {
832 let model_name = Self::format_model_name(model);
833 let base_url = format!("{PUBLIC_API_ENDPOINT}/{model_name}:generateContent");
834
835 if auth.is_api_key() {
836 format!("{base_url}?key={}", auth.api_key_credential())
837 } else {
838 base_url
839 }
840 }
841 }
842 }
843
844 fn http_client(&self) -> Client {
845 crate::config::build_runtime_proxy_client_with_timeouts("provider.gemini", 120, 10)
846 }
847
848 async fn resolve_oauth_project(&self, token: &str) -> anyhow::Result<String> {
851 let project_seed = Self::load_non_empty_env("GOOGLE_CLOUD_PROJECT")
852 .or_else(|| Self::load_non_empty_env("GOOGLE_CLOUD_PROJECT_ID"));
853 let project_seed_for_request = project_seed.clone();
854 let duet_project_for_request = project_seed.clone();
855
856 {
858 let cached = self.oauth_project.lock().await;
859 if let Some(ref project) = *cached {
860 return Ok(project.clone());
861 }
862 }
863
864 let client = self.http_client();
866 let response = client
867 .post(LOAD_CODE_ASSIST_ENDPOINT)
868 .bearer_auth(token)
869 .json(&serde_json::json!({
870 "cloudaicompanionProject": project_seed_for_request,
871 "metadata": {
872 "ideType": "GEMINI_CLI",
873 "platform": "PLATFORM_UNSPECIFIED",
874 "pluginType": "GEMINI",
875 "duetProject": duet_project_for_request,
876 }
877 }))
878 .send()
879 .await?;
880
881 if !response.status().is_success() {
882 let status = response.status();
883 let body = response.text().await.unwrap_or_default();
884 if let Some(seed) = project_seed {
885 tracing::warn!(
886 "loadCodeAssist failed (HTTP {status}); using GOOGLE_CLOUD_PROJECT fallback"
887 );
888 return Ok(seed);
889 }
890 anyhow::bail!("loadCodeAssist failed (HTTP {status}): {body}");
891 }
892
893 #[derive(Deserialize)]
894 struct LoadCodeAssistResponse {
895 #[serde(rename = "cloudaicompanionProject")]
896 cloudaicompanion_project: Option<String>,
897 }
898
899 let result: LoadCodeAssistResponse = response.json().await?;
900 let project = result
901 .cloudaicompanion_project
902 .filter(|p| !p.trim().is_empty())
903 .or(project_seed)
904 .ok_or_else(|| anyhow::anyhow!("loadCodeAssist response missing project context"))?;
905
906 {
908 let mut cached = self.oauth_project.lock().await;
909 *cached = Some(project.clone());
910 }
911
912 Ok(project)
913 }
914
915 fn build_generate_content_request(
920 &self,
921 auth: &GeminiAuth,
922 url: &str,
923 request: &GenerateContentRequest,
924 model: &str,
925 include_generation_config: bool,
926 project: Option<&str>,
927 oauth_token: Option<&str>,
928 ) -> reqwest::RequestBuilder {
929 let req = self.http_client().post(url).json(request);
930 match auth {
931 GeminiAuth::OAuthToken(_) | GeminiAuth::ManagedOAuth => {
932 let token = oauth_token.unwrap_or_default();
933 let internal_request = InternalGenerateContentEnvelope {
936 model: Self::format_internal_model_name(model),
937 project: project.map(|value| value.to_string()),
938 user_prompt_id: Some(uuid::Uuid::new_v4().to_string()),
939 request: InternalGenerateContentRequest {
940 contents: request.contents.clone(),
941 system_instruction: request.system_instruction.clone(),
942 generation_config: if include_generation_config {
943 Some(request.generation_config.clone())
944 } else {
945 None
946 },
947 },
948 };
949 self.http_client()
950 .post(url)
951 .json(&internal_request)
952 .bearer_auth(token)
953 }
954 _ => req,
955 }
956 }
957
958 fn should_retry_oauth_without_generation_config(
959 status: reqwest::StatusCode,
960 error_text: &str,
961 ) -> bool {
962 if status != reqwest::StatusCode::BAD_REQUEST {
963 return false;
964 }
965
966 error_text.contains("Unknown name \"generationConfig\"")
967 || error_text.contains("Unknown name 'generationConfig'")
968 || error_text.contains(r#"Unknown name \"generationConfig\""#)
969 }
970
971 fn should_rotate_oauth_on_error(status: reqwest::StatusCode, error_text: &str) -> bool {
972 status == reqwest::StatusCode::TOO_MANY_REQUESTS
973 || status == reqwest::StatusCode::SERVICE_UNAVAILABLE
974 || status.is_server_error()
975 || error_text.contains("RESOURCE_EXHAUSTED")
976 }
977}
978
979impl GeminiProvider {
980 async fn send_generate_content(
981 &self,
982 contents: Vec<Content>,
983 system_instruction: Option<Content>,
984 model: &str,
985 temperature: f64,
986 ) -> anyhow::Result<(String, Option<TokenUsage>)> {
987 let auth = self.auth.as_ref().ok_or_else(|| {
988 anyhow::anyhow!(
989 "Gemini API key not found. Options:\n\
990 1. Set GEMINI_API_KEY env var\n\
991 2. Run `gemini` CLI to authenticate (tokens will be reused)\n\
992 3. Run `construct auth login --provider gemini`\n\
993 4. Get an API key from https://aistudio.google.com/app/apikey\n\
994 5. Run `construct onboard` to configure"
995 )
996 })?;
997
998 let oauth_state = match auth {
999 GeminiAuth::OAuthToken(state) => Some(state.clone()),
1000 _ => None,
1001 };
1002
1003 let (mut oauth_token, mut project) = match auth {
1005 GeminiAuth::OAuthToken(state) => {
1006 let token = Self::get_valid_oauth_token(state).await?;
1007 let proj = self.resolve_oauth_project(&token).await?;
1008 (Some(token), Some(proj))
1009 }
1010 GeminiAuth::ManagedOAuth => {
1011 let auth_service = self
1012 .auth_service
1013 .as_ref()
1014 .ok_or_else(|| anyhow::anyhow!("ManagedOAuth requires auth_service"))?;
1015 let token = auth_service
1016 .get_valid_gemini_access_token(self.auth_profile_override.as_deref())
1017 .await?
1018 .ok_or_else(|| {
1019 anyhow::anyhow!(
1020 "Gemini auth profile not found. Run `construct auth login --provider gemini`."
1021 )
1022 })?;
1023 let proj = self.resolve_oauth_project(&token).await?;
1024 (Some(token), Some(proj))
1025 }
1026 _ => (None, None),
1027 };
1028
1029 let request = GenerateContentRequest {
1030 contents,
1031 system_instruction,
1032 generation_config: GenerationConfig {
1033 temperature,
1034 max_output_tokens: 8192,
1035 },
1036 };
1037
1038 let url = Self::build_generate_content_url(model, auth);
1039
1040 let mut response = self
1041 .build_generate_content_request(
1042 auth,
1043 &url,
1044 &request,
1045 model,
1046 true,
1047 project.as_deref(),
1048 oauth_token.as_deref(),
1049 )
1050 .send()
1051 .await?;
1052
1053 if !response.status().is_success() {
1054 let status = response.status();
1055 let error_text = response.text().await.unwrap_or_default();
1056
1057 if auth.is_oauth() && Self::should_rotate_oauth_on_error(status, &error_text) {
1058 let can_retry = match auth {
1061 GeminiAuth::OAuthToken(_) => {
1062 if let Some(state) = oauth_state.as_ref() {
1063 self.rotate_oauth_credential(state).await
1064 } else {
1065 false
1066 }
1067 }
1068 GeminiAuth::ManagedOAuth => true, _ => false,
1070 };
1071
1072 if can_retry {
1073 let (new_token, new_project) = match auth {
1075 GeminiAuth::OAuthToken(state) => {
1076 let token = Self::get_valid_oauth_token(state).await?;
1077 let proj = self.resolve_oauth_project(&token).await?;
1078 (token, proj)
1079 }
1080 GeminiAuth::ManagedOAuth => {
1081 let auth_service = self.auth_service.as_ref().unwrap();
1082 let token = auth_service
1083 .get_valid_gemini_access_token(
1084 self.auth_profile_override.as_deref(),
1085 )
1086 .await?
1087 .ok_or_else(|| anyhow::anyhow!("Gemini auth profile not found"))?;
1088 let proj = self.resolve_oauth_project(&token).await?;
1089 (token, proj)
1090 }
1091 _ => unreachable!(),
1092 };
1093 oauth_token = Some(new_token);
1094 project = Some(new_project);
1095 response = self
1096 .build_generate_content_request(
1097 auth,
1098 &url,
1099 &request,
1100 model,
1101 true,
1102 project.as_deref(),
1103 oauth_token.as_deref(),
1104 )
1105 .send()
1106 .await?;
1107 } else {
1108 anyhow::bail!("Gemini API error ({status}): {error_text}");
1109 }
1110 } else if auth.is_oauth()
1111 && Self::should_retry_oauth_without_generation_config(status, &error_text)
1112 {
1113 tracing::warn!(
1114 "Gemini OAuth internal endpoint rejected generationConfig; retrying without generationConfig"
1115 );
1116 response = self
1117 .build_generate_content_request(
1118 auth,
1119 &url,
1120 &request,
1121 model,
1122 false,
1123 project.as_deref(),
1124 oauth_token.as_deref(),
1125 )
1126 .send()
1127 .await?;
1128 } else {
1129 anyhow::bail!("Gemini API error ({status}): {error_text}");
1130 }
1131 }
1132
1133 if !response.status().is_success() {
1134 let status = response.status();
1135 let error_text = response.text().await.unwrap_or_default();
1136 if auth.is_oauth()
1137 && Self::should_retry_oauth_without_generation_config(status, &error_text)
1138 {
1139 tracing::warn!(
1140 "Gemini OAuth internal endpoint rejected generationConfig; retrying without generationConfig"
1141 );
1142 response = self
1143 .build_generate_content_request(
1144 auth,
1145 &url,
1146 &request,
1147 model,
1148 false,
1149 project.as_deref(),
1150 oauth_token.as_deref(),
1151 )
1152 .send()
1153 .await?;
1154 } else {
1155 anyhow::bail!("Gemini API error ({status}): {error_text}");
1156 }
1157 }
1158
1159 if !response.status().is_success() {
1160 let status = response.status();
1161 let error_text = response.text().await.unwrap_or_default();
1162 anyhow::bail!("Gemini API error ({status}): {error_text}");
1163 }
1164
1165 let result: GenerateContentResponse = response.json().await?;
1166 if let Some(err) = &result.error {
1167 anyhow::bail!("Gemini API error: {}", err.message);
1168 }
1169 let result = result.into_effective_response();
1170 if let Some(err) = result.error {
1171 anyhow::bail!("Gemini API error: {}", err.message);
1172 }
1173
1174 let usage = result.usage_metadata.map(|u| TokenUsage {
1175 input_tokens: u.prompt_token_count,
1176 output_tokens: u.candidates_token_count,
1177 cached_input_tokens: None,
1178 });
1179
1180 let text = result
1181 .candidates
1182 .and_then(|c| c.into_iter().next())
1183 .and_then(|c| c.content)
1184 .and_then(|c| c.effective_text())
1185 .ok_or_else(|| anyhow::anyhow!("No response from Gemini"))?;
1186
1187 Ok((text, usage))
1188 }
1189}
1190
1191#[async_trait]
1192impl Provider for GeminiProvider {
1193 fn capabilities(&self) -> crate::providers::traits::ProviderCapabilities {
1194 crate::providers::traits::ProviderCapabilities {
1195 vision: true,
1196 native_tool_calling: false,
1197 prompt_caching: false,
1198 }
1199 }
1200
1201 async fn chat_with_system(
1202 &self,
1203 system_prompt: Option<&str>,
1204 message: &str,
1205 model: &str,
1206 temperature: f64,
1207 ) -> anyhow::Result<String> {
1208 let system_instruction = system_prompt.map(|sys| Content {
1209 role: None,
1210 parts: vec![Part::text(sys)],
1211 });
1212
1213 let contents = vec![Content {
1214 role: Some("user".to_string()),
1215 parts: build_parts(message),
1216 }];
1217
1218 let (text, _usage) = self
1219 .send_generate_content(contents, system_instruction, model, temperature)
1220 .await?;
1221 Ok(text)
1222 }
1223
1224 async fn chat_with_history(
1225 &self,
1226 messages: &[ChatMessage],
1227 model: &str,
1228 temperature: f64,
1229 ) -> anyhow::Result<String> {
1230 let mut system_parts: Vec<&str> = Vec::new();
1231 let mut contents: Vec<Content> = Vec::new();
1232
1233 for msg in messages {
1234 match msg.role.as_str() {
1235 "system" => {
1236 system_parts.push(&msg.content);
1237 }
1238 "user" => {
1239 contents.push(Content {
1240 role: Some("user".to_string()),
1241 parts: build_parts(&msg.content),
1242 });
1243 }
1244 "assistant" => {
1245 contents.push(Content {
1247 role: Some("model".to_string()),
1248 parts: vec![Part::text(&msg.content)],
1249 });
1250 }
1251 _ => {}
1252 }
1253 }
1254
1255 let system_instruction = if system_parts.is_empty() {
1256 None
1257 } else {
1258 Some(Content {
1259 role: None,
1260 parts: vec![Part::text(system_parts.join("\n\n"))],
1261 })
1262 };
1263
1264 let (text, _usage) = self
1265 .send_generate_content(contents, system_instruction, model, temperature)
1266 .await?;
1267 Ok(text)
1268 }
1269
1270 async fn warmup(&self) -> anyhow::Result<()> {
1271 if let Some(auth) = self.auth.as_ref() {
1272 match auth {
1273 GeminiAuth::ManagedOAuth => {
1274 let auth_service = self
1277 .auth_service
1278 .as_ref()
1279 .ok_or_else(|| anyhow::anyhow!("ManagedOAuth requires auth_service"))?;
1280
1281 let _token = auth_service
1282 .get_valid_gemini_access_token(self.auth_profile_override.as_deref())
1283 .await?
1284 .ok_or_else(|| {
1285 anyhow::anyhow!(
1286 "Gemini auth profile not found or expired. Run: construct auth login --provider gemini"
1287 )
1288 })?;
1289
1290 }
1294 GeminiAuth::OAuthToken(_) => {
1295 }
1298 _ => {
1299 let url = if auth.is_api_key() {
1301 format!(
1302 "https://generativelanguage.googleapis.com/v1beta/models?key={}",
1303 auth.api_key_credential()
1304 )
1305 } else {
1306 "https://generativelanguage.googleapis.com/v1beta/models".to_string()
1307 };
1308
1309 self.http_client()
1310 .get(&url)
1311 .send()
1312 .await?
1313 .error_for_status()?;
1314 }
1315 }
1316 }
1317 Ok(())
1318 }
1319}
1320
1321#[cfg(test)]
1322mod tests {
1323 use super::*;
1324 use reqwest::{StatusCode, header::AUTHORIZATION};
1325
1326 fn test_oauth_auth(token: &str) -> GeminiAuth {
1328 GeminiAuth::OAuthToken(Arc::new(tokio::sync::Mutex::new(OAuthTokenState {
1329 access_token: token.to_string(),
1330 refresh_token: None,
1331 client_id: None,
1332 client_secret: None,
1333 expiry_millis: None,
1334 })))
1335 }
1336
1337 fn test_provider(auth: Option<GeminiAuth>) -> GeminiProvider {
1338 GeminiProvider {
1339 auth,
1340 oauth_project: Arc::new(tokio::sync::Mutex::new(None)),
1341 oauth_cred_paths: Vec::new(),
1342 oauth_index: Arc::new(tokio::sync::Mutex::new(0)),
1343 auth_service: None,
1344 auth_profile_override: None,
1345 }
1346 }
1347
1348 #[test]
1349 fn normalize_non_empty_trims_and_filters() {
1350 assert_eq!(
1351 GeminiProvider::normalize_non_empty(" value "),
1352 Some("value".into())
1353 );
1354 assert_eq!(GeminiProvider::normalize_non_empty(""), None);
1355 assert_eq!(GeminiProvider::normalize_non_empty(" \t\n"), None);
1356 }
1357
1358 #[test]
1359 fn oauth_refresh_form_uses_provided_client_credentials() {
1360 let form = build_oauth_refresh_form("refresh-token", Some("client-id"), Some("secret"));
1361 let map: std::collections::HashMap<_, _> = form.into_iter().collect();
1362 assert_eq!(map.get("grant_type"), Some(&"refresh_token".to_string()));
1363 assert_eq!(map.get("refresh_token"), Some(&"refresh-token".to_string()));
1364 assert_eq!(map.get("client_id"), Some(&"client-id".to_string()));
1365 assert_eq!(map.get("client_secret"), Some(&"secret".to_string()));
1366 }
1367
1368 #[test]
1369 fn oauth_refresh_form_omits_client_credentials_when_missing() {
1370 let form = build_oauth_refresh_form("refresh-token", None, None);
1371 let map: std::collections::HashMap<_, _> = form.into_iter().collect();
1372 assert!(!map.contains_key("client_id"));
1373 assert!(!map.contains_key("client_secret"));
1374 }
1375
1376 #[test]
1377 fn extract_client_id_from_id_token_prefers_aud_claim() {
1378 let payload = serde_json::json!({
1379 "aud": "aud-client-id",
1380 "azp": "azp-client-id"
1381 });
1382 let payload_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD
1383 .encode(serde_json::to_vec(&payload).unwrap());
1384 let token = format!("header.{payload_b64}.sig");
1385
1386 assert_eq!(
1387 extract_client_id_from_id_token(&token),
1388 Some("aud-client-id".to_string())
1389 );
1390 }
1391
1392 #[test]
1393 fn extract_client_id_from_id_token_uses_azp_when_aud_missing() {
1394 let payload = serde_json::json!({
1395 "azp": "azp-client-id"
1396 });
1397 let payload_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD
1398 .encode(serde_json::to_vec(&payload).unwrap());
1399 let token = format!("header.{payload_b64}.sig");
1400
1401 assert_eq!(
1402 extract_client_id_from_id_token(&token),
1403 Some("azp-client-id".to_string())
1404 );
1405 }
1406
1407 #[test]
1408 fn extract_client_id_from_id_token_returns_none_for_invalid_tokens() {
1409 assert_eq!(extract_client_id_from_id_token("invalid"), None);
1410 assert_eq!(extract_client_id_from_id_token("a.b.c"), None);
1411 }
1412
1413 #[test]
1414 fn try_load_cli_token_derives_client_id_from_id_token_when_missing() {
1415 let payload = serde_json::json!({ "aud": "derived-client-id" });
1416 let payload_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD
1417 .encode(serde_json::to_vec(&payload).unwrap());
1418 let id_token = format!("header.{payload_b64}.sig");
1419
1420 let file = tempfile::NamedTempFile::new().unwrap();
1421 let json = format!(
1422 r#"{{
1423 "access_token": "ya29.test-access",
1424 "refresh_token": "1//test-refresh",
1425 "id_token": "{id_token}"
1426 }}"#
1427 );
1428 std::fs::write(file.path(), json).unwrap();
1429
1430 let path = file.path().to_path_buf();
1431 let state = GeminiProvider::try_load_gemini_cli_token(Some(&path)).unwrap();
1432 assert_eq!(state.client_id.as_deref(), Some("derived-client-id"));
1433 assert_eq!(state.client_secret, None);
1434 }
1435
1436 #[test]
1437 fn provider_creates_without_key() {
1438 let provider = GeminiProvider::new(None);
1439 let _ = provider.auth_source();
1441 }
1442
1443 #[test]
1444 fn provider_creates_with_key() {
1445 let provider = GeminiProvider::new(Some("test-api-key"));
1446 assert!(matches!(
1447 provider.auth,
1448 Some(GeminiAuth::ExplicitKey(ref key)) if key == "test-api-key"
1449 ));
1450 }
1451
1452 #[test]
1453 fn provider_rejects_empty_key() {
1454 let provider = GeminiProvider::new(Some(""));
1455 assert!(!matches!(provider.auth, Some(GeminiAuth::ExplicitKey(_))));
1456 }
1457
1458 #[test]
1459 fn gemini_cli_dir_returns_path() {
1460 let dir = GeminiProvider::gemini_cli_dir();
1461 if UserDirs::new().is_some() {
1463 assert!(dir.is_some());
1464 assert!(dir.unwrap().ends_with(".gemini"));
1465 }
1466 }
1467
1468 #[test]
1469 fn auth_source_explicit_key() {
1470 let provider = test_provider(Some(GeminiAuth::ExplicitKey("key".into())));
1471 assert_eq!(provider.auth_source(), "config");
1472 }
1473
1474 #[test]
1475 fn auth_source_none_without_credentials() {
1476 let provider = test_provider(None);
1477 assert_eq!(provider.auth_source(), "none");
1478 }
1479
1480 #[test]
1481 fn auth_source_oauth() {
1482 let provider = test_provider(Some(test_oauth_auth("ya29.mock")));
1483 assert_eq!(provider.auth_source(), "Gemini CLI OAuth");
1484 }
1485
1486 #[test]
1487 fn model_name_formatting() {
1488 assert_eq!(
1489 GeminiProvider::format_model_name("gemini-2.0-flash"),
1490 "models/gemini-2.0-flash"
1491 );
1492 assert_eq!(
1493 GeminiProvider::format_model_name("models/gemini-1.5-pro"),
1494 "models/gemini-1.5-pro"
1495 );
1496 assert_eq!(
1497 GeminiProvider::format_internal_model_name("models/gemini-2.5-flash"),
1498 "gemini-2.5-flash"
1499 );
1500 assert_eq!(
1501 GeminiProvider::format_internal_model_name("gemini-2.5-flash"),
1502 "gemini-2.5-flash"
1503 );
1504 }
1505
1506 #[test]
1507 fn api_key_url_includes_key_query_param() {
1508 let auth = GeminiAuth::ExplicitKey("api-key-123".into());
1509 let url = GeminiProvider::build_generate_content_url("gemini-2.0-flash", &auth);
1510 assert!(url.contains(":generateContent?key=api-key-123"));
1511 }
1512
1513 #[test]
1514 fn oauth_url_uses_internal_endpoint() {
1515 let auth = test_oauth_auth("ya29.test-token");
1516 let url = GeminiProvider::build_generate_content_url("gemini-2.0-flash", &auth);
1517 assert!(url.starts_with("https://cloudcode-pa.googleapis.com/v1internal"));
1518 assert!(url.ends_with(":generateContent"));
1519 assert!(!url.contains("generativelanguage.googleapis.com"));
1520 assert!(!url.contains("?key="));
1521 }
1522
1523 #[test]
1524 fn api_key_url_uses_public_endpoint() {
1525 let auth = GeminiAuth::ExplicitKey("api-key-123".into());
1526 let url = GeminiProvider::build_generate_content_url("gemini-2.0-flash", &auth);
1527 assert!(url.contains("generativelanguage.googleapis.com/v1beta"));
1528 assert!(url.contains("models/gemini-2.0-flash"));
1529 }
1530
1531 #[test]
1532 fn oauth_request_uses_bearer_auth_header() {
1533 let provider = test_provider(Some(test_oauth_auth("ya29.mock-token")));
1534 let auth = test_oauth_auth("ya29.mock-token");
1535 let url = GeminiProvider::build_generate_content_url("gemini-2.0-flash", &auth);
1536 let body = GenerateContentRequest {
1537 contents: vec![Content {
1538 role: Some("user".into()),
1539 parts: vec![Part::text("hello")],
1540 }],
1541 system_instruction: None,
1542 generation_config: GenerationConfig {
1543 temperature: 0.7,
1544 max_output_tokens: 8192,
1545 },
1546 };
1547
1548 let request = provider
1549 .build_generate_content_request(
1550 &auth,
1551 &url,
1552 &body,
1553 "gemini-2.0-flash",
1554 true,
1555 Some("test-project"),
1556 Some("ya29.mock-token"),
1557 )
1558 .build()
1559 .unwrap();
1560
1561 assert_eq!(
1562 request
1563 .headers()
1564 .get(AUTHORIZATION)
1565 .and_then(|h| h.to_str().ok()),
1566 Some("Bearer ya29.mock-token")
1567 );
1568 }
1569
1570 #[test]
1571 fn oauth_request_wraps_payload_in_request_envelope() {
1572 let provider = test_provider(Some(test_oauth_auth("ya29.mock-token")));
1573 let auth = test_oauth_auth("ya29.mock-token");
1574 let url = GeminiProvider::build_generate_content_url("gemini-2.0-flash", &auth);
1575 let body = GenerateContentRequest {
1576 contents: vec![Content {
1577 role: Some("user".into()),
1578 parts: vec![Part::text("hello")],
1579 }],
1580 system_instruction: None,
1581 generation_config: GenerationConfig {
1582 temperature: 0.7,
1583 max_output_tokens: 8192,
1584 },
1585 };
1586
1587 let request = provider
1588 .build_generate_content_request(
1589 &auth,
1590 &url,
1591 &body,
1592 "models/gemini-2.0-flash",
1593 true,
1594 Some("test-project"),
1595 Some("ya29.mock-token"),
1596 )
1597 .build()
1598 .unwrap();
1599
1600 let payload = request
1601 .body()
1602 .and_then(|b| b.as_bytes())
1603 .expect("json request body should be bytes");
1604 let json: serde_json::Value = serde_json::from_slice(payload).unwrap();
1605
1606 assert_eq!(json["model"], "gemini-2.0-flash");
1607 assert!(json.get("generationConfig").is_none());
1608 assert!(json.get("request").is_some());
1609 assert!(json["request"].get("generationConfig").is_some());
1610 }
1611
1612 #[test]
1613 fn api_key_request_does_not_set_bearer_header() {
1614 let provider = test_provider(Some(GeminiAuth::ExplicitKey("api-key-123".into())));
1615 let auth = GeminiAuth::ExplicitKey("api-key-123".into());
1616 let url = GeminiProvider::build_generate_content_url("gemini-2.0-flash", &auth);
1617 let body = GenerateContentRequest {
1618 contents: vec![Content {
1619 role: Some("user".into()),
1620 parts: vec![Part::text("hello")],
1621 }],
1622 system_instruction: None,
1623 generation_config: GenerationConfig {
1624 temperature: 0.7,
1625 max_output_tokens: 8192,
1626 },
1627 };
1628
1629 let request = provider
1630 .build_generate_content_request(
1631 &auth,
1632 &url,
1633 &body,
1634 "gemini-2.0-flash",
1635 true,
1636 None,
1637 None,
1638 )
1639 .build()
1640 .unwrap();
1641
1642 assert!(request.headers().get(AUTHORIZATION).is_none());
1643 }
1644
1645 #[test]
1646 fn request_serialization() {
1647 let request = GenerateContentRequest {
1648 contents: vec![Content {
1649 role: Some("user".to_string()),
1650 parts: vec![Part::text("Hello")],
1651 }],
1652 system_instruction: Some(Content {
1653 role: None,
1654 parts: vec![Part::text("You are helpful")],
1655 }),
1656 generation_config: GenerationConfig {
1657 temperature: 0.7,
1658 max_output_tokens: 8192,
1659 },
1660 };
1661
1662 let json = serde_json::to_string(&request).unwrap();
1663 assert!(json.contains("\"role\":\"user\""));
1664 assert!(json.contains("\"text\":\"Hello\""));
1665 assert!(json.contains("\"systemInstruction\""));
1666 assert!(!json.contains("\"system_instruction\""));
1667 assert!(json.contains("\"temperature\":0.7"));
1668 assert!(json.contains("\"maxOutputTokens\":8192"));
1669 }
1670
1671 #[test]
1672 fn internal_request_includes_model() {
1673 let request = InternalGenerateContentEnvelope {
1674 model: "gemini-3-pro-preview".to_string(),
1675 project: Some("test-project".to_string()),
1676 user_prompt_id: Some("prompt-123".to_string()),
1677 request: InternalGenerateContentRequest {
1678 contents: vec![Content {
1679 role: Some("user".to_string()),
1680 parts: vec![Part::text("Hello")],
1681 }],
1682 system_instruction: None,
1683 generation_config: Some(GenerationConfig {
1684 temperature: 0.7,
1685 max_output_tokens: 8192,
1686 }),
1687 },
1688 };
1689
1690 let json = serde_json::to_string(&request).unwrap();
1691 assert!(json.contains("\"model\":\"gemini-3-pro-preview\""));
1692 assert!(json.contains("\"request\""));
1693 assert!(json.contains("\"generationConfig\""));
1694 assert!(json.contains("\"maxOutputTokens\":8192"));
1695 assert!(json.contains("\"user_prompt_id\":\"prompt-123\""));
1696 assert!(json.contains("\"project\":\"test-project\""));
1697 assert!(json.contains("\"role\":\"user\""));
1698 assert!(json.contains("\"temperature\":0.7"));
1699 }
1700
1701 #[test]
1702 fn internal_request_omits_generation_config_when_none() {
1703 let request = InternalGenerateContentEnvelope {
1704 model: "gemini-3-pro-preview".to_string(),
1705 project: Some("test-project".to_string()),
1706 user_prompt_id: None,
1707 request: InternalGenerateContentRequest {
1708 contents: vec![Content {
1709 role: Some("user".to_string()),
1710 parts: vec![Part::text("Hello")],
1711 }],
1712 system_instruction: None,
1713 generation_config: None,
1714 },
1715 };
1716
1717 let json = serde_json::to_string(&request).unwrap();
1718 assert!(!json.contains("generationConfig"));
1719 assert!(json.contains("\"model\":\"gemini-3-pro-preview\""));
1720 }
1721
1722 #[test]
1723 fn internal_request_includes_project() {
1724 let request = InternalGenerateContentEnvelope {
1725 model: "gemini-2.5-flash".to_string(),
1726 project: Some("my-gcp-project-id".to_string()),
1727 user_prompt_id: None,
1728 request: InternalGenerateContentRequest {
1729 contents: vec![Content {
1730 role: Some("user".to_string()),
1731 parts: vec![Part::text("Hello")],
1732 }],
1733 system_instruction: None,
1734 generation_config: None,
1735 },
1736 };
1737
1738 let json = serde_json::to_string(&request).unwrap();
1739 assert!(json.contains("\"project\":\"my-gcp-project-id\""));
1740 }
1741
1742 #[test]
1743 fn internal_response_deserialize_nested() {
1744 let json = r#"{
1745 "response": {
1746 "candidates": [{
1747 "content": {
1748 "parts": [{"text": "Hello from internal API!"}]
1749 }
1750 }]
1751 }
1752 }"#;
1753
1754 let internal: InternalGenerateContentResponse = serde_json::from_str(json).unwrap();
1755 let text = internal
1756 .response
1757 .candidates
1758 .unwrap()
1759 .into_iter()
1760 .next()
1761 .unwrap()
1762 .content
1763 .unwrap()
1764 .parts
1765 .into_iter()
1766 .next()
1767 .unwrap()
1768 .text;
1769 assert_eq!(text, Some("Hello from internal API!".to_string()));
1770 }
1771
1772 #[test]
1773 fn creds_deserialize_with_expiry_date() {
1774 let json = r#"{
1775 "access_token": "ya29.test-token",
1776 "refresh_token": "1//test-refresh",
1777 "expiry_date": 4102444800000
1778 }"#;
1779
1780 let creds: GeminiCliOAuthCreds = serde_json::from_str(json).unwrap();
1781 assert_eq!(creds.access_token.as_deref(), Some("ya29.test-token"));
1782 assert_eq!(creds.refresh_token.as_deref(), Some("1//test-refresh"));
1783 assert_eq!(creds.expiry_date, Some(4_102_444_800_000));
1784 assert!(creds.expiry.is_none());
1785 }
1786
1787 #[test]
1788 fn creds_deserialize_accepts_camel_case_fields() {
1789 let json = r#"{
1790 "access_token": "ya29.test-token",
1791 "idToken": "header.payload.sig",
1792 "refresh_token": "1//test-refresh",
1793 "clientId": "test-client-id",
1794 "clientSecret": "test-client-secret",
1795 "expiryDate": 4102444800000
1796 }"#;
1797
1798 let creds: GeminiCliOAuthCreds = serde_json::from_str(json).unwrap();
1799 assert_eq!(creds.id_token.as_deref(), Some("header.payload.sig"));
1800 assert_eq!(creds.client_id.as_deref(), Some("test-client-id"));
1801 assert_eq!(creds.client_secret.as_deref(), Some("test-client-secret"));
1802 assert_eq!(creds.expiry_date, Some(4_102_444_800_000));
1803 }
1804
1805 #[test]
1806 fn oauth_retry_detection_for_generation_config_rejection() {
1807 let err =
1809 "Invalid JSON payload received. Unknown name \"generationConfig\": Cannot find field.";
1810 assert!(
1811 GeminiProvider::should_retry_oauth_without_generation_config(
1812 StatusCode::BAD_REQUEST,
1813 err
1814 )
1815 );
1816 let err_json = r#"Invalid JSON payload received. Unknown name \"generationConfig\": Cannot find field."#;
1818 assert!(
1819 GeminiProvider::should_retry_oauth_without_generation_config(
1820 StatusCode::BAD_REQUEST,
1821 err_json
1822 )
1823 );
1824 assert!(
1825 !GeminiProvider::should_retry_oauth_without_generation_config(
1826 StatusCode::UNAUTHORIZED,
1827 err
1828 )
1829 );
1830 assert!(
1831 !GeminiProvider::should_retry_oauth_without_generation_config(
1832 StatusCode::BAD_REQUEST,
1833 "something else"
1834 )
1835 );
1836 }
1837
1838 #[test]
1839 fn response_deserialization() {
1840 let json = r#"{
1841 "candidates": [{
1842 "content": {
1843 "parts": [{"text": "Hello there!"}]
1844 }
1845 }]
1846 }"#;
1847
1848 let response: GenerateContentResponse = serde_json::from_str(json).unwrap();
1849 assert!(response.candidates.is_some());
1850 let text = response
1851 .candidates
1852 .unwrap()
1853 .into_iter()
1854 .next()
1855 .unwrap()
1856 .content
1857 .unwrap()
1858 .parts
1859 .into_iter()
1860 .next()
1861 .unwrap()
1862 .text;
1863 assert_eq!(text, Some("Hello there!".to_string()));
1864 }
1865
1866 #[test]
1867 fn error_response_deserialization() {
1868 let json = r#"{
1869 "error": {
1870 "message": "Invalid API key"
1871 }
1872 }"#;
1873
1874 let response: GenerateContentResponse = serde_json::from_str(json).unwrap();
1875 assert!(response.error.is_some());
1876 assert_eq!(response.error.unwrap().message, "Invalid API key");
1877 }
1878
1879 #[test]
1880 fn internal_response_deserialization() {
1881 let json = r#"{
1882 "response": {
1883 "candidates": [{
1884 "content": {
1885 "parts": [{"text": "Hello from internal"}]
1886 }
1887 }]
1888 }
1889 }"#;
1890
1891 let response: GenerateContentResponse = serde_json::from_str(json).unwrap();
1892 let text = response
1893 .into_effective_response()
1894 .candidates
1895 .unwrap()
1896 .into_iter()
1897 .next()
1898 .unwrap()
1899 .content
1900 .unwrap()
1901 .parts
1902 .into_iter()
1903 .next()
1904 .unwrap()
1905 .text;
1906 assert_eq!(text, Some("Hello from internal".to_string()));
1907 }
1908
1909 #[test]
1912 fn thinking_response_extracts_non_thinking_text() {
1913 let json = r#"{
1914 "candidates": [{
1915 "content": {
1916 "parts": [
1917 {"thought": true, "text": "Let me think about this..."},
1918 {"text": "The answer is 42."},
1919 {"thoughtSignature": "c2lnbmF0dXJl"}
1920 ]
1921 }
1922 }]
1923 }"#;
1924
1925 let response: GenerateContentResponse = serde_json::from_str(json).unwrap();
1926 let candidate = response.candidates.unwrap().into_iter().next().unwrap();
1927 let text = candidate.content.unwrap().effective_text();
1928 assert_eq!(text, Some("The answer is 42.".to_string()));
1929 }
1930
1931 #[test]
1932 fn non_thinking_response_unaffected() {
1933 let json = r#"{
1934 "candidates": [{
1935 "content": {
1936 "parts": [{"text": "Hello there!"}]
1937 }
1938 }]
1939 }"#;
1940
1941 let response: GenerateContentResponse = serde_json::from_str(json).unwrap();
1942 let candidate = response.candidates.unwrap().into_iter().next().unwrap();
1943 let text = candidate.content.unwrap().effective_text();
1944 assert_eq!(text, Some("Hello there!".to_string()));
1945 }
1946
1947 #[test]
1948 fn thinking_only_response_falls_back_to_thinking_text() {
1949 let json = r#"{
1950 "candidates": [{
1951 "content": {
1952 "parts": [
1953 {"thought": true, "text": "I need more context..."},
1954 {"thoughtSignature": "c2lnbmF0dXJl"}
1955 ]
1956 }
1957 }]
1958 }"#;
1959
1960 let response: GenerateContentResponse = serde_json::from_str(json).unwrap();
1961 let candidate = response.candidates.unwrap().into_iter().next().unwrap();
1962 let text = candidate.content.unwrap().effective_text();
1963 assert_eq!(text, Some("I need more context...".to_string()));
1964 }
1965
1966 #[test]
1967 fn empty_parts_returns_none() {
1968 let json = r#"{
1969 "candidates": [{
1970 "content": {
1971 "parts": []
1972 }
1973 }]
1974 }"#;
1975
1976 let response: GenerateContentResponse = serde_json::from_str(json).unwrap();
1977 let candidate = response.candidates.unwrap().into_iter().next().unwrap();
1978 let text = candidate.content.unwrap().effective_text();
1979 assert_eq!(text, None);
1980 }
1981
1982 #[test]
1983 fn multiple_text_parts_concatenated() {
1984 let json = r#"{
1985 "candidates": [{
1986 "content": {
1987 "parts": [
1988 {"text": "Part one. "},
1989 {"text": "Part two."}
1990 ]
1991 }
1992 }]
1993 }"#;
1994
1995 let response: GenerateContentResponse = serde_json::from_str(json).unwrap();
1996 let candidate = response.candidates.unwrap().into_iter().next().unwrap();
1997 let text = candidate.content.unwrap().effective_text();
1998 assert_eq!(text, Some("Part one. Part two.".to_string()));
1999 }
2000
2001 #[test]
2002 fn thought_signature_only_parts_skipped() {
2003 let json = r#"{
2004 "candidates": [{
2005 "content": {
2006 "parts": [
2007 {"thoughtSignature": "c2lnbmF0dXJl"}
2008 ]
2009 }
2010 }]
2011 }"#;
2012
2013 let response: GenerateContentResponse = serde_json::from_str(json).unwrap();
2014 let candidate = response.candidates.unwrap().into_iter().next().unwrap();
2015 let text = candidate.content.unwrap().effective_text();
2016 assert_eq!(text, None);
2017 }
2018
2019 #[test]
2020 fn internal_response_thinking_model() {
2021 let json = r#"{
2022 "response": {
2023 "candidates": [{
2024 "content": {
2025 "parts": [
2026 {"thought": true, "text": "reasoning..."},
2027 {"text": "final answer"}
2028 ]
2029 }
2030 }]
2031 }
2032 }"#;
2033
2034 let response: GenerateContentResponse = serde_json::from_str(json).unwrap();
2035 let effective = response.into_effective_response();
2036 let candidate = effective.candidates.unwrap().into_iter().next().unwrap();
2037 let text = candidate.content.unwrap().effective_text();
2038 assert_eq!(text, Some("final answer".to_string()));
2039 }
2040
2041 #[tokio::test]
2042 async fn warmup_without_key_is_noop() {
2043 let provider = test_provider(None);
2044 let result = provider.warmup().await;
2045 assert!(result.is_ok());
2046 }
2047
2048 #[tokio::test]
2049 async fn warmup_oauth_is_noop() {
2050 let provider = test_provider(Some(test_oauth_auth("ya29.mock-token")));
2051 let result = provider.warmup().await;
2052 assert!(result.is_ok());
2053 }
2054
2055 #[test]
2056 fn discover_oauth_cred_paths_does_not_panic() {
2057 let _paths = GeminiProvider::discover_oauth_cred_paths();
2058 }
2059
2060 #[tokio::test]
2061 async fn rotate_oauth_without_alternatives_returns_false() {
2062 let state = Arc::new(tokio::sync::Mutex::new(OAuthTokenState {
2063 access_token: "ya29.mock".to_string(),
2064 refresh_token: None,
2065 client_id: None,
2066 client_secret: None,
2067 expiry_millis: None,
2068 }));
2069 let provider = test_provider(Some(GeminiAuth::OAuthToken(state.clone())));
2070 assert!(!provider.rotate_oauth_credential(&state).await);
2071 }
2072
2073 #[test]
2074 fn response_parses_usage_metadata() {
2075 let json = r#"{
2076 "candidates": [{"content": {"parts": [{"text": "Hello"}]}}],
2077 "usageMetadata": {"promptTokenCount": 120, "candidatesTokenCount": 40}
2078 }"#;
2079 let resp: GenerateContentResponse = serde_json::from_str(json).unwrap();
2080 let usage = resp.usage_metadata.unwrap();
2081 assert_eq!(usage.prompt_token_count, Some(120));
2082 assert_eq!(usage.candidates_token_count, Some(40));
2083 }
2084
2085 #[test]
2086 fn response_parses_without_usage_metadata() {
2087 let json = r#"{"candidates": [{"content": {"parts": [{"text": "Hello"}]}}]}"#;
2088 let resp: GenerateContentResponse = serde_json::from_str(json).unwrap();
2089 assert!(resp.usage_metadata.is_none());
2090 }
2091
2092 #[tokio::test]
2094 async fn warmup_managed_oauth_requires_auth_service() {
2095 let provider = GeminiProvider {
2096 auth: Some(GeminiAuth::ManagedOAuth),
2097 oauth_project: Arc::new(tokio::sync::Mutex::new(None)),
2098 oauth_cred_paths: Vec::new(),
2099 oauth_index: Arc::new(tokio::sync::Mutex::new(0)),
2100 auth_service: None, auth_profile_override: None,
2102 };
2103
2104 let result = provider.warmup().await;
2105 assert!(result.is_err());
2106 assert!(
2107 result
2108 .unwrap_err()
2109 .to_string()
2110 .contains("ManagedOAuth requires auth_service")
2111 );
2112 }
2113
2114 #[tokio::test]
2116 async fn warmup_cli_oauth_skips_validation() {
2117 let provider = test_provider(Some(test_oauth_auth("fake_token")));
2118 let result = provider.warmup().await;
2119 assert!(result.is_ok());
2121 }
2122
2123 #[test]
2126 fn part_text_serializes_as_text_object() {
2127 let part = Part::text("hello");
2128 let json = serde_json::to_value(&part).unwrap();
2129 assert_eq!(json, serde_json::json!({"text": "hello"}));
2130 }
2131
2132 #[test]
2133 fn part_inline_serializes_as_inline_data_object() {
2134 let part = Part::Inline {
2135 inline_data: InlineData {
2136 mime_type: "image/png".to_string(),
2137 data: "iVBOR...".to_string(),
2138 },
2139 };
2140 let json = serde_json::to_value(&part).unwrap();
2141 assert_eq!(
2142 json,
2143 serde_json::json!({"inline_data": {"mime_type": "image/png", "data": "iVBOR..."}})
2144 );
2145 }
2146
2147 #[test]
2148 fn part_text_constructor_accepts_string_and_str() {
2149 let from_str = Part::text("hello");
2150 let from_string = Part::text(String::from("hello"));
2151 assert_eq!(
2153 serde_json::to_value(&from_str).unwrap(),
2154 serde_json::to_value(&from_string).unwrap(),
2155 );
2156 }
2157
2158 #[test]
2159 fn content_with_mixed_parts_serializes_correctly() {
2160 let content = Content {
2161 role: Some("user".to_string()),
2162 parts: vec![
2163 Part::text("Describe this image:"),
2164 Part::Inline {
2165 inline_data: InlineData {
2166 mime_type: "image/jpeg".to_string(),
2167 data: "/9j/4AAQ...".to_string(),
2168 },
2169 },
2170 ],
2171 };
2172 let json = serde_json::to_value(&content).unwrap();
2173 let parts = json["parts"].as_array().unwrap();
2174 assert_eq!(parts.len(), 2);
2175 assert!(parts[0].get("text").is_some());
2176 assert!(parts[1].get("inline_data").is_some());
2177 }
2178
2179 #[test]
2182 fn build_parts_plain_text_returns_single_text_part() {
2183 let parts = build_parts("Hello, world!");
2184 assert_eq!(parts.len(), 1);
2185 assert_eq!(
2186 serde_json::to_value(&parts[0]).unwrap(),
2187 serde_json::json!({"text": "Hello, world!"})
2188 );
2189 }
2190
2191 #[test]
2192 fn build_parts_empty_string_returns_single_text_part() {
2193 let parts = build_parts("");
2194 assert_eq!(parts.len(), 1);
2195 assert_eq!(
2197 serde_json::to_value(&parts[0]).unwrap(),
2198 serde_json::json!({"text": ""})
2199 );
2200 }
2201
2202 #[test]
2203 fn build_parts_extracts_data_uri_as_inline_part() {
2204 let content = "Check this [IMAGE:data:image/png;base64,iVBORw0KGgo=]";
2205 let parts = build_parts(content);
2206 assert_eq!(parts.len(), 2);
2207 assert_eq!(
2209 serde_json::to_value(&parts[0]).unwrap(),
2210 serde_json::json!({"text": "Check this"})
2211 );
2212 assert_eq!(
2214 serde_json::to_value(&parts[1]).unwrap(),
2215 serde_json::json!({"inline_data": {"mime_type": "image/png", "data": "iVBORw0KGgo="}})
2216 );
2217 }
2218
2219 #[test]
2220 fn build_parts_multiple_images() {
2221 let content = "Image A: [IMAGE:data:image/png;base64,AAAA] Image B: [IMAGE:data:image/jpeg;base64,BBBB]";
2222 let parts = build_parts(content);
2223 assert_eq!(parts.len(), 3); let inline_parts: Vec<_> = parts
2226 .iter()
2227 .filter(|p| matches!(p, Part::Inline { .. }))
2228 .collect();
2229 assert_eq!(inline_parts.len(), 2);
2230 }
2231
2232 #[test]
2233 fn build_parts_ignores_non_data_uri_markers() {
2234 let content = "Look [IMAGE:/tmp/photo.png]";
2237 let parts = build_parts(content);
2238 for part in &parts {
2241 assert!(matches!(part, Part::Text { .. }));
2242 }
2243 }
2244
2245 #[test]
2246 fn build_parts_image_only_still_produces_inline_part() {
2247 let content = "[IMAGE:data:image/gif;base64,R0lGODlh]";
2248 let parts = build_parts(content);
2249 assert_eq!(parts.len(), 1);
2251 assert!(matches!(&parts[0], Part::Inline { .. }));
2252 }
2253
2254 #[test]
2257 fn chat_with_history_maps_roles_correctly() {
2258 let user_parts = build_parts("Hello [IMAGE:data:image/png;base64,AA==]");
2265 assert!(user_parts.iter().any(|p| matches!(p, Part::Inline { .. })));
2266
2267 let assistant_part = Part::text("I see the image");
2269 assert!(matches!(assistant_part, Part::Text { .. }));
2270
2271 let system_part = Part::text("You are helpful");
2273 assert!(matches!(system_part, Part::Text { .. }));
2274 }
2275}