use std::time::Duration;
use super::helpers::*;
use super::types::*;
pub struct OllamaProvider {
agent: ureq::Agent,
default_timeout_secs: u64,
}
impl OllamaProvider {
const DEFAULT_BASE_URL: &str = "http://localhost:11434";
pub fn new(default_timeout_secs: u64) -> Self {
let config = ureq::Agent::config_builder()
.timeout_global(Some(Duration::from_secs(default_timeout_secs)))
.build();
Self {
agent: ureq::Agent::new_with_config(config),
default_timeout_secs,
}
}
pub fn with_agent(agent: ureq::Agent, default_timeout_secs: u64) -> Self {
Self {
agent,
default_timeout_secs,
}
}
}
impl LlmProvider for OllamaProvider {
fn name(&self) -> &str {
"ollama"
}
fn default_base_url(&self) -> Option<&str> {
Some(Self::DEFAULT_BASE_URL)
}
fn chat(&self, request: &ChatRequest) -> Result<ChatResponse, String> {
let base_url = request
.base_url
.as_deref()
.unwrap_or(Self::DEFAULT_BASE_URL);
let url = format!("{base_url}/api/chat");
let mut messages = Vec::new();
if let Some(sys) = &request.system {
messages.push(serde_json::json!({"role": "system", "content": sys}));
}
for msg in &request.messages {
messages.push(serialize_message_ollama(msg)?);
}
let mut body = serde_json::json!({
"model": request.model,
"messages": messages,
"stream": false,
});
let mut options = serde_json::Map::new();
if let Some(temp) = request.temperature {
options.insert("temperature".into(), serde_json::json!(temp));
}
if let Some(max) = request.max_tokens {
options.insert("num_predict".into(), serde_json::json!(max));
}
if let Some(top_p) = request.top_p {
options.insert("top_p".into(), serde_json::json!(top_p));
}
if let Some(stop) = &request.stop {
options.insert("stop".into(), serde_json::json!(stop));
}
if let Some(serde_json::Value::Object(extra_map)) = &request.extra {
for (k, v) in extra_map {
options.insert(k.clone(), v.clone());
}
}
if !options.is_empty() {
body["options"] = serde_json::Value::Object(options);
}
let agent =
agent_with_timeout(&self.agent, self.default_timeout_secs, request.timeout_secs);
let json = post_json(&agent, &url, &body, &[], request.max_response_bytes)?;
check_api_error(&json, "ollama")?;
let content = json
.get("message")
.and_then(|m| m.get("content"))
.and_then(|c| c.as_str())
.unwrap_or_default()
.to_string();
let finish_reason = match json.get("done_reason").and_then(|r| r.as_str()) {
Some("stop") => FinishReason::Stop,
Some("length") => FinishReason::MaxTokens,
_ => FinishReason::Stop,
};
let usage = Usage {
input_tokens: sat_u32(
json.get("prompt_eval_count")
.and_then(|n| n.as_u64())
.unwrap_or(0),
),
output_tokens: sat_u32(json.get("eval_count").and_then(|n| n.as_u64()).unwrap_or(0)),
};
Ok(ChatResponse {
content,
finish_reason,
usage,
model: request.model.clone(),
})
}
}
pub(super) fn serialize_message_ollama(msg: &ChatMessage) -> Result<serde_json::Value, String> {
let role = match msg.role {
ChatRole::User => "user",
ChatRole::Assistant => "assistant",
};
match &msg.content {
ChatContent::Text(s) => Ok(serde_json::json!({"role": role, "content": s})),
ChatContent::Parts(parts) => {
let text: String = parts
.iter()
.filter_map(|p| match p {
ContentPart::Text { text } => Some(text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join("\n");
let images: Vec<&str> = parts
.iter()
.filter_map(|p| match p {
ContentPart::ImageBase64 { data, .. } => Some(data.as_str()),
_ => None,
})
.collect();
let has_url = parts
.iter()
.any(|p| matches!(p, ContentPart::ImageUrl { .. }));
if has_url {
return Err("Ollama does not support image URLs; use image_base64 instead".into());
}
let mut msg_json = serde_json::json!({"role": role, "content": text});
if !images.is_empty() {
msg_json["images"] = serde_json::json!(images);
}
Ok(msg_json)
}
}
}