1use crate::error::{ApiErrorResponse, OllamaError, Result};
16use reqwest::Client;
17use reqwest::Url;
18use std::time::Duration;
19use tokio::sync::mpsc;
20use tokio_stream::{Stream, wrappers::UnboundedReceiverStream};
21
22pub(crate) fn json_lines_stream<T>(response: reqwest::Response) -> impl Stream<Item = Result<T>>
23where
24 T: serde::de::DeserializeOwned + Send + 'static,
25{
26 let (tx, rx) = mpsc::unbounded_channel();
27 tokio::spawn(async move {
28 use tokio_stream::StreamExt;
29
30 let mut stream = response.bytes_stream();
31 let mut buf = String::new();
32
33 loop {
34 match stream.next().await {
35 Some(Ok(chunk)) => buf.push_str(&String::from_utf8_lossy(&chunk)),
36 Some(Err(e)) => {
37 let _ = tx.send(Err(OllamaError::RequestError(e)));
38 return;
39 }
40 None => {
41 let remainder = buf.trim();
42 if !remainder.is_empty() {
43 let _ = tx.send(
44 serde_json::from_str::<T>(remainder).map_err(OllamaError::JsonError),
45 );
46 }
47 return;
48 }
49 }
50
51 while let Some(nl) = buf.find('\n') {
52 let rest = buf.split_off(nl + 1);
53 let line = buf.trim();
54 if !line.is_empty() {
55 let item = serde_json::from_str::<T>(line).map_err(OllamaError::JsonError);
56 if tx.send(item).is_err() {
57 return;
58 }
59 }
60 buf = rest;
61 }
62 }
63 });
64 UnboundedReceiverStream::new(rx)
65}
66
67pub(crate) async fn handle_error_response(
69 response: reqwest::Response,
70 model: Option<&str>,
71) -> OllamaError {
72 let status = response.status();
73 let bytes = response.bytes().await.unwrap_or_default();
74 let error_message = if !bytes.is_empty() {
75 match serde_json::from_slice::<ApiErrorResponse>(&bytes) {
76 Ok(api_error) => api_error.error,
77 Err(_) => String::from_utf8_lossy(&bytes).to_string(),
78 }
79 } else {
80 "Unknown error".to_string()
81 };
82
83 if let Some(m) = model
84 && error_message.contains("not found")
85 {
86 return OllamaError::ModelNotFound(m.to_string());
87 }
88
89 OllamaError::ApiError {
90 status: status.as_u16(),
91 message: error_message,
92 }
93}
94
95#[derive(Debug, Clone)]
97pub struct ModelClient {
98 pub(crate) client: Client,
99 pub(crate) base_url: Url,
100 pub(crate) auth_token: Option<String>,
101}
102
103#[derive(Debug, Clone)]
105pub struct ModelClientBuilder {
106 base_url: String,
107 timeout: Duration,
108 auth_token: Option<String>,
109}
110
111impl Default for ModelClientBuilder {
112 fn default() -> Self {
113 Self {
114 base_url: "http://localhost:11434".to_string(),
115 timeout: Duration::from_secs(300),
116 auth_token: None,
117 }
118 }
119}
120
121impl ModelClientBuilder {
122 pub fn new() -> Self {
124 Self::default()
125 }
126
127 pub fn base_url(mut self, base_url: impl Into<String>) -> Self {
132 self.base_url = base_url.into();
133 self
134 }
135
136 pub fn timeout(mut self, timeout: Duration) -> Self {
138 self.timeout = timeout;
139 self
140 }
141
142 pub fn auth_token(mut self, token: String) -> Self {
147 self.auth_token = Some(token);
148 self
149 }
150
151 pub fn build(self) -> Result<ModelClient> {
153 let mut client_builder = Client::builder().timeout(self.timeout);
154
155 if let Some(token) = &self.auth_token {
156 let mut headers = reqwest::header::HeaderMap::new();
157 let auth_value =
158 format!("Bearer {}", token)
159 .parse()
160 .map_err(|_| OllamaError::ApiError {
161 status: 0,
162 message: "Invalid auth token format".to_string(),
163 })?;
164 headers.insert(reqwest::header::AUTHORIZATION, auth_value);
165 client_builder = client_builder.default_headers(headers);
166 }
167
168 let client = client_builder.build().map_err(OllamaError::RequestError)?;
169 let base_url = Url::parse(&self.base_url).map_err(OllamaError::UrlError)?;
170 Ok(ModelClient {
171 client,
172 base_url,
173 auth_token: self.auth_token,
174 })
175 }
176}
177
178impl ModelClient {
179 pub fn builder() -> ModelClientBuilder {
181 ModelClientBuilder::new()
182 }
183
184 pub fn base_url(&self) -> &Url {
186 &self.base_url
187 }
188
189 pub fn is_authenticated(&self) -> bool {
191 self.auth_token.is_some()
192 }
193
194 pub async fn handle_response<T>(
196 &self,
197 response: reqwest::Response,
198 model: Option<&str>,
199 ) -> Result<T>
200 where
201 for<'a> T: serde::Deserialize<'a>,
202 {
203 let status = response.status();
204 if !status.is_success() {
205 return Err(handle_error_response(response, model).await);
206 }
207
208 response.json().await.map_err(OllamaError::RequestError)
209 }
210
211 pub async fn handle_void_response(&self, response: reqwest::Response) -> Result<()> {
213 let status = response.status();
214 if !status.is_success() {
215 return Err(handle_error_response(response, None).await);
216 }
217 Ok(())
218 }
219
220 pub async fn get_version(&self) -> Result<crate::model::VersionResponse> {
222 let url = self
223 .base_url
224 .join("api/version")
225 .map_err(OllamaError::UrlError)?;
226 let response = self
227 .client
228 .get(url)
229 .send()
230 .await
231 .map_err(OllamaError::RequestError)?;
232
233 self.handle_response(response, None).await
234 }
235
236 #[cfg(feature = "local")]
238 pub async fn blob_exists(&self, digest: &str) -> Result<bool> {
239 let url = self
240 .base_url
241 .join(&format!("api/blobs/{}", digest))
242 .map_err(OllamaError::UrlError)?;
243 let response = self
244 .client
245 .head(url)
246 .send()
247 .await
248 .map_err(OllamaError::RequestError)?;
249
250 match response.status().as_u16() {
251 200 => Ok(true),
252 404 => Ok(false),
253 _ => Err(handle_error_response(response, None).await),
254 }
255 }
256
257 #[cfg(feature = "local")]
259 pub async fn push_blob(&self, digest: &str, content: &[u8]) -> Result<()> {
260 let url = self
261 .base_url
262 .join(&format!("api/blobs/{}", digest))
263 .map_err(OllamaError::UrlError)?;
264 let response = self
265 .client
266 .post(url)
267 .body(content.to_vec())
268 .send()
269 .await
270 .map_err(OllamaError::RequestError)?;
271
272 self.handle_void_response(response).await
273 }
274
275 #[cfg(feature = "local")]
277 pub async fn load_model(&self, model: &str) -> Result<crate::generate::GenerateResponse> {
278 let request = crate::generate::GenerateRequest {
279 model: model.to_string(),
280 prompt: String::new(),
281 stream: false,
282 ..Default::default()
283 };
284
285 self.generate(request).await
286 }
287
288 #[cfg(feature = "local")]
290 pub async fn unload_model(&self, model: &str) -> Result<crate::generate::GenerateResponse> {
291 let request = crate::generate::GenerateRequest {
292 model: model.to_string(),
293 prompt: String::new(),
294 stream: false,
295 keep_alive: Some("0".to_string()),
296 ..Default::default()
297 };
298
299 self.generate(request).await
300 }
301}