synaptic_embeddings/
ollama.rs1use std::sync::Arc;
2
3use async_trait::async_trait;
4use serde_json::json;
5use synaptic_core::SynapseError;
6use synaptic_models::backend::{ProviderBackend, ProviderRequest};
7
8use crate::Embeddings;
9
10pub struct OllamaEmbeddingsConfig {
11 pub model: String,
12 pub base_url: String,
13}
14
15impl OllamaEmbeddingsConfig {
16 pub fn new(model: impl Into<String>) -> Self {
17 Self {
18 model: model.into(),
19 base_url: "http://localhost:11434".to_string(),
20 }
21 }
22
23 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
24 self.base_url = base_url.into();
25 self
26 }
27}
28
29pub struct OllamaEmbeddings {
30 config: OllamaEmbeddingsConfig,
31 backend: Arc<dyn ProviderBackend>,
32}
33
34impl OllamaEmbeddings {
35 pub fn new(config: OllamaEmbeddingsConfig, backend: Arc<dyn ProviderBackend>) -> Self {
36 Self { config, backend }
37 }
38}
39
40#[async_trait]
41impl Embeddings for OllamaEmbeddings {
42 async fn embed_documents(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, SynapseError> {
43 let mut all_embeddings = Vec::with_capacity(texts.len());
44 for text in texts {
45 let embedding = self.embed_query(text).await?;
46 all_embeddings.push(embedding);
47 }
48 Ok(all_embeddings)
49 }
50
51 async fn embed_query(&self, text: &str) -> Result<Vec<f32>, SynapseError> {
52 let request = ProviderRequest {
53 url: format!("{}/api/embed", self.config.base_url),
54 headers: vec![("Content-Type".to_string(), "application/json".to_string())],
55 body: json!({
56 "model": self.config.model,
57 "input": text,
58 }),
59 };
60
61 let response = self.backend.send(request).await?;
62
63 if response.status != 200 {
64 return Err(SynapseError::Embedding(format!(
65 "Ollama API error ({}): {}",
66 response.status, response.body
67 )));
68 }
69
70 let embeddings = response
71 .body
72 .get("embeddings")
73 .and_then(|e| e.as_array())
74 .and_then(|arr| arr.first())
75 .and_then(|e| e.as_array())
76 .ok_or_else(|| SynapseError::Embedding("missing 'embeddings' field".to_string()))?;
77
78 Ok(embeddings
79 .iter()
80 .map(|v| v.as_f64().unwrap_or(0.0) as f32)
81 .collect())
82 }
83}