1use crate::core::error::{Error, Result};
2use reqwest::Client;
3use serde::{Deserialize, Serialize};
4use std::time::{Duration, Instant};
5
6const GEMINI_CLIENT_ID: &str =
8 "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com";
9const GEMINI_CLIENT_SECRET: &str = "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";
10
11const CODE_ASSIST_BASE: &str = "https://cloudcode-pa.googleapis.com";
13
14struct CachedToken {
16 token: String,
17 expires_at: Instant,
18}
19
20#[derive(Debug, Serialize)]
22struct InnerRequest {
23 contents: Vec<Content>,
24 #[serde(skip_serializing_if = "Option::is_none", rename = "systemInstruction")]
25 system_instruction: Option<Content>,
26}
27
28#[derive(Debug, Serialize)]
30struct CodeAssistRequest {
31 project: String,
32 model: String,
33 request: InnerRequest,
34}
35
36#[derive(Debug, Serialize)]
37struct Content {
38 role: String,
39 parts: Vec<Part>,
40}
41
42#[derive(Debug, Serialize)]
43struct Part {
44 text: String,
45}
46
47#[derive(Debug, Deserialize)]
49struct CodeAssistResponse {
50 response: Option<GenerateContentResponse>,
51}
52
53#[derive(Debug, Deserialize)]
54struct GenerateContentResponse {
55 candidates: Option<Vec<Candidate>>,
56}
57
58#[derive(Debug, Deserialize)]
59struct Candidate {
60 content: ResponseContent,
61}
62
63#[derive(Debug, Deserialize)]
64struct ResponseContent {
65 parts: Vec<PartResponse>,
66}
67
68#[derive(Debug, Deserialize)]
69struct PartResponse {
70 text: String,
71}
72
73#[derive(Debug, Deserialize)]
74struct TokenResponse {
75 access_token: String,
76 #[allow(dead_code)]
77 expires_in: u64,
78}
79
80#[derive(Debug, Deserialize)]
82struct LoadCodeAssistResponse {
83 #[serde(rename = "cloudaicompanionProject")]
84 cloudaicompanion_project: Option<String>,
85 #[serde(rename = "allowedTiers")]
86 allowed_tiers: Option<Vec<AllowedTier>>,
87}
88
89#[derive(Debug, Deserialize)]
90struct AllowedTier {
91 id: Option<String>,
92 #[serde(rename = "isDefault")]
93 is_default: Option<bool>,
94}
95
96#[derive(Debug, Deserialize)]
98struct OnboardUserResponse {
99 done: Option<bool>,
100 response: Option<OnboardResponseInner>,
101}
102
103#[derive(Debug, Deserialize)]
104struct OnboardResponseInner {
105 #[serde(rename = "cloudaicompanionProject")]
106 cloudaicompanion_project: Option<CloudaicompanionProject>,
107}
108
109#[derive(Debug, Deserialize)]
110struct CloudaicompanionProject {
111 id: Option<String>,
112}
113
114#[derive(Debug, Clone, Serialize)]
116struct CodeAssistMetadata {
117 #[serde(rename = "ideType")]
118 ide_type: String,
119 platform: String,
120 #[serde(rename = "pluginType")]
121 plugin_type: String,
122}
123
124#[derive(Debug, Serialize)]
125struct LoadCodeAssistRequest {
126 metadata: CodeAssistMetadata,
127}
128
129#[derive(Debug, Serialize)]
130struct OnboardUserRequest {
131 #[serde(rename = "tierId")]
132 tier_id: String,
133 metadata: CodeAssistMetadata,
134}
135
136pub struct GeminiClient {
137 client: Client,
138 refresh_token: String,
139 cached_token: std::sync::Mutex<Option<CachedToken>>,
140 cached_project_id: std::sync::Mutex<Option<String>>,
141}
142
143impl GeminiClient {
144 pub fn new(refresh_token: String) -> Self {
145 Self {
146 client: Client::new(),
147 refresh_token,
148 cached_token: std::sync::Mutex::new(None),
149 cached_project_id: std::sync::Mutex::new(None),
150 }
151 }
152
153 async fn get_access_token(&self) -> Result<String> {
155 if let Ok(guard) = self.cached_token.lock() {
157 if let Some(ref cached) = *guard {
158 if Instant::now() < cached.expires_at {
159 return Ok(cached.token.clone());
160 }
161 }
162 }
163
164 let params = [
166 ("client_id", GEMINI_CLIENT_ID),
167 ("client_secret", GEMINI_CLIENT_SECRET),
168 ("refresh_token", &self.refresh_token),
169 ("grant_type", "refresh_token"),
170 ];
171
172 let res = self
173 .client
174 .post("https://oauth2.googleapis.com/token")
175 .form(¶ms)
176 .send()
177 .await
178 .map_err(|e| Error::DaemonError {
179 message: format!("Token refresh failed: {}", e),
180 })?;
181
182 if !res.status().is_success() {
183 let text = res.text().await.unwrap_or_default();
184 return Err(Error::DaemonError {
185 message: format!("Token refresh error: {}", text),
186 });
187 }
188
189 let token_response: TokenResponse = res.json().await.map_err(|e| Error::DaemonError {
190 message: format!("Failed to parse token response: {}", e),
191 })?;
192
193 let expires_at = Instant::now()
195 + Duration::from_secs(token_response.expires_in.saturating_sub(600).max(60));
196 if let Ok(mut guard) = self.cached_token.lock() {
197 *guard = Some(CachedToken {
198 token: token_response.access_token.clone(),
199 expires_at,
200 });
201 }
202
203 Ok(token_response.access_token)
204 }
205
206 async fn get_project_id(&self, access_token: &str) -> Result<String> {
208 if let Ok(guard) = self.cached_project_id.lock() {
210 if let Some(ref project_id) = *guard {
211 return Ok(project_id.clone());
212 }
213 }
214
215 let metadata = CodeAssistMetadata {
216 ide_type: "IDE_UNSPECIFIED".to_string(),
217 platform: "PLATFORM_UNSPECIFIED".to_string(),
218 plugin_type: "GEMINI".to_string(),
219 };
220
221 let load_url = format!("{}/v1internal:loadCodeAssist", CODE_ASSIST_BASE);
223 let load_request = LoadCodeAssistRequest {
224 metadata: metadata.clone(),
225 };
226
227 let res = self
228 .client
229 .post(&load_url)
230 .header("Authorization", format!("Bearer {}", access_token))
231 .header("Content-Type", "application/json")
232 .header("User-Agent", "greppy/0.9.0")
233 .json(&load_request)
234 .send()
235 .await
236 .map_err(|e| Error::DaemonError {
237 message: format!("loadCodeAssist failed: {}", e),
238 })?;
239
240 if res.status().is_success() {
241 if let Ok(load_response) = res.json::<LoadCodeAssistResponse>().await {
242 if let Some(project_id) = load_response.cloudaicompanion_project {
243 if let Ok(mut guard) = self.cached_project_id.lock() {
245 *guard = Some(project_id.clone());
246 }
247 return Ok(project_id);
248 }
249
250 let tier_id = load_response
252 .allowed_tiers
253 .as_ref()
254 .and_then(|tiers| {
255 tiers
256 .iter()
257 .find(|t| t.is_default == Some(true))
258 .or(tiers.first())
259 })
260 .and_then(|t| t.id.clone())
261 .unwrap_or_else(|| "FREE".to_string());
262
263 let onboard_url = format!("{}/v1internal:onboardUser", CODE_ASSIST_BASE);
265 let onboard_request = OnboardUserRequest { tier_id, metadata };
266
267 let onboard_res = self
268 .client
269 .post(&onboard_url)
270 .header("Authorization", format!("Bearer {}", access_token))
271 .header("Content-Type", "application/json")
272 .header("User-Agent", "greppy/0.9.0")
273 .json(&onboard_request)
274 .send()
275 .await
276 .map_err(|e| Error::DaemonError {
277 message: format!("onboardUser failed: {}", e),
278 })?;
279
280 if onboard_res.status().is_success() {
281 if let Ok(onboard_response) = onboard_res.json::<OnboardUserResponse>().await {
282 if onboard_response.done == Some(true) {
283 if let Some(project_id) = onboard_response
284 .response
285 .and_then(|r| r.cloudaicompanion_project)
286 .and_then(|p| p.id)
287 {
288 if let Ok(mut guard) = self.cached_project_id.lock() {
290 *guard = Some(project_id.clone());
291 }
292 return Ok(project_id);
293 }
294 }
295 }
296 }
297 }
298 }
299
300 Err(Error::DaemonError {
301 message: "Failed to get Gemini project ID. You may need to enable Gemini API in Google Cloud Console.".to_string(),
302 })
303 }
304
305 pub async fn rerank(&self, query: &str, chunks: &[String]) -> Result<Vec<usize>> {
308 let access_token = self.get_access_token().await?;
309 let project_id = self.get_project_id(&access_token).await?;
310
311 let system_prompt =
312 "You are a code search reranker. Given a query and numbered code chunks, \
313 return ONLY a JSON array of chunk indices ordered by relevance to the query. \
314 Most relevant first. Example response: [2, 0, 5, 1, 3, 4]";
315
316 let mut user_prompt = format!("Query: {}\n\nCode chunks:\n", query);
317 for (i, chunk) in chunks.iter().enumerate() {
318 user_prompt.push_str(&format!("\n--- Chunk {} ---\n{}\n", i, chunk));
319 }
320 user_prompt.push_str("\nReturn ONLY the JSON array of indices, nothing else.");
321
322 let inner_request = InnerRequest {
324 contents: vec![Content {
325 role: "user".to_string(),
326 parts: vec![Part { text: user_prompt }],
327 }],
328 system_instruction: Some(Content {
329 role: "user".to_string(),
330 parts: vec![Part {
331 text: system_prompt.to_string(),
332 }],
333 }),
334 };
335
336 let request_body = CodeAssistRequest {
338 project: project_id,
339 model: "gemini-2.0-flash".to_string(),
340 request: inner_request,
341 };
342
343 let url = format!("{}/v1internal:generateContent", CODE_ASSIST_BASE);
344
345 let res = self
346 .client
347 .post(&url)
348 .header("Authorization", format!("Bearer {}", access_token))
349 .header("Content-Type", "application/json")
350 .header("User-Agent", "greppy/0.9.0")
351 .header("X-Goog-Api-Client", "greppy/0.9.0")
352 .json(&request_body)
353 .send()
354 .await
355 .map_err(|e| Error::DaemonError {
356 message: format!("API request failed: {}", e),
357 })?;
358
359 if !res.status().is_success() {
360 let text = res.text().await.unwrap_or_default();
361 return Err(Error::DaemonError {
362 message: format!("Gemini API Error: {}", text),
363 });
364 }
365
366 let wrapper: CodeAssistResponse = res.json().await.map_err(|e| Error::DaemonError {
368 message: format!("Failed to parse response: {}", e),
369 })?;
370
371 if let Some(response) = wrapper.response {
373 if let Some(candidates) = response.candidates {
374 if let Some(candidate) = candidates.first() {
375 if let Some(part) = candidate.content.parts.first() {
376 let text = part.text.trim();
377 if let Ok(indices) = serde_json::from_str::<Vec<usize>>(text) {
379 return Ok(indices);
380 }
381 if let Some(start) = text.find('[') {
383 if let Some(end) = text.rfind(']') {
384 let json_str = &text[start..=end];
385 if let Ok(indices) = serde_json::from_str::<Vec<usize>>(json_str) {
386 return Ok(indices);
387 }
388 }
389 }
390 }
391 }
392 }
393 }
394
395 Ok((0..chunks.len()).collect())
397 }
398
399 pub async fn expand_query(&self, query: &str) -> Result<Vec<String>> {
402 use crate::ai::trace_prompts::{
403 build_expansion_prompt, parse_expansion_response, QUERY_EXPANSION_SYSTEM,
404 };
405
406 let access_token = self.get_access_token().await?;
407 let project_id = self.get_project_id(&access_token).await?;
408
409 let inner_request = InnerRequest {
411 contents: vec![Content {
412 role: "user".to_string(),
413 parts: vec![Part {
414 text: build_expansion_prompt(query),
415 }],
416 }],
417 system_instruction: Some(Content {
418 role: "user".to_string(),
419 parts: vec![Part {
420 text: QUERY_EXPANSION_SYSTEM.to_string(),
421 }],
422 }),
423 };
424
425 let request_body = CodeAssistRequest {
427 project: project_id,
428 model: "gemini-2.0-flash".to_string(),
429 request: inner_request,
430 };
431
432 let url = format!("{}/v1internal:generateContent", CODE_ASSIST_BASE);
433
434 let res = self
435 .client
436 .post(&url)
437 .header("Authorization", format!("Bearer {}", access_token))
438 .header("Content-Type", "application/json")
439 .header("User-Agent", "greppy/0.9.0")
440 .header("X-Goog-Api-Client", "greppy/0.9.0")
441 .json(&request_body)
442 .send()
443 .await
444 .map_err(|e| Error::DaemonError {
445 message: format!("API request failed: {}", e),
446 })?;
447
448 if !res.status().is_success() {
449 let text = res.text().await.unwrap_or_default();
450 return Err(Error::DaemonError {
451 message: format!("Gemini API Error: {}", text),
452 });
453 }
454
455 let wrapper: CodeAssistResponse = res.json().await.map_err(|e| Error::DaemonError {
457 message: format!("Failed to parse response: {}", e),
458 })?;
459
460 if let Some(response) = wrapper.response {
462 if let Some(candidates) = response.candidates {
463 if let Some(candidate) = candidates.first() {
464 if let Some(part) = candidate.content.parts.first() {
465 let symbols = parse_expansion_response(&part.text);
466 if !symbols.is_empty() {
467 return Ok(symbols);
468 }
469 }
470 }
471 }
472 }
473
474 Ok(vec![query.to_string()])
476 }
477
478 pub async fn rerank_trace(&self, query: &str, paths: &[String]) -> Result<Vec<usize>> {
481 use crate::ai::trace_prompts::{
482 build_trace_rerank_prompt, parse_rerank_response, TRACE_RERANK_SYSTEM,
483 };
484
485 let access_token = self.get_access_token().await?;
486 let project_id = self.get_project_id(&access_token).await?;
487
488 let inner_request = InnerRequest {
490 contents: vec![Content {
491 role: "user".to_string(),
492 parts: vec![Part {
493 text: build_trace_rerank_prompt(query, paths),
494 }],
495 }],
496 system_instruction: Some(Content {
497 role: "user".to_string(),
498 parts: vec![Part {
499 text: TRACE_RERANK_SYSTEM.to_string(),
500 }],
501 }),
502 };
503
504 let request_body = CodeAssistRequest {
506 project: project_id,
507 model: "gemini-2.0-flash".to_string(),
508 request: inner_request,
509 };
510
511 let url = format!("{}/v1internal:generateContent", CODE_ASSIST_BASE);
512
513 let res = self
514 .client
515 .post(&url)
516 .header("Authorization", format!("Bearer {}", access_token))
517 .header("Content-Type", "application/json")
518 .header("User-Agent", "greppy/0.9.0")
519 .header("X-Goog-Api-Client", "greppy/0.9.0")
520 .json(&request_body)
521 .send()
522 .await
523 .map_err(|e| Error::DaemonError {
524 message: format!("API request failed: {}", e),
525 })?;
526
527 if !res.status().is_success() {
528 let text = res.text().await.unwrap_or_default();
529 return Err(Error::DaemonError {
530 message: format!("Gemini API Error: {}", text),
531 });
532 }
533
534 let wrapper: CodeAssistResponse = res.json().await.map_err(|e| Error::DaemonError {
536 message: format!("Failed to parse response: {}", e),
537 })?;
538
539 if let Some(response) = wrapper.response {
541 if let Some(candidates) = response.candidates {
542 if let Some(candidate) = candidates.first() {
543 if let Some(part) = candidate.content.parts.first() {
544 let indices = parse_rerank_response(&part.text, paths.len());
545 if !indices.is_empty() {
546 return Ok(indices);
547 }
548 }
549 }
550 }
551 }
552
553 Ok((0..paths.len()).collect())
555 }
556}