use crate::error::{OllamaError, Result};
use crate::types::{
ChatRequest, ChatResponse, ErrorResponse, GenerateRequest, GenerateResponse, ModelInfo,
ModelList,
};
use reqwest::header::{HeaderMap, HeaderValue, CONTENT_TYPE};
const DEFAULT_HOST: &str = "http://localhost:11434";
pub struct Client {
http: reqwest::Client,
host: String,
}
impl Client {
pub fn builder() -> ClientBuilder {
ClientBuilder::new()
}
pub fn new() -> Self {
ClientBuilder::new().build()
}
pub async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
let url = format!("{}/api/chat", self.host);
self.post(&url, &request).await
}
pub async fn generate(&self, request: GenerateRequest) -> Result<GenerateResponse> {
let url = format!("{}/api/generate", self.host);
self.post(&url, &request).await
}
pub async fn list_models(&self) -> Result<ModelList> {
let url = format!("{}/api/tags", self.host);
self.get(&url).await
}
pub async fn show_model(&self, name: &str) -> Result<ModelInfo> {
let url = format!("{}/api/show", self.host);
let body = serde_json::json!({ "name": name });
self.post(&url, &body).await
}
pub async fn pull_model(&self, name: &str) -> Result<()> {
let url = format!("{}/api/pull", self.host);
let body = serde_json::json!({ "name": name, "stream": false });
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
let response = self
.http
.post(&url)
.headers(headers)
.json(&body)
.send()
.await
.map_err(|e| {
if e.is_connect() {
OllamaError::ConnectionRefused
} else {
OllamaError::Request(e)
}
})?;
let status = response.status();
if status.is_success() {
Ok(())
} else {
let body = response.text().await?;
Err(OllamaError::Api {
status: status.as_u16(),
message: body,
})
}
}
pub async fn delete_model(&self, name: &str) -> Result<()> {
let url = format!("{}/api/delete", self.host);
let body = serde_json::json!({ "name": name });
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
let response = self
.http
.delete(&url)
.headers(headers)
.json(&body)
.send()
.await
.map_err(|e| {
if e.is_connect() {
OllamaError::ConnectionRefused
} else {
OllamaError::Request(e)
}
})?;
let status = response.status();
if status.is_success() {
Ok(())
} else {
let body = response.text().await?;
if status.as_u16() == 404 {
Err(OllamaError::ModelNotFound(name.to_string()))
} else {
Err(OllamaError::Api {
status: status.as_u16(),
message: body,
})
}
}
}
pub async fn is_running(&self) -> bool {
let url = format!("{}/api/tags", self.host);
self.http.get(&url).send().await.is_ok()
}
async fn get<T>(&self, url: &str) -> Result<T>
where
T: serde::de::DeserializeOwned,
{
tracing::debug!(url = %url, "GET request");
let response = self.http.get(url).send().await.map_err(|e| {
if e.is_connect() {
OllamaError::ConnectionRefused
} else {
OllamaError::Request(e)
}
})?;
self.handle_response(response).await
}
async fn post<T, B>(&self, url: &str, body: &B) -> Result<T>
where
T: serde::de::DeserializeOwned,
B: serde::Serialize,
{
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
tracing::debug!(url = %url, "POST request");
let response = self
.http
.post(url)
.headers(headers)
.json(body)
.send()
.await
.map_err(|e| {
if e.is_connect() {
OllamaError::ConnectionRefused
} else {
OllamaError::Request(e)
}
})?;
self.handle_response(response).await
}
async fn handle_response<T>(&self, response: reqwest::Response) -> Result<T>
where
T: serde::de::DeserializeOwned,
{
let status = response.status();
let status_code = status.as_u16();
if status.is_success() {
let body = response.text().await?;
tracing::debug!(status = %status_code, "Response received");
serde_json::from_str(&body).map_err(OllamaError::from)
} else {
let body = response.text().await?;
tracing::warn!(status = %status_code, body = %body, "API error");
if let Ok(error_response) = serde_json::from_str::<ErrorResponse>(&body) {
let message = error_response.error;
return Err(if message.contains("not found") {
OllamaError::ModelNotFound(message)
} else {
OllamaError::Api {
status: status_code,
message,
}
});
}
Err(OllamaError::Api {
status: status_code,
message: body,
})
}
}
}
impl Default for Client {
fn default() -> Self {
Self::new()
}
}
pub struct ClientBuilder {
host: String,
}
impl ClientBuilder {
pub fn new() -> Self {
Self {
host: DEFAULT_HOST.to_string(),
}
}
pub fn host(mut self, host: impl Into<String>) -> Self {
self.host = host.into();
self
}
pub fn build(self) -> Client {
Client {
http: reqwest::Client::new(),
host: self.host,
}
}
}
impl Default for ClientBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::Message;
#[test]
fn test_builder() {
let client = Client::builder().host("http://custom:8080").build();
assert_eq!(client.host, "http://custom:8080");
}
#[test]
fn test_default_host() {
let client = Client::new();
assert_eq!(client.host, "http://localhost:11434");
}
#[test]
fn test_chat_request() {
let request = ChatRequest::new("llama3.2", vec![Message::user("Hello")]);
assert_eq!(request.model, "llama3.2");
assert!(!request.stream);
}
}