Skip to main content

llmkit_ollama/
provider.rs

1//! [`OllamaProvider`] — implements [`LlmProvider`] against a local Ollama server.
2
3use std::time::Instant;
4
5use async_trait::async_trait;
6use llmkit_core::{
7    ChatRequest, ChatResponse, ChatStream, EmbedRequest, EmbedResponse, LlmError, LlmProvider,
8    LlmResult, TokenUsage,
9};
10
11use crate::types::{ChatResponseBody, EmbeddingsRequestBody, EmbeddingsResponseBody};
12use crate::{chat, stream};
13
14const DEFAULT_BASE_URL: &str = "http://localhost:11434";
15const DEFAULT_MODEL: &str = "llama3.1";
16
17/// Ollama provider for local models (Llama, Mistral, …).
18#[derive(Clone)]
19pub struct OllamaProvider {
20    http: reqwest::Client,
21    base_url: String,
22    model: String,
23}
24
25impl OllamaProvider {
26    /// Construct against the default local server (`http://localhost:11434`).
27    pub fn new() -> Self {
28        Self {
29            http: reqwest::Client::new(),
30            base_url: DEFAULT_BASE_URL.to_string(),
31            model: DEFAULT_MODEL.to_string(),
32        }
33    }
34
35    /// Construct from the `OLLAMA_HOST` environment variable, if set.
36    pub fn from_env() -> Self {
37        let mut p = Self::new();
38        if let Ok(host) = std::env::var("OLLAMA_HOST") {
39            p.base_url = host;
40        }
41        p
42    }
43
44    /// Set the default model.
45    pub fn model(mut self, model: impl Into<String>) -> Self {
46        self.model = model.into();
47        self
48    }
49
50    /// Override the base URL.
51    pub fn base_url(mut self, base_url: impl Into<String>) -> Self {
52        self.base_url = base_url.into();
53        self
54    }
55
56    /// Provide a custom [`reqwest::Client`].
57    pub fn with_client(mut self, client: reqwest::Client) -> Self {
58        self.http = client;
59        self
60    }
61
62    fn resolved_model(&self, req: &ChatRequest) -> String {
63        req.model.clone().unwrap_or_else(|| self.model.clone())
64    }
65}
66
67impl Default for OllamaProvider {
68    fn default() -> Self {
69        Self::new()
70    }
71}
72
73#[async_trait]
74impl LlmProvider for OllamaProvider {
75    async fn chat(&self, req: ChatRequest) -> LlmResult<ChatResponse> {
76        let model = self.resolved_model(&req);
77        let body = chat::build_request(&req, model, false);
78
79        let start = Instant::now();
80        let resp = self
81            .http
82            .post(format!("{}/api/chat", self.base_url))
83            .json(&body)
84            .send()
85            .await
86            .map_err(map_reqwest_err)?;
87
88        let resp = check_status(resp).await?;
89        let parsed: ChatResponseBody = resp.json().await.map_err(map_reqwest_err)?;
90        chat::map_response(parsed, start.elapsed().as_millis() as u64)
91    }
92
93    async fn chat_stream(&self, req: ChatRequest) -> LlmResult<ChatStream> {
94        let model = self.resolved_model(&req);
95        let body = chat::build_request(&req, model, true);
96
97        let resp = self
98            .http
99            .post(format!("{}/api/chat", self.base_url))
100            .json(&body)
101            .send()
102            .await
103            .map_err(map_reqwest_err)?;
104
105        let resp = check_status(resp).await?;
106        Ok(stream::parse(resp))
107    }
108
109    async fn embed(&self, req: EmbedRequest) -> LlmResult<EmbedResponse> {
110        let model = req.model.clone().unwrap_or_else(|| self.model.clone());
111        let body = EmbeddingsRequestBody { model, input: req.input };
112
113        let resp = self
114            .http
115            .post(format!("{}/api/embed", self.base_url))
116            .json(&body)
117            .send()
118            .await
119            .map_err(map_reqwest_err)?;
120
121        let resp = check_status(resp).await?;
122        let parsed: EmbeddingsResponseBody = resp.json().await.map_err(map_reqwest_err)?;
123
124        Ok(EmbedResponse {
125            provider: "ollama".into(),
126            model: parsed.model,
127            embeddings: parsed.embeddings,
128            usage: TokenUsage::new(parsed.prompt_eval_count.unwrap_or(0), 0),
129        })
130    }
131
132    fn name(&self) -> &'static str {
133        "ollama"
134    }
135
136    fn model(&self) -> &str {
137        &self.model
138    }
139}
140
141fn map_reqwest_err(e: reqwest::Error) -> LlmError {
142    if e.is_timeout() {
143        LlmError::Timeout
144    } else if e.is_connect() {
145        LlmError::Transport(format!("cannot reach Ollama server: {e}"))
146    } else if e.is_decode() {
147        LlmError::Serialization(e.to_string())
148    } else {
149        LlmError::Transport(e.to_string())
150    }
151}
152
153async fn check_status(resp: reqwest::Response) -> LlmResult<reqwest::Response> {
154    let status = resp.status();
155    if status.is_success() {
156        return Ok(resp);
157    }
158    let code = status.as_u16();
159    let message = resp.text().await.unwrap_or_default();
160    Err(match code {
161        404 => LlmError::InvalidRequest(format!("model not found or endpoint missing: {message}")),
162        400 => LlmError::InvalidRequest(message),
163        _ => LlmError::Provider { status: code, message },
164    })
165}