use std::time::Duration;
use serde::{Deserialize, Serialize};
use super::{EmbedError, EmbedResult, Embedder};
pub const OLLAMA_BACKEND_ID_PREFIX: &str = "ollama";
pub const DEFAULT_OLLAMA_ENDPOINT: &str = "http://localhost:11434";
pub const DEFAULT_OLLAMA_EMBED_MODEL: &str = "nomic-embed-text";
pub const NOMIC_EMBED_DIM: usize = 768;
const DEFAULT_TIMEOUT_MS: u64 = 30_000;
#[derive(Debug, Serialize)]
struct EmbedRequest<'a> {
model: &'a str,
prompt: &'a str,
}
#[derive(Debug, Deserialize)]
struct EmbedResponse {
embedding: Vec<f64>,
}
fn is_loopback_endpoint(endpoint: &str) -> bool {
let without_scheme = endpoint
.strip_prefix("https://")
.or_else(|| endpoint.strip_prefix("http://"))
.unwrap_or(endpoint);
let host_port = without_scheme.split('/').next().unwrap_or(without_scheme);
let host = if host_port.starts_with('[') {
host_port
.trim_start_matches('[')
.split(']')
.next()
.unwrap_or(host_port)
} else {
host_port.split(':').next().unwrap_or(host_port)
};
host.eq_ignore_ascii_case("localhost")
|| host == "127.0.0.1"
|| host.starts_with("127.")
|| host == "::1"
}
#[derive(Debug, Clone)]
pub struct OllamaEmbedder {
endpoint: String,
model: String,
dim: usize,
backend_id: String,
timeout_ms: u64,
}
impl OllamaEmbedder {
pub fn new(
endpoint: impl Into<String>,
model: impl Into<String>,
dim: usize,
) -> EmbedResult<Self> {
let endpoint = endpoint.into();
let model = model.into();
if endpoint.trim().is_empty() {
return Err(EmbedError::InvalidInput(
"OllamaEmbedder: endpoint must not be empty".to_string(),
));
}
if model.trim().is_empty() {
return Err(EmbedError::InvalidInput(
"OllamaEmbedder: model must not be empty".to_string(),
));
}
if dim == 0 {
return Err(EmbedError::InvalidInput(
"OllamaEmbedder: dim must be > 0".to_string(),
));
}
if !is_loopback_endpoint(&endpoint) {
return Err(EmbedError::InvalidInput(format!(
"OllamaEmbedder: endpoint must be loopback-only (localhost/127.0.0.1/::1), got `{endpoint}`"
)));
}
let backend_id = format!("{OLLAMA_BACKEND_ID_PREFIX}:{model}:{dim}");
Ok(Self {
endpoint,
model,
dim,
backend_id,
timeout_ms: DEFAULT_TIMEOUT_MS,
})
}
pub fn default_nomic() -> EmbedResult<Self> {
Self::new(
DEFAULT_OLLAMA_ENDPOINT,
DEFAULT_OLLAMA_EMBED_MODEL,
NOMIC_EMBED_DIM,
)
}
#[must_use]
pub fn with_timeout_ms(mut self, ms: u64) -> Self {
self.timeout_ms = ms;
self
}
pub fn backend_id_for(model: &str, dim: usize) -> String {
format!("{OLLAMA_BACKEND_ID_PREFIX}:{model}:{dim}")
}
}
impl Embedder for OllamaEmbedder {
fn backend_id(&self) -> &str {
&self.backend_id
}
fn dim(&self) -> usize {
self.dim
}
fn embed(&self, text: &str, tags: &[String]) -> EmbedResult<Vec<f32>> {
let prompt = if tags.is_empty() {
text.to_string()
} else {
format!("{text} | {}", tags.join(" "))
};
let url = format!("{}/api/embeddings", self.endpoint);
let body = EmbedRequest {
model: &self.model,
prompt: &prompt,
};
let timeout = Duration::from_millis(self.timeout_ms);
let agent = ureq::AgentBuilder::new().timeout(timeout).build();
let body_json = serde_json::to_value(&body)
.map_err(|e| EmbedError::Backend(format!("request serialization failed: {e}")))?;
let response = agent
.post(&url)
.send_json(body_json)
.map_err(|err| EmbedError::Backend(format!("Ollama HTTP error: {err}")))?;
if response.status() != 200 {
let status = response.status();
return Err(EmbedError::Backend(format!(
"Ollama returned HTTP {status}"
)));
}
let response_text = response
.into_string()
.map_err(|e| EmbedError::Backend(format!("reading Ollama response body: {e}")))?;
let parsed: EmbedResponse = serde_json::from_str(&response_text)
.map_err(|e| EmbedError::Backend(format!("Ollama response parse: {e}")))?;
let vector: Vec<f32> = parsed.embedding.iter().map(|&v| v as f32).collect();
if vector.len() != self.dim {
return Err(EmbedError::DimensionMismatch {
backend_id: self.backend_id.clone(),
expected: self.dim,
actual: vector.len(),
});
}
Ok(vector)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn constructor_rejects_empty_endpoint() {
let err = OllamaEmbedder::new("", "nomic-embed-text", 768).unwrap_err();
assert!(
matches!(err, EmbedError::InvalidInput(_)),
"expected InvalidInput, got {err:?}"
);
}
#[test]
fn constructor_rejects_empty_model() {
let err = OllamaEmbedder::new("http://localhost:11434", "", 768).unwrap_err();
assert!(
matches!(err, EmbedError::InvalidInput(_)),
"expected InvalidInput, got {err:?}"
);
}
#[test]
fn constructor_rejects_zero_dim() {
let err = OllamaEmbedder::new("http://localhost:11434", "nomic-embed-text", 0).unwrap_err();
assert!(
matches!(err, EmbedError::InvalidInput(_)),
"expected InvalidInput, got {err:?}"
);
}
#[test]
fn backend_id_encodes_model_and_dim() {
let e = OllamaEmbedder::new("http://localhost:11434", "nomic-embed-text", 768).unwrap();
assert_eq!(e.backend_id(), "ollama:nomic-embed-text:768");
assert_eq!(e.dim(), 768);
}
#[test]
fn backend_id_for_matches_instance() {
let id = OllamaEmbedder::backend_id_for("nomic-embed-text", 768);
let e = OllamaEmbedder::default_nomic().unwrap();
assert_eq!(id, e.backend_id());
}
#[test]
fn default_nomic_has_expected_backend_id() {
let e = OllamaEmbedder::default_nomic().unwrap();
assert_eq!(e.backend_id(), "ollama:nomic-embed-text:768");
assert_eq!(e.dim(), NOMIC_EMBED_DIM);
}
#[test]
fn with_timeout_ms_overrides_default() {
let e = OllamaEmbedder::default_nomic()
.unwrap()
.with_timeout_ms(5_000);
assert_eq!(e.timeout_ms, 5_000);
}
}