use async_trait::async_trait;
use reqwest::Client;
use serde_json::{json, Value};
use cognis_core::embeddings::Embeddings;
use cognis_core::error::{CognisError, Result};
#[derive(Debug)]
pub struct OllamaEmbeddingsBuilder {
model: Option<String>,
base_url: Option<String>,
}
impl OllamaEmbeddingsBuilder {
pub fn new() -> Self {
Self {
model: None,
base_url: None,
}
}
pub fn model(mut self, model: impl Into<String>) -> Self {
self.model = Some(model.into());
self
}
pub fn base_url(mut self, url: impl Into<String>) -> Self {
self.base_url = Some(url.into());
self
}
pub fn build(self) -> OllamaEmbeddings {
OllamaEmbeddings {
model: self.model.unwrap_or_else(|| "nomic-embed-text".into()),
base_url: self
.base_url
.unwrap_or_else(|| "http://localhost:11434".into()),
client: Client::new(),
}
}
}
impl Default for OllamaEmbeddingsBuilder {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct OllamaEmbeddings {
pub model: String,
pub base_url: String,
client: Client,
}
impl OllamaEmbeddings {
pub fn builder() -> OllamaEmbeddingsBuilder {
OllamaEmbeddingsBuilder::new()
}
fn build_payload(&self, texts: &[String]) -> Value {
json!({
"model": self.model,
"input": texts,
})
}
async fn call_api(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
let url = format!("{}/api/embed", self.base_url);
let payload = self.build_payload(&texts);
let response = self
.client
.post(&url)
.header("Content-Type", "application/json")
.json(&payload)
.send()
.await
.map_err(|e| CognisError::Other(format!("HTTP request failed: {}", e)))?;
let status = response.status().as_u16();
if !(200..300).contains(&status) {
let body = response.text().await.unwrap_or_default();
return Err(CognisError::HttpError { status, body });
}
let body: Value = response
.json()
.await
.map_err(|e| CognisError::Other(format!("Failed to parse response JSON: {}", e)))?;
let embeddings_arr = body
.get("embeddings")
.and_then(|v| v.as_array())
.ok_or_else(|| {
CognisError::Other("Missing 'embeddings' array in Ollama embed response".into())
})?;
let mut embeddings: Vec<Vec<f32>> = Vec::with_capacity(embeddings_arr.len());
for item in embeddings_arr {
let vec_arr = item
.as_array()
.ok_or_else(|| CognisError::Other("Expected array for embedding vector".into()))?;
let vec: Vec<f32> = vec_arr
.iter()
.map(|v| {
v.as_f64().map(|f| f as f32).ok_or_else(|| {
CognisError::Other("Non-numeric value in embedding array".into())
})
})
.collect::<Result<Vec<f32>>>()?;
embeddings.push(vec);
}
Ok(embeddings)
}
}
#[async_trait]
impl Embeddings for OllamaEmbeddings {
async fn embed_documents(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
self.call_api(texts).await
}
async fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
let results = self.embed_documents(vec![text.to_string()]).await?;
results
.into_iter()
.next()
.ok_or_else(|| CognisError::Other("Empty embedding response for query".into()))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_builder_defaults() {
let embeddings = OllamaEmbeddings::builder().build();
assert_eq!(embeddings.model, "nomic-embed-text");
assert_eq!(embeddings.base_url, "http://localhost:11434");
}
#[test]
fn test_builder_custom_values() {
let embeddings = OllamaEmbeddings::builder()
.model("mxbai-embed-large")
.base_url("http://remote-host:11434")
.build();
assert_eq!(embeddings.model, "mxbai-embed-large");
assert_eq!(embeddings.base_url, "http://remote-host:11434");
}
#[test]
fn test_build_payload() {
let embeddings = OllamaEmbeddings::builder().build();
let texts = vec!["hello".to_string(), "world".to_string()];
let payload = embeddings.build_payload(&texts);
assert_eq!(payload["model"], "nomic-embed-text");
assert_eq!(payload["input"], json!(["hello", "world"]));
}
#[tokio::test]
async fn test_embed_documents_empty() {
let embeddings = OllamaEmbeddings::builder().build();
let result = embeddings.embed_documents(vec![]).await.unwrap();
assert!(result.is_empty());
}
#[test]
fn test_debug_output() {
let embeddings = OllamaEmbeddings::builder()
.model("test-model")
.base_url("http://test:11434")
.build();
let debug_str = format!("{:?}", embeddings);
assert!(debug_str.contains("OllamaEmbeddings"));
assert!(debug_str.contains("test-model"));
assert!(debug_str.contains("http://test:11434"));
}
}