use ceres_core::HttpConfig;
use ceres_core::error::AppError;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::time::Duration;
const DEFAULT_ENDPOINT: &str = "http://localhost:11434";
const DEFAULT_MODEL: &str = "nomic-embed-text";
const OLLAMA_TIMEOUT_SECS: u64 = 120;
fn normalize_model_name(model: &str) -> &str {
model.split(':').next().unwrap_or(model)
}
pub fn model_dimension(model: &str) -> usize {
match normalize_model_name(model) {
"nomic-embed-text" => 768,
"mxbai-embed-large" | "snowflake-arctic-embed" => 1024,
"all-minilm" => 384,
_ => {
tracing::warn!(
model,
"Unknown Ollama model dimension, defaulting to 768. \
Set EMBEDDING_MODEL to a known model or verify dimension matches your database."
);
768
}
}
}
#[derive(Clone, Debug)]
pub struct OllamaClient {
client: Client,
model: String,
endpoint: String,
dim: usize,
timeout_secs: u64,
}
#[derive(Serialize)]
struct EmbedRequest<'a> {
model: &'a str,
input: Vec<&'a str>,
}
#[derive(Deserialize)]
struct EmbedResponse {
embeddings: Vec<Vec<f32>>,
}
#[derive(Deserialize)]
struct OllamaErrorResponse {
error: String,
}
impl OllamaClient {
pub fn new() -> Result<Self, AppError> {
Self::with_config(DEFAULT_MODEL, None)
}
pub fn with_model(model: &str) -> Result<Self, AppError> {
Self::with_config(model, None)
}
pub fn with_config(model: &str, endpoint: Option<&str>) -> Result<Self, AppError> {
let endpoint = endpoint.unwrap_or(DEFAULT_ENDPOINT);
let parsed = reqwest::Url::parse(endpoint).map_err(|e| {
AppError::ConfigError(format!("Invalid Ollama endpoint '{}': {}", endpoint, e))
})?;
match parsed.scheme() {
"http" | "https" => {}
scheme => {
return Err(AppError::ConfigError(format!(
"Invalid Ollama endpoint scheme '{}'. Only http and https are allowed.",
scheme
)));
}
}
let http_config = HttpConfig::default();
let timeout_secs = if http_config.timeout.as_secs() < OLLAMA_TIMEOUT_SECS {
OLLAMA_TIMEOUT_SECS
} else {
http_config.timeout.as_secs()
};
let client = Client::builder()
.timeout(Duration::from_secs(timeout_secs))
.build()
.map_err(|e| AppError::ClientError(e.to_string()))?;
let endpoint = endpoint.trim_end_matches('/').to_string();
let dim = model_dimension(model);
Ok(Self {
client,
model: model.to_string(),
endpoint,
dim,
timeout_secs,
})
}
pub fn model(&self) -> &str {
&self.model
}
pub async fn get_embeddings(&self, text: &str) -> Result<Vec<f32>, AppError> {
let embeddings = self.get_embeddings_batch(&[text]).await?;
embeddings.into_iter().next().ok_or(AppError::EmptyResponse)
}
pub async fn get_embeddings_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, AppError> {
if texts.is_empty() {
return Ok(Vec::new());
}
let url = format!("{}/api/embed", self.endpoint);
let request_body = EmbedRequest {
model: &self.model,
input: texts.to_vec(),
};
let response = self
.client
.post(&url)
.header("Content-Type", "application/json")
.json(&request_body)
.send()
.await
.map_err(|e| self.map_connection_error(e))?;
let status = response.status();
if !status.is_success() {
let status_code = status.as_u16();
let error_text = response.text().await.unwrap_or_default();
return Err(self.map_api_error(status_code, &error_text));
}
let embed_response: EmbedResponse = response.json().await.map_err(|e| {
AppError::ClientError(format!("Failed to parse Ollama response: {}", e))
})?;
Ok(embed_response.embeddings)
}
fn map_connection_error(&self, err: reqwest::Error) -> AppError {
if err.is_timeout() {
AppError::Timeout(self.timeout_secs)
} else if err.is_connect() {
AppError::NetworkError(format!(
"Cannot connect to Ollama at {}. Is it running? Try: ollama serve",
self.endpoint
))
} else {
AppError::ClientError(format!("Ollama request failed: {}", err))
}
}
fn map_api_error(&self, status_code: u16, error_text: &str) -> AppError {
let message = serde_json::from_str::<OllamaErrorResponse>(error_text)
.map(|e| e.error)
.unwrap_or_else(|_| format!("HTTP {}: {}", status_code, error_text));
let lower_message = message.to_lowercase();
let lower_model = self.model.to_lowercase();
let is_model_not_found = status_code == 404
&& (lower_message.contains("not found")
&& (lower_message.contains("model") || lower_message.contains(&lower_model)));
if is_model_not_found {
return AppError::ClientError(format!(
"Ollama model '{}' not found. Try: ollama pull {}",
self.model, self.model
));
}
if status_code == 404 {
return AppError::ClientError(format!(
"Received 404 from Ollama at {}: {}. Check that the endpoint is correct.",
self.endpoint, message
));
}
AppError::ClientError(format!("Ollama error: {}", message))
}
}
impl ceres_core::traits::EmbeddingProvider for OllamaClient {
fn name(&self) -> &'static str {
"ollama"
}
fn dimension(&self) -> usize {
self.dim
}
fn max_batch_size(&self) -> usize {
512
}
async fn generate(&self, text: &str) -> Result<Vec<f32>, AppError> {
self.get_embeddings(text).await
}
async fn generate_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, AppError> {
let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
self.get_embeddings_batch(&text_refs).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_dimension() {
assert_eq!(model_dimension("nomic-embed-text"), 768);
assert_eq!(model_dimension("mxbai-embed-large"), 1024);
assert_eq!(model_dimension("snowflake-arctic-embed"), 1024);
assert_eq!(model_dimension("all-minilm"), 384);
assert_eq!(model_dimension("unknown-model"), 768); }
#[test]
fn test_model_dimension_with_tags() {
assert_eq!(model_dimension("nomic-embed-text:latest"), 768);
assert_eq!(model_dimension("snowflake-arctic-embed:335m"), 1024);
assert_eq!(model_dimension("mxbai-embed-large:latest"), 1024);
assert_eq!(model_dimension("all-minilm:l6-v2"), 384);
}
#[test]
fn test_new_client() {
let client = OllamaClient::new();
assert!(client.is_ok());
let client = client.unwrap();
assert_eq!(client.model(), "nomic-embed-text");
assert_eq!(client.dim, 768);
assert_eq!(client.endpoint, "http://localhost:11434");
}
#[test]
fn test_client_with_model() {
let client = OllamaClient::with_model("mxbai-embed-large").unwrap();
assert_eq!(client.model(), "mxbai-embed-large");
assert_eq!(client.dim, 1024);
}
#[test]
fn test_client_with_config() {
let client =
OllamaClient::with_config("nomic-embed-text", Some("http://myhost:11434")).unwrap();
assert_eq!(client.endpoint, "http://myhost:11434");
assert_eq!(client.model(), "nomic-embed-text");
}
#[test]
fn test_endpoint_trailing_slash_normalized() {
let client =
OllamaClient::with_config("nomic-embed-text", Some("http://localhost:11434/")).unwrap();
assert_eq!(client.endpoint, "http://localhost:11434");
}
#[test]
fn test_invalid_endpoint_scheme() {
let result = OllamaClient::with_config("nomic-embed-text", Some("ftp://localhost:11434"));
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("scheme"));
}
#[test]
fn test_invalid_endpoint_url() {
let result = OllamaClient::with_config("nomic-embed-text", Some("not a url"));
assert!(result.is_err());
}
#[test]
fn test_request_serialization() {
let request = EmbedRequest {
model: "nomic-embed-text",
input: vec!["Hello world", "Test input"],
};
let json = serde_json::to_string(&request).unwrap();
assert!(json.contains("nomic-embed-text"));
assert!(json.contains("Hello world"));
assert!(json.contains("Test input"));
}
#[test]
fn test_trait_implementation() {
use ceres_core::traits::EmbeddingProvider;
let client = OllamaClient::new().unwrap();
assert_eq!(client.name(), "ollama");
assert_eq!(client.dimension(), 768);
assert_eq!(client.max_batch_size(), 512);
}
#[test]
fn test_map_api_error_model_not_found() {
let client = OllamaClient::new().unwrap();
let err = client.map_api_error(404, r#"{"error":"model \"nomic-embed-text\" not found"}"#);
let msg = err.to_string();
assert!(msg.contains("not found"));
assert!(msg.contains("ollama pull"));
}
#[test]
fn test_map_api_error_generic_404() {
let client = OllamaClient::new().unwrap();
let err = client.map_api_error(404, r#"{"error":"Not Found"}"#);
let msg = err.to_string();
assert!(!msg.contains("ollama pull"));
assert!(msg.contains("endpoint"));
}
#[test]
fn test_map_api_error_generic() {
let client = OllamaClient::new().unwrap();
let err = client.map_api_error(500, r#"{"error":"internal server error"}"#);
let msg = err.to_string();
assert!(msg.contains("internal server error"));
}
}