Skip to main content

noether_engine/llm/
vertex.rs

1use super::{LlmConfig, LlmError, LlmProvider, Message, Role};
2use crate::index::embedding::{Embedding, EmbeddingError, EmbeddingProvider};
3use serde_json::json;
4use std::sync::Mutex;
5use std::time::{Duration, Instant};
6
7// ── Token source ──────────────────────────────────────────────────────────────
8
9/// Cached access token with expiry tracking.
10struct CachedToken {
11    access_token: String,
12    /// Refresh this many seconds before actual expiry to avoid races.
13    expires_at: Instant,
14}
15
16impl CachedToken {
17    fn new(token: String, expires_in_secs: u64) -> Self {
18        // Refresh 5 minutes early to avoid using a token that's about to expire.
19        let margin = expires_in_secs.saturating_sub(300);
20        Self {
21            access_token: token,
22            expires_at: Instant::now() + Duration::from_secs(margin),
23        }
24    }
25
26    fn is_valid(&self) -> bool {
27        Instant::now() < self.expires_at
28    }
29}
30
31/// How the provider obtains a Bearer token.
32///
33/// Resolution order in `VertexAiConfig::from_env()`:
34///   1. `VERTEX_AI_TOKEN` env var → `Static` (no refresh, works for 1-hour tokens)
35///   2. `GOOGLE_APPLICATION_CREDENTIALS` file / `~/.config/gcloud/application_default_credentials.json`
36///      with `type: "authorized_user"` → `RefreshToken` (auto-refreshes every ~55 min)
37///   3. GCE/Cloud Run/GKE metadata server → `MetadataServer` (auto-refreshes, zero config)
38///   4. `gcloud auth print-access-token` subprocess → `GcloudSubprocess` (local dev fallback)
39enum TokenSource {
40    /// Explicit static token — no auto-refresh. Fine for short-lived CLI invocations.
41    Static(String),
42    /// OAuth2 refresh token flow (ADC user credentials or `authorized_user` service files).
43    RefreshToken {
44        client_id: String,
45        client_secret: String,
46        refresh_token: String,
47        cached: Mutex<Option<CachedToken>>,
48    },
49    /// GCE instance metadata server — zero-config inside Google Cloud.
50    MetadataServer { cached: Mutex<Option<CachedToken>> },
51    /// `gcloud auth print-access-token` subprocess — local dev fallback when no ADC file.
52    GcloudSubprocess { cached: Mutex<Option<CachedToken>> },
53}
54
55impl TokenSource {
56    /// Obtain a valid access token, refreshing if necessary.
57    fn get_token(&self) -> Result<String, String> {
58        match self {
59            Self::Static(t) => Ok(t.clone()),
60
61            Self::RefreshToken {
62                client_id,
63                client_secret,
64                refresh_token,
65                cached,
66            } => {
67                let mut guard = cached.lock().unwrap();
68                if let Some(ref c) = *guard {
69                    if c.is_valid() {
70                        return Ok(c.access_token.clone());
71                    }
72                }
73                let (token, expires_in) = oauth2_refresh(client_id, client_secret, refresh_token)?;
74                *guard = Some(CachedToken::new(token.clone(), expires_in));
75                Ok(token)
76            }
77
78            Self::MetadataServer { cached } => {
79                let mut guard = cached.lock().unwrap();
80                if let Some(ref c) = *guard {
81                    if c.is_valid() {
82                        return Ok(c.access_token.clone());
83                    }
84                }
85                let (token, expires_in) = metadata_server_token()?;
86                *guard = Some(CachedToken::new(token.clone(), expires_in));
87                Ok(token)
88            }
89
90            Self::GcloudSubprocess { cached } => {
91                let mut guard = cached.lock().unwrap();
92                if let Some(ref c) = *guard {
93                    if c.is_valid() {
94                        return Ok(c.access_token.clone());
95                    }
96                }
97                let token = gcloud_print_access_token()?;
98                // gcloud tokens last ~1h; cache for 55 minutes.
99                *guard = Some(CachedToken::new(token.clone(), 3300));
100                Ok(token)
101            }
102        }
103    }
104}
105
106// ── VertexAiConfig ────────────────────────────────────────────────────────────
107
108/// Configuration for Vertex AI providers.
109pub struct VertexAiConfig {
110    pub project: String,
111    pub location: String,
112    token_source: TokenSource,
113}
114
115impl VertexAiConfig {
116    /// Load from environment variables.
117    ///
118    /// Token resolution order:
119    ///   1. `VERTEX_AI_TOKEN` — explicit static token
120    ///   2. `GOOGLE_APPLICATION_CREDENTIALS` file (authorized_user or service account key)
121    ///   3. `~/.config/gcloud/application_default_credentials.json` (ADC)
122    ///   4. GCE/Cloud Run metadata server (http://metadata.google.internal/...)
123    ///   5. `gcloud auth print-access-token` subprocess
124    pub fn from_env() -> Result<Self, String> {
125        let project = std::env::var("VERTEX_AI_PROJECT")
126            .or_else(|_| std::env::var("GOOGLE_CLOUD_PROJECT"))
127            .map_err(|_| {
128                "Vertex AI project not configured. Set VERTEX_AI_PROJECT \
129                 (or GOOGLE_CLOUD_PROJECT) to your GCP project ID."
130                    .to_string()
131            })?;
132        let location = std::env::var("VERTEX_AI_LOCATION")
133            .or_else(|_| std::env::var("GOOGLE_CLOUD_LOCATION"))
134            .unwrap_or_else(|_| "europe-west1".into());
135
136        let token_source = resolve_token_source()?;
137        Ok(Self {
138            project,
139            location,
140            token_source,
141        })
142    }
143
144    /// Get a valid access token, auto-refreshing if the current one has expired.
145    pub fn get_token(&self) -> Result<String, String> {
146        self.token_source.get_token()
147    }
148}
149
150// Manual Clone: we need to clone config for the providers, but Mutex isn't Clone.
151// We just start with a fresh empty cache in the clone.
152impl Clone for VertexAiConfig {
153    fn clone(&self) -> Self {
154        let token_source = match &self.token_source {
155            TokenSource::Static(t) => TokenSource::Static(t.clone()),
156            TokenSource::RefreshToken {
157                client_id,
158                client_secret,
159                refresh_token,
160                ..
161            } => TokenSource::RefreshToken {
162                client_id: client_id.clone(),
163                client_secret: client_secret.clone(),
164                refresh_token: refresh_token.clone(),
165                cached: Mutex::new(None),
166            },
167            TokenSource::MetadataServer { .. } => TokenSource::MetadataServer {
168                cached: Mutex::new(None),
169            },
170            TokenSource::GcloudSubprocess { .. } => TokenSource::GcloudSubprocess {
171                cached: Mutex::new(None),
172            },
173        };
174        Self {
175            project: self.project.clone(),
176            location: self.location.clone(),
177            token_source,
178        }
179    }
180}
181
182impl std::fmt::Debug for VertexAiConfig {
183    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
184        let source = match &self.token_source {
185            TokenSource::Static(_) => "static",
186            TokenSource::RefreshToken { .. } => "refresh_token",
187            TokenSource::MetadataServer { .. } => "metadata_server",
188            TokenSource::GcloudSubprocess { .. } => "gcloud_subprocess",
189        };
190        f.debug_struct("VertexAiConfig")
191            .field("project", &self.project)
192            .field("location", &self.location)
193            .field("token_source", &source)
194            .finish()
195    }
196}
197
198// ── Token resolution ──────────────────────────────────────────────────────────
199
200fn resolve_token_source() -> Result<TokenSource, String> {
201    // 1. Explicit static token
202    if let Ok(t) = std::env::var("VERTEX_AI_TOKEN") {
203        return Ok(TokenSource::Static(t));
204    }
205
206    // 2. GOOGLE_APPLICATION_CREDENTIALS file
207    if let Ok(path) = std::env::var("GOOGLE_APPLICATION_CREDENTIALS") {
208        if let Ok(source) = load_credentials_file(&path) {
209            return Ok(source);
210        }
211    }
212
213    // 3. ADC file (~/.config/gcloud/application_default_credentials.json)
214    let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".into());
215    let adc_path =
216        std::path::PathBuf::from(&home).join(".config/gcloud/application_default_credentials.json");
217    if adc_path.exists() {
218        if let Ok(source) = load_credentials_file(adc_path.to_str().unwrap_or("")) {
219            return Ok(source);
220        }
221    }
222
223    // 4. GCE / Cloud Run / GKE metadata server
224    if metadata_server_available() {
225        return Ok(TokenSource::MetadataServer {
226            cached: Mutex::new(None),
227        });
228    }
229
230    // 5. gcloud subprocess (local dev fallback)
231    if gcloud_available() {
232        return Ok(TokenSource::GcloudSubprocess {
233            cached: Mutex::new(None),
234        });
235    }
236
237    Err("No Google credentials found. Options:\n\
238         • Run `gcloud auth application-default login`\n\
239         • Set VERTEX_AI_TOKEN to an access token\n\
240         • Set GOOGLE_APPLICATION_CREDENTIALS to a service account key file\n\
241         • Run on GCE/Cloud Run/GKE (metadata server)"
242        .into())
243}
244
245fn load_credentials_file(path: &str) -> Result<TokenSource, String> {
246    let content = std::fs::read_to_string(path)
247        .map_err(|e| format!("cannot read credentials file {path}: {e}"))?;
248    let creds: serde_json::Value =
249        serde_json::from_str(&content).map_err(|e| format!("credentials JSON parse error: {e}"))?;
250
251    match creds["type"].as_str() {
252        Some("authorized_user") => Ok(TokenSource::RefreshToken {
253            client_id: creds["client_id"]
254                .as_str()
255                .ok_or("missing client_id")?
256                .into(),
257            client_secret: creds["client_secret"]
258                .as_str()
259                .ok_or("missing client_secret")?
260                .into(),
261            refresh_token: creds["refresh_token"]
262                .as_str()
263                .ok_or("missing refresh_token")?
264                .into(),
265            cached: Mutex::new(None),
266        }),
267        Some("service_account") => {
268            // Service accounts on non-GCE machines need JWT → token exchange.
269            // We delegate to `gcloud auth print-access-token` which handles this
270            // transparently when GOOGLE_APPLICATION_CREDENTIALS is set.
271            Ok(TokenSource::GcloudSubprocess {
272                cached: Mutex::new(None),
273            })
274        }
275        other => Err(format!(
276            "unsupported credentials type: {:?}",
277            other.unwrap_or("missing")
278        )),
279    }
280}
281
282/// Exchange a refresh token for an access token via the Google OAuth2 endpoint.
283/// Returns `(access_token, expires_in_seconds)`.
284fn oauth2_refresh(
285    client_id: &str,
286    client_secret: &str,
287    refresh_token: &str,
288) -> Result<(String, u64), String> {
289    let client = reqwest::blocking::Client::builder()
290        .timeout(std::time::Duration::from_secs(15))
291        .connect_timeout(std::time::Duration::from_secs(10))
292        .build()
293        .unwrap_or_else(|_| reqwest::blocking::Client::new());
294    let resp = client
295        .post("https://oauth2.googleapis.com/token")
296        .form(&[
297            ("client_id", client_id),
298            ("client_secret", client_secret),
299            ("refresh_token", refresh_token),
300            ("grant_type", "refresh_token"),
301        ])
302        .send()
303        .map_err(|e| format!("token refresh HTTP error: {e}"))?;
304
305    let status = resp.status();
306    let body: serde_json::Value = resp
307        .json()
308        .map_err(|e| format!("token refresh parse error: {e}"))?;
309
310    if !status.is_success() {
311        return Err(format!(
312            "token refresh failed (HTTP {status}): {}",
313            body.get("error_description")
314                .or(body.get("error"))
315                .and_then(|v| v.as_str())
316                .unwrap_or("unknown error")
317        ));
318    }
319
320    let token = body["access_token"]
321        .as_str()
322        .ok_or("token refresh response has no access_token")?
323        .to_string();
324    let expires_in = body["expires_in"].as_u64().unwrap_or(3600);
325    Ok((token, expires_in))
326}
327
328/// Fetch a token from the GCE instance metadata server.
329/// Returns `(access_token, expires_in_seconds)`.
330fn metadata_server_token() -> Result<(String, u64), String> {
331    let client = reqwest::blocking::Client::builder()
332        .timeout(Duration::from_secs(5))
333        .build()
334        .unwrap();
335    let resp = client
336        .get("http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/token")
337        .header("Metadata-Flavor", "Google")
338        .send()
339        .map_err(|e| format!("metadata server request failed: {e}"))?;
340
341    if !resp.status().is_success() {
342        return Err(format!("metadata server returned HTTP {}", resp.status()));
343    }
344
345    let body: serde_json::Value = resp
346        .json()
347        .map_err(|e| format!("metadata server parse error: {e}"))?;
348    let token = body["access_token"]
349        .as_str()
350        .ok_or("metadata server response has no access_token")?
351        .to_string();
352    let expires_in = body["expires_in"].as_u64().unwrap_or(3600);
353    Ok((token, expires_in))
354}
355
356fn metadata_server_available() -> bool {
357    let client = reqwest::blocking::Client::builder()
358        .timeout(Duration::from_millis(500))
359        .build()
360        .unwrap_or_else(|_| reqwest::blocking::Client::new());
361    client
362        .get("http://metadata.google.internal/")
363        .header("Metadata-Flavor", "Google")
364        .send()
365        .is_ok()
366}
367
368fn gcloud_available() -> bool {
369    std::process::Command::new("gcloud")
370        .arg("version")
371        .output()
372        .is_ok()
373}
374
375fn gcloud_print_access_token() -> Result<String, String> {
376    let out = std::process::Command::new("gcloud")
377        .args(["auth", "print-access-token"])
378        .output()
379        .map_err(|e| format!("gcloud subprocess failed: {e}"))?;
380
381    if !out.status.success() {
382        let stderr = String::from_utf8_lossy(&out.stderr);
383        return Err(format!(
384            "gcloud auth print-access-token failed: {stderr}. \
385             Run `gcloud auth application-default login` to authenticate."
386        ));
387    }
388
389    Ok(std::str::from_utf8(&out.stdout)
390        .map_err(|e| format!("gcloud output encoding error: {e}"))?
391        .trim()
392        .to_string())
393}
394
395/// Vertex AI LLM provider for Gemini models.
396/// Uses the global endpoint: https://aiplatform.googleapis.com/v1/...
397pub struct VertexAiLlmProvider {
398    config: VertexAiConfig,
399    client: reqwest::blocking::Client,
400}
401
402impl VertexAiLlmProvider {
403    pub fn new(config: VertexAiConfig) -> Self {
404        let client = reqwest::blocking::Client::builder()
405            .timeout(std::time::Duration::from_secs(120))
406            .connect_timeout(std::time::Duration::from_secs(15))
407            .build()
408            .expect("failed to build reqwest client");
409        Self { config, client }
410    }
411
412    fn base_url(&self) -> String {
413        if self.config.location == "global" {
414            "https://aiplatform.googleapis.com/v1".into()
415        } else {
416            format!(
417                "https://{}-aiplatform.googleapis.com/v1",
418                self.config.location
419            )
420        }
421    }
422}
423
424impl LlmProvider for VertexAiLlmProvider {
425    fn complete(&self, messages: &[Message], config: &LlmConfig) -> Result<String, LlmError> {
426        let url = format!(
427            "{base}/projects/{project}/locations/{location}/publishers/google/models/{model}:generateContent",
428            base = self.base_url(),
429            project = self.config.project,
430            location = self.config.location,
431            model = config.model,
432        );
433
434        // Convert messages to Gemini format
435        let system_instruction: Option<String> = messages
436            .iter()
437            .find(|m| matches!(m.role, Role::System))
438            .map(|m| m.content.clone());
439
440        let contents: Vec<serde_json::Value> = messages
441            .iter()
442            .filter(|m| !matches!(m.role, Role::System))
443            .map(|m| {
444                let role = match m.role {
445                    Role::User => "user",
446                    Role::Assistant => "model",
447                    Role::System => unreachable!(),
448                };
449                json!({
450                    "role": role,
451                    "parts": [{"text": m.content}]
452                })
453            })
454            .collect();
455
456        let mut body = json!({
457            "contents": contents,
458            "generationConfig": {
459                "maxOutputTokens": config.max_tokens,
460                "temperature": config.temperature,
461            }
462        });
463
464        if let Some(sys) = system_instruction {
465            body["systemInstruction"] = json!({
466                "parts": [{"text": sys}]
467            });
468        }
469
470        let token = self
471            .config
472            .get_token()
473            .map_err(|e| LlmError::Provider(format!("auth error: {e}")))?;
474
475        let response = self
476            .client
477            .post(&url)
478            .bearer_auth(&token)
479            .json(&body)
480            .send()
481            .map_err(|e| LlmError::Http(e.to_string()))?;
482
483        let status = response.status();
484        let text = response.text().map_err(|e| LlmError::Http(e.to_string()))?;
485
486        if !status.is_success() {
487            return Err(LlmError::Provider(format!("HTTP {status}: {text}")));
488        }
489
490        let json: serde_json::Value =
491            serde_json::from_str(&text).map_err(|e| LlmError::Parse(e.to_string()))?;
492
493        // Extract text from Gemini response
494        json["candidates"][0]["content"]["parts"][0]["text"]
495            .as_str()
496            .map(|s| s.to_string())
497            .ok_or_else(|| LlmError::Parse(format!("unexpected response format: {json}")))
498    }
499}
500
501/// Vertex AI embedding provider.
502/// Uses the global endpoint by default.
503pub struct VertexAiEmbeddingProvider {
504    config: VertexAiConfig,
505    model: String,
506    dimensions: usize,
507    client: reqwest::blocking::Client,
508}
509
510impl VertexAiEmbeddingProvider {
511    pub fn new(config: VertexAiConfig, model: Option<String>, dimensions: Option<usize>) -> Self {
512        let client = reqwest::blocking::Client::builder()
513            .timeout(std::time::Duration::from_secs(30))
514            .connect_timeout(std::time::Duration::from_secs(15))
515            .build()
516            .expect("failed to build reqwest client");
517        Self {
518            config,
519            model: model.unwrap_or_else(|| "text-embedding-005".into()),
520            dimensions: dimensions.unwrap_or(256),
521            client,
522        }
523    }
524
525    fn base_url(&self) -> String {
526        if self.config.location == "global" {
527            "https://aiplatform.googleapis.com/v1".into()
528        } else {
529            format!(
530                "https://{}-aiplatform.googleapis.com/v1",
531                self.config.location
532            )
533        }
534    }
535}
536
537impl EmbeddingProvider for VertexAiEmbeddingProvider {
538    fn dimensions(&self) -> usize {
539        self.dimensions
540    }
541
542    fn embed(&self, text: &str) -> Result<Embedding, EmbeddingError> {
543        let url = format!(
544            "{base}/projects/{project}/locations/{location}/publishers/google/models/{model}:predict",
545            base = self.base_url(),
546            project = self.config.project,
547            location = self.config.location,
548            model = self.model,
549        );
550
551        let body = json!({
552            "instances": [{"content": text}],
553            "parameters": {"outputDimensionality": self.dimensions}
554        });
555
556        let token = self
557            .config
558            .get_token()
559            .map_err(|e| EmbeddingError::Provider(format!("auth error: {e}")))?;
560
561        let response = self
562            .client
563            .post(&url)
564            .bearer_auth(&token)
565            .json(&body)
566            .send()
567            .map_err(|e| EmbeddingError::Provider(e.to_string()))?;
568
569        let status = response.status();
570        let text = response
571            .text()
572            .map_err(|e| EmbeddingError::Provider(e.to_string()))?;
573
574        if !status.is_success() {
575            return Err(EmbeddingError::Provider(format!("HTTP {status}: {text}")));
576        }
577
578        let json: serde_json::Value =
579            serde_json::from_str(&text).map_err(|e| EmbeddingError::Provider(e.to_string()))?;
580
581        let values = json["predictions"][0]["embeddings"]["values"]
582            .as_array()
583            .ok_or_else(|| EmbeddingError::Provider("unexpected response format".into()))?;
584
585        values
586            .iter()
587            .map(|v| {
588                v.as_f64()
589                    .map(|f| f as f32)
590                    .ok_or_else(|| EmbeddingError::Provider("non-numeric embedding value".into()))
591            })
592            .collect()
593    }
594}
595
596// ── Mistral on Vertex AI ────────────────────────────────────────────────────
597
598/// Vertex AI LLM provider for Mistral models (mistral-small-2503, mistral-medium-3, codestral-2).
599///
600/// Mistral uses the OpenAI-compatible `rawPredict` endpoint and is only available in
601/// `us-central1` and `europe-west4` (not `global`). Models must be enabled from the
602/// Model Garden console before use.
603///
604/// Model name detection: model names containing "mistral" or "codestral" route here.
605pub struct MistralLlmProvider {
606    config: VertexAiConfig,
607    /// Resolved region: defaults to us-central1 if config.location is "global".
608    region: String,
609    client: reqwest::blocking::Client,
610}
611
612impl MistralLlmProvider {
613    pub fn new(config: VertexAiConfig) -> Self {
614        // Mistral doesn't support "global" — fall back to europe-west4 (enabled by default).
615        // us-central1 also works if explicitly set and the model is enabled there.
616        let region = if config.location == "global" || config.location.is_empty() {
617            "europe-west4".into()
618        } else {
619            config.location.clone()
620        };
621        let client = reqwest::blocking::Client::builder()
622            .timeout(std::time::Duration::from_secs(120))
623            .connect_timeout(std::time::Duration::from_secs(15))
624            .build()
625            .expect("failed to build reqwest client");
626        Self {
627            config,
628            region,
629            client,
630        }
631    }
632}
633
634impl LlmProvider for MistralLlmProvider {
635    fn complete(&self, messages: &[Message], config: &LlmConfig) -> Result<String, LlmError> {
636        let url = format!(
637            "https://{region}-aiplatform.googleapis.com/v1/projects/{project}/locations/{region}/publishers/mistralai/models/{model}:rawPredict",
638            region = self.region,
639            project = self.config.project,
640            model = config.model,
641        );
642
643        // OpenAI-compatible message format
644        let msgs: Vec<serde_json::Value> = messages
645            .iter()
646            .map(|m| {
647                let role = match m.role {
648                    Role::System => "system",
649                    Role::User => "user",
650                    Role::Assistant => "assistant",
651                };
652                json!({"role": role, "content": m.content})
653            })
654            .collect();
655
656        let body = json!({
657            "model": config.model,
658            "messages": msgs,
659            "max_tokens": config.max_tokens,
660            "temperature": config.temperature,
661            "stream": false,
662        });
663
664        let token = self
665            .config
666            .get_token()
667            .map_err(|e| LlmError::Provider(format!("auth error: {e}")))?;
668
669        let response = self
670            .client
671            .post(&url)
672            .bearer_auth(&token)
673            .json(&body)
674            .send()
675            .map_err(|e| LlmError::Http(e.to_string()))?;
676
677        let status = response.status();
678        let text = response.text().map_err(|e| LlmError::Http(e.to_string()))?;
679
680        if !status.is_success() {
681            return Err(LlmError::Provider(format!("HTTP {status}: {text}")));
682        }
683
684        let json: serde_json::Value =
685            serde_json::from_str(&text).map_err(|e| LlmError::Parse(e.to_string()))?;
686
687        // OpenAI-compatible response: choices[0].message.content
688        json["choices"][0]["message"]["content"]
689            .as_str()
690            .map(|s| s.to_string())
691            .ok_or_else(|| LlmError::Parse(format!("unexpected Mistral response: {json}")))
692    }
693}