use std::sync::OnceLock;
use std::time::Duration;
use serde::{Deserialize, Serialize};
use crate::config::OllamaConfig;
use crate::embedder::Embedder;
use crate::error::EmbedError;
use crate::http::{classify_ureq_error, decode_json};
use crate::manifest::EmbedderManifest;
pub struct OllamaEmbedder {
model_bare: String,
model_fq: String,
dim: OnceLock<u32>,
endpoint: String,
agent: ureq::Agent,
}
impl OllamaEmbedder {
pub fn from_config(config: &OllamaConfig) -> Result<Self, EmbedError> {
let endpoint = format!("{}/api/embeddings", config.base_url.trim_end_matches('/'));
let agent = ureq::AgentBuilder::new()
.timeout(Duration::from_secs(config.timeout_secs))
.build();
Ok(Self {
model_bare: config.model.clone(),
model_fq: format!("ollama:{}", config.model),
dim: OnceLock::new(),
endpoint,
agent,
})
}
}
impl Embedder for OllamaEmbedder {
fn model(&self) -> &str {
&self.model_fq
}
fn dim(&self) -> u32 {
self.dim.get().copied().unwrap_or(0)
}
fn embed(&self, text: &str) -> Result<Vec<f32>, EmbedError> {
#[derive(Serialize)]
struct Req<'a> {
model: &'a str,
prompt: &'a str,
}
#[derive(Deserialize)]
struct Resp {
embedding: Vec<f32>,
}
let body = Req {
model: &self.model_bare,
prompt: text,
};
let resp = self
.agent
.post(&self.endpoint)
.set("Content-Type", "application/json")
.send_json(&body)
.map_err(classify_ureq_error)?;
let parsed: Resp = decode_json(resp)?;
let got_dim = u32::try_from(parsed.embedding.len()).unwrap_or(u32::MAX);
match self.dim.get() {
Some(&expected) => {
if got_dim != expected {
return Err(EmbedError::DimMismatch {
expected,
got: got_dim,
});
}
}
None => {
let _ = self.dim.set(got_dim);
}
}
Ok(parsed.embedding)
}
fn manifest(&self) -> EmbedderManifest {
EmbedderManifest::new(
self.model_fq.clone(),
self.dim.get().copied().unwrap_or(0),
0.31,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn from_config_does_not_contact_network() {
let cfg = OllamaConfig {
model: "nomic-embed-text".into(),
base_url: "http://definitely-not-reachable.example.invalid:11434".into(),
..Default::default()
};
let e = OllamaEmbedder::from_config(&cfg).unwrap();
assert_eq!(e.model(), "ollama:nomic-embed-text");
assert_eq!(e.dim(), 0); }
}