use crate::core::config::Config;
use crate::core::error::{Error, Result};
use crate::core::types::{EmbeddingResult, Sector};
use crate::memory::embed::{resize_vector, EmbeddingProvider};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
pub struct OllamaProvider {
url: String,
dim: usize,
client: reqwest::Client,
}
impl OllamaProvider {
pub fn new(config: &Config) -> Self {
Self {
url: config.ollama_url.clone(),
dim: config.vec_dim,
client: reqwest::Client::new(),
}
}
fn model_for_sector(&self, sector: &Sector) -> &str {
match sector {
Sector::Episodic => "nomic-embed-text",
Sector::Semantic => "nomic-embed-text",
Sector::Procedural => "nomic-embed-text",
Sector::Emotional => "nomic-embed-text",
Sector::Reflective => "nomic-embed-text",
}
}
}
#[derive(Serialize)]
struct EmbeddingRequest {
model: String,
prompt: String,
}
#[derive(Deserialize)]
struct EmbeddingResponse {
embedding: Vec<f32>,
}
#[async_trait]
impl EmbeddingProvider for OllamaProvider {
async fn embed(&self, text: &str, sector: &Sector) -> Result<EmbeddingResult> {
let url = format!("{}/api/embeddings", self.url.trim_end_matches('/'));
let model = self.model_for_sector(sector);
let request = EmbeddingRequest {
model: model.to_string(),
prompt: text.to_string(),
};
let response = self
.client
.post(&url)
.header("Content-Type", "application/json")
.json(&request)
.send()
.await
.map_err(|e| Error::embedding(format!("Ollama connection failed: {}", e)))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(Error::embedding(format!(
"Ollama API error {}: {}",
status, body
)));
}
let data: EmbeddingResponse = response.json().await.map_err(|e| {
Error::embedding(format!("Failed to parse Ollama response: {}", e))
})?;
let vector = resize_vector(&data.embedding, self.dim);
Ok(EmbeddingResult {
sector: *sector,
vector: vector.clone(),
dim: vector.len(),
})
}
fn dimensions(&self) -> usize {
self.dim
}
fn name(&self) -> &'static str {
"ollama"
}
fn supports_batch(&self) -> bool {
false }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_provider_creation() {
let config = Config::default();
let provider = OllamaProvider::new(&config);
assert_eq!(provider.name(), "ollama");
assert!(!provider.supports_batch());
assert_eq!(provider.url, "http://localhost:11434");
}
#[test]
fn test_model_for_sector() {
let config = Config::default();
let provider = OllamaProvider::new(&config);
assert_eq!(provider.model_for_sector(&Sector::Semantic), "nomic-embed-text");
assert_eq!(provider.model_for_sector(&Sector::Episodic), "nomic-embed-text");
}
}