openmemory 0.1.1

OpenMemory - Cognitive memory system for AI applications
Documentation
//! Google Gemini embedding provider
//!
//! Uses the Gemini API to generate embeddings with task-specific types.

use crate::core::config::Config;
use crate::core::error::{Error, Result};
use crate::core::types::{EmbeddingResult, Sector};
use crate::memory::embed::{resize_vector, EmbeddingProvider};
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};

/// Gemini embedding provider
pub struct GeminiProvider {
    client: Client,
    api_key: String,
    dim: usize,
    model: String,
}

impl GeminiProvider {
    /// Create a new Gemini provider
    pub fn new(config: &Config) -> Self {
        Self {
            client: Client::new(),
            api_key: config.gemini_key.clone().unwrap_or_default(),
            dim: config.vec_dim,
            model: "text-embedding-004".to_string(),
        }
    }

    /// Get task type for sector (Gemini-specific)
    fn task_type_for_sector(sector: &Sector) -> &'static str {
        match sector {
            Sector::Episodic => "RETRIEVAL_DOCUMENT",
            Sector::Semantic => "SEMANTIC_SIMILARITY",
            Sector::Procedural => "RETRIEVAL_DOCUMENT",
            Sector::Emotional => "CLASSIFICATION",
            Sector::Reflective => "SEMANTIC_SIMILARITY",
        }
    }
}

#[derive(Serialize)]
struct BatchEmbedRequest {
    requests: Vec<EmbedContentRequest>,
}

#[derive(Serialize)]
struct EmbedContentRequest {
    model: String,
    content: Content,
    #[serde(rename = "taskType")]
    task_type: String,
}

#[derive(Serialize)]
struct Content {
    parts: Vec<Part>,
}

#[derive(Serialize)]
struct Part {
    text: String,
}

#[derive(Deserialize)]
struct BatchEmbedResponse {
    embeddings: Vec<EmbeddingData>,
}

#[derive(Deserialize)]
struct EmbeddingData {
    values: Vec<f32>,
}

#[async_trait]
impl EmbeddingProvider for GeminiProvider {
    async fn embed(&self, text: &str, sector: &Sector) -> Result<EmbeddingResult> {
        if self.api_key.is_empty() {
            return Err(Error::config("Gemini API key not configured"));
        }

        let url = format!(
            "https://generativelanguage.googleapis.com/v1beta/models/{}:batchEmbedContents?key={}",
            self.model, self.api_key
        );

        let request = BatchEmbedRequest {
            requests: vec![EmbedContentRequest {
                model: format!("models/{}", self.model),
                content: Content {
                    parts: vec![Part {
                        text: text.to_string(),
                    }],
                },
                task_type: Self::task_type_for_sector(sector).to_string(),
            }],
        };

        let response = self
            .client
            .post(&url)
            .header("Content-Type", "application/json")
            .json(&request)
            .send()
            .await?;

        if !response.status().is_success() {
            let status = response.status();
            let body = response.text().await.unwrap_or_default();

            if status.as_u16() == 429 {
                return Err(Error::RateLimit {
                    retry_after_secs: 2,
                });
            }

            return Err(Error::embedding(format!(
                "Gemini API error {}: {}",
                status, body
            )));
        }

        let data: BatchEmbedResponse = response.json().await?;

        let vector = data
            .embeddings
            .first()
            .map(|e| resize_vector(&e.values, self.dim))
            .unwrap_or_else(|| vec![0.0; self.dim]);

        Ok(EmbeddingResult {
            sector: *sector,
            vector: vector.clone(),
            dim: vector.len(),
        })
    }

    async fn embed_batch(&self, texts: &[(&str, &Sector)]) -> Result<Vec<EmbeddingResult>> {
        if self.api_key.is_empty() {
            return Err(Error::config("Gemini API key not configured"));
        }

        if texts.is_empty() {
            return Ok(Vec::new());
        }

        let url = format!(
            "https://generativelanguage.googleapis.com/v1beta/models/{}:batchEmbedContents?key={}",
            self.model, self.api_key
        );

        let requests: Vec<EmbedContentRequest> = texts
            .iter()
            .map(|(text, sector)| EmbedContentRequest {
                model: format!("models/{}", self.model),
                content: Content {
                    parts: vec![Part {
                        text: text.to_string(),
                    }],
                },
                task_type: Self::task_type_for_sector(sector).to_string(),
            })
            .collect();

        let request = BatchEmbedRequest { requests };

        // Retry logic with exponential backoff
        let max_retries = 3;
        let mut last_error = None;

        for attempt in 0..max_retries {
            let response = self
                .client
                .post(&url)
                .header("Content-Type", "application/json")
                .json(&request)
                .send()
                .await?;

            if response.status().is_success() {
                let data: BatchEmbedResponse = response.json().await?;

                let results: Vec<EmbeddingResult> = data
                    .embeddings
                    .into_iter()
                    .zip(texts.iter())
                    .map(|(emb, (_, sector))| {
                        let vector = resize_vector(&emb.values, self.dim);
                        EmbeddingResult {
                            sector: **sector,
                            vector: vector.clone(),
                            dim: vector.len(),
                        }
                    })
                    .collect();

                return Ok(results);
            }

            let status = response.status();

            if status.as_u16() == 429 {
                // Rate limited - wait and retry
                let delay = std::time::Duration::from_millis(1000 * 2_u64.pow(attempt as u32));
                tokio::time::sleep(delay).await;
                continue;
            }

            let body = response.text().await.unwrap_or_default();
            last_error = Some(Error::embedding(format!(
                "Gemini API error {}: {}",
                status, body
            )));
            break;
        }

        Err(last_error.unwrap_or_else(|| Error::embedding("Gemini API failed after retries")))
    }

    fn dimensions(&self) -> usize {
        self.dim
    }

    fn name(&self) -> &'static str {
        "gemini"
    }

    fn supports_batch(&self) -> bool {
        true
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_provider_creation() {
        let config = Config::default();
        let provider = GeminiProvider::new(&config);

        assert_eq!(provider.name(), "gemini");
        assert!(provider.supports_batch());
    }

    #[test]
    fn test_task_type_mapping() {
        assert_eq!(
            GeminiProvider::task_type_for_sector(&Sector::Episodic),
            "RETRIEVAL_DOCUMENT"
        );
        assert_eq!(
            GeminiProvider::task_type_for_sector(&Sector::Semantic),
            "SEMANTIC_SIMILARITY"
        );
        assert_eq!(
            GeminiProvider::task_type_for_sector(&Sector::Emotional),
            "CLASSIFICATION"
        );
    }
}