lib_client_ollama/
client.rs

1//! Ollama API client implementation.
2
3use crate::error::{OllamaError, Result};
4use crate::types::{
5    ChatRequest, ChatResponse, ErrorResponse, GenerateRequest, GenerateResponse, ModelInfo,
6    ModelList,
7};
8use reqwest::header::{HeaderMap, HeaderValue, CONTENT_TYPE};
9
10const DEFAULT_HOST: &str = "http://localhost:11434";
11
12/// Ollama API client.
13pub struct Client {
14    http: reqwest::Client,
15    host: String,
16}
17
18impl Client {
19    /// Create a new client builder.
20    pub fn builder() -> ClientBuilder {
21        ClientBuilder::new()
22    }
23
24    /// Create a new client with default settings.
25    pub fn new() -> Self {
26        ClientBuilder::new().build()
27    }
28
29    /// Generate a chat completion.
30    pub async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
31        let url = format!("{}/api/chat", self.host);
32        self.post(&url, &request).await
33    }
34
35    /// Generate a completion (non-chat).
36    pub async fn generate(&self, request: GenerateRequest) -> Result<GenerateResponse> {
37        let url = format!("{}/api/generate", self.host);
38        self.post(&url, &request).await
39    }
40
41    /// List available models.
42    pub async fn list_models(&self) -> Result<ModelList> {
43        let url = format!("{}/api/tags", self.host);
44        self.get(&url).await
45    }
46
47    /// Get information about a specific model.
48    pub async fn show_model(&self, name: &str) -> Result<ModelInfo> {
49        let url = format!("{}/api/show", self.host);
50        let body = serde_json::json!({ "name": name });
51        self.post(&url, &body).await
52    }
53
54    /// Pull a model from the registry.
55    pub async fn pull_model(&self, name: &str) -> Result<()> {
56        let url = format!("{}/api/pull", self.host);
57        let body = serde_json::json!({ "name": name, "stream": false });
58
59        let mut headers = HeaderMap::new();
60        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
61
62        let response = self
63            .http
64            .post(&url)
65            .headers(headers)
66            .json(&body)
67            .send()
68            .await
69            .map_err(|e| {
70                if e.is_connect() {
71                    OllamaError::ConnectionRefused
72                } else {
73                    OllamaError::Request(e)
74                }
75            })?;
76
77        let status = response.status();
78        if status.is_success() {
79            Ok(())
80        } else {
81            let body = response.text().await?;
82            Err(OllamaError::Api {
83                status: status.as_u16(),
84                message: body,
85            })
86        }
87    }
88
89    /// Delete a model.
90    pub async fn delete_model(&self, name: &str) -> Result<()> {
91        let url = format!("{}/api/delete", self.host);
92        let body = serde_json::json!({ "name": name });
93
94        let mut headers = HeaderMap::new();
95        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
96
97        let response = self
98            .http
99            .delete(&url)
100            .headers(headers)
101            .json(&body)
102            .send()
103            .await
104            .map_err(|e| {
105                if e.is_connect() {
106                    OllamaError::ConnectionRefused
107                } else {
108                    OllamaError::Request(e)
109                }
110            })?;
111
112        let status = response.status();
113        if status.is_success() {
114            Ok(())
115        } else {
116            let body = response.text().await?;
117            if status.as_u16() == 404 {
118                Err(OllamaError::ModelNotFound(name.to_string()))
119            } else {
120                Err(OllamaError::Api {
121                    status: status.as_u16(),
122                    message: body,
123                })
124            }
125        }
126    }
127
128    /// Check if Ollama is running.
129    pub async fn is_running(&self) -> bool {
130        let url = format!("{}/api/tags", self.host);
131        self.http.get(&url).send().await.is_ok()
132    }
133
134    /// Send a GET request.
135    async fn get<T>(&self, url: &str) -> Result<T>
136    where
137        T: serde::de::DeserializeOwned,
138    {
139        tracing::debug!(url = %url, "GET request");
140
141        let response = self.http.get(url).send().await.map_err(|e| {
142            if e.is_connect() {
143                OllamaError::ConnectionRefused
144            } else {
145                OllamaError::Request(e)
146            }
147        })?;
148
149        self.handle_response(response).await
150    }
151
152    /// Send a POST request with JSON body.
153    async fn post<T, B>(&self, url: &str, body: &B) -> Result<T>
154    where
155        T: serde::de::DeserializeOwned,
156        B: serde::Serialize,
157    {
158        let mut headers = HeaderMap::new();
159        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
160
161        tracing::debug!(url = %url, "POST request");
162
163        let response = self
164            .http
165            .post(url)
166            .headers(headers)
167            .json(body)
168            .send()
169            .await
170            .map_err(|e| {
171                if e.is_connect() {
172                    OllamaError::ConnectionRefused
173                } else {
174                    OllamaError::Request(e)
175                }
176            })?;
177
178        self.handle_response(response).await
179    }
180
181    /// Handle API response.
182    async fn handle_response<T>(&self, response: reqwest::Response) -> Result<T>
183    where
184        T: serde::de::DeserializeOwned,
185    {
186        let status = response.status();
187        let status_code = status.as_u16();
188
189        if status.is_success() {
190            let body = response.text().await?;
191            tracing::debug!(status = %status_code, "Response received");
192            serde_json::from_str(&body).map_err(OllamaError::from)
193        } else {
194            let body = response.text().await?;
195            tracing::warn!(status = %status_code, body = %body, "API error");
196
197            // Try to parse error response
198            if let Ok(error_response) = serde_json::from_str::<ErrorResponse>(&body) {
199                let message = error_response.error;
200
201                return Err(if message.contains("not found") {
202                    OllamaError::ModelNotFound(message)
203                } else {
204                    OllamaError::Api {
205                        status: status_code,
206                        message,
207                    }
208                });
209            }
210
211            Err(OllamaError::Api {
212                status: status_code,
213                message: body,
214            })
215        }
216    }
217}
218
219impl Default for Client {
220    fn default() -> Self {
221        Self::new()
222    }
223}
224
225/// Client builder.
226pub struct ClientBuilder {
227    host: String,
228}
229
230impl ClientBuilder {
231    /// Create a new client builder.
232    pub fn new() -> Self {
233        Self {
234            host: DEFAULT_HOST.to_string(),
235        }
236    }
237
238    /// Set a custom host URL.
239    pub fn host(mut self, host: impl Into<String>) -> Self {
240        self.host = host.into();
241        self
242    }
243
244    /// Build the client.
245    pub fn build(self) -> Client {
246        Client {
247            http: reqwest::Client::new(),
248            host: self.host,
249        }
250    }
251}
252
253impl Default for ClientBuilder {
254    fn default() -> Self {
255        Self::new()
256    }
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262    use crate::types::Message;
263
264    #[test]
265    fn test_builder() {
266        let client = Client::builder().host("http://custom:8080").build();
267        assert_eq!(client.host, "http://custom:8080");
268    }
269
270    #[test]
271    fn test_default_host() {
272        let client = Client::new();
273        assert_eq!(client.host, "http://localhost:11434");
274    }
275
276    #[test]
277    fn test_chat_request() {
278        let request = ChatRequest::new("llama3.2", vec![Message::user("Hello")]);
279        assert_eq!(request.model, "llama3.2");
280        assert!(!request.stream);
281    }
282}