use super::{LlmConfig, LlmError, LlmProvider, Message, Role};
use crate::index::embedding::{Embedding, EmbeddingError, EmbeddingProvider};
use serde_json::json;
use std::sync::Mutex;
use std::time::{Duration, Instant};
struct CachedToken {
access_token: String,
expires_at: Instant,
}
impl CachedToken {
fn new(token: String, expires_in_secs: u64) -> Self {
let margin = expires_in_secs.saturating_sub(300);
Self {
access_token: token,
expires_at: Instant::now() + Duration::from_secs(margin),
}
}
fn is_valid(&self) -> bool {
Instant::now() < self.expires_at
}
}
enum TokenSource {
Static(String),
RefreshToken {
client_id: String,
client_secret: String,
refresh_token: String,
cached: Mutex<Option<CachedToken>>,
},
MetadataServer { cached: Mutex<Option<CachedToken>> },
GcloudSubprocess { cached: Mutex<Option<CachedToken>> },
}
impl TokenSource {
fn get_token(&self) -> Result<String, String> {
match self {
Self::Static(t) => Ok(t.clone()),
Self::RefreshToken {
client_id,
client_secret,
refresh_token,
cached,
} => {
let mut guard = cached.lock().unwrap();
if let Some(ref c) = *guard {
if c.is_valid() {
return Ok(c.access_token.clone());
}
}
let (token, expires_in) = oauth2_refresh(client_id, client_secret, refresh_token)?;
*guard = Some(CachedToken::new(token.clone(), expires_in));
Ok(token)
}
Self::MetadataServer { cached } => {
let mut guard = cached.lock().unwrap();
if let Some(ref c) = *guard {
if c.is_valid() {
return Ok(c.access_token.clone());
}
}
let (token, expires_in) = metadata_server_token()?;
*guard = Some(CachedToken::new(token.clone(), expires_in));
Ok(token)
}
Self::GcloudSubprocess { cached } => {
let mut guard = cached.lock().unwrap();
if let Some(ref c) = *guard {
if c.is_valid() {
return Ok(c.access_token.clone());
}
}
let token = gcloud_print_access_token()?;
*guard = Some(CachedToken::new(token.clone(), 3300));
Ok(token)
}
}
}
}
pub struct VertexAiConfig {
pub project: String,
pub location: String,
token_source: TokenSource,
}
impl VertexAiConfig {
pub fn from_env() -> Result<Self, String> {
let project = std::env::var("VERTEX_AI_PROJECT")
.or_else(|_| std::env::var("GOOGLE_CLOUD_PROJECT"))
.map_err(|_| {
"Vertex AI project not configured. Set VERTEX_AI_PROJECT \
(or GOOGLE_CLOUD_PROJECT) to your GCP project ID."
.to_string()
})?;
let location = std::env::var("VERTEX_AI_LOCATION")
.or_else(|_| std::env::var("GOOGLE_CLOUD_LOCATION"))
.unwrap_or_else(|_| "europe-west1".into());
let token_source = resolve_token_source()?;
Ok(Self {
project,
location,
token_source,
})
}
pub fn get_token(&self) -> Result<String, String> {
self.token_source.get_token()
}
}
impl Clone for VertexAiConfig {
fn clone(&self) -> Self {
let token_source = match &self.token_source {
TokenSource::Static(t) => TokenSource::Static(t.clone()),
TokenSource::RefreshToken {
client_id,
client_secret,
refresh_token,
..
} => TokenSource::RefreshToken {
client_id: client_id.clone(),
client_secret: client_secret.clone(),
refresh_token: refresh_token.clone(),
cached: Mutex::new(None),
},
TokenSource::MetadataServer { .. } => TokenSource::MetadataServer {
cached: Mutex::new(None),
},
TokenSource::GcloudSubprocess { .. } => TokenSource::GcloudSubprocess {
cached: Mutex::new(None),
},
};
Self {
project: self.project.clone(),
location: self.location.clone(),
token_source,
}
}
}
impl std::fmt::Debug for VertexAiConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let source = match &self.token_source {
TokenSource::Static(_) => "static",
TokenSource::RefreshToken { .. } => "refresh_token",
TokenSource::MetadataServer { .. } => "metadata_server",
TokenSource::GcloudSubprocess { .. } => "gcloud_subprocess",
};
f.debug_struct("VertexAiConfig")
.field("project", &self.project)
.field("location", &self.location)
.field("token_source", &source)
.finish()
}
}
fn resolve_token_source() -> Result<TokenSource, String> {
if let Ok(t) = std::env::var("VERTEX_AI_TOKEN") {
return Ok(TokenSource::Static(t));
}
if let Ok(path) = std::env::var("GOOGLE_APPLICATION_CREDENTIALS") {
if let Ok(source) = load_credentials_file(&path) {
return Ok(source);
}
}
let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".into());
let adc_path =
std::path::PathBuf::from(&home).join(".config/gcloud/application_default_credentials.json");
if adc_path.exists() {
if let Ok(source) = load_credentials_file(adc_path.to_str().unwrap_or("")) {
return Ok(source);
}
}
if metadata_server_available() {
return Ok(TokenSource::MetadataServer {
cached: Mutex::new(None),
});
}
if gcloud_available() {
return Ok(TokenSource::GcloudSubprocess {
cached: Mutex::new(None),
});
}
Err("No Google credentials found. Options:\n\
• Run `gcloud auth application-default login`\n\
• Set VERTEX_AI_TOKEN to an access token\n\
• Set GOOGLE_APPLICATION_CREDENTIALS to a service account key file\n\
• Run on GCE/Cloud Run/GKE (metadata server)"
.into())
}
fn load_credentials_file(path: &str) -> Result<TokenSource, String> {
let content = std::fs::read_to_string(path)
.map_err(|e| format!("cannot read credentials file {path}: {e}"))?;
let creds: serde_json::Value =
serde_json::from_str(&content).map_err(|e| format!("credentials JSON parse error: {e}"))?;
match creds["type"].as_str() {
Some("authorized_user") => Ok(TokenSource::RefreshToken {
client_id: creds["client_id"]
.as_str()
.ok_or("missing client_id")?
.into(),
client_secret: creds["client_secret"]
.as_str()
.ok_or("missing client_secret")?
.into(),
refresh_token: creds["refresh_token"]
.as_str()
.ok_or("missing refresh_token")?
.into(),
cached: Mutex::new(None),
}),
Some("service_account") => {
Ok(TokenSource::GcloudSubprocess {
cached: Mutex::new(None),
})
}
other => Err(format!(
"unsupported credentials type: {:?}",
other.unwrap_or("missing")
)),
}
}
fn oauth2_refresh(
client_id: &str,
client_secret: &str,
refresh_token: &str,
) -> Result<(String, u64), String> {
let client = reqwest::blocking::Client::builder()
.timeout(std::time::Duration::from_secs(15))
.connect_timeout(std::time::Duration::from_secs(10))
.build()
.unwrap_or_else(|_| reqwest::blocking::Client::new());
let resp = client
.post("https://oauth2.googleapis.com/token")
.form(&[
("client_id", client_id),
("client_secret", client_secret),
("refresh_token", refresh_token),
("grant_type", "refresh_token"),
])
.send()
.map_err(|e| format!("token refresh HTTP error: {e}"))?;
let status = resp.status();
let body: serde_json::Value = resp
.json()
.map_err(|e| format!("token refresh parse error: {e}"))?;
if !status.is_success() {
return Err(format!(
"token refresh failed (HTTP {status}): {}",
body.get("error_description")
.or(body.get("error"))
.and_then(|v| v.as_str())
.unwrap_or("unknown error")
));
}
let token = body["access_token"]
.as_str()
.ok_or("token refresh response has no access_token")?
.to_string();
let expires_in = body["expires_in"].as_u64().unwrap_or(3600);
Ok((token, expires_in))
}
fn metadata_server_token() -> Result<(String, u64), String> {
let client = reqwest::blocking::Client::builder()
.timeout(Duration::from_secs(5))
.build()
.unwrap();
let resp = client
.get("http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/token")
.header("Metadata-Flavor", "Google")
.send()
.map_err(|e| format!("metadata server request failed: {e}"))?;
if !resp.status().is_success() {
return Err(format!("metadata server returned HTTP {}", resp.status()));
}
let body: serde_json::Value = resp
.json()
.map_err(|e| format!("metadata server parse error: {e}"))?;
let token = body["access_token"]
.as_str()
.ok_or("metadata server response has no access_token")?
.to_string();
let expires_in = body["expires_in"].as_u64().unwrap_or(3600);
Ok((token, expires_in))
}
fn metadata_server_available() -> bool {
let client = reqwest::blocking::Client::builder()
.timeout(Duration::from_millis(500))
.build()
.unwrap_or_else(|_| reqwest::blocking::Client::new());
client
.get("http://metadata.google.internal/")
.header("Metadata-Flavor", "Google")
.send()
.is_ok()
}
fn gcloud_available() -> bool {
std::process::Command::new("gcloud")
.arg("version")
.output()
.is_ok()
}
fn gcloud_print_access_token() -> Result<String, String> {
let out = std::process::Command::new("gcloud")
.args(["auth", "print-access-token"])
.output()
.map_err(|e| format!("gcloud subprocess failed: {e}"))?;
if !out.status.success() {
let stderr = String::from_utf8_lossy(&out.stderr);
return Err(format!(
"gcloud auth print-access-token failed: {stderr}. \
Run `gcloud auth application-default login` to authenticate."
));
}
Ok(std::str::from_utf8(&out.stdout)
.map_err(|e| format!("gcloud output encoding error: {e}"))?
.trim()
.to_string())
}
pub struct VertexAiLlmProvider {
config: VertexAiConfig,
client: reqwest::blocking::Client,
}
impl VertexAiLlmProvider {
pub fn new(config: VertexAiConfig) -> Self {
let client = reqwest::blocking::Client::builder()
.timeout(std::time::Duration::from_secs(120))
.connect_timeout(std::time::Duration::from_secs(15))
.build()
.expect("failed to build reqwest client");
Self { config, client }
}
fn base_url(&self) -> String {
if self.config.location == "global" {
"https://aiplatform.googleapis.com/v1".into()
} else {
format!(
"https://{}-aiplatform.googleapis.com/v1",
self.config.location
)
}
}
}
impl LlmProvider for VertexAiLlmProvider {
fn complete(&self, messages: &[Message], config: &LlmConfig) -> Result<String, LlmError> {
let url = format!(
"{base}/projects/{project}/locations/{location}/publishers/google/models/{model}:generateContent",
base = self.base_url(),
project = self.config.project,
location = self.config.location,
model = config.model,
);
let system_instruction: Option<String> = messages
.iter()
.find(|m| matches!(m.role, Role::System))
.map(|m| m.content.clone());
let contents: Vec<serde_json::Value> = messages
.iter()
.filter(|m| !matches!(m.role, Role::System))
.map(|m| {
let role = match m.role {
Role::User => "user",
Role::Assistant => "model",
Role::System => unreachable!(),
};
json!({
"role": role,
"parts": [{"text": m.content}]
})
})
.collect();
let mut body = json!({
"contents": contents,
"generationConfig": {
"maxOutputTokens": config.max_tokens,
"temperature": config.temperature,
}
});
if let Some(sys) = system_instruction {
body["systemInstruction"] = json!({
"parts": [{"text": sys}]
});
}
let token = self
.config
.get_token()
.map_err(|e| LlmError::Provider(format!("auth error: {e}")))?;
let response = self
.client
.post(&url)
.bearer_auth(&token)
.json(&body)
.send()
.map_err(|e| LlmError::Http(e.to_string()))?;
let status = response.status();
let text = response.text().map_err(|e| LlmError::Http(e.to_string()))?;
if !status.is_success() {
return Err(LlmError::Provider(format!("HTTP {status}: {text}")));
}
let json: serde_json::Value =
serde_json::from_str(&text).map_err(|e| LlmError::Parse(e.to_string()))?;
json["candidates"][0]["content"]["parts"][0]["text"]
.as_str()
.map(|s| s.to_string())
.ok_or_else(|| LlmError::Parse(format!("unexpected response format: {json}")))
}
}
pub struct VertexAiEmbeddingProvider {
config: VertexAiConfig,
model: String,
dimensions: usize,
client: reqwest::blocking::Client,
}
impl VertexAiEmbeddingProvider {
pub fn new(config: VertexAiConfig, model: Option<String>, dimensions: Option<usize>) -> Self {
let client = reqwest::blocking::Client::builder()
.timeout(std::time::Duration::from_secs(30))
.connect_timeout(std::time::Duration::from_secs(15))
.build()
.expect("failed to build reqwest client");
Self {
config,
model: model.unwrap_or_else(|| "text-embedding-005".into()),
dimensions: dimensions.unwrap_or(256),
client,
}
}
fn base_url(&self) -> String {
if self.config.location == "global" {
"https://aiplatform.googleapis.com/v1".into()
} else {
format!(
"https://{}-aiplatform.googleapis.com/v1",
self.config.location
)
}
}
}
impl EmbeddingProvider for VertexAiEmbeddingProvider {
fn dimensions(&self) -> usize {
self.dimensions
}
fn embed(&self, text: &str) -> Result<Embedding, EmbeddingError> {
let url = format!(
"{base}/projects/{project}/locations/{location}/publishers/google/models/{model}:predict",
base = self.base_url(),
project = self.config.project,
location = self.config.location,
model = self.model,
);
let body = json!({
"instances": [{"content": text}],
"parameters": {"outputDimensionality": self.dimensions}
});
let token = self
.config
.get_token()
.map_err(|e| EmbeddingError::Provider(format!("auth error: {e}")))?;
let response = self
.client
.post(&url)
.bearer_auth(&token)
.json(&body)
.send()
.map_err(|e| EmbeddingError::Provider(e.to_string()))?;
let status = response.status();
let text = response
.text()
.map_err(|e| EmbeddingError::Provider(e.to_string()))?;
if !status.is_success() {
return Err(EmbeddingError::Provider(format!("HTTP {status}: {text}")));
}
let json: serde_json::Value =
serde_json::from_str(&text).map_err(|e| EmbeddingError::Provider(e.to_string()))?;
let values = json["predictions"][0]["embeddings"]["values"]
.as_array()
.ok_or_else(|| EmbeddingError::Provider("unexpected response format".into()))?;
values
.iter()
.map(|v| {
v.as_f64()
.map(|f| f as f32)
.ok_or_else(|| EmbeddingError::Provider("non-numeric embedding value".into()))
})
.collect()
}
}
pub struct MistralLlmProvider {
config: VertexAiConfig,
region: String,
client: reqwest::blocking::Client,
}
impl MistralLlmProvider {
pub fn new(config: VertexAiConfig) -> Self {
let region = if config.location == "global" || config.location.is_empty() {
"europe-west4".into()
} else {
config.location.clone()
};
let client = reqwest::blocking::Client::builder()
.timeout(std::time::Duration::from_secs(120))
.connect_timeout(std::time::Duration::from_secs(15))
.build()
.expect("failed to build reqwest client");
Self {
config,
region,
client,
}
}
}
impl LlmProvider for MistralLlmProvider {
fn complete(&self, messages: &[Message], config: &LlmConfig) -> Result<String, LlmError> {
let url = format!(
"https://{region}-aiplatform.googleapis.com/v1/projects/{project}/locations/{region}/publishers/mistralai/models/{model}:rawPredict",
region = self.region,
project = self.config.project,
model = config.model,
);
let msgs: Vec<serde_json::Value> = messages
.iter()
.map(|m| {
let role = match m.role {
Role::System => "system",
Role::User => "user",
Role::Assistant => "assistant",
};
json!({"role": role, "content": m.content})
})
.collect();
let body = json!({
"model": config.model,
"messages": msgs,
"max_tokens": config.max_tokens,
"temperature": config.temperature,
"stream": false,
});
let token = self
.config
.get_token()
.map_err(|e| LlmError::Provider(format!("auth error: {e}")))?;
let response = self
.client
.post(&url)
.bearer_auth(&token)
.json(&body)
.send()
.map_err(|e| LlmError::Http(e.to_string()))?;
let status = response.status();
let text = response.text().map_err(|e| LlmError::Http(e.to_string()))?;
if !status.is_success() {
return Err(LlmError::Provider(format!("HTTP {status}: {text}")));
}
let json: serde_json::Value =
serde_json::from_str(&text).map_err(|e| LlmError::Parse(e.to_string()))?;
json["choices"][0]["message"]["content"]
.as_str()
.map(|s| s.to_string())
.ok_or_else(|| LlmError::Parse(format!("unexpected Mistral response: {json}")))
}
}