use async_trait::async_trait;
use futures::stream::BoxStream;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use tracing::debug;
use crate::error::{LlmError, Result};
use crate::traits::{
ChatMessage, ChatRole, CompletionOptions, EmbeddingProvider, FunctionCall, LLMProvider,
LLMResponse, StreamChunk as TraitStreamChunk, ToolCall, ToolChoice, ToolDefinition,
};
const DEFAULT_OLLAMA_HOST: &str = "http://localhost:11434";
const DEFAULT_OLLAMA_MODEL: &str = "gemma3:12b";
const DEFAULT_OLLAMA_EMBEDDING_MODEL: &str = "embeddinggemma:latest";
#[derive(Debug, Clone)]
pub struct OllamaProvider {
client: Client,
host: String,
model: String,
embedding_model: String,
max_context_length: usize,
embedding_dimension: usize,
}
#[derive(Debug, Clone)]
pub struct OllamaProviderBuilder {
host: String,
model: String,
embedding_model: String,
max_context_length: usize,
embedding_dimension: usize,
}
impl Default for OllamaProviderBuilder {
fn default() -> Self {
Self {
host: DEFAULT_OLLAMA_HOST.to_string(),
model: DEFAULT_OLLAMA_MODEL.to_string(),
embedding_model: DEFAULT_OLLAMA_EMBEDDING_MODEL.to_string(),
max_context_length: 131072, embedding_dimension: 768, }
}
}
impl OllamaProviderBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn host(mut self, host: impl Into<String>) -> Self {
self.host = host.into();
self
}
pub fn model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self
}
pub fn embedding_model(mut self, model: impl Into<String>) -> Self {
self.embedding_model = model.into();
self
}
pub fn max_context_length(mut self, length: usize) -> Self {
self.max_context_length = length;
self
}
pub fn embedding_dimension(mut self, dimension: usize) -> Self {
self.embedding_dimension = dimension;
self
}
pub fn build(self) -> Result<OllamaProvider> {
let client = Client::builder()
.timeout(std::time::Duration::from_secs(300)) .no_proxy() .build()
.map_err(|e| LlmError::NetworkError(e.to_string()))?;
Ok(OllamaProvider {
client,
host: self.host,
model: self.model,
embedding_model: self.embedding_model,
max_context_length: self.max_context_length,
embedding_dimension: self.embedding_dimension,
})
}
}
impl OllamaProvider {
pub fn from_env() -> Result<Self> {
let host = std::env::var("OLLAMA_HOST").unwrap_or_else(|_| DEFAULT_OLLAMA_HOST.to_string());
let model =
std::env::var("OLLAMA_MODEL").unwrap_or_else(|_| DEFAULT_OLLAMA_MODEL.to_string());
let embedding_model = std::env::var("OLLAMA_EMBEDDING_MODEL")
.unwrap_or_else(|_| DEFAULT_OLLAMA_EMBEDDING_MODEL.to_string());
let max_context_length = std::env::var("OLLAMA_CONTEXT_LENGTH")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(131072);
OllamaProviderBuilder::new()
.host(host)
.model(model)
.embedding_model(embedding_model)
.max_context_length(max_context_length)
.build()
}
pub fn builder() -> OllamaProviderBuilder {
OllamaProviderBuilder::new()
}
pub fn default_local() -> Result<Self> {
OllamaProviderBuilder::new().build()
}
}
#[derive(Debug, Serialize)]
struct ChatRequest {
model: String,
messages: Vec<OllamaMessage>,
stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
options: Option<ChatOptions>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<OllamaTool>>, #[serde(skip_serializing_if = "Option::is_none")]
think: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
format: Option<serde_json::Value>,
}
#[derive(Debug, Serialize)]
struct OllamaMessage {
role: String,
content: String,
#[serde(skip_serializing_if = "Option::is_none")]
images: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<OllamaToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_name: Option<String>,
}
#[derive(Debug, Serialize)]
struct ChatOptions {
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
num_predict: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
stop: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
num_ctx: Option<usize>,
}
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
struct ChatResponse {
model: String,
message: ResponseMessage,
done: bool,
#[serde(default)]
done_reason: Option<String>,
#[serde(default)]
total_duration: u64,
#[serde(default)]
prompt_eval_count: u32,
#[serde(default)]
eval_count: u32,
}
#[derive(Debug, Deserialize)]
struct ResponseMessage {
#[allow(dead_code)]
role: String,
content: String,
#[serde(default)]
thinking: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<OllamaToolCall>>, }
#[derive(Debug, Deserialize)]
struct OllamaStreamChunk {
#[serde(default)]
message: Option<ResponseMessage>,
done: bool,
#[serde(default)]
done_reason: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OllamaTool {
pub r#type: String,
pub function: OllamaFunction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OllamaFunction {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OllamaToolCall {
pub function: OllamaFunctionCall,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OllamaFunctionCall {
pub name: String,
pub arguments: serde_json::Value,
}
#[derive(Debug, Serialize)]
struct EmbeddingRequest {
model: String,
input: Vec<String>,
}
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
struct EmbeddingResponse {
embeddings: Vec<Vec<f32>>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct OllamaModelsResponse {
pub models: Vec<OllamaModelInfo>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct OllamaModelDetails {
#[serde(default)]
pub parent_model: String,
#[serde(default)]
pub format: String,
#[serde(default)]
pub family: String,
#[serde(default)]
pub families: Vec<String>,
#[serde(default)]
pub parameter_size: String,
#[serde(default)]
pub quantization_level: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct OllamaModelInfo {
pub name: String,
#[serde(default)]
pub model: String,
#[serde(default)]
pub modified_at: String,
#[serde(default)]
pub size: u64,
#[serde(default)]
pub digest: String,
#[serde(default)]
pub details: Option<OllamaModelDetails>,
}
impl OllamaProvider {
fn convert_role(role: &ChatRole) -> &'static str {
match role {
ChatRole::System => "system",
ChatRole::User => "user",
ChatRole::Assistant => "assistant",
ChatRole::Tool => "tool",
ChatRole::Function => "user", }
}
fn convert_messages(messages: &[ChatMessage]) -> Vec<OllamaMessage> {
messages
.iter()
.map(|msg| {
let images = msg.images.as_ref().and_then(|imgs| {
if imgs.is_empty() {
None
} else {
Some(imgs.iter().map(|img| img.data.clone()).collect::<Vec<_>>())
}
});
let tool_calls = if msg.role == ChatRole::Assistant {
msg.tool_calls.as_ref().and_then(|tcs| {
if tcs.is_empty() {
None
} else {
Some(
tcs.iter()
.map(|tc| OllamaToolCall {
function: OllamaFunctionCall {
name: tc.function.name.clone(),
arguments: serde_json::from_str(&tc.function.arguments)
.unwrap_or(serde_json::Value::Object(
serde_json::Map::new(),
)),
},
})
.collect::<Vec<_>>(),
)
}
})
} else {
None
};
let tool_name = if msg.role == ChatRole::Tool {
msg.name.clone()
} else {
None
};
OllamaMessage {
role: Self::convert_role(&msg.role).to_string(),
content: msg.content.clone(),
images,
tool_calls,
tool_name,
}
})
.collect()
}
fn convert_tools(tools: &[ToolDefinition]) -> Vec<OllamaTool> {
tools
.iter()
.map(|tool| OllamaTool {
r#type: "function".to_string(),
function: OllamaFunction {
name: tool.function.name.clone(),
description: tool.function.description.clone(),
parameters: tool.function.parameters.clone(),
},
})
.collect()
}
pub async fn list_models(&self) -> Result<OllamaModelsResponse> {
let url = format!("{}/api/tags", self.host);
debug!(url = %url, "Fetching Ollama models list");
let response = self
.client
.get(&url)
.send()
.await
.map_err(|e| LlmError::NetworkError(e.to_string()))?;
let status = response.status();
if !status.is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(LlmError::ApiError(format!(
"Ollama API error ({}): {}",
status, error_text
)));
}
let models: OllamaModelsResponse = response
.json()
.await
.map_err(|e| LlmError::ApiError(format!("Failed to parse models response: {}", e)))?;
debug!("Ollama returned {} models", models.models.len());
Ok(models)
}
pub fn host(&self) -> &str {
&self.host
}
fn is_thinking_model(model: &str) -> bool {
let model_lower = model.to_lowercase();
model_lower.contains("deepseek-r1")
|| model_lower.contains("qwen3")
|| model_lower.contains("qwq")
|| model_lower.contains("openthinker")
|| model_lower.contains("phi4-reasoning")
|| model_lower.contains("magistral")
|| model_lower.contains("cogito")
|| model_lower.contains("gpt-oss") }
}
#[async_trait]
impl LLMProvider for OllamaProvider {
fn name(&self) -> &str {
"ollama"
}
fn model(&self) -> &str {
&self.model
}
fn max_context_length(&self) -> usize {
self.max_context_length
}
async fn complete(&self, prompt: &str) -> Result<LLMResponse> {
self.complete_with_options(prompt, &CompletionOptions::default())
.await
}
async fn complete_with_options(
&self,
prompt: &str,
options: &CompletionOptions,
) -> Result<LLMResponse> {
let mut messages = Vec::new();
if let Some(system) = &options.system_prompt {
messages.push(ChatMessage::system(system));
}
messages.push(ChatMessage::user(prompt));
self.chat(&messages, Some(options)).await
}
async fn chat(
&self,
messages: &[ChatMessage],
options: Option<&CompletionOptions>,
) -> Result<LLMResponse> {
debug!(
"Ollama chat request: {} messages to model {}",
messages.len(),
self.model
);
let url = format!("{}/api/chat", self.host);
let opts = options.cloned().unwrap_or_default();
let chat_options = ChatOptions {
temperature: opts.temperature,
num_predict: opts.max_tokens.map(|t| t as i32),
stop: opts.stop.clone(),
num_ctx: Some(self.max_context_length),
};
let format = opts.response_format.as_deref().and_then(|f| match f {
"json_object" | "json" => Some(serde_json::Value::String("json".to_string())),
_ => None,
});
let think = Self::is_thinking_model(&self.model);
let request = ChatRequest {
model: self.model.clone(),
messages: Self::convert_messages(messages),
stream: false,
options: Some(chat_options),
tools: None, think: if think { Some(true) } else { None },
format,
};
let response = self
.client
.post(&url)
.json(&request)
.send()
.await
.map_err(|e| LlmError::NetworkError(e.to_string()))?;
let status = response.status();
if !status.is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(LlmError::ApiError(format!(
"Ollama API error ({}): {}",
status, error_text
)));
}
let response: ChatResponse = response
.json()
.await
.map_err(|e| LlmError::ApiError(format!("Failed to parse response: {}", e)))?;
let mut llm_response = LLMResponse::new(response.message.content, response.model)
.with_usage(
response.prompt_eval_count as usize,
response.eval_count as usize,
);
if let Some(done_reason) = response.done_reason {
llm_response = llm_response.with_finish_reason(done_reason);
}
if let Some(thinking) = &response.message.thinking {
if !thinking.is_empty() {
llm_response = llm_response.with_thinking_content(thinking.clone());
let thinking_tokens = thinking.len() / 4;
llm_response = llm_response.with_thinking_tokens(thinking_tokens);
}
}
Ok(llm_response)
}
async fn stream(&self, prompt: &str) -> Result<BoxStream<'static, Result<String>>> {
use futures::StreamExt;
debug!("Ollama stream request: prompt to model {}", self.model);
let url = format!("{}/api/chat", self.host);
let chat_options = ChatOptions {
temperature: None,
num_predict: None,
stop: None,
num_ctx: Some(self.max_context_length),
};
let think = Self::is_thinking_model(&self.model);
let request = ChatRequest {
model: self.model.clone(),
messages: vec![OllamaMessage {
role: "user".to_string(),
content: prompt.to_string(),
images: None,
tool_calls: None,
tool_name: None,
}],
stream: true,
options: Some(chat_options),
tools: None, think: if think { Some(true) } else { None },
format: None,
};
let response = self
.client
.post(&url)
.json(&request)
.send()
.await
.map_err(|e| LlmError::NetworkError(e.to_string()))?;
let status = response.status();
if !status.is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(LlmError::ApiError(format!(
"Ollama API error ({}): {}",
status, error_text
)));
}
let stream = response.bytes_stream();
let mapped_stream = stream.map(|chunk_result| {
match chunk_result {
Ok(bytes) => {
let text = String::from_utf8_lossy(&bytes);
let mut content = String::new();
for line in text.lines() {
if line.is_empty() {
continue;
}
if let Ok(chunk) = serde_json::from_str::<OllamaStreamChunk>(line) {
if let Some(msg) = chunk.message {
content.push_str(&msg.content);
}
}
}
Ok(content)
}
Err(e) => Err(LlmError::NetworkError(e.to_string())),
}
});
Ok(mapped_stream.boxed())
}
fn supports_streaming(&self) -> bool {
true
}
fn supports_function_calling(&self) -> bool {
true }
fn supports_tool_streaming(&self) -> bool {
true }
fn supports_json_mode(&self) -> bool {
true }
async fn chat_with_tools(
&self,
messages: &[ChatMessage],
tools: &[ToolDefinition],
_tool_choice: Option<ToolChoice>, options: Option<&CompletionOptions>,
) -> Result<LLMResponse> {
let url = format!("{}/api/chat", self.host);
let opts = options.cloned().unwrap_or_default();
let chat_options = ChatOptions {
temperature: opts.temperature,
num_predict: opts.max_tokens.map(|t| t as i32),
stop: opts.stop.clone(),
num_ctx: Some(self.max_context_length),
};
let ollama_tools = if !tools.is_empty() {
Some(Self::convert_tools(tools))
} else {
None
};
let think = Self::is_thinking_model(&self.model);
let request = ChatRequest {
model: self.model.clone(),
messages: Self::convert_messages(messages),
stream: false,
options: Some(chat_options),
tools: ollama_tools,
think: if think { Some(true) } else { None },
format: None, };
let response = self
.client
.post(&url)
.json(&request)
.send()
.await
.map_err(|e| LlmError::NetworkError(e.to_string()))?;
let status = response.status();
if !status.is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(LlmError::ApiError(format!(
"Ollama API error ({}): {}",
status, error_text
)));
}
let response: ChatResponse = response
.json()
.await
.map_err(|e| LlmError::ApiError(format!("Failed to parse response: {}", e)))?;
let tool_calls: Vec<ToolCall> = response
.message
.tool_calls
.unwrap_or_default()
.into_iter()
.map(|tc| ToolCall {
id: uuid::Uuid::new_v4().to_string(),
call_type: "function".to_string(),
function: FunctionCall {
name: tc.function.name,
arguments: serde_json::to_string(&tc.function.arguments).unwrap_or_default(),
},
thought_signature: None,
})
.collect();
let mut llm_response = LLMResponse::new(response.message.content, response.model)
.with_usage(
response.prompt_eval_count as usize,
response.eval_count as usize,
)
.with_tool_calls(tool_calls);
if let Some(done_reason) = response.done_reason {
llm_response = llm_response.with_finish_reason(done_reason);
}
if let Some(thinking) = &response.message.thinking {
if !thinking.is_empty() {
llm_response = llm_response.with_thinking_content(thinking.clone());
let thinking_tokens = thinking.len() / 4;
llm_response = llm_response.with_thinking_tokens(thinking_tokens);
}
}
Ok(llm_response)
}
async fn chat_with_tools_stream(
&self,
messages: &[ChatMessage],
tools: &[ToolDefinition],
_tool_choice: Option<ToolChoice>,
options: Option<&CompletionOptions>,
) -> Result<BoxStream<'static, Result<TraitStreamChunk>>> {
use futures::StreamExt;
let url = format!("{}/api/chat", self.host);
let opts = options.cloned().unwrap_or_default();
let chat_options = ChatOptions {
temperature: opts.temperature,
num_predict: opts.max_tokens.map(|t| t as i32),
stop: opts.stop.clone(),
num_ctx: Some(self.max_context_length),
};
let ollama_tools = if !tools.is_empty() {
Some(Self::convert_tools(tools))
} else {
None
};
let think = Self::is_thinking_model(&self.model);
let request = ChatRequest {
model: self.model.clone(),
messages: Self::convert_messages(messages),
stream: true,
options: Some(chat_options),
tools: ollama_tools,
think: if think { Some(true) } else { None },
format: None, };
let response = self
.client
.post(&url)
.json(&request)
.send()
.await
.map_err(|e| LlmError::NetworkError(e.to_string()))?;
let status = response.status();
if !status.is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(LlmError::ApiError(format!(
"Ollama API error ({}): {}",
status, error_text
)));
}
let stream = response.bytes_stream();
let mapped_stream = stream.map(|chunk_result| {
match chunk_result {
Ok(bytes) => {
let text = String::from_utf8_lossy(&bytes);
for line in text.lines() {
if line.is_empty() {
continue;
}
if let Ok(chunk) = serde_json::from_str::<OllamaStreamChunk>(line) {
if let Some(msg) = chunk.message {
if let Some(thinking) = &msg.thinking {
if !thinking.is_empty() {
let tokens_used = thinking.len() / 4;
return Ok(TraitStreamChunk::ThinkingContent {
text: thinking.clone(),
tokens_used: Some(tokens_used),
budget_total: None,
});
}
}
if let Some(tool_calls) = msg.tool_calls {
if let Some(tc) = tool_calls.first() {
return Ok(TraitStreamChunk::ToolCallDelta {
index: 0,
id: Some(uuid::Uuid::new_v4().to_string()),
function_name: Some(tc.function.name.clone()),
function_arguments: serde_json::to_string(
&tc.function.arguments,
)
.ok(),
thought_signature: None,
});
}
}
if !msg.content.is_empty() {
return Ok(TraitStreamChunk::Content(msg.content));
}
}
if chunk.done {
let reason =
chunk.done_reason.unwrap_or_else(|| "stop".to_string());
return Ok(TraitStreamChunk::Finished {
reason,
ttft_ms: None,
usage: None,
});
}
}
}
Ok(TraitStreamChunk::Content(String::new()))
}
Err(e) => Err(LlmError::NetworkError(e.to_string())),
}
});
Ok(mapped_stream.boxed())
}
}
#[async_trait]
impl EmbeddingProvider for OllamaProvider {
fn name(&self) -> &str {
"ollama"
}
#[allow(clippy::misnamed_getters)]
fn model(&self) -> &str {
&self.embedding_model
}
fn dimension(&self) -> usize {
self.embedding_dimension
}
fn max_tokens(&self) -> usize {
8192 }
async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(vec![]);
}
debug!(
"Ollama embedding request: {} texts with model {}",
texts.len(),
self.embedding_model
);
let url = format!("{}/api/embed", self.host);
let request = EmbeddingRequest {
model: self.embedding_model.clone(),
input: texts.to_vec(),
};
let response = self
.client
.post(&url)
.json(&request)
.send()
.await
.map_err(|e| LlmError::NetworkError(e.to_string()))?;
let status = response.status();
if !status.is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(LlmError::ApiError(format!(
"Ollama API error ({}): {}",
status, error_text
)));
}
let response: EmbeddingResponse = response
.json()
.await
.map_err(|e| LlmError::ApiError(format!("Failed to parse response: {}", e)))?;
debug!(
"Ollama embedding response: {} embeddings",
response.embeddings.len()
);
Ok(response.embeddings)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_builder_creation() {
let provider = OllamaProviderBuilder::new()
.host("http://localhost:11434")
.model("mistral")
.embedding_model("nomic-embed-text")
.build()
.unwrap();
assert_eq!(LLMProvider::name(&provider), "ollama");
assert_eq!(LLMProvider::model(&provider), "mistral");
assert_eq!(EmbeddingProvider::model(&provider), "nomic-embed-text");
}
#[test]
fn test_default_builder() {
let provider = OllamaProviderBuilder::new().build().unwrap();
assert_eq!(LLMProvider::name(&provider), "ollama");
assert_eq!(LLMProvider::model(&provider), "gemma3:12b");
assert_eq!(EmbeddingProvider::model(&provider), "embeddinggemma:latest");
assert_eq!(provider.max_context_length(), 131072);
}
#[test]
fn test_message_conversion() {
let messages = vec![
ChatMessage::system("You are helpful"),
ChatMessage::user("Hello"),
ChatMessage::assistant("Hi there!"),
];
let converted = OllamaProvider::convert_messages(&messages);
assert_eq!(converted.len(), 3);
assert_eq!(converted[0].role, "system");
assert_eq!(converted[1].role, "user");
assert_eq!(converted[2].role, "assistant");
}
#[tokio::test]
async fn test_embed_empty_input() {
let provider = OllamaProviderBuilder::new().build().unwrap();
let result = provider.embed(&[]).await;
assert!(result.is_ok());
assert!(result.unwrap().is_empty());
}
#[test]
fn test_is_thinking_model_deepseek_r1() {
assert!(OllamaProvider::is_thinking_model("deepseek-r1:8b"));
assert!(OllamaProvider::is_thinking_model("deepseek-r1:70b"));
assert!(OllamaProvider::is_thinking_model("DEEPSEEK-R1:latest"));
}
#[test]
fn test_is_thinking_model_qwen3() {
assert!(OllamaProvider::is_thinking_model("qwen3:8b"));
assert!(OllamaProvider::is_thinking_model("qwen3:32b"));
assert!(OllamaProvider::is_thinking_model("QWEN3:latest"));
}
#[test]
fn test_is_thinking_model_others() {
assert!(OllamaProvider::is_thinking_model("qwq:32b"));
assert!(OllamaProvider::is_thinking_model("openthinker:7b"));
assert!(OllamaProvider::is_thinking_model("phi4-reasoning:14b"));
assert!(OllamaProvider::is_thinking_model("magistral:24b"));
assert!(OllamaProvider::is_thinking_model("cogito:8b"));
assert!(OllamaProvider::is_thinking_model("gpt-oss:20b"));
}
#[test]
fn test_is_thinking_model_non_thinking() {
assert!(!OllamaProvider::is_thinking_model("llama3.2:8b"));
assert!(!OllamaProvider::is_thinking_model("gemma3:12b"));
assert!(!OllamaProvider::is_thinking_model("mistral:7b"));
assert!(!OllamaProvider::is_thinking_model("codellama:34b"));
}
#[test]
fn test_response_with_thinking_parsing() {
let json = r#"{
"role": "assistant",
"content": "The answer is 3.",
"thinking": "Let me count the r's in strawberry: s-t-r-a-w-b-e-r-r-y. That's 3 r's."
}"#;
let msg: ResponseMessage = serde_json::from_str(json).unwrap();
assert_eq!(msg.content, "The answer is 3.");
assert!(msg.thinking.is_some());
assert!(msg.thinking.unwrap().contains("Let me count"));
}
#[test]
fn test_response_without_thinking_parsing() {
let json = r#"{
"role": "assistant",
"content": "Hello, how can I help you?"
}"#;
let msg: ResponseMessage = serde_json::from_str(json).unwrap();
assert_eq!(msg.content, "Hello, how can I help you?");
assert!(msg.thinking.is_none());
}
#[test]
fn test_chat_request_with_think_serialization() {
let request = ChatRequest {
model: "deepseek-r1:8b".to_string(),
messages: vec![OllamaMessage {
role: "user".to_string(),
content: "How many r's in strawberry?".to_string(),
images: None,
tool_calls: None,
tool_name: None,
}],
stream: false,
options: None,
tools: None,
think: Some(true),
format: None,
};
let json = serde_json::to_string(&request).unwrap();
assert!(json.contains("\"think\":true"));
}
#[test]
fn test_chat_request_without_think_serialization() {
let request = ChatRequest {
model: "llama3.2:8b".to_string(),
messages: vec![OllamaMessage {
role: "user".to_string(),
content: "Hello".to_string(),
images: None,
tool_calls: None,
tool_name: None,
}],
stream: false,
options: None,
tools: None,
think: None,
format: None,
};
let json = serde_json::to_string(&request).unwrap();
assert!(!json.contains("think"));
}
#[test]
fn test_chat_options_num_ctx_serialization() {
let options = ChatOptions {
temperature: None,
num_predict: None,
stop: None,
num_ctx: Some(65536),
};
let json = serde_json::to_string(&options).unwrap();
assert!(
json.contains("\"num_ctx\":65536"),
"num_ctx should be serialized in options: {}",
json
);
}
#[test]
fn test_chat_options_num_ctx_omitted_when_none() {
let options = ChatOptions {
temperature: None,
num_predict: None,
stop: None,
num_ctx: None,
};
let json = serde_json::to_string(&options).unwrap();
assert!(
!json.contains("num_ctx"),
"num_ctx should be omitted when None: {}",
json
);
}
#[test]
fn test_constants() {
assert_eq!(DEFAULT_OLLAMA_HOST, "http://localhost:11434");
assert_eq!(DEFAULT_OLLAMA_MODEL, "gemma3:12b");
assert_eq!(DEFAULT_OLLAMA_EMBEDDING_MODEL, "embeddinggemma:latest");
}
#[test]
fn test_builder_default_values() {
let builder = OllamaProviderBuilder::default();
assert_eq!(builder.host, "http://localhost:11434");
assert_eq!(builder.model, "gemma3:12b");
assert_eq!(builder.embedding_model, "embeddinggemma:latest");
assert_eq!(builder.max_context_length, 131072);
assert_eq!(builder.embedding_dimension, 768);
}
#[test]
fn test_builder_custom_context_length() {
let provider = OllamaProviderBuilder::new()
.max_context_length(65536)
.build()
.unwrap();
assert_eq!(provider.max_context_length(), 65536);
}
#[test]
fn test_builder_custom_embedding_dimension() {
let provider = OllamaProviderBuilder::new()
.embedding_dimension(1536)
.build()
.unwrap();
assert_eq!(provider.dimension(), 1536);
}
#[test]
fn test_default_local_creation() {
let provider = OllamaProvider::default_local().unwrap();
assert_eq!(LLMProvider::name(&provider), "ollama");
assert_eq!(LLMProvider::model(&provider), "gemma3:12b");
assert_eq!(provider.host, "http://localhost:11434");
}
#[test]
fn test_supports_streaming() {
let provider = OllamaProviderBuilder::new().build().unwrap();
assert!(provider.supports_streaming());
}
#[test]
fn test_supports_json_mode() {
let provider = OllamaProviderBuilder::new().build().unwrap();
assert!(provider.supports_json_mode());
}
#[test]
fn test_embedding_provider_name() {
let provider = OllamaProviderBuilder::new().build().unwrap();
assert_eq!(EmbeddingProvider::name(&provider), "ollama");
}
#[test]
fn test_embedding_provider_dimension() {
let provider = OllamaProviderBuilder::new()
.embedding_dimension(1024)
.build()
.unwrap();
assert_eq!(provider.dimension(), 1024);
}
#[test]
fn test_embedding_provider_max_tokens() {
let provider = OllamaProviderBuilder::new().build().unwrap();
assert_eq!(provider.max_tokens(), 8192);
}
#[test]
fn test_message_conversion_tool_role() {
let messages = vec![ChatMessage::tool_result("tool-1", "Tool output")];
let converted = OllamaProvider::convert_messages(&messages);
assert_eq!(converted.len(), 1);
assert_eq!(converted[0].role, "tool");
assert_eq!(converted[0].content, "Tool output");
}
#[test]
fn test_ollama_message_serialization() {
let msg = OllamaMessage {
role: "user".to_string(),
content: "Hello world".to_string(),
images: None,
tool_calls: None,
tool_name: None,
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains("\"role\":\"user\""));
assert!(json.contains("\"content\":\"Hello world\""));
}
#[test]
fn test_from_env_uses_defaults() {
std::env::remove_var("OLLAMA_HOST");
std::env::remove_var("OLLAMA_MODEL");
std::env::remove_var("OLLAMA_EMBEDDING_MODEL");
std::env::remove_var("OLLAMA_CONTEXT_LENGTH");
let provider = OllamaProvider::from_env().unwrap();
assert_eq!(provider.host, "http://localhost:11434");
assert_eq!(LLMProvider::model(&provider), "gemma3:12b");
assert_eq!(EmbeddingProvider::model(&provider), "embeddinggemma:latest");
assert_eq!(provider.max_context_length(), 131072);
}
#[test]
fn test_chat_options_temperature_serialization() {
let options = ChatOptions {
temperature: Some(0.7),
num_predict: Some(1024),
stop: Some(vec!["END".to_string()]),
num_ctx: Some(32768),
};
let json = serde_json::to_string(&options).unwrap();
assert!(json.contains("\"temperature\":0.7"));
assert!(json.contains("\"num_predict\":1024"));
assert!(json.contains("\"stop\":[\"END\"]"));
assert!(json.contains("\"num_ctx\":32768"));
}
#[test]
fn test_convert_messages_with_image_populates_images_field() {
use crate::traits::ImageData;
let img = ImageData::new("base64abc", "image/png");
let messages = vec![ChatMessage::user_with_images("describe this", vec![img])];
let converted = OllamaProvider::convert_messages(&messages);
assert_eq!(converted.len(), 1);
let images = converted[0].images.as_ref().expect("images must be Some");
assert_eq!(images.len(), 1);
assert_eq!(
images[0], "base64abc",
"Should be raw base64 without data-URI prefix"
);
}
#[test]
fn test_convert_messages_without_image_omits_images_field() {
let messages = vec![ChatMessage::user("no image here")];
let converted = OllamaProvider::convert_messages(&messages);
assert!(
converted[0].images.is_none(),
"images must be None for text-only messages"
);
let json = serde_json::to_string(&converted[0]).unwrap();
assert!(
!json.contains("\"images\""),
"images key must not appear in JSON for text-only message"
);
}
#[test]
fn test_ollama_message_with_images_serialization() {
let msg = OllamaMessage {
role: "user".to_string(),
content: "what is this?".to_string(),
images: Some(vec!["base64data".to_string()]),
tool_calls: None,
tool_name: None,
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains("\"images\":[\"base64data\"]"));
}
#[test]
fn test_convert_role_tool_maps_to_tool() {
assert_eq!(OllamaProvider::convert_role(&ChatRole::Tool), "tool");
}
#[test]
fn test_convert_role_function_maps_to_user() {
assert_eq!(OllamaProvider::convert_role(&ChatRole::Function), "user");
}
#[test]
fn test_tool_role_not_downgraded_to_user_in_json() {
let messages = vec![ChatMessage::tool_result("call-1", "sunny")];
let converted = OllamaProvider::convert_messages(&messages);
let json = serde_json::to_string(&converted[0]).unwrap();
assert!(
json.contains("\"role\":\"tool\""),
"Tool messages must use 'tool' role in Ollama API, got: {}",
json
);
assert!(
!json.contains("\"role\":\"user\""),
"Tool messages must NOT be downgraded to 'user' role, got: {}",
json
);
}
#[test]
fn test_convert_messages_propagates_tool_calls_from_assistant() {
let tc = ToolCall {
id: "call-abc".to_string(),
call_type: "function".to_string(),
function: FunctionCall {
name: "get_weather".to_string(),
arguments: r#"{"city":"Tokyo"}"#.to_string(),
},
thought_signature: None,
};
let msg = ChatMessage::assistant_with_tools("", vec![tc]);
let converted = OllamaProvider::convert_messages(&[msg]);
let ollama_tool_calls = converted[0]
.tool_calls
.as_ref()
.expect("assistant tool_calls must be Some after conversion");
assert_eq!(ollama_tool_calls.len(), 1);
assert_eq!(ollama_tool_calls[0].function.name, "get_weather");
assert!(ollama_tool_calls[0].function.arguments.is_object());
assert_eq!(
ollama_tool_calls[0].function.arguments["city"],
serde_json::Value::String("Tokyo".to_string())
);
}
#[test]
fn test_convert_messages_assistant_without_tool_calls_has_none() {
let msg = ChatMessage::assistant("Hello!");
let converted = OllamaProvider::convert_messages(&[msg]);
assert!(
converted[0].tool_calls.is_none(),
"tool_calls must be None for plain assistant messages"
);
}
#[test]
fn test_assistant_tool_calls_serialized_in_json() {
let tc = ToolCall {
id: "call-1".to_string(),
call_type: "function".to_string(),
function: FunctionCall {
name: "fn_name".to_string(),
arguments: r#"{"x":1}"#.to_string(),
},
thought_signature: None,
};
let msg = ChatMessage::assistant_with_tools("", vec![tc]);
let converted = OllamaProvider::convert_messages(&[msg]);
let json = serde_json::to_string(&converted[0]).unwrap();
assert!(
json.contains("\"tool_calls\""),
"tool_calls must appear in serialized JSON for assistant messages"
);
assert!(
json.contains("\"fn_name\""),
"function name must be serialized"
);
}
#[test]
fn test_tool_calls_not_set_for_non_assistant_messages() {
let messages = vec![
ChatMessage::system("sys"),
ChatMessage::user("hello"),
ChatMessage::tool_result("id-1", "result"),
];
let converted = OllamaProvider::convert_messages(&messages);
for m in &converted {
assert!(
m.tool_calls.is_none(),
"tool_calls must be None for non-assistant messages, got role={}",
m.role
);
}
}
#[test]
fn test_tool_name_set_from_message_name_field() {
let mut msg = ChatMessage::tool_result("call-1", "11 degrees celsius");
msg.name = Some("get_weather".to_string());
let converted = OllamaProvider::convert_messages(&[msg]);
assert_eq!(
converted[0].tool_name.as_deref(),
Some("get_weather"),
"tool_name must be populated from ChatMessage::name for tool messages"
);
}
#[test]
fn test_tool_name_none_when_not_set() {
let msg = ChatMessage::tool_result("call-1", "some result");
let converted = OllamaProvider::convert_messages(&[msg]);
assert!(
converted[0].tool_name.is_none(),
"tool_name must be None when ChatMessage::name is not set"
);
}
#[test]
fn test_tool_name_omitted_for_non_tool_messages() {
let messages = vec![
ChatMessage::system("sys"),
ChatMessage::user("hi"),
ChatMessage::assistant("hello"),
];
let converted = OllamaProvider::convert_messages(&messages);
for m in &converted {
assert!(
m.tool_name.is_none(),
"tool_name must be None for non-tool messages, role={}",
m.role
);
}
}
#[test]
fn test_tool_name_not_in_json_when_absent() {
let msg = ChatMessage::tool_result("call-1", "result");
let converted = OllamaProvider::convert_messages(&[msg]);
let json = serde_json::to_string(&converted[0]).unwrap();
assert!(
!json.contains("tool_name"),
"tool_name must not appear in JSON when it is None: {}",
json
);
}
#[test]
fn test_tool_name_in_json_when_set() {
let mut msg = ChatMessage::tool_result("call-1", "result");
msg.name = Some("my_function".to_string());
let converted = OllamaProvider::convert_messages(&[msg]);
let json = serde_json::to_string(&converted[0]).unwrap();
assert!(
json.contains("\"tool_name\":\"my_function\""),
"tool_name must appear in JSON when set: {}",
json
);
}
#[test]
fn test_chat_response_parses_done_reason() {
let json = r#"{
"model": "llama3.2",
"created_at": "2025-01-01T00:00:00Z",
"message": {"role":"assistant","content":"Hello"},
"done": true,
"done_reason": "stop",
"total_duration": 1000000,
"prompt_eval_count": 10,
"eval_count": 5
}"#;
let resp: ChatResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.done_reason.as_deref(), Some("stop"));
}
#[test]
fn test_chat_response_done_reason_defaults_to_none() {
let json = r#"{
"model": "llama3.2",
"created_at": "2025-01-01T00:00:00Z",
"message": {"role":"assistant","content":"Hello"},
"done": true
}"#;
let resp: ChatResponse = serde_json::from_str(json).unwrap();
assert!(resp.done_reason.is_none());
}
#[test]
fn test_stream_chunk_parses_done_reason() {
let json = r#"{"done":true,"done_reason":"length"}"#;
let chunk: OllamaStreamChunk = serde_json::from_str(json).unwrap();
assert!(chunk.done);
assert_eq!(chunk.done_reason.as_deref(), Some("length"));
}
#[test]
fn test_stream_chunk_done_reason_defaults_none() {
let json = r#"{"message":{"role":"assistant","content":"hi"},"done":false}"#;
let chunk: OllamaStreamChunk = serde_json::from_str(json).unwrap();
assert!(!chunk.done);
assert!(chunk.done_reason.is_none());
}
#[test]
fn test_chat_request_format_json_serialization() {
let request = ChatRequest {
model: "llama3.2".to_string(),
messages: vec![OllamaMessage {
role: "user".to_string(),
content: "Give JSON".to_string(),
images: None,
tool_calls: None,
tool_name: None,
}],
stream: false,
options: None,
tools: None,
think: None,
format: Some(serde_json::Value::String("json".to_string())),
};
let json = serde_json::to_string(&request).unwrap();
assert!(
json.contains("\"format\":\"json\""),
"format:'json' must appear in serialized request: {}",
json
);
}
#[test]
fn test_chat_request_format_absent_when_none() {
let request = ChatRequest {
model: "llama3.2".to_string(),
messages: vec![OllamaMessage {
role: "user".to_string(),
content: "Hello".to_string(),
images: None,
tool_calls: None,
tool_name: None,
}],
stream: false,
options: None,
tools: None,
think: None,
format: None,
};
let json = serde_json::to_string(&request).unwrap();
assert!(
!json.contains("format"),
"format field must not appear when None: {}",
json
);
}
#[test]
fn test_multi_turn_tool_conversation_converts_correctly() {
let tc = ToolCall {
id: "call-xyz".to_string(),
call_type: "function".to_string(),
function: FunctionCall {
name: "get_weather".to_string(),
arguments: r#"{"city":"Toronto"}"#.to_string(),
},
thought_signature: None,
};
let mut tool_msg = ChatMessage::tool_result("call-xyz", "11 degrees celsius");
tool_msg.name = Some("get_weather".to_string());
let messages = vec![
ChatMessage::user("What is the weather in Toronto?"),
ChatMessage::assistant_with_tools("", vec![tc]),
tool_msg,
ChatMessage::assistant("The current temperature in Toronto is 11°C."),
];
let converted = OllamaProvider::convert_messages(&messages);
assert_eq!(converted.len(), 4);
assert_eq!(converted[0].role, "user");
assert_eq!(converted[1].role, "assistant");
assert!(converted[1].tool_calls.is_some());
assert_eq!(converted[2].role, "tool");
assert_eq!(converted[2].tool_name.as_deref(), Some("get_weather"));
assert_eq!(converted[3].role, "assistant");
assert!(converted[3].tool_calls.is_none());
}
}