openmemory 0.1.1

OpenMemory - Cognitive memory system for AI applications
Documentation
//! Ollama embedding provider
//!
//! Uses a local Ollama server to generate vector embeddings.

use crate::core::config::Config;
use crate::core::error::{Error, Result};
use crate::core::types::{EmbeddingResult, Sector};
use crate::memory::embed::{resize_vector, EmbeddingProvider};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};

/// Ollama embedding provider
pub struct OllamaProvider {
    url: String,
    dim: usize,
    client: reqwest::Client,
}

impl OllamaProvider {
    /// Create a new Ollama provider
    pub fn new(config: &Config) -> Self {
        Self {
            url: config.ollama_url.clone(),
            dim: config.vec_dim,
            client: reqwest::Client::new(),
        }
    }

    /// Get the model name for a sector
    fn model_for_sector(&self, sector: &Sector) -> &str {
        match sector {
            Sector::Episodic => "nomic-embed-text",
            Sector::Semantic => "nomic-embed-text",
            Sector::Procedural => "nomic-embed-text",
            Sector::Emotional => "nomic-embed-text",
            Sector::Reflective => "nomic-embed-text",
        }
    }
}

#[derive(Serialize)]
struct EmbeddingRequest {
    model: String,
    prompt: String,
}

#[derive(Deserialize)]
struct EmbeddingResponse {
    embedding: Vec<f32>,
}

#[async_trait]
impl EmbeddingProvider for OllamaProvider {
    async fn embed(&self, text: &str, sector: &Sector) -> Result<EmbeddingResult> {
        let url = format!("{}/api/embeddings", self.url.trim_end_matches('/'));
        let model = self.model_for_sector(sector);

        let request = EmbeddingRequest {
            model: model.to_string(),
            prompt: text.to_string(),
        };

        let response = self
            .client
            .post(&url)
            .header("Content-Type", "application/json")
            .json(&request)
            .send()
            .await
            .map_err(|e| Error::embedding(format!("Ollama connection failed: {}", e)))?;

        if !response.status().is_success() {
            let status = response.status();
            let body = response.text().await.unwrap_or_default();
            return Err(Error::embedding(format!(
                "Ollama API error {}: {}",
                status, body
            )));
        }

        let data: EmbeddingResponse = response.json().await.map_err(|e| {
            Error::embedding(format!("Failed to parse Ollama response: {}", e))
        })?;

        let vector = resize_vector(&data.embedding, self.dim);

        Ok(EmbeddingResult {
            sector: *sector,
            vector: vector.clone(),
            dim: vector.len(),
        })
    }

    fn dimensions(&self) -> usize {
        self.dim
    }

    fn name(&self) -> &'static str {
        "ollama"
    }

    fn supports_batch(&self) -> bool {
        false // Ollama doesn't have a batch API
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_provider_creation() {
        let config = Config::default();
        let provider = OllamaProvider::new(&config);

        assert_eq!(provider.name(), "ollama");
        assert!(!provider.supports_batch());
        assert_eq!(provider.url, "http://localhost:11434");
    }

    #[test]
    fn test_model_for_sector() {
        let config = Config::default();
        let provider = OllamaProvider::new(&config);

        assert_eq!(provider.model_for_sector(&Sector::Semantic), "nomic-embed-text");
        assert_eq!(provider.model_for_sector(&Sector::Episodic), "nomic-embed-text");
    }
}