use crate::error::LlmConnectorError;
use crate::protocols::core::{ErrorMapper, ProviderAdapter};
use crate::types::{ChatRequest, ChatResponse, Choice, Message, Role, Usage};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::sync::Arc;
fn parse_role(role: &str) -> Role {
match role {
"system" => Role::System,
"user" => Role::User,
"assistant" => Role::Assistant,
"tool" => Role::Tool,
_ => Role::User, }
}
#[derive(Serialize, Debug)]
pub struct OllamaRequest {
model: String,
messages: Vec<OllamaMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
options: Option<OllamaOptions>,
}
#[derive(Serialize, Debug)]
pub struct OllamaMessage {
role: String,
content: String,
}
#[derive(Serialize, Debug)]
pub struct OllamaOptions {
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
num_predict: Option<u32>,
}
#[derive(Deserialize, Debug)]
pub struct OllamaResponse {
model: String,
#[serde(default)]
_created_at: String,
message: OllamaResponseMessage,
done: bool,
#[serde(default)]
eval_count: Option<u32>,
}
#[derive(Deserialize, Debug, Clone)]
pub struct OllamaModel {
pub name: String,
pub model: String,
pub modified_at: String,
pub size: Option<u64>,
pub digest: Option<String>,
pub details: Option<OllamaModelDetails>,
pub expires_at: Option<String>,
}
#[derive(Deserialize, Debug, Clone)]
pub struct OllamaModelDetails {
pub format: Option<String>,
pub family: Option<String>,
pub families: Option<Vec<String>>,
pub parameter_size: Option<String>,
pub quantization_level: Option<String>,
}
#[derive(Deserialize, Debug)]
pub struct OllamaModelsResponse {
pub models: Vec<OllamaModel>,
}
#[derive(Serialize, Debug)]
pub struct OllamaModelRequest {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub insecure: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
}
#[derive(Serialize, Debug)]
pub struct OllamaModelDeleteRequest {
pub name: String,
}
#[derive(Deserialize, Debug)]
pub struct OllamaModelProgressResponse {
pub status: String,
#[serde(default)]
pub digest: Option<String>,
#[serde(default)]
pub total: Option<u64>,
#[serde(default)]
pub completed: Option<u64>,
}
#[derive(Deserialize, Debug)]
pub struct OllamaResponseMessage {
role: String,
content: String,
}
pub struct OllamaErrorMapper;
impl ErrorMapper for OllamaErrorMapper {
fn map_http_error(status: u16, body: Value) -> LlmConnectorError {
let error_message = body["error"]
.as_str()
.or_else(|| body["message"].as_str())
.unwrap_or("Unknown Ollama error");
match status {
400 => LlmConnectorError::InvalidRequest(format!(
"Ollama: {}",
error_message
)),
404 => LlmConnectorError::InvalidRequest(format!(
"Ollama: Model not found. Make sure to pull the model first with 'ollama pull <model>'"
)),
500 => LlmConnectorError::ServerError(format!(
"Ollama: Server error. Is Ollama running on localhost:11434?"
)),
429 => LlmConnectorError::RateLimitError(format!(
"Ollama: {}",
error_message
)),
_ => LlmConnectorError::ProviderError(format!(
"Ollama HTTP {}: {}",
status, error_message
)),
}
}
fn map_network_error(error: reqwest::Error) -> LlmConnectorError {
if error.is_timeout() {
LlmConnectorError::TimeoutError(format!("Ollama: {}", error))
} else if error.is_connect() {
LlmConnectorError::ConnectionError(format!(
"Ollama: Cannot connect to Ollama server. Is it running on localhost:11434?"
))
} else {
LlmConnectorError::NetworkError(format!("Ollama: {}", error))
}
}
fn is_retriable_error(error: &LlmConnectorError) -> bool {
matches!(
error,
LlmConnectorError::RateLimitError(_)
| LlmConnectorError::ServerError(_)
| LlmConnectorError::TimeoutError(_)
| LlmConnectorError::ConnectionError(_)
)
}
}
#[derive(Debug, Clone)]
pub struct OllamaProtocol {
base_url: Arc<str>,
}
impl OllamaProtocol {
pub fn new() -> Self {
Self {
base_url: Arc::from("http://localhost:11434"),
}
}
pub fn with_url(base_url: &str) -> Self {
Self {
base_url: Arc::from(base_url),
}
}
pub async fn list_models(&self, client: &reqwest::Client) -> Result<Vec<String>, LlmConnectorError> {
let base = &self.base_url;
let url = format!("{}/api/tags", base);
let response = client
.get(&url)
.send()
.await
.map_err(|e| OllamaErrorMapper::map_network_error(e))?;
if response.status().is_success() {
let models_response: OllamaModelsResponse = response.json().await
.map_err(|e| LlmConnectorError::ParseError(e.to_string()))?;
Ok(models_response.models.into_iter().map(|m| m.name).collect())
} else {
let status = response.status().as_u16();
let body = response.json().await.unwrap_or_default();
Err(OllamaErrorMapper::map_http_error(status, body))
}
}
pub async fn pull_model(&self, client: &reqwest::Client, model_name: &str) -> Result<(), LlmConnectorError> {
let base = &self.base_url;
let url = format!("{}/api/pull", base);
let request = OllamaModelRequest {
name: model_name.to_string(),
insecure: None,
stream: Some(false),
};
let response = client
.post(&url)
.json(&request)
.send()
.await
.map_err(|e| OllamaErrorMapper::map_network_error(e))?;
if response.status().is_success() {
Ok(())
} else {
let status = response.status().as_u16();
let body = response.json().await.unwrap_or_default();
Err(OllamaErrorMapper::map_http_error(status, body))
}
}
pub async fn push_model(&self, client: &reqwest::Client, model_name: &str) -> Result<(), LlmConnectorError> {
let base = &self.base_url;
let url = format!("{}/api/push", base);
let request = OllamaModelRequest {
name: model_name.to_string(),
insecure: None,
stream: Some(false),
};
let response = client
.post(&url)
.json(&request)
.send()
.await
.map_err(|e| OllamaErrorMapper::map_network_error(e))?;
if response.status().is_success() {
Ok(())
} else {
let status = response.status().as_u16();
let body = response.json().await.unwrap_or_default();
Err(OllamaErrorMapper::map_http_error(status, body))
}
}
pub async fn delete_model(&self, client: &reqwest::Client, model_name: &str) -> Result<(), LlmConnectorError> {
let base = &self.base_url;
let url = format!("{}/api/delete", base);
let request = OllamaModelDeleteRequest {
name: model_name.to_string(),
};
let response = client
.delete(&url)
.json(&request)
.send()
.await
.map_err(|e| OllamaErrorMapper::map_network_error(e))?;
if response.status().is_success() {
Ok(())
} else {
let status = response.status().as_u16();
let body = response.json().await.unwrap_or_default();
Err(OllamaErrorMapper::map_http_error(status, body))
}
}
pub async fn show_model(&self, client: &reqwest::Client, model_name: &str) -> Result<OllamaModel, LlmConnectorError> {
let base = &self.base_url;
let url = format!("{}/api/show", base);
let request = serde_json::json!({
"name": model_name
});
let response = client
.post(&url)
.json(&request)
.send()
.await
.map_err(|e| OllamaErrorMapper::map_network_error(e))?;
if response.status().is_success() {
let model_info: OllamaModel = response.json().await
.map_err(|e| LlmConnectorError::ParseError(e.to_string()))?;
Ok(model_info)
} else {
let status = response.status().as_u16();
let body = response.json().await.unwrap_or_default();
Err(OllamaErrorMapper::map_http_error(status, body))
}
}
}
#[async_trait]
impl ProviderAdapter for OllamaProtocol {
type RequestType = OllamaRequest;
type ResponseType = OllamaResponse;
#[cfg(feature = "streaming")]
type StreamResponseType = serde_json::Value; type ErrorMapperType = OllamaErrorMapper;
fn name(&self) -> &str {
"ollama"
}
fn endpoint_url(&self, base_url: &Option<String>) -> String {
let base = base_url.as_deref().unwrap_or(&self.base_url);
format!("{}/api/chat", base)
}
fn models_endpoint_url(&self, base_url: &Option<String>) -> Option<String> {
let base = base_url.as_deref().unwrap_or(&self.base_url);
Some(format!("{}/api/tags", base))
}
fn build_request_data(&self, request: &ChatRequest, _stream: bool) -> Self::RequestType {
let messages = request
.messages
.iter()
.map(|msg| OllamaMessage {
role: match msg.role {
Role::System => "system".to_string(),
Role::User => "user".to_string(),
Role::Assistant => "assistant".to_string(),
Role::Tool => "tool".to_string(),
},
content: msg.content.clone(),
})
.collect();
OllamaRequest {
model: request.model.clone(),
messages,
stream: None, options: Some(OllamaOptions {
temperature: request.temperature,
top_p: request.top_p,
num_predict: request.max_tokens,
}),
}
}
fn parse_response_data(&self, response: Self::ResponseType) -> ChatResponse {
let first_content = response.message.content.clone();
ChatResponse {
id: format!("ollama-{}", response.model),
object: "chat.completion".to_string(),
created: chrono::Utc::now().timestamp() as u64,
model: response.model,
choices: vec![Choice {
index: 0,
message: Message {
role: parse_role(&response.message.role),
content: response.message.content,
name: None,
tool_calls: None,
tool_call_id: None,
..Default::default()
},
finish_reason: if response.done {
Some("stop".to_string())
} else {
Some("length".to_string())
},
logprobs: None,
}],
content: first_content,
usage: Some(Usage {
prompt_tokens: 0, completion_tokens: response.eval_count.unwrap_or(0),
total_tokens: response.eval_count.unwrap_or(0),
prompt_cache_hit_tokens: None,
prompt_cache_miss_tokens: None,
prompt_tokens_details: None,
completion_tokens_details: None,
}),
system_fingerprint: None,
}
}
#[cfg(feature = "streaming")]
fn parse_stream_response_data(&self, _response: Self::StreamResponseType) -> crate::types::StreamingResponse {
crate::types::StreamingResponse {
id: "ollama-stream".to_string(),
object: "chat.completion.chunk".to_string(),
created: chrono::Utc::now().timestamp() as u64,
model: "ollama".to_string(),
choices: vec![crate::types::StreamingChoice {
index: 0,
delta: crate::types::Delta {
role: None,
content: None,
tool_calls: None,
reasoning_content: None,
..Default::default()
},
finish_reason: None,
logprobs: None,
}],
content: String::new(),
reasoning_content: None,
usage: None,
system_fingerprint: None,
}
}
}
pub fn ollama() -> OllamaProtocol {
OllamaProtocol::new()
}
pub type OllamaProvider = crate::protocols::core::GenericProvider<OllamaProtocol>;