use crate::error::{ApiErrorResponse, OllamaError, Result};
use reqwest::{Client, Url};
use std::time::Duration;
use tokio_stream::Stream;
pub(crate) fn json_lines_stream<T>(response: reqwest::Response) -> impl Stream<Item = Result<T>>
where
T: serde::de::DeserializeOwned + Send + 'static,
{
let (tx, rx) = futures::channel::mpsc::unbounded();
tokio::spawn(async move {
let bytes = match response.bytes().await {
Ok(b) => b,
Err(e) => {
let _ = tx.unbounded_send(Err(OllamaError::RequestError(e)));
return;
}
};
let text = match std::str::from_utf8(&bytes) {
Ok(t) => t,
Err(_) => {
let _ =
tx.unbounded_send(Err(OllamaError::StreamError("Invalid UTF-8 in response")));
return;
}
};
for line in text.lines().map(str::trim_ascii).filter(|l| !l.is_empty()) {
let item = serde_json::from_str::<T>(line).map_err(OllamaError::JsonError);
if tx.unbounded_send(item).is_err() {
break;
}
}
});
rx
}
pub(crate) async fn handle_error_response(
response: reqwest::Response,
model: Option<&str>,
) -> OllamaError {
let status = response.status();
let bytes = response.bytes().await.unwrap_or_default();
let error_message = if !bytes.is_empty() {
match serde_json::from_slice::<ApiErrorResponse>(&bytes) {
Ok(api_error) => api_error.error,
Err(_) => String::from_utf8_lossy(&bytes).to_string(),
}
} else {
"Unknown error".to_string()
};
if let Some(m) = model
&& error_message.contains("not found")
{
return OllamaError::ModelNotFound(m.to_string());
}
OllamaError::ApiError {
status: status.as_u16(),
message: error_message,
}
}
#[derive(Debug, Clone)]
pub struct ModelClient {
pub(crate) client: Client,
pub(crate) base_url: Url,
pub(crate) auth_token: Option<String>,
}
#[derive(Debug, Clone)]
pub struct ModelClientBuilder {
base_url: String,
timeout: Duration,
auth_token: Option<String>,
}
impl Default for ModelClientBuilder {
fn default() -> Self {
Self {
base_url: "http://localhost:11434".to_string(),
timeout: Duration::from_secs(300),
auth_token: None,
}
}
}
impl ModelClientBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn base_url(mut self, base_url: String) -> Self {
self.base_url = base_url;
self
}
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub fn auth_token(mut self, token: String) -> Self {
self.auth_token = Some(token);
self
}
pub fn build(self) -> Result<ModelClient> {
let mut client_builder = Client::builder().timeout(self.timeout);
if let Some(token) = &self.auth_token {
let mut headers = reqwest::header::HeaderMap::new();
let auth_value =
format!("Bearer {}", token)
.parse()
.map_err(|_| OllamaError::ApiError {
status: 0,
message: "Invalid auth token format".to_string(),
})?;
headers.insert(reqwest::header::AUTHORIZATION, auth_value);
client_builder = client_builder.default_headers(headers);
}
let client = client_builder.build().map_err(OllamaError::RequestError)?;
let base_url = Url::parse(&self.base_url).map_err(OllamaError::UrlError)?;
Ok(ModelClient {
client,
base_url,
auth_token: self.auth_token,
})
}
}
impl ModelClient {
pub fn builder() -> ModelClientBuilder {
ModelClientBuilder::new()
}
pub fn base_url(&self) -> &Url {
&self.base_url
}
pub fn is_authenticated(&self) -> bool {
self.auth_token.is_some()
}
pub async fn handle_response<T>(
&self,
response: reqwest::Response,
model: Option<&str>,
) -> Result<T>
where
for<'a> T: serde::Deserialize<'a>,
{
let status = response.status();
if !status.is_success() {
return Err(handle_error_response(response, model).await);
}
response.json().await.map_err(OllamaError::RequestError)
}
pub async fn handle_void_response(&self, response: reqwest::Response) -> Result<()> {
let status = response.status();
if !status.is_success() {
return Err(handle_error_response(response, None).await);
}
Ok(())
}
pub async fn get_version(&self) -> Result<crate::model::VersionResponse> {
let url = self
.base_url
.join("api/version")
.map_err(OllamaError::UrlError)?;
let response = self
.client
.get(url)
.send()
.await
.map_err(OllamaError::RequestError)?;
self.handle_response(response, None).await
}
#[cfg(feature = "local")]
pub async fn blob_exists(&self, digest: &str) -> Result<bool> {
let url = self
.base_url
.join(&format!("api/blobs/{}", digest))
.map_err(OllamaError::UrlError)?;
let response = self
.client
.head(url)
.send()
.await
.map_err(OllamaError::RequestError)?;
match response.status().as_u16() {
200 => Ok(true),
404 => Ok(false),
_ => Err(handle_error_response(response, None).await),
}
}
#[cfg(feature = "local")]
pub async fn push_blob(&self, digest: &str, content: &[u8]) -> Result<()> {
let url = self
.base_url
.join(&format!("api/blobs/{}", digest))
.map_err(OllamaError::UrlError)?;
let response = self
.client
.post(url)
.body(content.to_vec())
.send()
.await
.map_err(OllamaError::RequestError)?;
self.handle_void_response(response).await
}
#[cfg(feature = "local")]
pub async fn load_model(&self, model: &str) -> Result<crate::generate::GenerateResponse> {
let request = crate::generate::GenerateRequest {
model: model.to_string(),
prompt: String::new(),
stream: false,
..Default::default()
};
self.generate(request).await
}
#[cfg(feature = "local")]
pub async fn unload_model(&self, model: &str) -> Result<crate::generate::GenerateResponse> {
let request = crate::generate::GenerateRequest {
model: model.to_string(),
prompt: String::new(),
stream: false,
keep_alive: Some("0".to_string()),
..Default::default()
};
self.generate(request).await
}
}