use std::time::Instant;
use async_trait::async_trait;
use llmkit_core::{
ChatRequest, ChatResponse, ChatStream, EmbedRequest, EmbedResponse, LlmError, LlmProvider,
LlmResult, TokenUsage,
};
use crate::types::{ChatResponseBody, EmbeddingsRequestBody, EmbeddingsResponseBody};
use crate::{chat, stream};
const DEFAULT_BASE_URL: &str = "http://localhost:11434";
const DEFAULT_MODEL: &str = "llama3.1";
#[derive(Clone)]
pub struct OllamaProvider {
http: reqwest::Client,
base_url: String,
model: String,
}
impl OllamaProvider {
pub fn new() -> Self {
Self {
http: reqwest::Client::new(),
base_url: DEFAULT_BASE_URL.to_string(),
model: DEFAULT_MODEL.to_string(),
}
}
pub fn from_env() -> Self {
let mut p = Self::new();
if let Ok(host) = std::env::var("OLLAMA_HOST") {
p.base_url = host;
}
p
}
pub fn model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self
}
pub fn base_url(mut self, base_url: impl Into<String>) -> Self {
self.base_url = base_url.into();
self
}
pub fn with_client(mut self, client: reqwest::Client) -> Self {
self.http = client;
self
}
fn resolved_model(&self, req: &ChatRequest) -> String {
req.model.clone().unwrap_or_else(|| self.model.clone())
}
}
impl Default for OllamaProvider {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl LlmProvider for OllamaProvider {
async fn chat(&self, req: ChatRequest) -> LlmResult<ChatResponse> {
let model = self.resolved_model(&req);
let body = chat::build_request(&req, model, false);
let start = Instant::now();
let resp = self
.http
.post(format!("{}/api/chat", self.base_url))
.json(&body)
.send()
.await
.map_err(map_reqwest_err)?;
let resp = check_status(resp).await?;
let parsed: ChatResponseBody = resp.json().await.map_err(map_reqwest_err)?;
chat::map_response(parsed, start.elapsed().as_millis() as u64)
}
async fn chat_stream(&self, req: ChatRequest) -> LlmResult<ChatStream> {
let model = self.resolved_model(&req);
let body = chat::build_request(&req, model, true);
let resp = self
.http
.post(format!("{}/api/chat", self.base_url))
.json(&body)
.send()
.await
.map_err(map_reqwest_err)?;
let resp = check_status(resp).await?;
Ok(stream::parse(resp))
}
async fn embed(&self, req: EmbedRequest) -> LlmResult<EmbedResponse> {
let model = req.model.clone().unwrap_or_else(|| self.model.clone());
let body = EmbeddingsRequestBody { model, input: req.input };
let resp = self
.http
.post(format!("{}/api/embed", self.base_url))
.json(&body)
.send()
.await
.map_err(map_reqwest_err)?;
let resp = check_status(resp).await?;
let parsed: EmbeddingsResponseBody = resp.json().await.map_err(map_reqwest_err)?;
Ok(EmbedResponse {
provider: "ollama".into(),
model: parsed.model,
embeddings: parsed.embeddings,
usage: TokenUsage::new(parsed.prompt_eval_count.unwrap_or(0), 0),
})
}
fn name(&self) -> &'static str {
"ollama"
}
fn model(&self) -> &str {
&self.model
}
}
fn map_reqwest_err(e: reqwest::Error) -> LlmError {
if e.is_timeout() {
LlmError::Timeout
} else if e.is_connect() {
LlmError::Transport(format!("cannot reach Ollama server: {e}"))
} else if e.is_decode() {
LlmError::Serialization(e.to_string())
} else {
LlmError::Transport(e.to_string())
}
}
async fn check_status(resp: reqwest::Response) -> LlmResult<reqwest::Response> {
let status = resp.status();
if status.is_success() {
return Ok(resp);
}
let code = status.as_u16();
let message = resp.text().await.unwrap_or_default();
Err(match code {
404 => LlmError::InvalidRequest(format!("model not found or endpoint missing: {message}")),
400 => LlmError::InvalidRequest(message),
_ => LlmError::Provider { status: code, message },
})
}