use crate::provider_api::{
FinishReason, LlmError, LlmErrorKind, LlmProvider, LlmRequest, LlmResponse, TokenUsage,
};
use converge_core::capability::{
CapabilityError, CapabilityErrorKind, EmbedInput, EmbedRequest, EmbedResponse, EmbedUsage,
Embedding, Modality,
};
use serde::{Deserialize, Serialize};
pub const DEFAULT_OLLAMA_URL: &str = "http://localhost:11434";
pub struct OllamaProvider {
model: String,
client: reqwest::blocking::Client,
base_url: String,
embedding_support: std::sync::OnceLock<bool>,
}
impl OllamaProvider {
#[must_use]
pub fn new(model: impl Into<String>) -> Self {
Self::with_url(DEFAULT_OLLAMA_URL, model)
}
#[must_use]
pub fn with_url(url: impl Into<String>, model: impl Into<String>) -> Self {
Self {
model: model.into(),
client: reqwest::blocking::Client::builder()
.timeout(std::time::Duration::from_secs(120))
.build()
.expect("Failed to create HTTP client"),
base_url: url.into(),
embedding_support: std::sync::OnceLock::new(),
}
}
pub fn health_check(&self) -> Result<ModelInfo, LlmError> {
let url = format!("{}/api/show", self.base_url);
let response = self
.client
.post(&url)
.json(&serde_json::json!({"name": &self.model}))
.send()
.map_err(|e| {
if e.is_connect() {
LlmError::network(format!(
"Cannot connect to Ollama at {}. Is it running?",
self.base_url
))
} else {
LlmError::network(format!("Ollama request failed: {e}"))
}
})?;
if response.status().is_success() {
let info: ModelInfo = response
.json()
.map_err(|e| LlmError::parse(format!("Failed to parse model info: {e}")))?;
Ok(info)
} else if response.status().as_u16() == 404 {
Err(LlmError {
kind: LlmErrorKind::ModelNotFound,
message: format!(
"Model '{}' not found. Try: ollama pull {}",
self.model, self.model
),
retryable: false,
})
} else {
let status = response.status();
let body = response.text().unwrap_or_default();
Err(LlmError::provider(format!(
"Ollama returned status {status}: {body}"
)))
}
}
#[must_use]
pub fn supports_embedding(&self) -> bool {
*self.embedding_support.get_or_init(|| {
let known_embedding_models = [
"nomic-embed-text",
"mxbai-embed-large",
"bge-m3",
"bge-large",
"all-minilm",
"snowflake-arctic-embed",
];
for known in known_embedding_models {
if self.model.contains(known) {
return true;
}
}
if let Ok(info) = self.health_check() {
return !info.details.families.is_empty();
}
false
})
}
pub fn list_models(&self) -> Result<Vec<ModelListEntry>, LlmError> {
let url = format!("{}/api/tags", self.base_url);
let response = self
.client
.get(&url)
.send()
.map_err(|e| LlmError::network(format!("Failed to list models: {e}")))?;
if response.status().is_success() {
let list: ModelList = response
.json()
.map_err(|e| LlmError::parse(format!("Failed to parse model list: {e}")))?;
Ok(list.models)
} else {
Err(LlmError::provider("Failed to list models"))
}
}
}
impl LlmProvider for OllamaProvider {
fn name(&self) -> &'static str {
"ollama"
}
fn model(&self) -> &str {
&self.model
}
fn complete(&self, request: &LlmRequest) -> Result<LlmResponse, LlmError> {
let url = format!("{}/api/chat", self.base_url);
let mut messages = Vec::new();
if let Some(ref system) = request.system {
messages.push(OllamaMessage {
role: "system",
content: system.clone(),
});
}
messages.push(OllamaMessage {
role: "user",
content: request.prompt.clone(),
});
#[allow(clippy::cast_possible_wrap)]
let body = OllamaChatRequest {
model: &self.model,
messages,
stream: false,
options: Some(OllamaOptions {
temperature: request.temperature,
num_predict: Some(request.max_tokens as i32),
stop: if request.stop_sequences.is_empty() {
None
} else {
Some(request.stop_sequences.clone())
},
}),
};
let response = self
.client
.post(&url)
.json(&body)
.send()
.map_err(|e| LlmError::network(format!("Ollama request failed: {e}")))?;
if response.status().is_success() {
let ollama_response: OllamaChatResponse = response
.json()
.map_err(|e| LlmError::parse(format!("Failed to parse response: {e}")))?;
Ok(LlmResponse {
content: ollama_response.message.content,
model: ollama_response.model,
usage: TokenUsage {
prompt_tokens: ollama_response.prompt_eval_count.unwrap_or(0),
completion_tokens: ollama_response.eval_count.unwrap_or(0),
total_tokens: ollama_response.prompt_eval_count.unwrap_or(0)
+ ollama_response.eval_count.unwrap_or(0),
},
finish_reason: if ollama_response.done {
FinishReason::Stop
} else {
FinishReason::MaxTokens
},
})
} else {
let status = response.status();
let body = response.text().unwrap_or_default();
if status.as_u16() == 404 {
Err(LlmError {
kind: LlmErrorKind::ModelNotFound,
message: format!("Model '{}' not found", self.model),
retryable: false,
})
} else {
Err(LlmError::provider(format!(
"Ollama returned status {status}: {body}"
)))
}
}
}
}
impl Embedding for OllamaProvider {
fn name(&self) -> &'static str {
"ollama"
}
fn modalities(&self) -> Vec<Modality> {
vec![Modality::Text]
}
fn default_dimensions(&self) -> usize {
768
}
fn embed(&self, request: &EmbedRequest) -> Result<EmbedResponse, CapabilityError> {
let url = format!("{}/api/embeddings", self.base_url);
let mut embeddings = Vec::with_capacity(request.inputs.len());
for input in &request.inputs {
let text = match input {
EmbedInput::Text(t) => t.clone(),
other => {
return Err(CapabilityError::unsupported_modality(other.modality()));
}
};
let body = OllamaEmbedRequest {
model: &self.model,
prompt: &text,
};
let response =
self.client.post(&url).json(&body).send().map_err(|e| {
CapabilityError::network(format!("Embedding request failed: {e}"))
})?;
if response.status().is_success() {
let embed_response: OllamaEmbedResponse =
response.json().map_err(|e| CapabilityError {
kind: CapabilityErrorKind::ProviderError,
message: format!("Failed to parse embedding response: {e}"),
retryable: false,
})?;
embeddings.push(embed_response.embedding);
} else {
let status = response.status();
let body = response.text().unwrap_or_default();
return Err(CapabilityError {
kind: CapabilityErrorKind::ProviderError,
message: format!("Ollama embedding failed with status {status}: {body}"),
retryable: false,
});
}
}
let dimensions = embeddings.first().map(std::vec::Vec::len).unwrap_or(0);
Ok(EmbedResponse {
embeddings,
model: self.model.clone(),
dimensions,
usage: Some(EmbedUsage { total_tokens: 0 }), })
}
}
#[derive(Serialize)]
struct OllamaChatRequest<'a> {
model: &'a str,
messages: Vec<OllamaMessage>,
stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
options: Option<OllamaOptions>,
}
#[derive(Serialize)]
struct OllamaMessage {
role: &'static str,
content: String,
}
#[derive(Serialize)]
struct OllamaOptions {
temperature: f64,
#[serde(skip_serializing_if = "Option::is_none")]
num_predict: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
stop: Option<Vec<String>>,
}
#[derive(Deserialize)]
struct OllamaChatResponse {
model: String,
message: OllamaResponseMessage,
done: bool,
#[serde(default)]
prompt_eval_count: Option<u32>,
#[serde(default)]
eval_count: Option<u32>,
}
#[derive(Deserialize)]
struct OllamaResponseMessage {
content: String,
}
#[derive(Serialize)]
struct OllamaEmbedRequest<'a> {
model: &'a str,
prompt: &'a str,
}
#[derive(Deserialize)]
struct OllamaEmbedResponse {
embedding: Vec<f32>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ModelInfo {
pub modelfile: String,
pub parameters: Option<String>,
pub template: Option<String>,
pub details: ModelDetails,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ModelDetails {
pub format: String,
#[serde(default)]
pub families: Vec<String>,
pub parameter_size: Option<String>,
pub quantization_level: Option<String>,
}
#[derive(Deserialize)]
struct ModelList {
models: Vec<ModelListEntry>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ModelListEntry {
pub name: String,
pub size: u64,
pub modified_at: String,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::provider_api::LlmProvider;
#[test]
fn provider_name_and_model() {
let provider = OllamaProvider::new("qwen2.5:7b");
assert_eq!(LlmProvider::name(&provider), "ollama");
assert_eq!(provider.model(), "qwen2.5:7b");
}
#[test]
fn custom_url() {
let provider = OllamaProvider::with_url("http://gpu-server:11434", "llama3.2:8b");
assert_eq!(provider.base_url, "http://gpu-server:11434");
}
#[test]
fn embedding_modalities() {
let provider = OllamaProvider::new("nomic-embed-text");
let modalities = Embedding::modalities(&provider);
assert_eq!(modalities, vec![Modality::Text]);
}
#[test]
fn known_embedding_models_detected() {
let embedding_models = [
"nomic-embed-text",
"nomic-embed-text:latest",
"mxbai-embed-large",
"bge-m3:latest",
];
for model in embedding_models {
let _provider = OllamaProvider::new(model);
assert!(model.contains("embed") || model.contains("bge"));
}
}
}