use super::{BaseProvider, ModelPricing, Provider, ProviderError, ProviderType};
use crate::config::ProviderConfig;
use crate::core::models::{RequestContext, openai::*};
use crate::utils::error::Result;
use async_trait::async_trait;
use serde_json::json;
use std::collections::HashMap;
use tracing::{debug, info};
#[derive(Debug, Clone)]
pub struct OllamaProvider {
base: BaseProvider,
pricing_cache: HashMap<String, ModelPricing>,
}
impl OllamaProvider {
pub async fn new(config: &ProviderConfig) -> Result<Self> {
let base = BaseProvider::new(config)?;
let base_url = config
.base_url
.clone()
.unwrap_or_else(|| "http://localhost:11434".to_string());
let provider = Self {
base: BaseProvider {
base_url,
api_key: String::new(), ..base
},
pricing_cache: Self::initialize_pricing_cache(),
};
info!("Ollama provider '{}' initialized successfully", config.name);
Ok(provider)
}
fn initialize_pricing_cache() -> HashMap<String, ModelPricing> {
let mut cache = HashMap::new();
let models = vec![
"llama2",
"llama2:13b",
"llama2:70b",
"codellama",
"codellama:13b",
"codellama:34b",
"mistral",
"mistral:7b",
"neural-chat",
"starling-lm",
"phi",
"orca-mini",
];
for model in models {
cache.insert(
model.to_string(),
ModelPricing {
model: model.to_string(),
input_cost_per_1k: 0.0, output_cost_per_1k: 0.0,
currency: "USD".to_string(),
updated_at: chrono::Utc::now(),
},
);
}
cache
}
fn create_headers(&self) -> reqwest::header::HeaderMap {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::CONTENT_TYPE,
"application/json".parse().unwrap(),
);
headers
}
fn convert_messages_to_ollama(&self, messages: &[ChatMessage]) -> Vec<serde_json::Value> {
messages
.iter()
.map(|msg| {
let role = match msg.role {
MessageRole::System => "system",
MessageRole::User => "user",
MessageRole::Assistant => "assistant",
_ => "user",
};
let content = match &msg.content {
Some(MessageContent::Text(text)) => text.clone(),
Some(MessageContent::Parts(parts)) => parts
.iter()
.filter_map(|part| match part {
ContentPart::Text { text } => Some(text.clone()),
_ => None,
})
.collect::<Vec<String>>()
.join(" "),
None => String::new(),
};
json!({
"role": role,
"content": content
})
})
.collect()
}
fn convert_ollama_response_to_openai(
&self,
ollama_response: serde_json::Value,
model: &str,
) -> Result<ChatCompletionResponse> {
let content = ollama_response
.get("message")
.and_then(|msg| msg.get("content"))
.and_then(|content| content.as_str())
.unwrap_or("")
.to_string();
let prompt_eval_count = ollama_response
.get("prompt_eval_count")
.and_then(|c| c.as_u64())
.unwrap_or(0) as u32;
let eval_count = ollama_response
.get("eval_count")
.and_then(|c| c.as_u64())
.unwrap_or(0) as u32;
let usage = Usage {
prompt_tokens: prompt_eval_count,
completion_tokens: eval_count,
total_tokens: prompt_eval_count + eval_count,
prompt_tokens_details: None,
completion_tokens_details: None,
};
Ok(ChatCompletionResponse {
id: format!("chatcmpl-ollama-{}", uuid::Uuid::new_v4()),
object: "chat.completion".to_string(),
created: chrono::Utc::now().timestamp() as u64,
model: model.to_string(),
choices: vec![ChatChoice {
index: 0,
message: ChatMessage {
role: MessageRole::Assistant,
content: Some(MessageContent::Text(content)),
name: None,
function_call: None,
tool_calls: None,
tool_call_id: None,
audio: None,
},
finish_reason: Some("stop".to_string()),
logprobs: None,
}],
usage: Some(usage),
system_fingerprint: None,
})
}
}
#[async_trait]
impl Provider for OllamaProvider {
fn name(&self) -> &str {
&self.base.name
}
fn provider_type(&self) -> ProviderType {
ProviderType::Custom("ollama".to_string())
}
async fn supports_model(&self, model: &str) -> bool {
self.base.is_model_supported(model) || self.pricing_cache.contains_key(model)
}
async fn supports_images(&self) -> bool {
false }
async fn supports_embeddings(&self) -> bool {
true }
async fn supports_streaming(&self) -> bool {
true }
async fn list_models(&self) -> Result<Vec<Model>> {
debug!("Listing Ollama models");
let response = self
.base
.client
.get(format!("{}/api/tags", self.base.base_url))
.headers(self.create_headers())
.send()
.await;
match response {
Ok(resp) if resp.status().is_success() => {
let models_response: serde_json::Value =
self.base.parse_json_response(resp).await?;
let mut models = Vec::new();
if let Some(model_list) = models_response.get("models").and_then(|m| m.as_array()) {
for model_data in model_list {
if let Some(name) = model_data.get("name").and_then(|n| n.as_str()) {
models.push(Model {
id: name.to_string(),
object: "model".to_string(),
created: chrono::Utc::now().timestamp() as u64,
owned_by: "ollama".to_string(),
});
}
}
}
Ok(models)
}
_ => {
let known_models = self
.pricing_cache
.keys()
.map(|model| Model {
id: model.clone(),
object: "model".to_string(),
created: chrono::Utc::now().timestamp() as u64,
owned_by: "ollama".to_string(),
})
.collect();
Ok(known_models)
}
}
}
async fn health_check(&self) -> Result<()> {
debug!("Performing Ollama health check");
let response = self
.base
.client
.get(format!("{}/api/tags", self.base.base_url))
.headers(self.create_headers())
.send()
.await
.map_err(|e| ProviderError::Network(e.to_string()))?;
if response.status().is_success() {
Ok(())
} else {
Err(ProviderError::Unavailable(format!(
"Ollama health check failed with status: {}",
response.status()
))
.into())
}
}
async fn chat_completion(
&self,
request: ChatCompletionRequest,
_context: RequestContext,
) -> Result<ChatCompletionResponse> {
debug!("Ollama chat completion for model: {}", request.model);
let messages = self.convert_messages_to_ollama(&request.messages);
let mut body = json!({
"model": request.model,
"messages": messages,
"stream": false
});
if let Some(temperature) = request.temperature {
body["options"] = json!({
"temperature": temperature
});
}
let response = self
.base
.client
.post(format!("{}/api/chat", self.base.base_url))
.headers(self.create_headers())
.json(&body)
.send()
.await
.map_err(|e| ProviderError::Network(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
return Err(match status.as_u16() {
404 => ProviderError::ModelNotFound(error_text),
400 => ProviderError::InvalidRequest(error_text),
503 => ProviderError::Unavailable(error_text),
_ => ProviderError::Unknown(format!("HTTP {}: {}", status, error_text)),
}
.into());
}
let ollama_response: serde_json::Value = self.base.parse_json_response(response).await?;
self.convert_ollama_response_to_openai(ollama_response, &request.model)
}
async fn completion(
&self,
request: CompletionRequest,
_context: RequestContext,
) -> Result<CompletionResponse> {
debug!("Ollama completion for model: {}", request.model);
let mut body = json!({
"model": request.model,
"prompt": request.prompt,
"stream": false
});
let mut options = json!({});
if let Some(temperature) = request.temperature {
options["temperature"] = json!(temperature);
}
if let Some(max_tokens) = request.max_tokens {
options["num_predict"] = json!(max_tokens);
}
if !options.as_object().unwrap().is_empty() {
body["options"] = options;
}
let response = self
.base
.client
.post(format!("{}/api/generate", self.base.base_url))
.headers(self.create_headers())
.json(&body)
.send()
.await
.map_err(|e| ProviderError::Network(e.to_string()))?;
let ollama_response: serde_json::Value = self.base.parse_json_response(response).await?;
let text = ollama_response
.get("response")
.and_then(|r| r.as_str())
.unwrap_or("")
.to_string();
let prompt_eval_count = ollama_response
.get("prompt_eval_count")
.and_then(|c| c.as_u64())
.unwrap_or(0) as u32;
let eval_count = ollama_response
.get("eval_count")
.and_then(|c| c.as_u64())
.unwrap_or(0) as u32;
Ok(CompletionResponse {
id: format!("cmpl-ollama-{}", uuid::Uuid::new_v4()),
object: "text_completion".to_string(),
created: chrono::Utc::now().timestamp() as u64,
model: request.model,
choices: vec![CompletionChoice {
text,
index: 0,
logprobs: None,
finish_reason: Some("stop".to_string()),
}],
usage: Some(Usage {
prompt_tokens: prompt_eval_count,
completion_tokens: eval_count,
total_tokens: prompt_eval_count + eval_count,
prompt_tokens_details: None,
completion_tokens_details: None,
}),
})
}
async fn embedding(
&self,
request: EmbeddingRequest,
_context: RequestContext,
) -> Result<EmbeddingResponse> {
debug!("Ollama embedding for model: {}", request.model);
let body = json!({
"model": request.model,
"prompt": request.input
});
let response = self
.base
.client
.post(format!("{}/api/embeddings", self.base.base_url))
.headers(self.create_headers())
.json(&body)
.send()
.await
.map_err(|e| ProviderError::Network(e.to_string()))?;
let ollama_response: serde_json::Value = self.base.parse_json_response(response).await?;
let embedding_vec = ollama_response
.get("embedding")
.and_then(|e| e.as_array())
.unwrap_or(&vec![])
.iter()
.filter_map(|v| v.as_f64())
.collect();
let embeddings = vec![EmbeddingObject {
object: "embedding".to_string(),
embedding: embedding_vec,
index: 0,
}];
Ok(EmbeddingResponse {
object: "list".to_string(),
data: embeddings,
model: request.model,
usage: EmbeddingUsage {
prompt_tokens: 0, total_tokens: 0,
},
})
}
async fn image_generation(
&self,
_request: ImageGenerationRequest,
_context: RequestContext,
) -> Result<ImageGenerationResponse> {
Err(ProviderError::InvalidRequest(
"Image generation not supported by Ollama text models".to_string(),
)
.into())
}
async fn get_model_pricing(&self, model: &str) -> Result<ModelPricing> {
if let Some(pricing) = self.pricing_cache.get(model) {
Ok(pricing.clone())
} else {
Ok(ModelPricing {
model: model.to_string(),
input_cost_per_1k: 0.0,
output_cost_per_1k: 0.0,
currency: "USD".to_string(),
updated_at: chrono::Utc::now(),
})
}
}
async fn calculate_cost(
&self,
_model: &str,
_input_tokens: u32,
_output_tokens: u32,
) -> Result<f64> {
Ok(0.0)
}
}