use std::time::Duration;
use async_trait::async_trait;
use lunaris_core::{Embedder, LunarisError, StorageError};
use serde::{Deserialize, Serialize};
const DEFAULT_ENDPOINT: &str = "http://localhost:11434";
const DEFAULT_MODEL: &str = "embeddinggemma:300m";
const DEFAULT_DIM: usize = 768;
const HTTP_TIMEOUT: Duration = Duration::from_secs(10);
#[derive(Clone, Debug)]
pub struct OllamaEmbedderOpts {
pub endpoint: String,
pub model: String,
pub dim: usize,
}
impl Default for OllamaEmbedderOpts {
fn default() -> Self {
Self {
endpoint: std::env::var("LUNARIS_OLLAMA_URL")
.ok()
.unwrap_or_else(|| DEFAULT_ENDPOINT.to_string()),
model: std::env::var("LUNARIS_OLLAMA_MODEL")
.ok()
.unwrap_or_else(|| DEFAULT_MODEL.to_string()),
dim: DEFAULT_DIM,
}
}
}
#[derive(Clone)]
pub struct OllamaEmbedder {
client: reqwest::Client,
endpoint: String,
model: String,
dim: usize,
}
impl OllamaEmbedder {
pub fn new(opts: OllamaEmbedderOpts) -> Result<Self, LunarisError> {
let client = reqwest::Client::builder().timeout(HTTP_TIMEOUT).build().map_err(|e| {
LunarisError::Storage(StorageError::Backend(format!("ollama client: {e}")))
})?;
Ok(Self { client, endpoint: opts.endpoint, model: opts.model, dim: opts.dim })
}
}
#[derive(Serialize)]
struct EmbedRequest<'a> {
model: &'a str,
input: Vec<&'a str>,
}
#[derive(Deserialize)]
struct EmbedResponse {
embeddings: Vec<Vec<f32>>,
}
#[async_trait]
impl Embedder for OllamaEmbedder {
fn dim(&self) -> usize {
self.dim
}
async fn embed_batch(&self, inputs: &[&str]) -> Result<Vec<Vec<f32>>, LunarisError> {
if inputs.is_empty() {
return Ok(Vec::new());
}
let url = format!("{}/api/embed", self.endpoint.trim_end_matches('/'));
let body = EmbedRequest { model: &self.model, input: inputs.to_vec() };
let resp =
self.client.post(&url).json(&body).send().await.map_err(|e| {
LunarisError::Storage(StorageError::Backend(format!("ollama: {e}")))
})?;
if !resp.status().is_success() {
return Err(LunarisError::Storage(StorageError::Backend(format!(
"ollama: HTTP {}",
resp.status()
))));
}
let parsed: EmbedResponse = resp.json().await.map_err(|e| {
LunarisError::Storage(StorageError::Backend(format!("ollama parse: {e}")))
})?;
if parsed.embeddings.len() != inputs.len() {
return Err(LunarisError::Storage(StorageError::Backend(format!(
"ollama returned wrong shape: expected {} rows, got {}",
inputs.len(),
parsed.embeddings.len()
))));
}
for (i, row) in parsed.embeddings.iter().enumerate() {
if row.len() != self.dim {
return Err(LunarisError::Storage(StorageError::Backend(format!(
"ollama returned wrong shape: row {} has dim {} (expected {})",
i,
row.len(),
self.dim
))));
}
}
Ok(parsed.embeddings)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn opts_default_resolves_to_localhost_ollama() {
let opts = OllamaEmbedderOpts::default();
assert_eq!(opts.endpoint, DEFAULT_ENDPOINT);
assert_eq!(opts.model, DEFAULT_MODEL);
assert_eq!(opts.dim, 768);
}
#[test]
fn embedder_construction_succeeds_with_defaults() {
let _e = OllamaEmbedder::new(OllamaEmbedderOpts::default()).expect("client builds");
}
}