use anyhow::{Context, Result};
use serde_json::{json, Value};
use super::super::types::InputType;
use super::{EmbeddingProvider, HTTP_CLIENT};
pub struct GoogleProviderImpl {
model_name: String,
dimension: usize,
}
impl GoogleProviderImpl {
pub fn new(model: &str) -> Result<Self> {
let dimension = Self::get_model_dimension(model)?;
Ok(Self {
model_name: model.to_string(),
dimension,
})
}
fn get_model_dimension(model: &str) -> Result<usize> {
match model {
"gemini-embedding-001" => Ok(3072), "text-embedding-005" => Ok(768), "text-multilingual-embedding-002" => Ok(768), _ => Err(anyhow::anyhow!(
"Unsupported Google model: '{}'. Supported models: gemini-embedding-001 (3072d), text-embedding-005 (768d), text-multilingual-embedding-002 (768d)",
model
)),
}
}
}
#[async_trait::async_trait]
impl EmbeddingProvider for GoogleProviderImpl {
async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
GoogleProvider::generate_embeddings(text, &self.model_name).await
}
async fn generate_embeddings_batch(
&self,
texts: Vec<String>,
input_type: InputType,
) -> Result<Vec<Vec<f32>>> {
let processed_texts: Vec<String> = texts
.into_iter()
.map(|text| input_type.apply_prefix(&text))
.collect();
GoogleProvider::generate_embeddings_batch(processed_texts, &self.model_name).await
}
fn get_dimension(&self) -> usize {
self.dimension
}
fn is_model_supported(&self) -> bool {
matches!(
self.model_name.as_str(),
"gemini-embedding-001" | "text-embedding-005" | "text-multilingual-embedding-002"
)
}
}
pub struct GoogleProvider;
impl GoogleProvider {
pub fn get_supported_models() -> Vec<&'static str> {
vec![
"gemini-embedding-001",
"text-embedding-005",
"text-multilingual-embedding-002",
]
}
pub async fn generate_embeddings(contents: &str, model: &str) -> Result<Vec<f32>> {
let result = Self::generate_embeddings_batch(vec![contents.to_string()], model).await?;
result
.first()
.cloned()
.ok_or_else(|| anyhow::anyhow!("No embeddings found"))
}
pub async fn generate_embeddings_batch(
texts: Vec<String>,
model: &str,
) -> Result<Vec<Vec<f32>>> {
let google_api_key = std::env::var("GOOGLE_API_KEY")
.context("GOOGLE_API_KEY environment variable not set")?;
let mut all_embeddings = Vec::new();
for text in texts {
let response = HTTP_CLIENT
.post(format!("https://generativelanguage.googleapis.com/v1beta/models/{}:embedContent?key={}", model, google_api_key))
.header("Content-Type", "application/json")
.json(&json!({
"content": {
"parts": [{
"text": text
}]
}
}))
.send()
.await?;
if !response.status().is_success() {
let error_text = response.text().await?;
return Err(anyhow::anyhow!("Google API error: {}", error_text));
}
let response_json: Value = response.json().await?;
let embedding = response_json["embedding"]["values"]
.as_array()
.context("Failed to get embedding values")?
.iter()
.map(|v| v.as_f64().unwrap_or_default() as f32)
.collect();
all_embeddings.push(embedding);
}
Ok(all_embeddings)
}
}