lib_client_ollama/
client.rs1use crate::error::{OllamaError, Result};
4use crate::types::{
5 ChatRequest, ChatResponse, ErrorResponse, GenerateRequest, GenerateResponse, ModelInfo,
6 ModelList,
7};
8use reqwest::header::{HeaderMap, HeaderValue, CONTENT_TYPE};
9
10const DEFAULT_HOST: &str = "http://localhost:11434";
11
12pub struct Client {
14 http: reqwest::Client,
15 host: String,
16}
17
18impl Client {
19 pub fn builder() -> ClientBuilder {
21 ClientBuilder::new()
22 }
23
24 pub fn new() -> Self {
26 ClientBuilder::new().build()
27 }
28
29 pub async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
31 let url = format!("{}/api/chat", self.host);
32 self.post(&url, &request).await
33 }
34
35 pub async fn generate(&self, request: GenerateRequest) -> Result<GenerateResponse> {
37 let url = format!("{}/api/generate", self.host);
38 self.post(&url, &request).await
39 }
40
41 pub async fn list_models(&self) -> Result<ModelList> {
43 let url = format!("{}/api/tags", self.host);
44 self.get(&url).await
45 }
46
47 pub async fn show_model(&self, name: &str) -> Result<ModelInfo> {
49 let url = format!("{}/api/show", self.host);
50 let body = serde_json::json!({ "name": name });
51 self.post(&url, &body).await
52 }
53
54 pub async fn pull_model(&self, name: &str) -> Result<()> {
56 let url = format!("{}/api/pull", self.host);
57 let body = serde_json::json!({ "name": name, "stream": false });
58
59 let mut headers = HeaderMap::new();
60 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
61
62 let response = self
63 .http
64 .post(&url)
65 .headers(headers)
66 .json(&body)
67 .send()
68 .await
69 .map_err(|e| {
70 if e.is_connect() {
71 OllamaError::ConnectionRefused
72 } else {
73 OllamaError::Request(e)
74 }
75 })?;
76
77 let status = response.status();
78 if status.is_success() {
79 Ok(())
80 } else {
81 let body = response.text().await?;
82 Err(OllamaError::Api {
83 status: status.as_u16(),
84 message: body,
85 })
86 }
87 }
88
89 pub async fn delete_model(&self, name: &str) -> Result<()> {
91 let url = format!("{}/api/delete", self.host);
92 let body = serde_json::json!({ "name": name });
93
94 let mut headers = HeaderMap::new();
95 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
96
97 let response = self
98 .http
99 .delete(&url)
100 .headers(headers)
101 .json(&body)
102 .send()
103 .await
104 .map_err(|e| {
105 if e.is_connect() {
106 OllamaError::ConnectionRefused
107 } else {
108 OllamaError::Request(e)
109 }
110 })?;
111
112 let status = response.status();
113 if status.is_success() {
114 Ok(())
115 } else {
116 let body = response.text().await?;
117 if status.as_u16() == 404 {
118 Err(OllamaError::ModelNotFound(name.to_string()))
119 } else {
120 Err(OllamaError::Api {
121 status: status.as_u16(),
122 message: body,
123 })
124 }
125 }
126 }
127
128 pub async fn is_running(&self) -> bool {
130 let url = format!("{}/api/tags", self.host);
131 self.http.get(&url).send().await.is_ok()
132 }
133
134 async fn get<T>(&self, url: &str) -> Result<T>
136 where
137 T: serde::de::DeserializeOwned,
138 {
139 tracing::debug!(url = %url, "GET request");
140
141 let response = self.http.get(url).send().await.map_err(|e| {
142 if e.is_connect() {
143 OllamaError::ConnectionRefused
144 } else {
145 OllamaError::Request(e)
146 }
147 })?;
148
149 self.handle_response(response).await
150 }
151
152 async fn post<T, B>(&self, url: &str, body: &B) -> Result<T>
154 where
155 T: serde::de::DeserializeOwned,
156 B: serde::Serialize,
157 {
158 let mut headers = HeaderMap::new();
159 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
160
161 tracing::debug!(url = %url, "POST request");
162
163 let response = self
164 .http
165 .post(url)
166 .headers(headers)
167 .json(body)
168 .send()
169 .await
170 .map_err(|e| {
171 if e.is_connect() {
172 OllamaError::ConnectionRefused
173 } else {
174 OllamaError::Request(e)
175 }
176 })?;
177
178 self.handle_response(response).await
179 }
180
181 async fn handle_response<T>(&self, response: reqwest::Response) -> Result<T>
183 where
184 T: serde::de::DeserializeOwned,
185 {
186 let status = response.status();
187 let status_code = status.as_u16();
188
189 if status.is_success() {
190 let body = response.text().await?;
191 tracing::debug!(status = %status_code, "Response received");
192 serde_json::from_str(&body).map_err(OllamaError::from)
193 } else {
194 let body = response.text().await?;
195 tracing::warn!(status = %status_code, body = %body, "API error");
196
197 if let Ok(error_response) = serde_json::from_str::<ErrorResponse>(&body) {
199 let message = error_response.error;
200
201 return Err(if message.contains("not found") {
202 OllamaError::ModelNotFound(message)
203 } else {
204 OllamaError::Api {
205 status: status_code,
206 message,
207 }
208 });
209 }
210
211 Err(OllamaError::Api {
212 status: status_code,
213 message: body,
214 })
215 }
216 }
217}
218
219impl Default for Client {
220 fn default() -> Self {
221 Self::new()
222 }
223}
224
225pub struct ClientBuilder {
227 host: String,
228}
229
230impl ClientBuilder {
231 pub fn new() -> Self {
233 Self {
234 host: DEFAULT_HOST.to_string(),
235 }
236 }
237
238 pub fn host(mut self, host: impl Into<String>) -> Self {
240 self.host = host.into();
241 self
242 }
243
244 pub fn build(self) -> Client {
246 Client {
247 http: reqwest::Client::new(),
248 host: self.host,
249 }
250 }
251}
252
253impl Default for ClientBuilder {
254 fn default() -> Self {
255 Self::new()
256 }
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262 use crate::types::Message;
263
264 #[test]
265 fn test_builder() {
266 let client = Client::builder().host("http://custom:8080").build();
267 assert_eq!(client.host, "http://custom:8080");
268 }
269
270 #[test]
271 fn test_default_host() {
272 let client = Client::new();
273 assert_eq!(client.host, "http://localhost:11434");
274 }
275
276 #[test]
277 fn test_chat_request() {
278 let request = ChatRequest::new("llama3.2", vec![Message::user("Hello")]);
279 assert_eq!(request.model, "llama3.2");
280 assert!(!request.stream);
281 }
282}