use crate::core::config::Config;
use crate::core::error::{Error, Result};
use crate::core::types::{EmbeddingResult, Sector};
use crate::memory::embed::{resize_vector, EmbeddingProvider};
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
pub struct GeminiProvider {
client: Client,
api_key: String,
dim: usize,
model: String,
}
impl GeminiProvider {
pub fn new(config: &Config) -> Self {
Self {
client: Client::new(),
api_key: config.gemini_key.clone().unwrap_or_default(),
dim: config.vec_dim,
model: "text-embedding-004".to_string(),
}
}
fn task_type_for_sector(sector: &Sector) -> &'static str {
match sector {
Sector::Episodic => "RETRIEVAL_DOCUMENT",
Sector::Semantic => "SEMANTIC_SIMILARITY",
Sector::Procedural => "RETRIEVAL_DOCUMENT",
Sector::Emotional => "CLASSIFICATION",
Sector::Reflective => "SEMANTIC_SIMILARITY",
}
}
}
#[derive(Serialize)]
struct BatchEmbedRequest {
requests: Vec<EmbedContentRequest>,
}
#[derive(Serialize)]
struct EmbedContentRequest {
model: String,
content: Content,
#[serde(rename = "taskType")]
task_type: String,
}
#[derive(Serialize)]
struct Content {
parts: Vec<Part>,
}
#[derive(Serialize)]
struct Part {
text: String,
}
#[derive(Deserialize)]
struct BatchEmbedResponse {
embeddings: Vec<EmbeddingData>,
}
#[derive(Deserialize)]
struct EmbeddingData {
values: Vec<f32>,
}
#[async_trait]
impl EmbeddingProvider for GeminiProvider {
async fn embed(&self, text: &str, sector: &Sector) -> Result<EmbeddingResult> {
if self.api_key.is_empty() {
return Err(Error::config("Gemini API key not configured"));
}
let url = format!(
"https://generativelanguage.googleapis.com/v1beta/models/{}:batchEmbedContents?key={}",
self.model, self.api_key
);
let request = BatchEmbedRequest {
requests: vec![EmbedContentRequest {
model: format!("models/{}", self.model),
content: Content {
parts: vec![Part {
text: text.to_string(),
}],
},
task_type: Self::task_type_for_sector(sector).to_string(),
}],
};
let response = self
.client
.post(&url)
.header("Content-Type", "application/json")
.json(&request)
.send()
.await?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
if status.as_u16() == 429 {
return Err(Error::RateLimit {
retry_after_secs: 2,
});
}
return Err(Error::embedding(format!(
"Gemini API error {}: {}",
status, body
)));
}
let data: BatchEmbedResponse = response.json().await?;
let vector = data
.embeddings
.first()
.map(|e| resize_vector(&e.values, self.dim))
.unwrap_or_else(|| vec![0.0; self.dim]);
Ok(EmbeddingResult {
sector: *sector,
vector: vector.clone(),
dim: vector.len(),
})
}
async fn embed_batch(&self, texts: &[(&str, &Sector)]) -> Result<Vec<EmbeddingResult>> {
if self.api_key.is_empty() {
return Err(Error::config("Gemini API key not configured"));
}
if texts.is_empty() {
return Ok(Vec::new());
}
let url = format!(
"https://generativelanguage.googleapis.com/v1beta/models/{}:batchEmbedContents?key={}",
self.model, self.api_key
);
let requests: Vec<EmbedContentRequest> = texts
.iter()
.map(|(text, sector)| EmbedContentRequest {
model: format!("models/{}", self.model),
content: Content {
parts: vec![Part {
text: text.to_string(),
}],
},
task_type: Self::task_type_for_sector(sector).to_string(),
})
.collect();
let request = BatchEmbedRequest { requests };
let max_retries = 3;
let mut last_error = None;
for attempt in 0..max_retries {
let response = self
.client
.post(&url)
.header("Content-Type", "application/json")
.json(&request)
.send()
.await?;
if response.status().is_success() {
let data: BatchEmbedResponse = response.json().await?;
let results: Vec<EmbeddingResult> = data
.embeddings
.into_iter()
.zip(texts.iter())
.map(|(emb, (_, sector))| {
let vector = resize_vector(&emb.values, self.dim);
EmbeddingResult {
sector: **sector,
vector: vector.clone(),
dim: vector.len(),
}
})
.collect();
return Ok(results);
}
let status = response.status();
if status.as_u16() == 429 {
let delay = std::time::Duration::from_millis(1000 * 2_u64.pow(attempt as u32));
tokio::time::sleep(delay).await;
continue;
}
let body = response.text().await.unwrap_or_default();
last_error = Some(Error::embedding(format!(
"Gemini API error {}: {}",
status, body
)));
break;
}
Err(last_error.unwrap_or_else(|| Error::embedding("Gemini API failed after retries")))
}
fn dimensions(&self) -> usize {
self.dim
}
fn name(&self) -> &'static str {
"gemini"
}
fn supports_batch(&self) -> bool {
true
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_provider_creation() {
let config = Config::default();
let provider = GeminiProvider::new(&config);
assert_eq!(provider.name(), "gemini");
assert!(provider.supports_batch());
}
#[test]
fn test_task_type_mapping() {
assert_eq!(
GeminiProvider::task_type_for_sector(&Sector::Episodic),
"RETRIEVAL_DOCUMENT"
);
assert_eq!(
GeminiProvider::task_type_for_sector(&Sector::Semantic),
"SEMANTIC_SIMILARITY"
);
assert_eq!(
GeminiProvider::task_type_for_sector(&Sector::Emotional),
"CLASSIFICATION"
);
}
}