use async_trait::async_trait;
use futures::stream::BoxStream;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use tracing::debug;
use crate::error::{LlmError, Result};
use crate::traits::{
ChatMessage, ChatRole, CompletionOptions, EmbeddingProvider, FunctionCall, LLMProvider,
LLMResponse, StreamChunk as TraitStreamChunk, ToolCall, ToolChoice, ToolDefinition,
};
const DEFAULT_OLLAMA_HOST: &str = "http://localhost:11434";
const DEFAULT_OLLAMA_MODEL: &str = "gemma3:12b";
const DEFAULT_OLLAMA_EMBEDDING_MODEL: &str = "embeddinggemma:latest";
#[derive(Debug, Clone)]
pub struct OllamaProvider {
client: Client,
host: String,
model: String,
embedding_model: String,
max_context_length: usize,
embedding_dimension: usize,
}
#[derive(Debug, Clone)]
pub struct OllamaProviderBuilder {
host: String,
model: String,
embedding_model: String,
max_context_length: usize,
embedding_dimension: usize,
}
impl Default for OllamaProviderBuilder {
fn default() -> Self {
Self {
host: DEFAULT_OLLAMA_HOST.to_string(),
model: DEFAULT_OLLAMA_MODEL.to_string(),
embedding_model: DEFAULT_OLLAMA_EMBEDDING_MODEL.to_string(),
max_context_length: 131072, embedding_dimension: 768, }
}
}
impl OllamaProviderBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn host(mut self, host: impl Into<String>) -> Self {
self.host = host.into();
self
}
pub fn model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self
}
pub fn embedding_model(mut self, model: impl Into<String>) -> Self {
self.embedding_model = model.into();
self
}
pub fn max_context_length(mut self, length: usize) -> Self {
self.max_context_length = length;
self
}
pub fn embedding_dimension(mut self, dimension: usize) -> Self {
self.embedding_dimension = dimension;
self
}
pub fn build(self) -> Result<OllamaProvider> {
let client = Client::builder()
.timeout(std::time::Duration::from_secs(300)) .no_proxy() .build()
.map_err(|e| LlmError::NetworkError(e.to_string()))?;
Ok(OllamaProvider {
client,
host: self.host,
model: self.model,
embedding_model: self.embedding_model,
max_context_length: self.max_context_length,
embedding_dimension: self.embedding_dimension,
})
}
}
impl OllamaProvider {
pub fn from_env() -> Result<Self> {
let host = std::env::var("OLLAMA_HOST").unwrap_or_else(|_| DEFAULT_OLLAMA_HOST.to_string());
let model =
std::env::var("OLLAMA_MODEL").unwrap_or_else(|_| DEFAULT_OLLAMA_MODEL.to_string());
let embedding_model = std::env::var("OLLAMA_EMBEDDING_MODEL")
.unwrap_or_else(|_| DEFAULT_OLLAMA_EMBEDDING_MODEL.to_string());
let max_context_length = std::env::var("OLLAMA_CONTEXT_LENGTH")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(131072);
OllamaProviderBuilder::new()
.host(host)
.model(model)
.embedding_model(embedding_model)
.max_context_length(max_context_length)
.build()
}
pub fn builder() -> OllamaProviderBuilder {
OllamaProviderBuilder::new()
}
pub fn default_local() -> Result<Self> {
OllamaProviderBuilder::new().build()
}
}
#[derive(Debug, Serialize)]
struct ChatRequest {
model: String,
messages: Vec<OllamaMessage>,
stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
options: Option<ChatOptions>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<OllamaTool>>, #[serde(skip_serializing_if = "Option::is_none")]
think: Option<bool>,
}
#[derive(Debug, Serialize)]
struct OllamaMessage {
role: String,
content: String,
#[serde(skip_serializing_if = "Option::is_none")]
images: Option<Vec<String>>,
}
#[derive(Debug, Serialize)]
struct ChatOptions {
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
num_predict: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
stop: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
num_ctx: Option<usize>,
}
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
struct ChatResponse {
model: String,
message: ResponseMessage,
done: bool,
#[serde(default)]
total_duration: u64,
#[serde(default)]
prompt_eval_count: u32,
#[serde(default)]
eval_count: u32,
}
#[derive(Debug, Deserialize)]
struct ResponseMessage {
#[allow(dead_code)]
role: String,
content: String,
#[serde(default)]
thinking: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<OllamaToolCall>>, }
#[derive(Debug, Deserialize)]
struct OllamaStreamChunk {
#[serde(default)]
message: Option<ResponseMessage>,
#[allow(dead_code)]
done: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OllamaTool {
pub r#type: String,
pub function: OllamaFunction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OllamaFunction {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OllamaToolCall {
pub function: OllamaFunctionCall,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OllamaFunctionCall {
pub name: String,
pub arguments: serde_json::Value,
}
#[derive(Debug, Serialize)]
struct EmbeddingRequest {
model: String,
input: Vec<String>,
}
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
struct EmbeddingResponse {
embeddings: Vec<Vec<f32>>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct OllamaModelsResponse {
pub models: Vec<OllamaModelInfo>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct OllamaModelDetails {
#[serde(default)]
pub parent_model: String,
#[serde(default)]
pub format: String,
#[serde(default)]
pub family: String,
#[serde(default)]
pub families: Vec<String>,
#[serde(default)]
pub parameter_size: String,
#[serde(default)]
pub quantization_level: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct OllamaModelInfo {
pub name: String,
#[serde(default)]
pub model: String,
#[serde(default)]
pub modified_at: String,
#[serde(default)]
pub size: u64,
#[serde(default)]
pub digest: String,
#[serde(default)]
pub details: Option<OllamaModelDetails>,
}
impl OllamaProvider {
fn convert_role(role: &ChatRole) -> &'static str {
match role {
ChatRole::System => "system",
ChatRole::User => "user",
ChatRole::Assistant => "assistant",
ChatRole::Tool | ChatRole::Function => "user", }
}
fn convert_messages(messages: &[ChatMessage]) -> Vec<OllamaMessage> {
messages
.iter()
.map(|msg| {
let images = msg.images.as_ref().and_then(|imgs| {
if imgs.is_empty() {
None
} else {
Some(imgs.iter().map(|img| img.data.clone()).collect::<Vec<_>>())
}
});
OllamaMessage {
role: Self::convert_role(&msg.role).to_string(),
content: msg.content.clone(),
images,
}
})
.collect()
}
fn convert_tools(tools: &[ToolDefinition]) -> Vec<OllamaTool> {
tools
.iter()
.map(|tool| OllamaTool {
r#type: "function".to_string(),
function: OllamaFunction {
name: tool.function.name.clone(),
description: tool.function.description.clone(),
parameters: tool.function.parameters.clone(),
},
})
.collect()
}
pub async fn list_models(&self) -> Result<OllamaModelsResponse> {
let url = format!("{}/api/tags", self.host);
debug!(url = %url, "Fetching Ollama models list");
let response = self
.client
.get(&url)
.send()
.await
.map_err(|e| LlmError::NetworkError(e.to_string()))?;
let status = response.status();
if !status.is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(LlmError::ApiError(format!(
"Ollama API error ({}): {}",
status, error_text
)));
}
let models: OllamaModelsResponse = response
.json()
.await
.map_err(|e| LlmError::ApiError(format!("Failed to parse models response: {}", e)))?;
debug!("Ollama returned {} models", models.models.len());
Ok(models)
}
pub fn host(&self) -> &str {
&self.host
}
fn is_thinking_model(model: &str) -> bool {
let model_lower = model.to_lowercase();
model_lower.contains("deepseek-r1")
|| model_lower.contains("qwen3")
|| model_lower.contains("qwq")
|| model_lower.contains("openthinker")
|| model_lower.contains("phi4-reasoning")
|| model_lower.contains("magistral")
|| model_lower.contains("cogito")
|| model_lower.contains("gpt-oss") }
}
#[async_trait]
impl LLMProvider for OllamaProvider {
fn name(&self) -> &str {
"ollama"
}
fn model(&self) -> &str {
&self.model
}
fn max_context_length(&self) -> usize {
self.max_context_length
}
async fn complete(&self, prompt: &str) -> Result<LLMResponse> {
self.complete_with_options(prompt, &CompletionOptions::default())
.await
}
async fn complete_with_options(
&self,
prompt: &str,
options: &CompletionOptions,
) -> Result<LLMResponse> {
let mut messages = Vec::new();
if let Some(system) = &options.system_prompt {
messages.push(ChatMessage::system(system));
}
messages.push(ChatMessage::user(prompt));
self.chat(&messages, Some(options)).await
}
async fn chat(
&self,
messages: &[ChatMessage],
options: Option<&CompletionOptions>,
) -> Result<LLMResponse> {
debug!(
"Ollama chat request: {} messages to model {}",
messages.len(),
self.model
);
let url = format!("{}/api/chat", self.host);
let opts = options.cloned().unwrap_or_default();
let chat_options = ChatOptions {
temperature: opts.temperature,
num_predict: opts.max_tokens.map(|t| t as i32),
stop: opts.stop.clone(),
num_ctx: Some(self.max_context_length),
};
let think = Self::is_thinking_model(&self.model);
let request = ChatRequest {
model: self.model.clone(),
messages: Self::convert_messages(messages),
stream: false,
options: Some(chat_options),
tools: None, think: if think { Some(true) } else { None },
};
let response = self
.client
.post(&url)
.json(&request)
.send()
.await
.map_err(|e| LlmError::NetworkError(e.to_string()))?;
let status = response.status();
if !status.is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(LlmError::ApiError(format!(
"Ollama API error ({}): {}",
status, error_text
)));
}
let response: ChatResponse = response
.json()
.await
.map_err(|e| LlmError::ApiError(format!("Failed to parse response: {}", e)))?;
let mut llm_response = LLMResponse::new(response.message.content, response.model)
.with_usage(
response.prompt_eval_count as usize,
response.eval_count as usize,
);
if let Some(thinking) = &response.message.thinking {
if !thinking.is_empty() {
llm_response = llm_response.with_thinking_content(thinking.clone());
let thinking_tokens = thinking.len() / 4;
llm_response = llm_response.with_thinking_tokens(thinking_tokens);
}
}
Ok(llm_response)
}
async fn stream(&self, prompt: &str) -> Result<BoxStream<'static, Result<String>>> {
use futures::StreamExt;
debug!("Ollama stream request: prompt to model {}", self.model);
let url = format!("{}/api/chat", self.host);
let chat_options = ChatOptions {
temperature: None,
num_predict: None,
stop: None,
num_ctx: Some(self.max_context_length),
};
let think = Self::is_thinking_model(&self.model);
let request = ChatRequest {
model: self.model.clone(),
messages: vec![OllamaMessage {
role: "user".to_string(),
content: prompt.to_string(),
images: None,
}],
stream: true,
options: Some(chat_options),
tools: None, think: if think { Some(true) } else { None },
};
let response = self
.client
.post(&url)
.json(&request)
.send()
.await
.map_err(|e| LlmError::NetworkError(e.to_string()))?;
let status = response.status();
if !status.is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(LlmError::ApiError(format!(
"Ollama API error ({}): {}",
status, error_text
)));
}
let stream = response.bytes_stream();
let mapped_stream = stream.map(|chunk_result| {
match chunk_result {
Ok(bytes) => {
let text = String::from_utf8_lossy(&bytes);
let mut content = String::new();
for line in text.lines() {
if line.is_empty() {
continue;
}
if let Ok(chunk) = serde_json::from_str::<OllamaStreamChunk>(line) {
if let Some(msg) = chunk.message {
content.push_str(&msg.content);
}
}
}
Ok(content)
}
Err(e) => Err(LlmError::NetworkError(e.to_string())),
}
});
Ok(mapped_stream.boxed())
}
fn supports_streaming(&self) -> bool {
true
}
fn supports_function_calling(&self) -> bool {
true }
fn supports_tool_streaming(&self) -> bool {
true }
async fn chat_with_tools(
&self,
messages: &[ChatMessage],
tools: &[ToolDefinition],
_tool_choice: Option<ToolChoice>, options: Option<&CompletionOptions>,
) -> Result<LLMResponse> {
let url = format!("{}/api/chat", self.host);
let opts = options.cloned().unwrap_or_default();
let chat_options = ChatOptions {
temperature: opts.temperature,
num_predict: opts.max_tokens.map(|t| t as i32),
stop: opts.stop.clone(),
num_ctx: Some(self.max_context_length),
};
let ollama_tools = if !tools.is_empty() {
Some(Self::convert_tools(tools))
} else {
None
};
let think = Self::is_thinking_model(&self.model);
let request = ChatRequest {
model: self.model.clone(),
messages: Self::convert_messages(messages),
stream: false,
options: Some(chat_options),
tools: ollama_tools,
think: if think { Some(true) } else { None },
};
let response = self
.client
.post(&url)
.json(&request)
.send()
.await
.map_err(|e| LlmError::NetworkError(e.to_string()))?;
let status = response.status();
if !status.is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(LlmError::ApiError(format!(
"Ollama API error ({}): {}",
status, error_text
)));
}
let response: ChatResponse = response
.json()
.await
.map_err(|e| LlmError::ApiError(format!("Failed to parse response: {}", e)))?;
let tool_calls: Vec<ToolCall> = response
.message
.tool_calls
.unwrap_or_default()
.into_iter()
.map(|tc| ToolCall {
id: uuid::Uuid::new_v4().to_string(),
call_type: "function".to_string(),
function: FunctionCall {
name: tc.function.name,
arguments: serde_json::to_string(&tc.function.arguments).unwrap_or_default(),
},
thought_signature: None,
})
.collect();
let mut llm_response = LLMResponse::new(response.message.content, response.model)
.with_usage(
response.prompt_eval_count as usize,
response.eval_count as usize,
)
.with_tool_calls(tool_calls);
if let Some(thinking) = &response.message.thinking {
if !thinking.is_empty() {
llm_response = llm_response.with_thinking_content(thinking.clone());
let thinking_tokens = thinking.len() / 4;
llm_response = llm_response.with_thinking_tokens(thinking_tokens);
}
}
Ok(llm_response)
}
async fn chat_with_tools_stream(
&self,
messages: &[ChatMessage],
tools: &[ToolDefinition],
_tool_choice: Option<ToolChoice>,
options: Option<&CompletionOptions>,
) -> Result<BoxStream<'static, Result<TraitStreamChunk>>> {
use futures::StreamExt;
let url = format!("{}/api/chat", self.host);
let opts = options.cloned().unwrap_or_default();
let chat_options = ChatOptions {
temperature: opts.temperature,
num_predict: opts.max_tokens.map(|t| t as i32),
stop: opts.stop.clone(),
num_ctx: Some(self.max_context_length),
};
let ollama_tools = if !tools.is_empty() {
Some(Self::convert_tools(tools))
} else {
None
};
let think = Self::is_thinking_model(&self.model);
let request = ChatRequest {
model: self.model.clone(),
messages: Self::convert_messages(messages),
stream: true,
options: Some(chat_options),
tools: ollama_tools,
think: if think { Some(true) } else { None },
};
let response = self
.client
.post(&url)
.json(&request)
.send()
.await
.map_err(|e| LlmError::NetworkError(e.to_string()))?;
let status = response.status();
if !status.is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(LlmError::ApiError(format!(
"Ollama API error ({}): {}",
status, error_text
)));
}
let stream = response.bytes_stream();
let mapped_stream = stream.map(|chunk_result| {
match chunk_result {
Ok(bytes) => {
let text = String::from_utf8_lossy(&bytes);
for line in text.lines() {
if line.is_empty() {
continue;
}
if let Ok(chunk) = serde_json::from_str::<OllamaStreamChunk>(line) {
if let Some(msg) = chunk.message {
if let Some(thinking) = &msg.thinking {
if !thinking.is_empty() {
let tokens_used = thinking.len() / 4;
return Ok(TraitStreamChunk::ThinkingContent {
text: thinking.clone(),
tokens_used: Some(tokens_used),
budget_total: None,
});
}
}
if let Some(tool_calls) = msg.tool_calls {
if let Some(tc) = tool_calls.first() {
return Ok(TraitStreamChunk::ToolCallDelta {
index: 0,
id: Some(uuid::Uuid::new_v4().to_string()),
function_name: Some(tc.function.name.clone()),
function_arguments: serde_json::to_string(
&tc.function.arguments,
)
.ok(),
thought_signature: None,
});
}
}
if !msg.content.is_empty() {
return Ok(TraitStreamChunk::Content(msg.content));
}
}
if chunk.done {
return Ok(TraitStreamChunk::Finished {
reason: "stop".to_string(),
ttft_ms: None,
});
}
}
}
Ok(TraitStreamChunk::Content(String::new()))
}
Err(e) => Err(LlmError::NetworkError(e.to_string())),
}
});
Ok(mapped_stream.boxed())
}
}
#[async_trait]
impl EmbeddingProvider for OllamaProvider {
fn name(&self) -> &str {
"ollama"
}
#[allow(clippy::misnamed_getters)]
fn model(&self) -> &str {
&self.embedding_model
}
fn dimension(&self) -> usize {
self.embedding_dimension
}
fn max_tokens(&self) -> usize {
8192 }
async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(vec![]);
}
debug!(
"Ollama embedding request: {} texts with model {}",
texts.len(),
self.embedding_model
);
let url = format!("{}/api/embed", self.host);
let request = EmbeddingRequest {
model: self.embedding_model.clone(),
input: texts.to_vec(),
};
let response = self
.client
.post(&url)
.json(&request)
.send()
.await
.map_err(|e| LlmError::NetworkError(e.to_string()))?;
let status = response.status();
if !status.is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(LlmError::ApiError(format!(
"Ollama API error ({}): {}",
status, error_text
)));
}
let response: EmbeddingResponse = response
.json()
.await
.map_err(|e| LlmError::ApiError(format!("Failed to parse response: {}", e)))?;
debug!(
"Ollama embedding response: {} embeddings",
response.embeddings.len()
);
Ok(response.embeddings)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_builder_creation() {
let provider = OllamaProviderBuilder::new()
.host("http://localhost:11434")
.model("mistral")
.embedding_model("nomic-embed-text")
.build()
.unwrap();
assert_eq!(LLMProvider::name(&provider), "ollama");
assert_eq!(LLMProvider::model(&provider), "mistral");
assert_eq!(EmbeddingProvider::model(&provider), "nomic-embed-text");
}
#[test]
fn test_default_builder() {
let provider = OllamaProviderBuilder::new().build().unwrap();
assert_eq!(LLMProvider::name(&provider), "ollama");
assert_eq!(LLMProvider::model(&provider), "gemma3:12b");
assert_eq!(EmbeddingProvider::model(&provider), "embeddinggemma:latest");
assert_eq!(provider.max_context_length(), 131072);
}
#[test]
fn test_message_conversion() {
let messages = vec![
ChatMessage::system("You are helpful"),
ChatMessage::user("Hello"),
ChatMessage::assistant("Hi there!"),
];
let converted = OllamaProvider::convert_messages(&messages);
assert_eq!(converted.len(), 3);
assert_eq!(converted[0].role, "system");
assert_eq!(converted[1].role, "user");
assert_eq!(converted[2].role, "assistant");
}
#[tokio::test]
async fn test_embed_empty_input() {
let provider = OllamaProviderBuilder::new().build().unwrap();
let result = provider.embed(&[]).await;
assert!(result.is_ok());
assert!(result.unwrap().is_empty());
}
#[test]
fn test_is_thinking_model_deepseek_r1() {
assert!(OllamaProvider::is_thinking_model("deepseek-r1:8b"));
assert!(OllamaProvider::is_thinking_model("deepseek-r1:70b"));
assert!(OllamaProvider::is_thinking_model("DEEPSEEK-R1:latest"));
}
#[test]
fn test_is_thinking_model_qwen3() {
assert!(OllamaProvider::is_thinking_model("qwen3:8b"));
assert!(OllamaProvider::is_thinking_model("qwen3:32b"));
assert!(OllamaProvider::is_thinking_model("QWEN3:latest"));
}
#[test]
fn test_is_thinking_model_others() {
assert!(OllamaProvider::is_thinking_model("qwq:32b"));
assert!(OllamaProvider::is_thinking_model("openthinker:7b"));
assert!(OllamaProvider::is_thinking_model("phi4-reasoning:14b"));
assert!(OllamaProvider::is_thinking_model("magistral:24b"));
assert!(OllamaProvider::is_thinking_model("cogito:8b"));
assert!(OllamaProvider::is_thinking_model("gpt-oss:20b"));
}
#[test]
fn test_is_thinking_model_non_thinking() {
assert!(!OllamaProvider::is_thinking_model("llama3.2:8b"));
assert!(!OllamaProvider::is_thinking_model("gemma3:12b"));
assert!(!OllamaProvider::is_thinking_model("mistral:7b"));
assert!(!OllamaProvider::is_thinking_model("codellama:34b"));
}
#[test]
fn test_response_with_thinking_parsing() {
let json = r#"{
"role": "assistant",
"content": "The answer is 3.",
"thinking": "Let me count the r's in strawberry: s-t-r-a-w-b-e-r-r-y. That's 3 r's."
}"#;
let msg: ResponseMessage = serde_json::from_str(json).unwrap();
assert_eq!(msg.content, "The answer is 3.");
assert!(msg.thinking.is_some());
assert!(msg.thinking.unwrap().contains("Let me count"));
}
#[test]
fn test_response_without_thinking_parsing() {
let json = r#"{
"role": "assistant",
"content": "Hello, how can I help you?"
}"#;
let msg: ResponseMessage = serde_json::from_str(json).unwrap();
assert_eq!(msg.content, "Hello, how can I help you?");
assert!(msg.thinking.is_none());
}
#[test]
fn test_chat_request_with_think_serialization() {
let request = ChatRequest {
model: "deepseek-r1:8b".to_string(),
messages: vec![OllamaMessage {
role: "user".to_string(),
content: "How many r's in strawberry?".to_string(),
images: None,
}],
stream: false,
options: None,
tools: None,
think: Some(true),
};
let json = serde_json::to_string(&request).unwrap();
assert!(json.contains("\"think\":true"));
}
#[test]
fn test_chat_request_without_think_serialization() {
let request = ChatRequest {
model: "llama3.2:8b".to_string(),
messages: vec![OllamaMessage {
role: "user".to_string(),
content: "Hello".to_string(),
images: None,
}],
stream: false,
options: None,
tools: None,
think: None,
};
let json = serde_json::to_string(&request).unwrap();
assert!(!json.contains("think"));
}
#[test]
fn test_chat_options_num_ctx_serialization() {
let options = ChatOptions {
temperature: None,
num_predict: None,
stop: None,
num_ctx: Some(65536),
};
let json = serde_json::to_string(&options).unwrap();
assert!(
json.contains("\"num_ctx\":65536"),
"num_ctx should be serialized in options: {}",
json
);
}
#[test]
fn test_chat_options_num_ctx_omitted_when_none() {
let options = ChatOptions {
temperature: None,
num_predict: None,
stop: None,
num_ctx: None,
};
let json = serde_json::to_string(&options).unwrap();
assert!(
!json.contains("num_ctx"),
"num_ctx should be omitted when None: {}",
json
);
}
#[test]
fn test_constants() {
assert_eq!(DEFAULT_OLLAMA_HOST, "http://localhost:11434");
assert_eq!(DEFAULT_OLLAMA_MODEL, "gemma3:12b");
assert_eq!(DEFAULT_OLLAMA_EMBEDDING_MODEL, "embeddinggemma:latest");
}
#[test]
fn test_builder_default_values() {
let builder = OllamaProviderBuilder::default();
assert_eq!(builder.host, "http://localhost:11434");
assert_eq!(builder.model, "gemma3:12b");
assert_eq!(builder.embedding_model, "embeddinggemma:latest");
assert_eq!(builder.max_context_length, 131072);
assert_eq!(builder.embedding_dimension, 768);
}
#[test]
fn test_builder_custom_context_length() {
let provider = OllamaProviderBuilder::new()
.max_context_length(65536)
.build()
.unwrap();
assert_eq!(provider.max_context_length(), 65536);
}
#[test]
fn test_builder_custom_embedding_dimension() {
let provider = OllamaProviderBuilder::new()
.embedding_dimension(1536)
.build()
.unwrap();
assert_eq!(provider.dimension(), 1536);
}
#[test]
fn test_default_local_creation() {
let provider = OllamaProvider::default_local().unwrap();
assert_eq!(LLMProvider::name(&provider), "ollama");
assert_eq!(LLMProvider::model(&provider), "gemma3:12b");
assert_eq!(provider.host, "http://localhost:11434");
}
#[test]
fn test_supports_streaming() {
let provider = OllamaProviderBuilder::new().build().unwrap();
assert!(provider.supports_streaming());
}
#[test]
fn test_supports_json_mode() {
let provider = OllamaProviderBuilder::new().build().unwrap();
assert!(!provider.supports_json_mode());
}
#[test]
fn test_embedding_provider_name() {
let provider = OllamaProviderBuilder::new().build().unwrap();
assert_eq!(EmbeddingProvider::name(&provider), "ollama");
}
#[test]
fn test_embedding_provider_dimension() {
let provider = OllamaProviderBuilder::new()
.embedding_dimension(1024)
.build()
.unwrap();
assert_eq!(provider.dimension(), 1024);
}
#[test]
fn test_embedding_provider_max_tokens() {
let provider = OllamaProviderBuilder::new().build().unwrap();
assert_eq!(provider.max_tokens(), 8192);
}
#[test]
fn test_message_conversion_tool_role() {
let messages = vec![ChatMessage::tool_result("tool-1", "Tool output")];
let converted = OllamaProvider::convert_messages(&messages);
assert_eq!(converted.len(), 1);
assert_eq!(converted[0].role, "user");
assert_eq!(converted[0].content, "Tool output");
}
#[test]
fn test_ollama_message_serialization() {
let msg = OllamaMessage {
role: "user".to_string(),
content: "Hello world".to_string(),
images: None,
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains("\"role\":\"user\""));
assert!(json.contains("\"content\":\"Hello world\""));
}
#[test]
fn test_from_env_uses_defaults() {
std::env::remove_var("OLLAMA_HOST");
std::env::remove_var("OLLAMA_MODEL");
std::env::remove_var("OLLAMA_EMBEDDING_MODEL");
std::env::remove_var("OLLAMA_CONTEXT_LENGTH");
let provider = OllamaProvider::from_env().unwrap();
assert_eq!(provider.host, "http://localhost:11434");
assert_eq!(LLMProvider::model(&provider), "gemma3:12b");
assert_eq!(EmbeddingProvider::model(&provider), "embeddinggemma:latest");
assert_eq!(provider.max_context_length(), 131072);
}
#[test]
fn test_chat_options_temperature_serialization() {
let options = ChatOptions {
temperature: Some(0.7),
num_predict: Some(1024),
stop: Some(vec!["END".to_string()]),
num_ctx: Some(32768),
};
let json = serde_json::to_string(&options).unwrap();
assert!(json.contains("\"temperature\":0.7"));
assert!(json.contains("\"num_predict\":1024"));
assert!(json.contains("\"stop\":[\"END\"]"));
assert!(json.contains("\"num_ctx\":32768"));
}
#[test]
fn test_convert_messages_with_image_populates_images_field() {
use crate::traits::ImageData;
let img = ImageData::new("base64abc", "image/png");
let messages = vec![ChatMessage::user_with_images("describe this", vec![img])];
let converted = OllamaProvider::convert_messages(&messages);
assert_eq!(converted.len(), 1);
let images = converted[0].images.as_ref().expect("images must be Some");
assert_eq!(images.len(), 1);
assert_eq!(
images[0], "base64abc",
"Should be raw base64 without data-URI prefix"
);
}
#[test]
fn test_convert_messages_without_image_omits_images_field() {
let messages = vec![ChatMessage::user("no image here")];
let converted = OllamaProvider::convert_messages(&messages);
assert!(
converted[0].images.is_none(),
"images must be None for text-only messages"
);
let json = serde_json::to_string(&converted[0]).unwrap();
assert!(
!json.contains("\"images\""),
"images key must not appear in JSON for text-only message"
);
}
#[test]
fn test_ollama_message_with_images_serialization() {
let msg = OllamaMessage {
role: "user".to_string(),
content: "what is this?".to_string(),
images: Some(vec!["base64data".to_string()]),
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains("\"images\":[\"base64data\"]"));
}
}