Skip to main content

oai_sdk/
client.rs

1// Copyright 2026 Cloudflavor GmbH
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6
7// http://www.apache.org/licenses/LICENSE-2.0
8
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::error::{ApiErrorResponse, OllamaError, Result};
16use reqwest::Client;
17use reqwest::Url;
18use std::time::Duration;
19use tokio::sync::mpsc;
20use tokio_stream::{Stream, wrappers::UnboundedReceiverStream};
21
22pub(crate) fn json_lines_stream<T>(response: reqwest::Response) -> impl Stream<Item = Result<T>>
23where
24    T: serde::de::DeserializeOwned + Send + 'static,
25{
26    let (tx, rx) = mpsc::unbounded_channel();
27    tokio::spawn(async move {
28        use tokio_stream::StreamExt;
29
30        let mut stream = response.bytes_stream();
31        let mut buf = String::new();
32
33        loop {
34            match stream.next().await {
35                Some(Ok(chunk)) => buf.push_str(&String::from_utf8_lossy(&chunk)),
36                Some(Err(e)) => {
37                    let _ = tx.send(Err(OllamaError::RequestError(e)));
38                    return;
39                }
40                None => {
41                    let remainder = buf.trim();
42                    if !remainder.is_empty() {
43                        let _ = tx.send(
44                            serde_json::from_str::<T>(remainder).map_err(OllamaError::JsonError),
45                        );
46                    }
47                    return;
48                }
49            }
50
51            while let Some(nl) = buf.find('\n') {
52                let rest = buf.split_off(nl + 1);
53                let line = buf.trim();
54                if !line.is_empty() {
55                    let item = serde_json::from_str::<T>(line).map_err(OllamaError::JsonError);
56                    if tx.send(item).is_err() {
57                        return;
58                    }
59                }
60                buf = rest;
61            }
62        }
63    });
64    UnboundedReceiverStream::new(rx)
65}
66
67/// Helper function to handle error responses consistently across the SDK.
68pub(crate) async fn handle_error_response(
69    response: reqwest::Response,
70    model: Option<&str>,
71) -> OllamaError {
72    let status = response.status();
73    let bytes = response.bytes().await.unwrap_or_default();
74    let error_message = if !bytes.is_empty() {
75        match serde_json::from_slice::<ApiErrorResponse>(&bytes) {
76            Ok(api_error) => api_error.error,
77            Err(_) => String::from_utf8_lossy(&bytes).to_string(),
78        }
79    } else {
80        "Unknown error".to_string()
81    };
82
83    if let Some(m) = model
84        && error_message.contains("not found")
85    {
86        return OllamaError::ModelNotFound(m.to_string());
87    }
88
89    OllamaError::ApiError {
90        status: status.as_u16(),
91        message: error_message,
92    }
93}
94
95/// A client for interacting with the Ollama API.
96#[derive(Debug, Clone)]
97pub struct ModelClient {
98    pub(crate) client: Client,
99    pub(crate) base_url: Url,
100    pub(crate) auth_token: Option<String>,
101}
102
103/// A builder for creating a `ModelClient`.
104#[derive(Debug, Clone)]
105pub struct ModelClientBuilder {
106    base_url: String,
107    timeout: Duration,
108    auth_token: Option<String>,
109}
110
111impl Default for ModelClientBuilder {
112    fn default() -> Self {
113        Self {
114            base_url: "http://localhost:11434".to_string(),
115            timeout: Duration::from_secs(300),
116            auth_token: None,
117        }
118    }
119}
120
121impl ModelClientBuilder {
122    /// Create a new builder with default settings.
123    pub fn new() -> Self {
124        Self::default()
125    }
126
127    /// Set the base URL for the Ollama API.
128    ///
129    /// Defaults to `http://localhost:11434` for local instances.
130    /// Use `https://ollama.com` for cloud access.
131    pub fn base_url(mut self, base_url: impl Into<String>) -> Self {
132        self.base_url = base_url.into();
133        self
134    }
135
136    /// Set the timeout for API requests.
137    pub fn timeout(mut self, timeout: Duration) -> Self {
138        self.timeout = timeout;
139        self
140    }
141
142    /// Set an authentication token for cloud access.
143    ///
144    /// Required when accessing cloud-hosted models, publishing models,
145    /// or downloading private models.
146    pub fn auth_token(mut self, token: String) -> Self {
147        self.auth_token = Some(token);
148        self
149    }
150
151    /// Build the `ModelClient`.
152    pub fn build(self) -> Result<ModelClient> {
153        let mut client_builder = Client::builder().timeout(self.timeout);
154
155        if let Some(token) = &self.auth_token {
156            let mut headers = reqwest::header::HeaderMap::new();
157            let auth_value =
158                format!("Bearer {}", token)
159                    .parse()
160                    .map_err(|_| OllamaError::ApiError {
161                        status: 0,
162                        message: "Invalid auth token format".to_string(),
163                    })?;
164            headers.insert(reqwest::header::AUTHORIZATION, auth_value);
165            client_builder = client_builder.default_headers(headers);
166        }
167
168        let client = client_builder.build().map_err(OllamaError::RequestError)?;
169        let base_url = Url::parse(&self.base_url).map_err(OllamaError::UrlError)?;
170        Ok(ModelClient {
171            client,
172            base_url,
173            auth_token: self.auth_token,
174        })
175    }
176}
177
178impl ModelClient {
179    /// Create a new builder for a `ModelClient`.
180    pub fn builder() -> ModelClientBuilder {
181        ModelClientBuilder::new()
182    }
183
184    /// Get the configured base URL.
185    pub fn base_url(&self) -> &Url {
186        &self.base_url
187    }
188
189    /// Check if authentication is configured.
190    pub fn is_authenticated(&self) -> bool {
191        self.auth_token.is_some()
192    }
193
194    /// Helper method to handle responses consistently.
195    pub async fn handle_response<T>(
196        &self,
197        response: reqwest::Response,
198        model: Option<&str>,
199    ) -> Result<T>
200    where
201        for<'a> T: serde::Deserialize<'a>,
202    {
203        let status = response.status();
204        if !status.is_success() {
205            return Err(handle_error_response(response, model).await);
206        }
207
208        response.json().await.map_err(OllamaError::RequestError)
209    }
210
211    /// Helper method to handle responses that return nothing (Result<()>).
212    pub async fn handle_void_response(&self, response: reqwest::Response) -> Result<()> {
213        let status = response.status();
214        if !status.is_success() {
215            return Err(handle_error_response(response, None).await);
216        }
217        Ok(())
218    }
219
220    /// Get the version of the Ollama API.
221    pub async fn get_version(&self) -> Result<crate::model::VersionResponse> {
222        let url = self
223            .base_url
224            .join("api/version")
225            .map_err(OllamaError::UrlError)?;
226        let response = self
227            .client
228            .get(url)
229            .send()
230            .await
231            .map_err(OllamaError::RequestError)?;
232
233        self.handle_response(response, None).await
234    }
235
236    /// Check if a blob exists.
237    #[cfg(feature = "local")]
238    pub async fn blob_exists(&self, digest: &str) -> Result<bool> {
239        let url = self
240            .base_url
241            .join(&format!("api/blobs/{}", digest))
242            .map_err(OllamaError::UrlError)?;
243        let response = self
244            .client
245            .head(url)
246            .send()
247            .await
248            .map_err(OllamaError::RequestError)?;
249
250        match response.status().as_u16() {
251            200 => Ok(true),
252            404 => Ok(false),
253            _ => Err(handle_error_response(response, None).await),
254        }
255    }
256
257    /// Push a blob to the Ollama server.
258    #[cfg(feature = "local")]
259    pub async fn push_blob(&self, digest: &str, content: &[u8]) -> Result<()> {
260        let url = self
261            .base_url
262            .join(&format!("api/blobs/{}", digest))
263            .map_err(OllamaError::UrlError)?;
264        let response = self
265            .client
266            .post(url)
267            .body(content.to_vec())
268            .send()
269            .await
270            .map_err(OllamaError::RequestError)?;
271
272        self.handle_void_response(response).await
273    }
274
275    /// Load a model into memory by sending an empty prompt.
276    #[cfg(feature = "local")]
277    pub async fn load_model(&self, model: &str) -> Result<crate::generate::GenerateResponse> {
278        let request = crate::generate::GenerateRequest {
279            model: model.to_string(),
280            prompt: String::new(),
281            stream: false,
282            ..Default::default()
283        };
284
285        self.generate(request).await
286    }
287
288    /// Unload a model from memory by setting keep_alive to "0".
289    #[cfg(feature = "local")]
290    pub async fn unload_model(&self, model: &str) -> Result<crate::generate::GenerateResponse> {
291        let request = crate::generate::GenerateRequest {
292            model: model.to_string(),
293            prompt: String::new(),
294            stream: false,
295            keep_alive: Some("0".to_string()),
296            ..Default::default()
297        };
298
299        self.generate(request).await
300    }
301}