use std::time::Duration;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use solo_core::{Embedder, Embedding, EmbeddingDtype, Error, Result};
use crate::llm::retry::{
RetryConfig, exp_backoff_with_jitter, is_retryable_reqwest_err, is_retryable_status,
parse_retry_after,
};
const DEFAULT_BASE_URL: &str = "http://localhost:11434";
const EMBEDDINGS_PATH: &str = "/api/embeddings";
const DEFAULT_TIMEOUT_SECS: u64 = 60;
const EMBEDDER_VERSION: &str = "v1";
pub const DEFAULT_OLLAMA_MODEL: &str = "nomic-embed-text";
pub const DEFAULT_OLLAMA_DIM: usize = 768;
#[derive(Clone)]
pub struct OllamaEmbedder {
http: reqwest::Client,
base_url: String,
model: String,
dim: usize,
retry: RetryConfig,
display_name: String,
}
impl OllamaEmbedder {
pub fn new(
base_url: impl Into<String>,
model: impl Into<String>,
dim: usize,
) -> Result<Self> {
let http = reqwest::Client::builder()
.timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS))
.build()
.map_err(|e| Error::embedder(format!("build reqwest client: {e}")))?;
let mut base = base_url.into();
while base.ends_with('/') {
base.pop();
}
let model = model.into();
let display_name = format!("ollama:{model}");
Ok(Self {
http,
base_url: base,
model,
dim,
retry: RetryConfig::default(),
display_name,
})
}
pub fn with_defaults() -> Result<Self> {
Self::new(DEFAULT_BASE_URL, DEFAULT_OLLAMA_MODEL, DEFAULT_OLLAMA_DIM)
}
pub fn with_timeout(mut self, timeout: Duration) -> Result<Self> {
self.http = reqwest::Client::builder()
.timeout(timeout)
.build()
.map_err(|e| Error::embedder(format!("rebuild reqwest client: {e}")))?;
Ok(self)
}
pub fn with_retry_config(mut self, retry: RetryConfig) -> Self {
self.retry = retry;
self
}
pub fn base_url(&self) -> &str {
&self.base_url
}
pub fn model(&self) -> &str {
&self.model
}
pub async fn probe_dim(&self) -> Result<usize> {
let vec = self.embed_one("solo_init_dim_probe").await?;
if vec.is_empty() {
return Err(Error::embedder(
"ollama /api/embeddings returned an empty vector during dim probe",
));
}
Ok(vec.len())
}
async fn embed_one(&self, text: &str) -> Result<Vec<f32>> {
let body = EmbeddingsRequest {
model: &self.model,
prompt: text,
};
let url = format!("{}{}", self.base_url, EMBEDDINGS_PATH);
let mut attempt: u32 = 0;
loop {
let send_res = self
.http
.post(&url)
.header("content-type", "application/json")
.json(&body)
.send()
.await;
match send_res {
Ok(resp) => {
let status = resp.status();
if status.is_success() {
let parsed: EmbeddingsResponse = resp.json().await.map_err(|e| {
Error::embedder(format!("ollama embeddings parse: {e}"))
})?;
if parsed.embedding.is_empty() {
return Err(Error::embedder(
"ollama /api/embeddings returned empty embedding vector",
));
}
return Ok(parsed.embedding);
}
let retry_after_hdr = resp
.headers()
.get("retry-after")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let body_text = resp.text().await.unwrap_or_default();
if attempt < self.retry.max_retries
&& is_retryable_status(status.as_u16())
{
let delay = parse_retry_after(
retry_after_hdr.as_deref(),
self.retry.max_delay,
)
.unwrap_or_else(|| {
exp_backoff_with_jitter(attempt + 1, &self.retry)
});
tracing::warn!(
attempt = attempt + 1,
status = %status,
delay_ms = delay.as_millis() as u64,
"ollama embeddings retryable HTTP error; backing off"
);
tokio::time::sleep(delay).await;
attempt += 1;
continue;
}
return Err(Error::embedder(format!(
"ollama embeddings HTTP {}: {}",
status,
truncate(&body_text, 500)
)));
}
Err(e) => {
if attempt < self.retry.max_retries
&& is_retryable_reqwest_err(&e)
{
let delay = exp_backoff_with_jitter(attempt + 1, &self.retry);
tracing::warn!(
attempt = attempt + 1,
error = %e,
delay_ms = delay.as_millis() as u64,
"ollama embeddings retryable network error; backing off"
);
tokio::time::sleep(delay).await;
attempt += 1;
continue;
}
return Err(Error::embedder(format!(
"ollama embeddings request: {e}"
)));
}
}
}
}
}
#[async_trait]
impl Embedder for OllamaEmbedder {
fn name(&self) -> &str {
&self.display_name
}
fn version(&self) -> &str {
EMBEDDER_VERSION
}
fn dim(&self) -> usize {
self.dim
}
fn dtype(&self) -> EmbeddingDtype {
EmbeddingDtype::F32
}
async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Embedding>> {
if texts.is_empty() {
return Ok(Vec::new());
}
let mut out: Vec<Embedding> = Vec::with_capacity(texts.len());
for text in texts {
let vec = self.embed_one(text).await?;
if vec.len() != self.dim {
return Err(Error::embedder(format!(
"ollama {} produced {} dims, expected {}",
self.model,
vec.len(),
self.dim
)));
}
let mut bytes = Vec::with_capacity(self.dim * 4);
for v in &vec {
bytes.extend_from_slice(&v.to_le_bytes());
}
out.push(Embedding {
dtype: EmbeddingDtype::F32,
dim: self.dim,
data: bytes,
});
}
Ok(out)
}
}
#[derive(Debug, Serialize)]
struct EmbeddingsRequest<'a> {
model: &'a str,
prompt: &'a str,
}
#[derive(Debug, Deserialize)]
struct EmbeddingsResponse {
#[serde(default)]
embedding: Vec<f32>,
}
fn truncate(s: &str, max: usize) -> String {
if s.chars().count() <= max {
s.to_string()
} else {
let mut out: String = s.chars().take(max - 1).collect();
out.push('…');
out
}
}
#[cfg(test)]
mod tests {
use super::*;
use wiremock::matchers::{header, method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
fn fixture_embedding(seed: u32, dim: usize) -> Vec<f32> {
(0..dim)
.map(|i| ((seed.wrapping_add(i as u32)) as f32) * 1e-3)
.collect()
}
fn embedder_for(server: &MockServer, dim: usize) -> OllamaEmbedder {
OllamaEmbedder::new(server.uri(), "nomic-embed-test", dim)
.unwrap()
.with_retry_config(RetryConfig::none())
}
#[tokio::test]
async fn happy_path_returns_embedding_vec() {
let server = MockServer::start().await;
let dim = 8;
let fixture = fixture_embedding(1, dim);
Mock::given(method("POST"))
.and(path("/api/embeddings"))
.and(header("content-type", "application/json"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"embedding": fixture
})))
.expect(1)
.mount(&server)
.await;
let e = embedder_for(&server, dim);
assert_eq!(e.dim(), dim);
assert_eq!(e.dtype(), EmbeddingDtype::F32);
let out = e.embed("hello world").await.expect("embed succeeds");
assert_eq!(out.dim, dim);
assert_eq!(out.dtype, EmbeddingDtype::F32);
assert_eq!(out.data.len(), dim * 4);
let parsed = out.as_f32_slice().expect("F32 slice");
for (i, expected) in fixture.iter().enumerate() {
assert!(
(parsed[i] - expected).abs() < 1e-6,
"dim {i}: got {} expected {}",
parsed[i],
expected
);
}
}
#[tokio::test]
async fn batch_iterates_and_preserves_order() {
let server = MockServer::start().await;
let dim = 4;
let fixture_a = fixture_embedding(10, dim);
let fixture_b = fixture_embedding(20, dim);
let fixture_c = fixture_embedding(30, dim);
Mock::given(method("POST"))
.and(path("/api/embeddings"))
.and(wiremock::matchers::body_partial_json(serde_json::json!({"prompt": "alpha"})))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"embedding": fixture_a
})))
.mount(&server)
.await;
Mock::given(method("POST"))
.and(path("/api/embeddings"))
.and(wiremock::matchers::body_partial_json(serde_json::json!({"prompt": "beta"})))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"embedding": fixture_b
})))
.mount(&server)
.await;
Mock::given(method("POST"))
.and(path("/api/embeddings"))
.and(wiremock::matchers::body_partial_json(serde_json::json!({"prompt": "gamma"})))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"embedding": fixture_c
})))
.mount(&server)
.await;
let e = embedder_for(&server, dim);
let out = e
.embed_batch(&["alpha", "beta", "gamma"])
.await
.expect("batch succeeds");
assert_eq!(out.len(), 3);
let a = out[0].as_f32_slice().unwrap();
let b = out[1].as_f32_slice().unwrap();
let c = out[2].as_f32_slice().unwrap();
assert!((a[0] - fixture_a[0]).abs() < 1e-6, "row 0 first elem");
assert!((b[0] - fixture_b[0]).abs() < 1e-6, "row 1 first elem");
assert!((c[0] - fixture_c[0]).abs() < 1e-6, "row 2 first elem");
assert_ne!(a, b);
assert_ne!(b, c);
}
#[tokio::test]
async fn server_500_retries_then_succeeds() {
let server = MockServer::start().await;
let dim = 4;
let fixture = fixture_embedding(99, dim);
Mock::given(method("POST"))
.and(path("/api/embeddings"))
.respond_with(ResponseTemplate::new(503))
.up_to_n_times(1)
.expect(1)
.mount(&server)
.await;
Mock::given(method("POST"))
.and(path("/api/embeddings"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"embedding": fixture
})))
.expect(1)
.mount(&server)
.await;
let retry = RetryConfig {
max_retries: 2,
base_delay: Duration::from_millis(5),
max_delay: Duration::from_millis(20),
};
let e = OllamaEmbedder::new(server.uri(), "nomic-embed-test", dim)
.unwrap()
.with_retry_config(retry);
let out = e.embed("retry test").await.expect("eventual success");
assert_eq!(out.dim, dim);
let parsed = out.as_f32_slice().unwrap();
assert!((parsed[0] - fixture[0]).abs() < 1e-6);
}
#[tokio::test]
async fn server_500_permanently_fails_after_max_retries() {
let server = MockServer::start().await;
let dim = 4;
Mock::given(method("POST"))
.and(path("/api/embeddings"))
.respond_with(ResponseTemplate::new(500))
.expect(3)
.mount(&server)
.await;
let retry = RetryConfig {
max_retries: 2,
base_delay: Duration::from_millis(5),
max_delay: Duration::from_millis(20),
};
let e = OllamaEmbedder::new(server.uri(), "nomic-embed-test", dim)
.unwrap()
.with_retry_config(retry);
let err = e
.embed("perma fail")
.await
.expect_err("expected error after exhausting retries");
let msg = format!("{err}");
assert!(
msg.contains("ollama embeddings HTTP 500"),
"unexpected error message: {msg}"
);
}
#[tokio::test]
async fn name_returns_ollama_prefixed_model() {
let e = OllamaEmbedder::new("http://localhost:11434", "nomic-embed-text", 768).unwrap();
assert_eq!(e.name(), "ollama:nomic-embed-text");
assert_eq!(e.version(), "v1");
assert_eq!(e.dim(), 768);
assert_eq!(e.dtype(), EmbeddingDtype::F32);
assert_eq!(e.model(), "nomic-embed-text");
assert_eq!(e.base_url(), "http://localhost:11434");
}
#[tokio::test]
async fn with_defaults_matches_locked_roadmap_values() {
let e = OllamaEmbedder::with_defaults().unwrap();
assert_eq!(e.name(), "ollama:nomic-embed-text");
assert_eq!(e.dim(), 768);
assert_eq!(e.base_url(), "http://localhost:11434");
}
#[tokio::test]
async fn base_url_trailing_slashes_are_trimmed() {
let e = OllamaEmbedder::new("http://localhost:11434///", "m", 1).unwrap();
assert_eq!(e.base_url(), "http://localhost:11434");
}
#[tokio::test]
async fn malformed_response_errors_cleanly() {
let server = MockServer::start().await;
let dim = 4;
Mock::given(method("POST"))
.and(path("/api/embeddings"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"not_embedding": [0.1, 0.2, 0.3, 0.4]
})))
.expect(1)
.mount(&server)
.await;
let e = embedder_for(&server, dim);
let err = e
.embed("malformed")
.await
.expect_err("missing embedding field must error, not panic");
let msg = format!("{err}");
assert!(
msg.contains("empty embedding"),
"expected clean empty-vector error, got: {msg}"
);
}
#[tokio::test]
async fn dim_mismatch_surfaces_as_error_not_silent_truncation() {
let server = MockServer::start().await;
let configured_dim = 8;
let server_returned_dim = 4;
let fixture = fixture_embedding(1, server_returned_dim);
Mock::given(method("POST"))
.and(path("/api/embeddings"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"embedding": fixture
})))
.expect(1)
.mount(&server)
.await;
let e = embedder_for(&server, configured_dim);
let err = e
.embed("dim mismatch")
.await
.expect_err("dim mismatch must error");
let msg = format!("{err}");
assert!(
msg.contains("produced 4 dims, expected 8"),
"unexpected dim-mismatch error: {msg}"
);
}
#[tokio::test]
async fn probe_dim_reports_server_returned_length_ignoring_configured_dim() {
let server = MockServer::start().await;
let placeholder_dim = 768;
let actual_dim = 384;
let fixture = fixture_embedding(7, actual_dim);
Mock::given(method("POST"))
.and(path("/api/embeddings"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"embedding": fixture
})))
.expect(1)
.mount(&server)
.await;
let e = embedder_for(&server, placeholder_dim);
let probed = e.probe_dim().await.expect("probe ok");
assert_eq!(probed, actual_dim);
}
#[tokio::test]
async fn probe_dim_surfaces_empty_response_as_error() {
let server = MockServer::start().await;
let empty: Vec<f32> = Vec::new();
Mock::given(method("POST"))
.and(path("/api/embeddings"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"embedding": empty
})))
.expect(1)
.mount(&server)
.await;
let e = embedder_for(&server, 768);
let err = e
.probe_dim()
.await
.expect_err("empty probe response must error");
let msg = format!("{err}");
assert!(
msg.contains("empty"),
"expected empty-vector error, got: {msg}"
);
}
#[tokio::test]
async fn empty_batch_yields_empty_output_no_http_calls() {
let server = MockServer::start().await;
let e = embedder_for(&server, 768);
let out = e.embed_batch(&[]).await.expect("empty batch is ok");
assert!(out.is_empty());
}
#[tokio::test]
#[ignore]
async fn ollama_embedder_smoke_real_ollama() {
let base_url = std::env::var("SOLO_OLLAMA_BASE_URL")
.unwrap_or_else(|_| DEFAULT_BASE_URL.to_string());
let model = std::env::var("SOLO_OLLAMA_EMBED_MODEL")
.unwrap_or_else(|_| DEFAULT_OLLAMA_MODEL.to_string());
let dim = if model == "nomic-embed-text" {
DEFAULT_OLLAMA_DIM
} else if model == "mxbai-embed-large" {
1024
} else {
eprintln!(
"ollama_embedder_smoke_real_ollama: unknown model {model}, \
cannot pick dim; skipping. Override `dim` literal in test \
source to run."
);
return;
};
let e = OllamaEmbedder::new(base_url, model, dim).unwrap();
let out = e
.embed("the quick brown fox jumps over the lazy dog")
.await
.expect("real-Ollama embed");
assert_eq!(out.dim, dim);
assert_eq!(out.dtype, EmbeddingDtype::F32);
let slice = out.as_f32_slice().unwrap();
let mag: f32 = slice.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(mag > 0.0, "embedding should not be all-zero");
}
}