use async_trait::async_trait;
use futures::stream::StreamExt;
use futures::Stream;
use reqwest::Client;
use std::pin::Pin;
use super::google_shared::{
build_request_body, convert_messages, convert_tools, create_error_message, parse_google_events,
};
use super::shared_client;
use super::{Provider, ProviderError, ProviderEvent, StreamOptions};
use crate::{Api, Context, Model, StopReason};
#[derive(Clone)]
pub struct VertexProvider {
client: &'static Client,
}
impl VertexProvider {
pub fn new() -> Self {
Self {
client: shared_client(),
}
}
async fn get_access_token(&self) -> Result<String, ProviderError> {
if let Ok(token) = std::env::var("GOOGLE_ACCESS_TOKEN") {
if !token.is_empty() {
return Ok(token);
}
}
if let Ok(token) = Self::get_gcloud_token().await {
return Ok(token);
}
if let Ok(creds) = std::env::var("GOOGLE_APPLICATION_CREDENTIALS") {
if !creds.is_empty() {
return Self::get_token_from_service_account(&creds).await;
}
}
Err(ProviderError::MissingApiKey)
}
async fn get_gcloud_token() -> Result<String, ProviderError> {
use std::io;
use tokio::process::Command;
let output = Command::new("gcloud")
.args(["auth", "print-access-token"])
.output()
.await
.map_err(ProviderError::IoError)?;
if output.status.success() {
let token = String::from_utf8_lossy(&output.stdout).trim().to_string();
if !token.is_empty() {
return Ok(token);
}
}
Err(ProviderError::IoError(io::Error::new(
io::ErrorKind::NotFound,
"gcloud token not available",
)))
}
async fn get_token_from_service_account(
credentials_path: &str,
) -> Result<String, ProviderError> {
use std::fs;
use tokio::time::{sleep, Duration};
let creds_json = fs::read_to_string(credentials_path).map_err(ProviderError::IoError)?;
let creds: ServiceAccountCreds =
serde_json::from_str(&creds_json).map_err(|_| ProviderError::InvalidApiKey)?;
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
let header = base64_url_encode(&serde_json::json!({"alg": "RS256", "typ": "JWT"}));
let claims = serde_json::json!({
"iss": creds.client_email,
"sub": creds.client_email,
"aud": "https://oauth2.googleapis.com/token",
"iat": now,
"exp": now + 3600,
"scope": "https://www.googleapis.com/auth/cloud-platform"
});
let claims_b64 = base64_url_encode(&claims);
let signature = sign_rs256(&header, &claims_b64, &creds.private_key)?;
let jwt = signature;
let client = shared_client();
let response = client
.post("https://oauth2.googleapis.com/token")
.form(&[
("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"),
("assertion", &jwt),
])
.send()
.await
.map_err(ProviderError::RequestFailed)?;
if !response.status().is_success() {
return Err(ProviderError::HttpError(
response.status().as_u16(),
response.text().await.unwrap_or_default(),
));
}
let token_response: TokenResponse = response
.json()
.await
.map_err(|e| ProviderError::RequestFailed(e))?;
sleep(Duration::from_secs(60 * 55)).await;
Ok(token_response.access_token)
}
fn get_project_id() -> Result<String, ProviderError> {
std::env::var("GOOGLE_CLOUD_PROJECT")
.or_else(|_| std::env::var("GOOGLE_PROJECT"))
.map_err(|_| ProviderError::MissingApiKey)
}
fn get_region() -> String {
std::env::var("GOOGLE_CLOUD_REGION").unwrap_or_else(|_| "us-central1".to_string())
}
}
impl Default for VertexProvider {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Provider for VertexProvider {
async fn stream(
&self,
model: &Model,
context: &Context,
options: Option<StreamOptions>,
) -> Result<Pin<Box<dyn Stream<Item = ProviderEvent> + Send>>, ProviderError> {
let options = options.unwrap_or_default();
let access_token = self.get_access_token().await?;
let project_id = Self::get_project_id()?;
let region = Self::get_region();
let model_id = &model.id;
let url = format!(
"https://{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/google/models/{}:streamGenerateContent",
region, project_id, region, model_id
);
let contents = convert_messages(context)?;
let tools_json = convert_tools(&context.tools, false);
let body = build_request_body(
&contents,
context.system_prompt.as_deref(),
tools_json.as_ref(),
options.temperature,
options.max_tokens,
);
let response = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", access_token))
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.map_err(ProviderError::RequestFailed)?;
if !response.status().is_success() {
let status = response.status();
let body: String = response.text().await.unwrap_or_default();
return Err(ProviderError::HttpError(status.as_u16(), body));
}
let model_name = model.id.clone();
let stream = response.bytes_stream().flat_map(move |chunk| match chunk {
Ok(bytes) => {
let text = String::from_utf8_lossy(&bytes);
futures::stream::iter(parse_google_events(
&text,
Api::GoogleVertex,
"vertex",
&model_name,
))
}
Err(e) => futures::stream::iter(vec![ProviderEvent::Error {
reason: StopReason::Error,
error: create_error_message(Api::GoogleVertex, "vertex", &e.to_string()),
}]),
});
Ok(Box::pin(stream))
}
fn name(&self) -> &str {
"vertex"
}
}
fn base64_url_encode(value: &serde_json::Value) -> String {
use base64::Engine as _;
let json = serde_json::to_string(value).unwrap();
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(json.as_bytes())
}
fn sign_rs256(
header_b64: &str,
claims_b64: &str,
private_key_pem: &str,
) -> Result<String, ProviderError> {
use base64::Engine as _;
use pkcs8::DecodePrivateKey;
use rsa::pkcs1v15::SigningKey;
use rsa::RsaPrivateKey;
use sha2::Sha256;
use signature::{SignatureEncoding, Signer};
let message = format!("{}.{}", header_b64, claims_b64);
let key = RsaPrivateKey::from_pkcs8_pem(private_key_pem).map_err(|_| ProviderError::InvalidApiKey)?;
let signing_key = SigningKey::<Sha256>::new_unprefixed(key);
let signature = signing_key.sign(message.as_bytes());
let sig_bytes = signature.to_bytes();
let sig_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(&sig_bytes);
Ok(format!("{}.{}", message, sig_b64))
}
#[derive(Debug, serde::Deserialize)]
#[allow(dead_code)]
struct TokenResponse {
access_token: String,
expires_in: usize,
token_type: String,
}
#[derive(Debug, serde::Deserialize)]
#[allow(dead_code)]
struct ServiceAccountCreds {
#[serde(rename = "type")]
_type: String,
project_id: String,
private_key_id: String,
private_key: String,
client_email: String,
client_id: String,
auth_uri: String,
token_uri: String,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{AssistantMessage, Context, Message};
#[test]
fn test_vertex_provider_name() {
let provider = VertexProvider::new();
assert_eq!(provider.name(), "vertex");
}
#[test]
fn test_build_vertex_contents_with_text() {
let mut ctx = Context::new();
ctx.add_message(Message::user("Hello, world!"));
let contents = convert_messages(&ctx).unwrap();
assert_eq!(contents.len(), 1);
assert_eq!(contents[0]["role"], "user");
assert_eq!(contents[0]["parts"][0]["text"], "Hello, world!");
}
#[test]
fn test_build_vertex_tools() {
let tools = vec![crate::Tool::new(
"get_weather",
"Get weather for a location",
serde_json::json!({
"type": "object",
"properties": {
"location": { "type": "string", "description": "The city name" }
},
"required": ["location"]
}),
)];
let tools_json = convert_tools(&tools, false).unwrap();
let declarations = tools_json[0]["functionDeclarations"].as_array().unwrap();
assert_eq!(declarations.len(), 1);
assert_eq!(declarations[0]["name"], "get_weather");
}
#[test]
fn test_parse_vertex_events_basic_text() {
let sse_data = r#"data: {"candidates":[{"content":{"parts":[{"text":"Hello"}]}}]}"#;
let events = parse_google_events(sse_data, Api::GoogleVertex, "vertex", "gemini-1.5-pro");
assert!(!events.is_empty());
if let ProviderEvent::TextDelta { delta, .. } = &events[0] {
assert_eq!(delta, "Hello");
} else {
panic!("Expected TextDelta event");
}
}
#[test]
fn test_get_region_default() {
std::env::remove_var("GOOGLE_CLOUD_REGION");
assert_eq!(VertexProvider::get_region(), "us-central1");
}
#[test]
fn test_create_error_message() {
let msg = create_error_message(Api::GoogleVertex, "vertex", "Something went wrong");
assert_eq!(msg.provider, "vertex");
assert_eq!(msg.api, Api::GoogleVertex);
assert_eq!(msg.stop_reason, StopReason::Error);
}
}