use crate::error::{ApiErrorResponse, OllamaError, Result};
use reqwest::{Client, Url};
use std::time::Duration;
pub async fn handle_error_response(
response: reqwest::Response,
model: Option<&str>,
) -> OllamaError {
let status = response.status();
let bytes = response.bytes().await.unwrap_or_default();
let error_message = if !bytes.is_empty() {
match serde_json::from_slice::<ApiErrorResponse>(&bytes) {
Ok(api_error) => api_error.error,
Err(_) => String::from_utf8_lossy(&bytes).to_string(),
}
} else {
"Unknown error".to_string()
};
if let Some(m) = model
&& error_message.contains("not found")
{
return OllamaError::ModelNotFound(m.to_string());
}
OllamaError::ApiError {
status: status.as_u16(),
message: error_message,
}
}
#[derive(Debug, Clone)]
pub struct ModelClient {
pub(crate) client: Client,
pub(crate) base_url: Url,
pub(crate) auth_token: Option<String>,
}
#[derive(Debug, Clone)]
pub struct ModelClientBuilder {
base_url: String,
timeout: Duration,
auth_token: Option<String>,
}
impl Default for ModelClientBuilder {
fn default() -> Self {
Self {
base_url: "http://localhost:11434".to_string(),
timeout: Duration::from_secs(300),
auth_token: None,
}
}
}
impl ModelClientBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn base_url(mut self, base_url: String) -> Self {
self.base_url = base_url;
self
}
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub fn auth_token(mut self, token: String) -> Self {
self.auth_token = Some(token);
self
}
pub fn build(self) -> Result<ModelClient> {
let mut client_builder = Client::builder().timeout(self.timeout);
if let Some(token) = &self.auth_token {
let mut headers = reqwest::header::HeaderMap::new();
let auth_value =
format!("Bearer {}", token)
.parse()
.map_err(|_| OllamaError::ApiError {
status: 0,
message: "Invalid auth token format".to_string(),
})?;
headers.insert(reqwest::header::AUTHORIZATION, auth_value);
client_builder = client_builder.default_headers(headers);
}
let client = client_builder.build().map_err(OllamaError::RequestError)?;
let base_url = Url::parse(&self.base_url).map_err(OllamaError::UrlError)?;
Ok(ModelClient {
client,
base_url,
auth_token: self.auth_token,
})
}
}
impl ModelClient {
pub fn builder() -> ModelClientBuilder {
ModelClientBuilder::new()
}
pub fn base_url(&self) -> &Url {
&self.base_url
}
pub fn is_authenticated(&self) -> bool {
self.auth_token.is_some()
}
pub async fn handle_response<T>(
&self,
response: reqwest::Response,
model: Option<&str>,
) -> Result<T>
where
for<'a> T: serde::Deserialize<'a>,
{
let status = response.status();
if !status.is_success() {
return Err(handle_error_response(response, model).await);
}
response.json().await.map_err(OllamaError::RequestError)
}
pub async fn handle_void_response(&self, response: reqwest::Response) -> Result<()> {
let status = response.status();
if !status.is_success() {
return Err(handle_error_response(response, None).await);
}
Ok(())
}
pub async fn get_version(&self) -> Result<crate::model::VersionResponse> {
let url = self
.base_url
.join("api/version")
.map_err(OllamaError::UrlError)?;
let response = self
.client
.get(url)
.send()
.await
.map_err(OllamaError::RequestError)?;
self.handle_response(response, None).await
}
#[cfg(feature = "local")]
pub async fn blob_exists(&self, digest: &str) -> Result<bool> {
let url = self
.base_url
.join(&format!("api/blobs/{}", digest))
.map_err(OllamaError::UrlError)?;
let response = self
.client
.head(url)
.send()
.await
.map_err(OllamaError::RequestError)?;
match response.status().as_u16() {
200 => Ok(true),
404 => Ok(false),
_ => Err(handle_error_response(response, None).await),
}
}
#[cfg(feature = "local")]
pub async fn push_blob(&self, digest: &str, content: &[u8]) -> Result<()> {
let url = self
.base_url
.join(&format!("api/blobs/{}", digest))
.map_err(OllamaError::UrlError)?;
let response = self
.client
.post(url)
.body(content.to_vec())
.send()
.await
.map_err(OllamaError::RequestError)?;
self.handle_void_response(response).await
}
#[cfg(feature = "local")]
pub async fn load_model(&self, model: &str) -> Result<crate::generate::GenerateResponse> {
let request = crate::generate::GenerateRequest {
model: model.to_string(),
prompt: String::new(),
stream: false,
..Default::default()
};
self.generate(request).await
}
#[cfg(feature = "local")]
pub async fn unload_model(&self, model: &str) -> Result<crate::generate::GenerateResponse> {
let request = crate::generate::GenerateRequest {
model: model.to_string(),
prompt: String::new(),
stream: false,
keep_alive: Some("0".to_string()),
..Default::default()
};
self.generate(request).await
}
#[cfg(feature = "local")]
pub async fn load_model_chat(&self, model: &str) -> Result<crate::chat::ChatResponse> {
let request = crate::chat::ChatRequest {
model: model.to_string(),
messages: vec![],
stream: false,
..Default::default()
};
self.chat(request).await
}
#[cfg(feature = "local")]
pub async fn unload_model_chat(&self, model: &str) -> Result<crate::chat::ChatResponse> {
let request = crate::chat::ChatRequest {
model: model.to_string(),
messages: vec![],
stream: false,
keep_alive: Some("0".to_string()),
..Default::default()
};
self.chat(request).await
}
}