use std::collections::HashMap;
use crate::load_balancer::tasks::TaskDefinition;
use crate::providers::instances::{LlmInstance, BaseInstance};
use crate::providers::types::{LlmRequest, LlmResponse, LlmStream, StreamChunk, TokenUsage, Message};
use crate::errors::{LlmError, LlmResult};
use crate::constants;
use async_trait::async_trait;
use reqwest::header;
use serde::{Serialize, Deserialize};
use url::Url;
use futures::StreamExt;
pub struct OllamaInstance {
base: BaseInstance,
endpoint_url: String,
}
#[derive(Serialize)]
struct OllamaRequest {
model: String,
messages: Vec<Message>,
stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
options: Option<OllamaOptions>,
}
#[derive(Serialize, Default)]
struct OllamaOptions {
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
num_predict: Option<u32>, }
#[derive(Deserialize, Debug)]
struct OllamaResponse {
model: String,
created_at: String,
message: Message,
done: bool, #[serde(default)] prompt_eval_count: u32,
#[serde(default)] eval_count: u32, }
#[derive(Deserialize, Debug)]
struct OllamaStreamResponse {
#[serde(default)]
model: Option<String>,
#[serde(default)]
message: Option<Message>,
done: bool,
#[serde(default)]
prompt_eval_count: Option<u32>,
#[serde(default)]
eval_count: Option<u32>,
}
impl OllamaInstance {
pub fn new(api_key: String, model: String, supported_tasks: HashMap<String, TaskDefinition>, enabled: bool, endpoint_url: Option<String>) -> Self {
let base_endpoint = endpoint_url.unwrap_or_else(|| constants::OLLAMA_API_ENDPOINT.to_string());
let final_endpoint = match Url::parse(&base_endpoint) {
Ok(mut url) => {
if !url.path().ends_with("/api/chat") {
if url.path() == "/" {
url.set_path("api/chat");
} else {
let current_path = url.path().trim_end_matches('/');
url.set_path(&format!("{}/api/chat", current_path));
}
}
url.to_string()
}
Err(_) => {
eprintln!(
"Warning: Invalid Ollama endpoint URL '{}' provided. Falling back to default: {}",
base_endpoint, constants::OLLAMA_API_ENDPOINT
);
constants::OLLAMA_API_ENDPOINT.to_string()
}
};
let base = BaseInstance::new("ollama".to_string(), api_key, model, supported_tasks, enabled);
Self {
base,
endpoint_url: final_endpoint,
}
}
}
#[async_trait]
impl LlmInstance for OllamaInstance {
async fn generate(&self, request: &LlmRequest) -> LlmResult<LlmResponse> {
if !self.base.is_enabled() {
return Err(LlmError::ProviderDisabled("Ollama".to_string()));
}
let mut headers = header::HeaderMap::new();
headers.insert(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
);
if !self.base.api_key().is_empty() {
match header::HeaderValue::from_str(&format!("Bearer {}", self.base.api_key())) {
Ok(val) => { headers.insert(header::AUTHORIZATION, val); },
Err(e) => return Err(LlmError::ConfigError(format!("Invalid API key format for Ollama: {}", e))),
}
}
let model = request.model.clone().unwrap_or_else(|| self.base.model().to_string());
let mut options = OllamaOptions::default();
if request.temperature.is_some() {
options.temperature = request.temperature;
}
if request.max_tokens.is_some() {
options.num_predict = request.max_tokens;
}
let ollama_request = OllamaRequest {
model,
messages: request.messages.clone(),
stream: false,
options: if options.temperature.is_some() || options.num_predict.is_some() { Some(options) } else { None },
};
let response = self.base.client()
.post(&self.endpoint_url)
.headers(headers)
.json(&ollama_request)
.send()
.await?;
let response_status = response.status();
if !response_status.is_success() {
let error_text = response.text().await
.unwrap_or_else(|_| format!("Unknown error. Status: {}", response_status));
return Err(LlmError::ApiError(format!("Ollama API error: {}", error_text)));
}
let response_text = response.text().await?;
if response_text.is_empty() {
return Err(LlmError::ApiError("Received empty response body from Ollama".to_string()));
}
let ollama_response: OllamaResponse = serde_json::from_str(&response_text)
.map_err(|e| LlmError::ApiError(format!("Failed to parse Ollama JSON response: {}. Body: {}", e, response_text)))?;
let usage = Some(TokenUsage {
prompt_tokens: ollama_response.prompt_eval_count,
completion_tokens: ollama_response.eval_count,
total_tokens: ollama_response.prompt_eval_count + ollama_response.eval_count,
});
Ok(LlmResponse {
content: ollama_response.message.content.clone(),
model: ollama_response.model,
usage,
})
}
async fn generate_stream(&self, request: &LlmRequest) -> LlmResult<LlmStream> {
if !self.base.is_enabled() {
return Err(LlmError::ProviderDisabled("Ollama".to_string()));
}
let mut headers = header::HeaderMap::new();
headers.insert(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
);
if !self.base.api_key().is_empty() {
match header::HeaderValue::from_str(&format!("Bearer {}", self.base.api_key())) {
Ok(val) => { headers.insert(header::AUTHORIZATION, val); },
Err(e) => return Err(LlmError::ConfigError(format!("Invalid API key format for Ollama: {}", e))),
}
}
let model = request.model.clone().unwrap_or_else(|| self.base.model().to_string());
let mut options = OllamaOptions::default();
if request.temperature.is_some() {
options.temperature = request.temperature;
}
if request.max_tokens.is_some() {
options.num_predict = request.max_tokens;
}
let ollama_request = OllamaRequest {
model,
messages: request.messages.clone(),
stream: true, options: if options.temperature.is_some() || options.num_predict.is_some() { Some(options) } else { None },
};
let response = self.base.client()
.post(&self.endpoint_url)
.headers(headers)
.json(&ollama_request)
.send()
.await?;
let response_status = response.status();
if !response_status.is_success() {
let error_text = response.text().await
.unwrap_or_else(|_| format!("Unknown error. Status: {}", response_status));
return Err(LlmError::ApiError(format!("Ollama API error: {}", error_text)));
}
let byte_stream = response.bytes_stream();
let chunk_stream = byte_stream
.map(|result| result.map_err(|e| LlmError::RequestError(e)))
.flat_map(|result| {
match result {
Ok(bytes) => {
let text = String::from_utf8_lossy(&bytes);
let chunks: Vec<Result<StreamChunk, LlmError>> = text
.lines()
.filter_map(|line| {
let line = line.trim();
if line.is_empty() {
return None;
}
match serde_json::from_str::<OllamaStreamResponse>(line) {
Ok(response) => {
let content = response.message
.map(|m| m.content)
.unwrap_or_default();
let usage = if response.done {
let prompt = response.prompt_eval_count.unwrap_or(0);
let completion = response.eval_count.unwrap_or(0);
Some(TokenUsage {
prompt_tokens: prompt,
completion_tokens: completion,
total_tokens: prompt + completion,
})
} else {
None
};
Some(Ok(StreamChunk {
content,
model: response.model,
is_final: response.done,
usage,
}))
}
Err(e) => Some(Err(LlmError::ParseError(
format!("Failed to parse Ollama streaming response: {}", e)
))),
}
})
.collect();
futures::stream::iter(chunks)
}
Err(e) => futures::stream::iter(vec![Err(e)])
}
});
Ok(Box::pin(chunk_stream))
}
fn supports_streaming(&self) -> bool {
true
}
fn get_name(&self) -> &str {
self.base.name()
}
fn get_model(&self) -> &str {
self.base.model()
}
fn get_supported_tasks(&self) -> &HashMap<String, TaskDefinition> {
&self.base.supported_tasks()
}
fn is_enabled(&self) -> bool {
self.base.is_enabled()
}
}