greppy/ai/
gemini.rs

1use crate::core::error::{Error, Result};
2use reqwest::Client;
3use serde::{Deserialize, Serialize};
4use std::time::{Duration, Instant};
5
6// Gemini CLI OAuth credentials (for token refresh)
7const GEMINI_CLIENT_ID: &str =
8    "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com";
9const GEMINI_CLIENT_SECRET: &str = "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";
10
11// Cloud Code Assist API endpoints
12const CODE_ASSIST_BASE: &str = "https://cloudcode-pa.googleapis.com";
13
14/// Cached access token with expiry
15struct CachedToken {
16    token: String,
17    expires_at: Instant,
18}
19
20// Inner request structure for Cloud Code Assist API
21#[derive(Debug, Serialize)]
22struct InnerRequest {
23    contents: Vec<Content>,
24    #[serde(skip_serializing_if = "Option::is_none", rename = "systemInstruction")]
25    system_instruction: Option<Content>,
26}
27
28// Wrapped request for Cloud Code Assist API
29#[derive(Debug, Serialize)]
30struct CodeAssistRequest {
31    project: String,
32    model: String,
33    request: InnerRequest,
34}
35
36#[derive(Debug, Serialize)]
37struct Content {
38    role: String,
39    parts: Vec<Part>,
40}
41
42#[derive(Debug, Serialize)]
43struct Part {
44    text: String,
45}
46
47// Cloud Code Assist wraps response in a "response" field
48#[derive(Debug, Deserialize)]
49struct CodeAssistResponse {
50    response: Option<GenerateContentResponse>,
51}
52
53#[derive(Debug, Deserialize)]
54struct GenerateContentResponse {
55    candidates: Option<Vec<Candidate>>,
56}
57
58#[derive(Debug, Deserialize)]
59struct Candidate {
60    content: ResponseContent,
61}
62
63#[derive(Debug, Deserialize)]
64struct ResponseContent {
65    parts: Vec<PartResponse>,
66}
67
68#[derive(Debug, Deserialize)]
69struct PartResponse {
70    text: String,
71}
72
73#[derive(Debug, Deserialize)]
74struct TokenResponse {
75    access_token: String,
76    #[allow(dead_code)]
77    expires_in: u64,
78}
79
80// loadCodeAssist response
81#[derive(Debug, Deserialize)]
82struct LoadCodeAssistResponse {
83    #[serde(rename = "cloudaicompanionProject")]
84    cloudaicompanion_project: Option<String>,
85    #[serde(rename = "allowedTiers")]
86    allowed_tiers: Option<Vec<AllowedTier>>,
87}
88
89#[derive(Debug, Deserialize)]
90struct AllowedTier {
91    id: Option<String>,
92    #[serde(rename = "isDefault")]
93    is_default: Option<bool>,
94}
95
96// onboardUser response
97#[derive(Debug, Deserialize)]
98struct OnboardUserResponse {
99    done: Option<bool>,
100    response: Option<OnboardResponseInner>,
101}
102
103#[derive(Debug, Deserialize)]
104struct OnboardResponseInner {
105    #[serde(rename = "cloudaicompanionProject")]
106    cloudaicompanion_project: Option<CloudaicompanionProject>,
107}
108
109#[derive(Debug, Deserialize)]
110struct CloudaicompanionProject {
111    id: Option<String>,
112}
113
114// Metadata for Code Assist requests
115#[derive(Debug, Clone, Serialize)]
116struct CodeAssistMetadata {
117    #[serde(rename = "ideType")]
118    ide_type: String,
119    platform: String,
120    #[serde(rename = "pluginType")]
121    plugin_type: String,
122}
123
124#[derive(Debug, Serialize)]
125struct LoadCodeAssistRequest {
126    metadata: CodeAssistMetadata,
127}
128
129#[derive(Debug, Serialize)]
130struct OnboardUserRequest {
131    #[serde(rename = "tierId")]
132    tier_id: String,
133    metadata: CodeAssistMetadata,
134}
135
136pub struct GeminiClient {
137    client: Client,
138    refresh_token: String,
139    cached_token: std::sync::Mutex<Option<CachedToken>>,
140    cached_project_id: std::sync::Mutex<Option<String>>,
141}
142
143impl GeminiClient {
144    pub fn new(refresh_token: String) -> Self {
145        Self {
146            client: Client::new(),
147            refresh_token,
148            cached_token: std::sync::Mutex::new(None),
149            cached_project_id: std::sync::Mutex::new(None),
150        }
151    }
152
153    /// Get access token, using cache if valid
154    async fn get_access_token(&self) -> Result<String> {
155        // Check cache first
156        if let Ok(guard) = self.cached_token.lock() {
157            if let Some(ref cached) = *guard {
158                if Instant::now() < cached.expires_at {
159                    return Ok(cached.token.clone());
160                }
161            }
162        }
163
164        // Refresh token
165        let params = [
166            ("client_id", GEMINI_CLIENT_ID),
167            ("client_secret", GEMINI_CLIENT_SECRET),
168            ("refresh_token", &self.refresh_token),
169            ("grant_type", "refresh_token"),
170        ];
171
172        let res = self
173            .client
174            .post("https://oauth2.googleapis.com/token")
175            .form(&params)
176            .send()
177            .await
178            .map_err(|e| Error::DaemonError {
179                message: format!("Token refresh failed: {}", e),
180            })?;
181
182        if !res.status().is_success() {
183            let text = res.text().await.unwrap_or_default();
184            return Err(Error::DaemonError {
185                message: format!("Token refresh error: {}", text),
186            });
187        }
188
189        let token_response: TokenResponse = res.json().await.map_err(|e| Error::DaemonError {
190            message: format!("Failed to parse token response: {}", e),
191        })?;
192
193        // Cache the token (expires_in is typically 3600 seconds, cache for 50 min)
194        let expires_at = Instant::now()
195            + Duration::from_secs(token_response.expires_in.saturating_sub(600).max(60));
196        if let Ok(mut guard) = self.cached_token.lock() {
197            *guard = Some(CachedToken {
198                token: token_response.access_token.clone(),
199                expires_at,
200            });
201        }
202
203        Ok(token_response.access_token)
204    }
205
206    /// Get or create a managed project ID for Cloud Code Assist
207    async fn get_project_id(&self, access_token: &str) -> Result<String> {
208        // Check cache first
209        if let Ok(guard) = self.cached_project_id.lock() {
210            if let Some(ref project_id) = *guard {
211                return Ok(project_id.clone());
212            }
213        }
214
215        let metadata = CodeAssistMetadata {
216            ide_type: "IDE_UNSPECIFIED".to_string(),
217            platform: "PLATFORM_UNSPECIFIED".to_string(),
218            plugin_type: "GEMINI".to_string(),
219        };
220
221        // Try to load existing project
222        let load_url = format!("{}/v1internal:loadCodeAssist", CODE_ASSIST_BASE);
223        let load_request = LoadCodeAssistRequest {
224            metadata: metadata.clone(),
225        };
226
227        let res = self
228            .client
229            .post(&load_url)
230            .header("Authorization", format!("Bearer {}", access_token))
231            .header("Content-Type", "application/json")
232            .header("User-Agent", "greppy/0.9.0")
233            .json(&load_request)
234            .send()
235            .await
236            .map_err(|e| Error::DaemonError {
237                message: format!("loadCodeAssist failed: {}", e),
238            })?;
239
240        if res.status().is_success() {
241            if let Ok(load_response) = res.json::<LoadCodeAssistResponse>().await {
242                if let Some(project_id) = load_response.cloudaicompanion_project {
243                    // Cache and return
244                    if let Ok(mut guard) = self.cached_project_id.lock() {
245                        *guard = Some(project_id.clone());
246                    }
247                    return Ok(project_id);
248                }
249
250                // Need to onboard - get default tier
251                let tier_id = load_response
252                    .allowed_tiers
253                    .as_ref()
254                    .and_then(|tiers| {
255                        tiers
256                            .iter()
257                            .find(|t| t.is_default == Some(true))
258                            .or(tiers.first())
259                    })
260                    .and_then(|t| t.id.clone())
261                    .unwrap_or_else(|| "FREE".to_string());
262
263                // Onboard user
264                let onboard_url = format!("{}/v1internal:onboardUser", CODE_ASSIST_BASE);
265                let onboard_request = OnboardUserRequest { tier_id, metadata };
266
267                let onboard_res = self
268                    .client
269                    .post(&onboard_url)
270                    .header("Authorization", format!("Bearer {}", access_token))
271                    .header("Content-Type", "application/json")
272                    .header("User-Agent", "greppy/0.9.0")
273                    .json(&onboard_request)
274                    .send()
275                    .await
276                    .map_err(|e| Error::DaemonError {
277                        message: format!("onboardUser failed: {}", e),
278                    })?;
279
280                if onboard_res.status().is_success() {
281                    if let Ok(onboard_response) = onboard_res.json::<OnboardUserResponse>().await {
282                        if onboard_response.done == Some(true) {
283                            if let Some(project_id) = onboard_response
284                                .response
285                                .and_then(|r| r.cloudaicompanion_project)
286                                .and_then(|p| p.id)
287                            {
288                                // Cache and return
289                                if let Ok(mut guard) = self.cached_project_id.lock() {
290                                    *guard = Some(project_id.clone());
291                                }
292                                return Ok(project_id);
293                            }
294                        }
295                    }
296                }
297            }
298        }
299
300        Err(Error::DaemonError {
301            message: "Failed to get Gemini project ID. You may need to enable Gemini API in Google Cloud Console.".to_string(),
302        })
303    }
304
305    /// Rerank search results by relevance to query
306    /// Returns JSON array of indices in order of relevance: [2, 0, 5, 1, ...]
307    pub async fn rerank(&self, query: &str, chunks: &[String]) -> Result<Vec<usize>> {
308        let access_token = self.get_access_token().await?;
309        let project_id = self.get_project_id(&access_token).await?;
310
311        let system_prompt =
312            "You are a code search reranker. Given a query and numbered code chunks, \
313            return ONLY a JSON array of chunk indices ordered by relevance to the query. \
314            Most relevant first. Example response: [2, 0, 5, 1, 3, 4]";
315
316        let mut user_prompt = format!("Query: {}\n\nCode chunks:\n", query);
317        for (i, chunk) in chunks.iter().enumerate() {
318            user_prompt.push_str(&format!("\n--- Chunk {} ---\n{}\n", i, chunk));
319        }
320        user_prompt.push_str("\nReturn ONLY the JSON array of indices, nothing else.");
321
322        // Build the inner request
323        let inner_request = InnerRequest {
324            contents: vec![Content {
325                role: "user".to_string(),
326                parts: vec![Part { text: user_prompt }],
327            }],
328            system_instruction: Some(Content {
329                role: "user".to_string(),
330                parts: vec![Part {
331                    text: system_prompt.to_string(),
332                }],
333            }),
334        };
335
336        // Wrap for Cloud Code Assist API
337        let request_body = CodeAssistRequest {
338            project: project_id,
339            model: "gemini-2.0-flash".to_string(),
340            request: inner_request,
341        };
342
343        let url = format!("{}/v1internal:generateContent", CODE_ASSIST_BASE);
344
345        let res = self
346            .client
347            .post(&url)
348            .header("Authorization", format!("Bearer {}", access_token))
349            .header("Content-Type", "application/json")
350            .header("User-Agent", "greppy/0.9.0")
351            .header("X-Goog-Api-Client", "greppy/0.9.0")
352            .json(&request_body)
353            .send()
354            .await
355            .map_err(|e| Error::DaemonError {
356                message: format!("API request failed: {}", e),
357            })?;
358
359        if !res.status().is_success() {
360            let text = res.text().await.unwrap_or_default();
361            return Err(Error::DaemonError {
362                message: format!("Gemini API Error: {}", text),
363            });
364        }
365
366        // Cloud Code Assist wraps response in "response" field
367        let wrapper: CodeAssistResponse = res.json().await.map_err(|e| Error::DaemonError {
368            message: format!("Failed to parse response: {}", e),
369        })?;
370
371        // Parse the JSON array from response
372        if let Some(response) = wrapper.response {
373            if let Some(candidates) = response.candidates {
374                if let Some(candidate) = candidates.first() {
375                    if let Some(part) = candidate.content.parts.first() {
376                        let text = part.text.trim();
377                        // Try direct parse
378                        if let Ok(indices) = serde_json::from_str::<Vec<usize>>(text) {
379                            return Ok(indices);
380                        }
381                        // Try to find JSON array in the text
382                        if let Some(start) = text.find('[') {
383                            if let Some(end) = text.rfind(']') {
384                                let json_str = &text[start..=end];
385                                if let Ok(indices) = serde_json::from_str::<Vec<usize>>(json_str) {
386                                    return Ok(indices);
387                                }
388                            }
389                        }
390                    }
391                }
392            }
393        }
394
395        // Fallback: return original order
396        Ok((0..chunks.len()).collect())
397    }
398
399    /// Expand a query into related symbol names for trace operations
400    /// Input: "auth" -> Output: ["auth", "login", "authenticate", "session", ...]
401    pub async fn expand_query(&self, query: &str) -> Result<Vec<String>> {
402        use crate::ai::trace_prompts::{
403            build_expansion_prompt, parse_expansion_response, QUERY_EXPANSION_SYSTEM,
404        };
405
406        let access_token = self.get_access_token().await?;
407        let project_id = self.get_project_id(&access_token).await?;
408
409        // Build the inner request
410        let inner_request = InnerRequest {
411            contents: vec![Content {
412                role: "user".to_string(),
413                parts: vec![Part {
414                    text: build_expansion_prompt(query),
415                }],
416            }],
417            system_instruction: Some(Content {
418                role: "user".to_string(),
419                parts: vec![Part {
420                    text: QUERY_EXPANSION_SYSTEM.to_string(),
421                }],
422            }),
423        };
424
425        // Wrap for Cloud Code Assist API
426        let request_body = CodeAssistRequest {
427            project: project_id,
428            model: "gemini-2.0-flash".to_string(),
429            request: inner_request,
430        };
431
432        let url = format!("{}/v1internal:generateContent", CODE_ASSIST_BASE);
433
434        let res = self
435            .client
436            .post(&url)
437            .header("Authorization", format!("Bearer {}", access_token))
438            .header("Content-Type", "application/json")
439            .header("User-Agent", "greppy/0.9.0")
440            .header("X-Goog-Api-Client", "greppy/0.9.0")
441            .json(&request_body)
442            .send()
443            .await
444            .map_err(|e| Error::DaemonError {
445                message: format!("API request failed: {}", e),
446            })?;
447
448        if !res.status().is_success() {
449            let text = res.text().await.unwrap_or_default();
450            return Err(Error::DaemonError {
451                message: format!("Gemini API Error: {}", text),
452            });
453        }
454
455        // Cloud Code Assist wraps response in "response" field
456        let wrapper: CodeAssistResponse = res.json().await.map_err(|e| Error::DaemonError {
457            message: format!("Failed to parse response: {}", e),
458        })?;
459
460        // Parse the expanded symbols from response
461        if let Some(response) = wrapper.response {
462            if let Some(candidates) = response.candidates {
463                if let Some(candidate) = candidates.first() {
464                    if let Some(part) = candidate.content.parts.first() {
465                        let symbols = parse_expansion_response(&part.text);
466                        if !symbols.is_empty() {
467                            return Ok(symbols);
468                        }
469                    }
470                }
471            }
472        }
473
474        // Fallback: return the original query as a single symbol
475        Ok(vec![query.to_string()])
476    }
477
478    /// Rerank trace invocation paths by relevance to query
479    /// Returns indices in order of relevance: [2, 0, 5, 1, ...]
480    pub async fn rerank_trace(&self, query: &str, paths: &[String]) -> Result<Vec<usize>> {
481        use crate::ai::trace_prompts::{
482            build_trace_rerank_prompt, parse_rerank_response, TRACE_RERANK_SYSTEM,
483        };
484
485        let access_token = self.get_access_token().await?;
486        let project_id = self.get_project_id(&access_token).await?;
487
488        // Build the inner request
489        let inner_request = InnerRequest {
490            contents: vec![Content {
491                role: "user".to_string(),
492                parts: vec![Part {
493                    text: build_trace_rerank_prompt(query, paths),
494                }],
495            }],
496            system_instruction: Some(Content {
497                role: "user".to_string(),
498                parts: vec![Part {
499                    text: TRACE_RERANK_SYSTEM.to_string(),
500                }],
501            }),
502        };
503
504        // Wrap for Cloud Code Assist API
505        let request_body = CodeAssistRequest {
506            project: project_id,
507            model: "gemini-2.0-flash".to_string(),
508            request: inner_request,
509        };
510
511        let url = format!("{}/v1internal:generateContent", CODE_ASSIST_BASE);
512
513        let res = self
514            .client
515            .post(&url)
516            .header("Authorization", format!("Bearer {}", access_token))
517            .header("Content-Type", "application/json")
518            .header("User-Agent", "greppy/0.9.0")
519            .header("X-Goog-Api-Client", "greppy/0.9.0")
520            .json(&request_body)
521            .send()
522            .await
523            .map_err(|e| Error::DaemonError {
524                message: format!("API request failed: {}", e),
525            })?;
526
527        if !res.status().is_success() {
528            let text = res.text().await.unwrap_or_default();
529            return Err(Error::DaemonError {
530                message: format!("Gemini API Error: {}", text),
531            });
532        }
533
534        // Cloud Code Assist wraps response in "response" field
535        let wrapper: CodeAssistResponse = res.json().await.map_err(|e| Error::DaemonError {
536            message: format!("Failed to parse response: {}", e),
537        })?;
538
539        // Parse the reranked indices from response
540        if let Some(response) = wrapper.response {
541            if let Some(candidates) = response.candidates {
542                if let Some(candidate) = candidates.first() {
543                    if let Some(part) = candidate.content.parts.first() {
544                        let indices = parse_rerank_response(&part.text, paths.len());
545                        if !indices.is_empty() {
546                            return Ok(indices);
547                        }
548                    }
549                }
550            }
551        }
552
553        // Fallback: return original order
554        Ok((0..paths.len()).collect())
555    }
556}