Skip to main content

synaptic_embeddings/
ollama.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use serde_json::json;
5use synaptic_core::SynapticError;
6use synaptic_models::backend::{ProviderBackend, ProviderRequest};
7
8use crate::Embeddings;
9
10pub struct OllamaEmbeddingsConfig {
11    pub model: String,
12    pub base_url: String,
13}
14
15impl OllamaEmbeddingsConfig {
16    pub fn new(model: impl Into<String>) -> Self {
17        Self {
18            model: model.into(),
19            base_url: "http://localhost:11434".to_string(),
20        }
21    }
22
23    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
24        self.base_url = base_url.into();
25        self
26    }
27}
28
29pub struct OllamaEmbeddings {
30    config: OllamaEmbeddingsConfig,
31    backend: Arc<dyn ProviderBackend>,
32}
33
34impl OllamaEmbeddings {
35    pub fn new(config: OllamaEmbeddingsConfig, backend: Arc<dyn ProviderBackend>) -> Self {
36        Self { config, backend }
37    }
38}
39
40#[async_trait]
41impl Embeddings for OllamaEmbeddings {
42    async fn embed_documents(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, SynapticError> {
43        let mut all_embeddings = Vec::with_capacity(texts.len());
44        for text in texts {
45            let embedding = self.embed_query(text).await?;
46            all_embeddings.push(embedding);
47        }
48        Ok(all_embeddings)
49    }
50
51    async fn embed_query(&self, text: &str) -> Result<Vec<f32>, SynapticError> {
52        let request = ProviderRequest {
53            url: format!("{}/api/embed", self.config.base_url),
54            headers: vec![("Content-Type".to_string(), "application/json".to_string())],
55            body: json!({
56                "model": self.config.model,
57                "input": text,
58            }),
59        };
60
61        let response = self.backend.send(request).await?;
62
63        if response.status != 200 {
64            return Err(SynapticError::Embedding(format!(
65                "Ollama API error ({}): {}",
66                response.status, response.body
67            )));
68        }
69
70        let embeddings = response
71            .body
72            .get("embeddings")
73            .and_then(|e| e.as_array())
74            .and_then(|arr| arr.first())
75            .and_then(|e| e.as_array())
76            .ok_or_else(|| SynapticError::Embedding("missing 'embeddings' field".to_string()))?;
77
78        Ok(embeddings
79            .iter()
80            .map(|v| v.as_f64().unwrap_or(0.0) as f32)
81            .collect())
82    }
83}