use async_trait::async_trait;
use gcp_auth::{Token, TokenProvider};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use tokio::sync::Semaphore;
use crate::context::TokenCounter;
use crate::embedding::config::{EmbeddingConfig, ParallelConfig, Provider};
use crate::embedding::error::{ApiError, ConfigError, EmbeddingError};
use crate::embedding::provider::{EmbeddingProvider, ProviderMetrics, Vector};
#[derive(Debug, Clone, Copy)]
pub enum TaskType {
RetrievalDocument,
RetrievalQuery,
SemanticSimilarity,
}
impl TaskType {
fn as_str(&self) -> &'static str {
match self {
TaskType::RetrievalDocument => "RETRIEVAL_DOCUMENT",
TaskType::RetrievalQuery => "RETRIEVAL_QUERY",
TaskType::SemanticSimilarity => "SEMANTIC_SIMILARITY",
}
}
}
#[derive(Serialize, Clone)]
struct EmbeddingInstance {
content: String,
task_type: &'static str,
}
#[derive(Serialize)]
struct PredictRequest {
instances: Vec<EmbeddingInstance>,
}
#[derive(Deserialize)]
struct Prediction {
embeddings: EmbeddingValues,
}
#[derive(Deserialize)]
struct EmbeddingValues {
values: Vec<f32>,
}
#[derive(Deserialize)]
struct PredictResponse {
predictions: Vec<Prediction>,
}
#[derive(Clone)]
pub struct GoogleProvider {
client: Client,
project_id: String,
region: String,
model: String,
task_type: TaskType,
token_provider: Arc<dyn TokenProvider>,
metrics: Arc<RwLock<ProviderMetrics>>,
parallel_config: ParallelConfig,
semaphore: Arc<Semaphore>,
}
impl GoogleProvider {
pub const DEFAULT_MODEL: &'static str = "text-embedding-004";
pub const DEFAULT_REGION: &'static str = "us-central1";
pub const MAX_BATCH_SIZE: usize = 250;
const MAX_TOKENS_PER_TEXT: usize = 19_000;
const REQUEST_TIMEOUT_SECS: u64 = 30;
const BATCH_TIMEOUT_SECS: u64 = 90;
const MAX_RETRIES: u32 = 3;
const BASE_RETRY_DELAY_MS: u64 = 1000;
pub async fn new_with_config(
project_id: String,
credentials_path: PathBuf,
region: String,
model: String,
parallel_config: ParallelConfig,
) -> Result<Self, EmbeddingError> {
if !credentials_path.exists() {
return Err(EmbeddingError::Config(ConfigError::FileError(format!(
"Credentials file not found: {}",
credentials_path.display()
))));
}
std::env::set_var("GOOGLE_APPLICATION_CREDENTIALS", &credentials_path);
let token_provider = gcp_auth::provider().await.map_err(|e| {
EmbeddingError::Config(ConfigError::InvalidValue {
field: "credentials".to_string(),
reason: format!("Failed to create token provider: {}", e),
})
})?;
let client = Client::builder()
.timeout(Duration::from_secs(Self::REQUEST_TIMEOUT_SECS))
.build()?;
let semaphore = Arc::new(Semaphore::new(parallel_config.max_concurrency));
Ok(Self {
client,
project_id,
region,
model,
task_type: TaskType::RetrievalDocument,
token_provider,
metrics: Arc::new(RwLock::new(ProviderMetrics::default())),
parallel_config,
semaphore,
})
}
pub async fn new(
project_id: String,
credentials_path: PathBuf,
region: String,
model: String,
) -> Result<Self, EmbeddingError> {
Self::new_with_config(
project_id,
credentials_path,
region,
model,
ParallelConfig::google_defaults(),
)
.await
}
pub async fn from_env() -> Result<Self, EmbeddingError> {
let config = EmbeddingConfig::from_env_with_provider(Some(Provider::Google))?;
let parallel_config = config.parallel;
let credentials_path = std::env::var("MAPROOM_GOOGLE_APPLICATION_CREDENTIALS")
.or_else(|_| std::env::var("GOOGLE_APPLICATION_CREDENTIALS"))
.map_err(|_| {
EmbeddingError::Config(ConfigError::EnvVarNotFound(
"MAPROOM_GOOGLE_APPLICATION_CREDENTIALS or GOOGLE_APPLICATION_CREDENTIALS"
.to_string(),
))
})?;
let project_id = std::env::var("MAPROOM_GOOGLE_PROJECT_ID")
.or_else(|_| std::env::var("GOOGLE_PROJECT_ID"))
.map_err(|_| {
EmbeddingError::Config(ConfigError::EnvVarNotFound(
"MAPROOM_GOOGLE_PROJECT_ID or GOOGLE_PROJECT_ID".to_string(),
))
})?;
let region =
std::env::var("GOOGLE_REGION").unwrap_or_else(|_| Self::DEFAULT_REGION.to_string());
let model =
std::env::var("GOOGLE_MODEL").unwrap_or_else(|_| Self::DEFAULT_MODEL.to_string());
Self::new_with_config(
project_id,
PathBuf::from(credentials_path),
region,
model,
parallel_config,
)
.await
}
pub async fn from_adc(
project_id: String,
region: String,
model: String,
parallel_config: ParallelConfig,
) -> Result<Self, EmbeddingError> {
let token_provider = gcp_auth::provider().await.map_err(|e| {
EmbeddingError::Config(ConfigError::MissingConfig(format!(
"No Google credentials found. Tried Application Default Credentials (ADC) but failed: {}\n\
Configure credentials using one of:\n\
1. Set GOOGLE_APPLICATION_CREDENTIALS to your service account JSON key file path\n\
2. Run: gcloud auth application-default login\n\
3. Use GCE metadata server or Workload Identity Federation",
e
)))
})?;
let client = Client::builder()
.timeout(Duration::from_secs(Self::REQUEST_TIMEOUT_SECS))
.build()?;
let semaphore = Arc::new(Semaphore::new(parallel_config.max_concurrency));
Ok(Self {
client,
project_id,
region,
model,
task_type: TaskType::RetrievalDocument,
token_provider,
metrics: Arc::new(RwLock::new(ProviderMetrics::default())),
parallel_config,
semaphore,
})
}
pub fn with_task_type(&mut self, task_type: TaskType) -> &mut Self {
self.task_type = task_type;
self
}
async fn get_access_token(&self) -> Result<String, EmbeddingError> {
let scopes = &["https://www.googleapis.com/auth/cloud-platform"];
let token: Arc<Token> = self.token_provider.token(scopes).await.map_err(|e| {
EmbeddingError::Api(ApiError::Authentication(format!(
"Failed to obtain access token: {}. Ensure GOOGLE_APPLICATION_CREDENTIALS \
points to a valid service account key and the service account has \
roles/aiplatform.user role.",
e
)))
})?;
Ok(token.as_str().to_string())
}
fn predict_url(&self) -> String {
format!(
"https://{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/google/models/{}:predict",
self.region, self.project_id, self.region, self.model
)
}
async fn predict_with_retry(
&self,
instances: Vec<EmbeddingInstance>,
) -> Result<Vec<Vector>, EmbeddingError> {
let mut last_error = None;
for attempt in 0..Self::MAX_RETRIES {
match self.predict_request(instances.clone()).await {
Ok(embeddings) => {
let mut metrics = self.metrics.write().await;
metrics.total_requests += 1;
return Ok(embeddings);
}
Err(e) => {
{
let mut metrics = self.metrics.write().await;
metrics.total_requests += 1;
metrics.failed_requests += 1;
}
let should_retry = match &e {
EmbeddingError::Network(_) => true,
EmbeddingError::Api(api_err) => api_err.is_retryable(),
_ => false,
};
if !should_retry || attempt == Self::MAX_RETRIES - 1 {
return Err(e);
}
last_error = Some(e);
let delay_ms = Self::BASE_RETRY_DELAY_MS * 2u64.pow(attempt);
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
}
}
}
Err(last_error
.unwrap_or_else(|| EmbeddingError::Other("All retry attempts failed".to_string())))
}
async fn predict_request(
&self,
instances: Vec<EmbeddingInstance>,
) -> Result<Vec<Vector>, EmbeddingError> {
let access_token = self.get_access_token().await?;
let request_body = PredictRequest { instances };
let timeout = if request_body.instances.len() > 1 {
Duration::from_secs(Self::BATCH_TIMEOUT_SECS)
} else {
Duration::from_secs(Self::REQUEST_TIMEOUT_SECS)
};
let response = self
.client
.post(self.predict_url())
.header("Authorization", format!("Bearer {}", access_token))
.header("Content-Type", "application/json")
.timeout(timeout)
.json(&request_body)
.send()
.await?;
let status = response.status();
if !status.is_success() {
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(EmbeddingError::Api(match status.as_u16() {
401 => ApiError::Authentication(format!(
"Invalid credentials or expired token. Ensure service account has roles/aiplatform.user role. Error: {}",
error_text
)),
403 => ApiError::Authentication(format!(
"Insufficient IAM permissions. Service account needs roles/aiplatform.user role. Error: {}",
error_text
)),
429 => {
let retry_after_ms = 1000; ApiError::RateLimit { retry_after_ms }
}
503 => ApiError::ServerError {
status: 503,
message: format!("Service temporarily unavailable: {}", error_text),
},
500..=599 => ApiError::ServerError {
status: status.as_u16(),
message: error_text,
},
400 => ApiError::BadRequest(error_text),
_ => ApiError::InvalidResponse(format!("HTTP {}: {}", status, error_text)),
}));
}
let response_body: PredictResponse = response.json().await?;
let embeddings: Vec<Vector> = response_body
.predictions
.into_iter()
.map(|pred| pred.embeddings.values)
.collect();
let expected_dim = self.dimension();
for embedding in embeddings.iter() {
if embedding.len() != expected_dim {
use crate::embedding::error::DimensionMismatchError;
return Err(EmbeddingError::DimensionMismatch(
DimensionMismatchError::new(
expected_dim,
embedding.len(),
"Google".to_string(),
self.model.clone(),
expected_dim,
),
));
}
}
Ok(embeddings)
}
async fn embed_batch_raw(&self, texts: Vec<String>) -> Result<Vec<Vector>, EmbeddingError> {
if texts.is_empty() {
return Ok(Vec::new());
}
if texts.len() > Self::MAX_BATCH_SIZE {
return Err(EmbeddingError::InvalidInput(format!(
"Batch size {} exceeds maximum of {}",
texts.len(),
Self::MAX_BATCH_SIZE
)));
}
let token_counter = TokenCounter::new();
let instances: Vec<EmbeddingInstance> = texts
.into_iter()
.map(|content| {
let truncated =
token_counter.truncate_to_limit(&content, Self::MAX_TOKENS_PER_TEXT);
if truncated.len() < content.len() {
tracing::warn!(
"Truncated embedding text from {} to {} chars (max {} tokens)",
content.len(),
truncated.len(),
Self::MAX_TOKENS_PER_TEXT
);
}
EmbeddingInstance {
content: truncated,
task_type: self.task_type.as_str(),
}
})
.collect();
self.predict_with_retry(instances).await
}
async fn embed_batch_parallel(
&self,
texts: Vec<String>,
) -> Result<Vec<Vector>, EmbeddingError> {
let total_texts = texts.len();
let sub_batch_size = self
.parallel_config
.sub_batch_size
.min(Self::MAX_BATCH_SIZE);
let sub_batches: Vec<Vec<String>> = texts
.chunks(sub_batch_size)
.map(|chunk| chunk.to_vec())
.collect();
let num_batches = sub_batches.len();
tracing::info!(
"Parallel batch embedding: {} texts in {} sub-batches (size: {}, concurrency: {})",
total_texts,
num_batches,
sub_batch_size,
self.parallel_config.max_concurrency
);
let start = std::time::Instant::now();
let handles: Vec<_> = sub_batches
.into_iter()
.enumerate()
.map(|(idx, batch)| {
let semaphore = self.semaphore.clone();
let this = self.clone();
let batch_size = batch.len();
tokio::spawn(async move {
let _permit = semaphore.acquire().await.unwrap();
let batch_start = std::time::Instant::now();
tracing::debug!("Starting sub-batch {} ({} texts)", idx, batch_size);
let result = this.embed_batch_raw(batch).await;
let elapsed = batch_start.elapsed();
tracing::debug!(
"Sub-batch {} completed in {:.2}s ({} texts)",
idx,
elapsed.as_secs_f64(),
batch_size
);
(idx, result)
})
})
.collect();
let mut results: Vec<(usize, Result<Vec<Vector>, EmbeddingError>)> = Vec::new();
for handle in handles {
let (idx, result) = handle.await.map_err(|e| {
EmbeddingError::Api(ApiError::InvalidResponse(format!("Task join error: {}", e)))
})?;
results.push((idx, result));
}
results.sort_by_key(|(idx, _)| *idx);
let mut embeddings = Vec::with_capacity(total_texts);
for (idx, result) in results {
let batch_embeddings = result.map_err(|e| {
EmbeddingError::Api(ApiError::InvalidResponse(format!(
"Sub-batch {} failed: {}",
idx, e
)))
})?;
embeddings.extend(batch_embeddings);
}
let elapsed = start.elapsed();
let throughput = total_texts as f64 / elapsed.as_secs_f64();
tracing::info!(
"Parallel batch completed in {:.2}s ({:.1} texts/sec)",
elapsed.as_secs_f64(),
throughput
);
Ok(embeddings)
}
}
#[async_trait]
impl EmbeddingProvider for GoogleProvider {
async fn embed(&self, text: String) -> Result<Vector, EmbeddingError> {
let instances = vec![EmbeddingInstance {
content: text,
task_type: self.task_type.as_str(),
}];
let mut embeddings = self.predict_with_retry(instances).await?;
Ok(embeddings.remove(0))
}
async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vector>, EmbeddingError> {
if texts.is_empty() {
return Ok(Vec::new());
}
if self.parallel_config.enabled && texts.len() > self.parallel_config.sub_batch_size {
self.embed_batch_parallel(texts).await
} else {
self.embed_batch_raw(texts).await
}
}
fn dimension(&self) -> usize {
768 }
fn provider_name(&self) -> &'static str {
"google"
}
fn metrics(&self) -> Option<ProviderMetrics> {
self.metrics.try_read().ok().map(|m| m.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_task_type_as_str() {
assert_eq!(TaskType::RetrievalDocument.as_str(), "RETRIEVAL_DOCUMENT");
assert_eq!(TaskType::RetrievalQuery.as_str(), "RETRIEVAL_QUERY");
assert_eq!(TaskType::SemanticSimilarity.as_str(), "SEMANTIC_SIMILARITY");
}
#[tokio::test]
async fn test_predict_url_construction() {
let temp_creds = tempfile::NamedTempFile::new().unwrap();
std::fs::write(
temp_creds.path(),
r#"{
"type": "service_account",
"project_id": "test-project",
"private_key_id": "key-id",
"private_key": "-----BEGIN RSA PRIVATE KEY-----\nMIIEpAIBAAKCAQEA2Z3qX2BTLS4e7VPIQKfSqfE8LKqCBOcN67jv\n-----END RSA PRIVATE KEY-----\n",
"client_email": "test@test-project.iam.gserviceaccount.com",
"client_id": "123456789",
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token"
}"#,
)
.unwrap();
std::env::set_var("GOOGLE_APPLICATION_CREDENTIALS", temp_creds.path());
let project_id = "my-project";
let region = "us-central1";
let model = "text-embedding-004";
let url = format!(
"https://{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/google/models/{}:predict",
region, project_id, region, model
);
assert!(url.contains("us-central1-aiplatform.googleapis.com"));
assert!(url.contains("my-project"));
assert!(url.contains("text-embedding-004"));
assert!(url.contains(":predict"));
}
#[test]
fn test_embedding_instance_serialization() {
let instance = EmbeddingInstance {
content: "test text".to_string(),
task_type: "RETRIEVAL_DOCUMENT",
};
let json = serde_json::to_string(&instance).unwrap();
assert!(json.contains("test text"));
assert!(json.contains("RETRIEVAL_DOCUMENT"));
}
#[test]
fn test_predict_response_deserialization() {
let json = r#"{
"predictions": [
{
"embeddings": {
"values": [0.1, 0.2, 0.3]
}
}
]
}"#;
let response: PredictResponse = serde_json::from_str(json).unwrap();
assert_eq!(response.predictions.len(), 1);
assert_eq!(response.predictions[0].embeddings.values.len(), 3);
assert_eq!(response.predictions[0].embeddings.values[0], 0.1);
}
#[tokio::test]
async fn test_dimension_and_provider_name() {
assert_eq!(GoogleProvider::DEFAULT_MODEL, "text-embedding-004");
assert_eq!(GoogleProvider::DEFAULT_REGION, "us-central1");
assert_eq!(GoogleProvider::MAX_BATCH_SIZE, 250);
}
#[test]
fn test_max_batch_size_constant() {
assert_eq!(GoogleProvider::MAX_BATCH_SIZE, 250);
}
#[test]
fn test_google_sub_batch_splitting_exact_boundary() {
let texts: Vec<String> = (0..200).map(|i| format!("text_{}", i)).collect();
let sub_batch_size = 200usize.min(GoogleProvider::MAX_BATCH_SIZE);
let sub_batches: Vec<Vec<String>> = texts
.chunks(sub_batch_size)
.map(|chunk| chunk.to_vec())
.collect();
assert_eq!(
sub_batches.len(),
1,
"200 texts with sub_batch_size=200 should produce 1 sub-batch"
);
assert_eq!(
sub_batches[0].len(),
200,
"Single sub-batch should contain all 200 texts"
);
}
#[test]
fn test_google_sub_batch_splitting_uneven() {
let texts: Vec<String> = (0..450).map(|i| format!("text_{}", i)).collect();
let sub_batch_size = 200usize.min(GoogleProvider::MAX_BATCH_SIZE);
let sub_batches: Vec<Vec<String>> = texts
.chunks(sub_batch_size)
.map(|chunk| chunk.to_vec())
.collect();
assert_eq!(
sub_batches.len(),
3,
"450 texts with sub_batch_size=200 should produce 3 sub-batches"
);
assert_eq!(
sub_batches[0].len(),
200,
"First sub-batch should have 200 texts"
);
assert_eq!(
sub_batches[1].len(),
200,
"Second sub-batch should have 200 texts"
);
assert_eq!(
sub_batches[2].len(),
50,
"Third sub-batch should have remaining 50 texts"
);
}
#[test]
fn test_google_sub_batch_splitting_respects_api_limit() {
let texts: Vec<String> = (0..600).map(|i| format!("text_{}", i)).collect();
let configured_sub_batch_size = 300;
let sub_batch_size = configured_sub_batch_size.min(GoogleProvider::MAX_BATCH_SIZE);
assert_eq!(
sub_batch_size, 250,
"sub_batch_size should be capped at MAX_BATCH_SIZE (250)"
);
let sub_batches: Vec<Vec<String>> = texts
.chunks(sub_batch_size)
.map(|chunk| chunk.to_vec())
.collect();
assert_eq!(
sub_batches.len(),
3,
"600 texts with capped sub_batch_size=250 should produce 3 sub-batches"
);
assert_eq!(
sub_batches[0].len(),
250,
"First sub-batch should have MAX_BATCH_SIZE texts"
);
assert_eq!(
sub_batches[1].len(),
250,
"Second sub-batch should have MAX_BATCH_SIZE texts"
);
assert_eq!(
sub_batches[2].len(),
100,
"Third sub-batch should have remaining 100 texts"
);
for (i, batch) in sub_batches.iter().enumerate() {
assert!(
batch.len() <= GoogleProvider::MAX_BATCH_SIZE,
"Sub-batch {} has {} texts, exceeds MAX_BATCH_SIZE ({})",
i,
batch.len(),
GoogleProvider::MAX_BATCH_SIZE
);
}
}
#[test]
fn test_google_result_merge_ordering_in_order() {
let results: Vec<(usize, Vec<Vec<f32>>)> = vec![
(0, vec![vec![0.0_f32; 768]]),
(1, vec![vec![1.0_f32; 768]]),
(2, vec![vec![2.0_f32; 768]]),
];
let mut sorted_results = results.clone();
sorted_results.sort_by_key(|(idx, _)| *idx);
let embeddings: Vec<Vec<f32>> = sorted_results
.into_iter()
.flat_map(|(_, batch)| batch)
.collect();
assert_eq!(
embeddings.len(),
3,
"Should have 3 embeddings after flattening"
);
assert_eq!(
embeddings[0][0], 0.0,
"First embedding should be from batch 0"
);
assert_eq!(
embeddings[1][0], 1.0,
"Second embedding should be from batch 1"
);
assert_eq!(
embeddings[2][0], 2.0,
"Third embedding should be from batch 2"
);
}
#[test]
fn test_google_result_merge_ordering_out_of_order() {
let results: Vec<(usize, Vec<Vec<f32>>)> = vec![
(2, vec![vec![2.0_f32; 768]]), (0, vec![vec![0.0_f32; 768]]), (1, vec![vec![1.0_f32; 768]]), ];
let mut sorted_results = results.clone();
sorted_results.sort_by_key(|(idx, _)| *idx);
assert_eq!(
sorted_results[0].0, 0,
"After sorting, first result should have index 0"
);
assert_eq!(
sorted_results[1].0, 1,
"After sorting, second result should have index 1"
);
assert_eq!(
sorted_results[2].0, 2,
"After sorting, third result should have index 2"
);
let embeddings: Vec<Vec<f32>> = sorted_results
.into_iter()
.flat_map(|(_, batch)| batch)
.collect();
assert_eq!(
embeddings.len(),
3,
"Should have 3 embeddings after flattening"
);
assert_eq!(
embeddings[0][0], 0.0,
"First embedding should be from batch 0 (order preserved)"
);
assert_eq!(
embeddings[1][0], 1.0,
"Second embedding should be from batch 1 (order preserved)"
);
assert_eq!(
embeddings[2][0], 2.0,
"Third embedding should be from batch 2 (order preserved)"
);
}
#[test]
fn test_google_result_merge_ordering_single_batch() {
let texts: Vec<String> = (0..100).map(|i| format!("text_{}", i)).collect();
let sub_batch_size = 200;
let sub_batches: Vec<Vec<String>> = texts
.chunks(sub_batch_size)
.map(|chunk| chunk.to_vec())
.collect();
assert_eq!(
sub_batches.len(),
1,
"100 texts with sub_batch_size=200 should be 1 batch"
);
let results: Vec<(usize, Vec<Vec<f32>>)> =
vec![(0, (0..100).map(|i| vec![i as f32; 768]).collect())];
let mut sorted_results = results.clone();
sorted_results.sort_by_key(|(idx, _)| *idx);
let embeddings: Vec<Vec<f32>> = sorted_results
.into_iter()
.flat_map(|(_, batch)| batch)
.collect();
assert_eq!(
embeddings.len(),
100,
"Should have 100 embeddings after flattening"
);
for (i, embedding) in embeddings.iter().enumerate() {
assert_eq!(
embedding[0], i as f32,
"Embedding at position {} should have value {} (order preserved)",
i, i
);
}
}
}