use async_trait::async_trait;
use tracing::{debug, error};
use adk_gemini::{EmbedBuilder, Gemini, Model, TaskType};
use crate::embedding::EmbeddingProvider;
use crate::error::{RagError, Result};
pub struct GeminiEmbeddingProvider {
client: Gemini,
task_type: TaskType,
output_dimensionality: Option<i32>,
dimensions: usize,
}
impl GeminiEmbeddingProvider {
const DEFAULT_DIMENSIONS: usize = 3072;
pub fn new(api_key: impl AsRef<str>) -> Result<Self> {
let client = Gemini::with_model(api_key, Model::GeminiEmbedding001).map_err(|e| {
RagError::EmbeddingError {
provider: "Gemini".into(),
message: format!("failed to create Gemini client: {e}"),
}
})?;
Ok(Self {
client,
task_type: TaskType::RetrievalDocument,
output_dimensionality: None,
dimensions: Self::DEFAULT_DIMENSIONS,
})
}
pub fn from_client(client: Gemini) -> Self {
Self {
client,
task_type: TaskType::RetrievalDocument,
output_dimensionality: None,
dimensions: Self::DEFAULT_DIMENSIONS,
}
}
pub fn with_task_type(mut self, task_type: TaskType) -> Self {
self.task_type = task_type;
self
}
pub fn with_output_dimensionality(mut self, dims: i32) -> Self {
self.output_dimensionality = Some(dims);
self.dimensions = dims as usize;
self
}
fn embed_builder(&self) -> EmbedBuilder {
let mut builder = self.client.embed_content().with_task_type(self.task_type.clone());
if let Some(dims) = self.output_dimensionality {
builder = builder.with_output_dimensionality(dims);
}
builder
}
}
#[async_trait]
impl EmbeddingProvider for GeminiEmbeddingProvider {
async fn embed(&self, text: &str) -> Result<Vec<f32>> {
debug!(provider = "Gemini", text_len = text.len(), "embedding single text");
let response = self.embed_builder().with_text(text).execute().await.map_err(|e| {
error!(provider = "Gemini", error = %e, "embedding request failed");
RagError::EmbeddingError { provider: "Gemini".into(), message: format!("{e}") }
})?;
Ok(response.embedding.values)
}
async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
debug!(provider = "Gemini", batch_size = texts.len(), "embedding batch");
let response = self
.embed_builder()
.with_chunks(texts.iter().map(|t| t.to_string()).collect())
.execute_batch()
.await
.map_err(|e| {
error!(provider = "Gemini", error = %e, "batch embedding request failed");
RagError::EmbeddingError { provider: "Gemini".into(), message: format!("{e}") }
})?;
Ok(response.embeddings.into_iter().map(|e| e.values).collect())
}
fn dimensions(&self) -> usize {
self.dimensions
}
}