1use super::{LlmConfig, LlmError, LlmProvider, Message, Role};
2use crate::index::embedding::{Embedding, EmbeddingError, EmbeddingProvider};
3use serde_json::json;
4use std::sync::Mutex;
5use std::time::{Duration, Instant};
6
7struct CachedToken {
11 access_token: String,
12 expires_at: Instant,
14}
15
16impl CachedToken {
17 fn new(token: String, expires_in_secs: u64) -> Self {
18 let margin = expires_in_secs.saturating_sub(300);
20 Self {
21 access_token: token,
22 expires_at: Instant::now() + Duration::from_secs(margin),
23 }
24 }
25
26 fn is_valid(&self) -> bool {
27 Instant::now() < self.expires_at
28 }
29}
30
31enum TokenSource {
40 Static(String),
42 RefreshToken {
44 client_id: String,
45 client_secret: String,
46 refresh_token: String,
47 cached: Mutex<Option<CachedToken>>,
48 },
49 MetadataServer { cached: Mutex<Option<CachedToken>> },
51 GcloudSubprocess { cached: Mutex<Option<CachedToken>> },
53}
54
55impl TokenSource {
56 fn get_token(&self) -> Result<String, String> {
58 match self {
59 Self::Static(t) => Ok(t.clone()),
60
61 Self::RefreshToken {
62 client_id,
63 client_secret,
64 refresh_token,
65 cached,
66 } => {
67 let mut guard = cached.lock().unwrap();
68 if let Some(ref c) = *guard {
69 if c.is_valid() {
70 return Ok(c.access_token.clone());
71 }
72 }
73 let (token, expires_in) = oauth2_refresh(client_id, client_secret, refresh_token)?;
74 *guard = Some(CachedToken::new(token.clone(), expires_in));
75 Ok(token)
76 }
77
78 Self::MetadataServer { cached } => {
79 let mut guard = cached.lock().unwrap();
80 if let Some(ref c) = *guard {
81 if c.is_valid() {
82 return Ok(c.access_token.clone());
83 }
84 }
85 let (token, expires_in) = metadata_server_token()?;
86 *guard = Some(CachedToken::new(token.clone(), expires_in));
87 Ok(token)
88 }
89
90 Self::GcloudSubprocess { cached } => {
91 let mut guard = cached.lock().unwrap();
92 if let Some(ref c) = *guard {
93 if c.is_valid() {
94 return Ok(c.access_token.clone());
95 }
96 }
97 let token = gcloud_print_access_token()?;
98 *guard = Some(CachedToken::new(token.clone(), 3300));
100 Ok(token)
101 }
102 }
103 }
104}
105
106pub struct VertexAiConfig {
110 pub project: String,
111 pub location: String,
112 token_source: TokenSource,
113}
114
115impl VertexAiConfig {
116 pub fn from_env() -> Result<Self, String> {
125 let project = std::env::var("VERTEX_AI_PROJECT")
126 .or_else(|_| std::env::var("GOOGLE_CLOUD_PROJECT"))
127 .map_err(|_| {
128 "Vertex AI project not configured. Set VERTEX_AI_PROJECT \
129 (or GOOGLE_CLOUD_PROJECT) to your GCP project ID."
130 .to_string()
131 })?;
132 let location = std::env::var("VERTEX_AI_LOCATION")
133 .or_else(|_| std::env::var("GOOGLE_CLOUD_LOCATION"))
134 .unwrap_or_else(|_| "europe-west1".into());
135
136 let token_source = resolve_token_source()?;
137 Ok(Self {
138 project,
139 location,
140 token_source,
141 })
142 }
143
144 pub fn get_token(&self) -> Result<String, String> {
146 self.token_source.get_token()
147 }
148}
149
150impl Clone for VertexAiConfig {
153 fn clone(&self) -> Self {
154 let token_source = match &self.token_source {
155 TokenSource::Static(t) => TokenSource::Static(t.clone()),
156 TokenSource::RefreshToken {
157 client_id,
158 client_secret,
159 refresh_token,
160 ..
161 } => TokenSource::RefreshToken {
162 client_id: client_id.clone(),
163 client_secret: client_secret.clone(),
164 refresh_token: refresh_token.clone(),
165 cached: Mutex::new(None),
166 },
167 TokenSource::MetadataServer { .. } => TokenSource::MetadataServer {
168 cached: Mutex::new(None),
169 },
170 TokenSource::GcloudSubprocess { .. } => TokenSource::GcloudSubprocess {
171 cached: Mutex::new(None),
172 },
173 };
174 Self {
175 project: self.project.clone(),
176 location: self.location.clone(),
177 token_source,
178 }
179 }
180}
181
182impl std::fmt::Debug for VertexAiConfig {
183 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
184 let source = match &self.token_source {
185 TokenSource::Static(_) => "static",
186 TokenSource::RefreshToken { .. } => "refresh_token",
187 TokenSource::MetadataServer { .. } => "metadata_server",
188 TokenSource::GcloudSubprocess { .. } => "gcloud_subprocess",
189 };
190 f.debug_struct("VertexAiConfig")
191 .field("project", &self.project)
192 .field("location", &self.location)
193 .field("token_source", &source)
194 .finish()
195 }
196}
197
198fn resolve_token_source() -> Result<TokenSource, String> {
201 if let Ok(t) = std::env::var("VERTEX_AI_TOKEN") {
203 return Ok(TokenSource::Static(t));
204 }
205
206 if let Ok(path) = std::env::var("GOOGLE_APPLICATION_CREDENTIALS") {
208 if let Ok(source) = load_credentials_file(&path) {
209 return Ok(source);
210 }
211 }
212
213 let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".into());
215 let adc_path =
216 std::path::PathBuf::from(&home).join(".config/gcloud/application_default_credentials.json");
217 if adc_path.exists() {
218 if let Ok(source) = load_credentials_file(adc_path.to_str().unwrap_or("")) {
219 return Ok(source);
220 }
221 }
222
223 if metadata_server_available() {
225 return Ok(TokenSource::MetadataServer {
226 cached: Mutex::new(None),
227 });
228 }
229
230 if gcloud_available() {
232 return Ok(TokenSource::GcloudSubprocess {
233 cached: Mutex::new(None),
234 });
235 }
236
237 Err("No Google credentials found. Options:\n\
238 • Run `gcloud auth application-default login`\n\
239 • Set VERTEX_AI_TOKEN to an access token\n\
240 • Set GOOGLE_APPLICATION_CREDENTIALS to a service account key file\n\
241 • Run on GCE/Cloud Run/GKE (metadata server)"
242 .into())
243}
244
245fn load_credentials_file(path: &str) -> Result<TokenSource, String> {
246 let content = std::fs::read_to_string(path)
247 .map_err(|e| format!("cannot read credentials file {path}: {e}"))?;
248 let creds: serde_json::Value =
249 serde_json::from_str(&content).map_err(|e| format!("credentials JSON parse error: {e}"))?;
250
251 match creds["type"].as_str() {
252 Some("authorized_user") => Ok(TokenSource::RefreshToken {
253 client_id: creds["client_id"]
254 .as_str()
255 .ok_or("missing client_id")?
256 .into(),
257 client_secret: creds["client_secret"]
258 .as_str()
259 .ok_or("missing client_secret")?
260 .into(),
261 refresh_token: creds["refresh_token"]
262 .as_str()
263 .ok_or("missing refresh_token")?
264 .into(),
265 cached: Mutex::new(None),
266 }),
267 Some("service_account") => {
268 Ok(TokenSource::GcloudSubprocess {
272 cached: Mutex::new(None),
273 })
274 }
275 other => Err(format!(
276 "unsupported credentials type: {:?}",
277 other.unwrap_or("missing")
278 )),
279 }
280}
281
282fn oauth2_refresh(
285 client_id: &str,
286 client_secret: &str,
287 refresh_token: &str,
288) -> Result<(String, u64), String> {
289 let client = reqwest::blocking::Client::builder()
290 .timeout(std::time::Duration::from_secs(15))
291 .connect_timeout(std::time::Duration::from_secs(10))
292 .build()
293 .unwrap_or_else(|_| reqwest::blocking::Client::new());
294 let resp = client
295 .post("https://oauth2.googleapis.com/token")
296 .form(&[
297 ("client_id", client_id),
298 ("client_secret", client_secret),
299 ("refresh_token", refresh_token),
300 ("grant_type", "refresh_token"),
301 ])
302 .send()
303 .map_err(|e| format!("token refresh HTTP error: {e}"))?;
304
305 let status = resp.status();
306 let body: serde_json::Value = resp
307 .json()
308 .map_err(|e| format!("token refresh parse error: {e}"))?;
309
310 if !status.is_success() {
311 return Err(format!(
312 "token refresh failed (HTTP {status}): {}",
313 body.get("error_description")
314 .or(body.get("error"))
315 .and_then(|v| v.as_str())
316 .unwrap_or("unknown error")
317 ));
318 }
319
320 let token = body["access_token"]
321 .as_str()
322 .ok_or("token refresh response has no access_token")?
323 .to_string();
324 let expires_in = body["expires_in"].as_u64().unwrap_or(3600);
325 Ok((token, expires_in))
326}
327
328fn metadata_server_token() -> Result<(String, u64), String> {
331 let client = reqwest::blocking::Client::builder()
332 .timeout(Duration::from_secs(5))
333 .build()
334 .unwrap();
335 let resp = client
336 .get("http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/token")
337 .header("Metadata-Flavor", "Google")
338 .send()
339 .map_err(|e| format!("metadata server request failed: {e}"))?;
340
341 if !resp.status().is_success() {
342 return Err(format!("metadata server returned HTTP {}", resp.status()));
343 }
344
345 let body: serde_json::Value = resp
346 .json()
347 .map_err(|e| format!("metadata server parse error: {e}"))?;
348 let token = body["access_token"]
349 .as_str()
350 .ok_or("metadata server response has no access_token")?
351 .to_string();
352 let expires_in = body["expires_in"].as_u64().unwrap_or(3600);
353 Ok((token, expires_in))
354}
355
356fn metadata_server_available() -> bool {
357 let client = reqwest::blocking::Client::builder()
358 .timeout(Duration::from_millis(500))
359 .build()
360 .unwrap_or_else(|_| reqwest::blocking::Client::new());
361 client
362 .get("http://metadata.google.internal/")
363 .header("Metadata-Flavor", "Google")
364 .send()
365 .is_ok()
366}
367
368fn gcloud_available() -> bool {
369 std::process::Command::new("gcloud")
370 .arg("version")
371 .output()
372 .is_ok()
373}
374
375fn gcloud_print_access_token() -> Result<String, String> {
376 let out = std::process::Command::new("gcloud")
377 .args(["auth", "print-access-token"])
378 .output()
379 .map_err(|e| format!("gcloud subprocess failed: {e}"))?;
380
381 if !out.status.success() {
382 let stderr = String::from_utf8_lossy(&out.stderr);
383 return Err(format!(
384 "gcloud auth print-access-token failed: {stderr}. \
385 Run `gcloud auth application-default login` to authenticate."
386 ));
387 }
388
389 Ok(std::str::from_utf8(&out.stdout)
390 .map_err(|e| format!("gcloud output encoding error: {e}"))?
391 .trim()
392 .to_string())
393}
394
395pub struct VertexAiLlmProvider {
398 config: VertexAiConfig,
399 client: reqwest::blocking::Client,
400}
401
402impl VertexAiLlmProvider {
403 pub fn new(config: VertexAiConfig) -> Self {
404 let client = reqwest::blocking::Client::builder()
405 .timeout(std::time::Duration::from_secs(120))
406 .connect_timeout(std::time::Duration::from_secs(15))
407 .build()
408 .expect("failed to build reqwest client");
409 Self { config, client }
410 }
411
412 fn base_url(&self) -> String {
413 if self.config.location == "global" {
414 "https://aiplatform.googleapis.com/v1".into()
415 } else {
416 format!(
417 "https://{}-aiplatform.googleapis.com/v1",
418 self.config.location
419 )
420 }
421 }
422}
423
424impl LlmProvider for VertexAiLlmProvider {
425 fn complete(&self, messages: &[Message], config: &LlmConfig) -> Result<String, LlmError> {
426 let url = format!(
427 "{base}/projects/{project}/locations/{location}/publishers/google/models/{model}:generateContent",
428 base = self.base_url(),
429 project = self.config.project,
430 location = self.config.location,
431 model = config.model,
432 );
433
434 let system_instruction: Option<String> = messages
436 .iter()
437 .find(|m| matches!(m.role, Role::System))
438 .map(|m| m.content.clone());
439
440 let contents: Vec<serde_json::Value> = messages
441 .iter()
442 .filter(|m| !matches!(m.role, Role::System))
443 .map(|m| {
444 let role = match m.role {
445 Role::User => "user",
446 Role::Assistant => "model",
447 Role::System => unreachable!(),
448 };
449 json!({
450 "role": role,
451 "parts": [{"text": m.content}]
452 })
453 })
454 .collect();
455
456 let mut body = json!({
457 "contents": contents,
458 "generationConfig": {
459 "maxOutputTokens": config.max_tokens,
460 "temperature": config.temperature,
461 }
462 });
463
464 if let Some(sys) = system_instruction {
465 body["systemInstruction"] = json!({
466 "parts": [{"text": sys}]
467 });
468 }
469
470 let token = self
471 .config
472 .get_token()
473 .map_err(|e| LlmError::Provider(format!("auth error: {e}")))?;
474
475 let response = self
476 .client
477 .post(&url)
478 .bearer_auth(&token)
479 .json(&body)
480 .send()
481 .map_err(|e| LlmError::Http(e.to_string()))?;
482
483 let status = response.status();
484 let text = response.text().map_err(|e| LlmError::Http(e.to_string()))?;
485
486 if !status.is_success() {
487 return Err(LlmError::Provider(format!("HTTP {status}: {text}")));
488 }
489
490 let json: serde_json::Value =
491 serde_json::from_str(&text).map_err(|e| LlmError::Parse(e.to_string()))?;
492
493 json["candidates"][0]["content"]["parts"][0]["text"]
495 .as_str()
496 .map(|s| s.to_string())
497 .ok_or_else(|| LlmError::Parse(format!("unexpected response format: {json}")))
498 }
499}
500
501pub struct VertexAiEmbeddingProvider {
504 config: VertexAiConfig,
505 model: String,
506 dimensions: usize,
507 client: reqwest::blocking::Client,
508}
509
510impl VertexAiEmbeddingProvider {
511 pub fn new(config: VertexAiConfig, model: Option<String>, dimensions: Option<usize>) -> Self {
512 let client = reqwest::blocking::Client::builder()
513 .timeout(std::time::Duration::from_secs(30))
514 .connect_timeout(std::time::Duration::from_secs(15))
515 .build()
516 .expect("failed to build reqwest client");
517 Self {
518 config,
519 model: model.unwrap_or_else(|| "text-embedding-005".into()),
520 dimensions: dimensions.unwrap_or(256),
521 client,
522 }
523 }
524
525 fn base_url(&self) -> String {
526 if self.config.location == "global" {
527 "https://aiplatform.googleapis.com/v1".into()
528 } else {
529 format!(
530 "https://{}-aiplatform.googleapis.com/v1",
531 self.config.location
532 )
533 }
534 }
535}
536
537impl EmbeddingProvider for VertexAiEmbeddingProvider {
538 fn dimensions(&self) -> usize {
539 self.dimensions
540 }
541
542 fn embed(&self, text: &str) -> Result<Embedding, EmbeddingError> {
543 let url = format!(
544 "{base}/projects/{project}/locations/{location}/publishers/google/models/{model}:predict",
545 base = self.base_url(),
546 project = self.config.project,
547 location = self.config.location,
548 model = self.model,
549 );
550
551 let body = json!({
552 "instances": [{"content": text}],
553 "parameters": {"outputDimensionality": self.dimensions}
554 });
555
556 let token = self
557 .config
558 .get_token()
559 .map_err(|e| EmbeddingError::Provider(format!("auth error: {e}")))?;
560
561 let response = self
562 .client
563 .post(&url)
564 .bearer_auth(&token)
565 .json(&body)
566 .send()
567 .map_err(|e| EmbeddingError::Provider(e.to_string()))?;
568
569 let status = response.status();
570 let text = response
571 .text()
572 .map_err(|e| EmbeddingError::Provider(e.to_string()))?;
573
574 if !status.is_success() {
575 return Err(EmbeddingError::Provider(format!("HTTP {status}: {text}")));
576 }
577
578 let json: serde_json::Value =
579 serde_json::from_str(&text).map_err(|e| EmbeddingError::Provider(e.to_string()))?;
580
581 let values = json["predictions"][0]["embeddings"]["values"]
582 .as_array()
583 .ok_or_else(|| EmbeddingError::Provider("unexpected response format".into()))?;
584
585 values
586 .iter()
587 .map(|v| {
588 v.as_f64()
589 .map(|f| f as f32)
590 .ok_or_else(|| EmbeddingError::Provider("non-numeric embedding value".into()))
591 })
592 .collect()
593 }
594}
595
596pub struct MistralLlmProvider {
606 config: VertexAiConfig,
607 region: String,
609 client: reqwest::blocking::Client,
610}
611
612impl MistralLlmProvider {
613 pub fn new(config: VertexAiConfig) -> Self {
614 let region = if config.location == "global" || config.location.is_empty() {
617 "europe-west4".into()
618 } else {
619 config.location.clone()
620 };
621 let client = reqwest::blocking::Client::builder()
622 .timeout(std::time::Duration::from_secs(120))
623 .connect_timeout(std::time::Duration::from_secs(15))
624 .build()
625 .expect("failed to build reqwest client");
626 Self {
627 config,
628 region,
629 client,
630 }
631 }
632}
633
634impl LlmProvider for MistralLlmProvider {
635 fn complete(&self, messages: &[Message], config: &LlmConfig) -> Result<String, LlmError> {
636 let url = format!(
637 "https://{region}-aiplatform.googleapis.com/v1/projects/{project}/locations/{region}/publishers/mistralai/models/{model}:rawPredict",
638 region = self.region,
639 project = self.config.project,
640 model = config.model,
641 );
642
643 let msgs: Vec<serde_json::Value> = messages
645 .iter()
646 .map(|m| {
647 let role = match m.role {
648 Role::System => "system",
649 Role::User => "user",
650 Role::Assistant => "assistant",
651 };
652 json!({"role": role, "content": m.content})
653 })
654 .collect();
655
656 let body = json!({
657 "model": config.model,
658 "messages": msgs,
659 "max_tokens": config.max_tokens,
660 "temperature": config.temperature,
661 "stream": false,
662 });
663
664 let token = self
665 .config
666 .get_token()
667 .map_err(|e| LlmError::Provider(format!("auth error: {e}")))?;
668
669 let response = self
670 .client
671 .post(&url)
672 .bearer_auth(&token)
673 .json(&body)
674 .send()
675 .map_err(|e| LlmError::Http(e.to_string()))?;
676
677 let status = response.status();
678 let text = response.text().map_err(|e| LlmError::Http(e.to_string()))?;
679
680 if !status.is_success() {
681 return Err(LlmError::Provider(format!("HTTP {status}: {text}")));
682 }
683
684 let json: serde_json::Value =
685 serde_json::from_str(&text).map_err(|e| LlmError::Parse(e.to_string()))?;
686
687 json["choices"][0]["message"]["content"]
689 .as_str()
690 .map(|s| s.to_string())
691 .ok_or_else(|| LlmError::Parse(format!("unexpected Mistral response: {json}")))
692 }
693}