use std::time::Duration;
use futures_core::Stream;
use url::Url;
use crate::error::{OllamaError, Result};
use crate::streaming::ndjson_stream;
use crate::types::chat::{ChatRequest, ChatResponse, ChatStreamChunk};
use crate::types::common::{Message, Options, Tool};
use crate::types::embed::{EmbedInput, EmbedRequest, EmbedResponse};
use crate::types::generate::{GenerateRequest, GenerateResponse, GenerateStreamChunk};
use crate::types::models::*;
const DEFAULT_BASE_URL: &str = "http://localhost:11434";
const DEFAULT_TIMEOUT_SECS: u64 = 120;
const DEFAULT_CONNECT_TIMEOUT_SECS: u64 = 30;
const MAX_ERROR_MESSAGE_LEN: usize = 1024;
#[derive(Debug)]
pub struct OllamaClient {
http: reqwest::Client,
base_url: Url,
}
impl OllamaClient {
pub fn new() -> Self {
Self::try_new().expect("failed to create OllamaClient")
}
pub fn try_new() -> Result<Self> {
Self::try_with_base_url(DEFAULT_BASE_URL)
}
pub fn with_base_url(base_url: impl Into<String>) -> Self {
Self::try_with_base_url(base_url).expect("invalid base URL")
}
pub fn try_with_base_url(base_url: impl Into<String>) -> Result<Self> {
let base_url = parse_base_url(base_url.into())?;
let http = reqwest::Client::builder()
.timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS))
.connect_timeout(Duration::from_secs(DEFAULT_CONNECT_TIMEOUT_SECS))
.build()
.map_err(OllamaError::Http)?;
Ok(Self { http, base_url })
}
pub fn from_reqwest(client: reqwest::Client, base_url: impl Into<String>) -> Self {
Self::try_from_reqwest(client, base_url).expect("invalid base URL")
}
pub fn try_from_reqwest(client: reqwest::Client, base_url: impl Into<String>) -> Result<Self> {
let base_url = parse_base_url(base_url.into())?;
Ok(Self {
http: client,
base_url,
})
}
pub fn builder() -> OllamaClientBuilder {
OllamaClientBuilder::default()
}
pub(crate) fn url(&self, path: &str) -> String {
let mut url = self.base_url.clone();
url.set_path(path);
url.to_string()
}
pub(crate) fn http(&self) -> &reqwest::Client {
&self.http
}
pub fn chat(&self) -> ChatRequestBuilder<'_> {
ChatRequestBuilder {
client: self,
model: String::new(),
messages: Vec::new(),
format: None,
tools: None,
think: None,
options: None,
keep_alive: None,
}
}
pub fn generate(&self) -> GenerateRequestBuilder<'_> {
GenerateRequestBuilder {
client: self,
model: String::new(),
prompt: None,
suffix: None,
images: None,
format: None,
system: None,
think: None,
raw: None,
keep_alive: None,
options: None,
context: None,
}
}
pub fn embed(&self) -> EmbedRequestBuilder<'_> {
EmbedRequestBuilder {
client: self,
model: String::new(),
input: EmbedInput::Single(String::new()),
truncate: None,
dimensions: None,
keep_alive: None,
options: None,
}
}
pub async fn list_models(&self) -> Result<ListModelsResponse> {
let response = self.http.get(self.url("/api/tags")).send().await?;
let response = check_api_error(response).await?;
Ok(response.json().await?)
}
pub fn show_model(&self) -> ShowModelRequestBuilder<'_> {
ShowModelRequestBuilder {
client: self,
model: String::new(),
verbose: None,
}
}
pub async fn copy_model(
&self,
source: impl Into<String>,
destination: impl Into<String>,
) -> Result<()> {
let request = CopyModelRequest {
source: source.into(),
destination: destination.into(),
};
let response = self
.http
.post(self.url("/api/copy"))
.json(&request)
.send()
.await?;
check_api_error(response).await?;
Ok(())
}
pub async fn delete_model(&self, model: impl Into<String>) -> Result<()> {
let request = DeleteModelRequest {
model: model.into(),
};
let response = self
.http
.delete(self.url("/api/delete"))
.json(&request)
.send()
.await?;
check_api_error(response).await?;
Ok(())
}
pub fn pull_model(&self) -> PullModelRequestBuilder<'_> {
PullModelRequestBuilder {
client: self,
model: String::new(),
insecure: None,
}
}
pub fn push_model(&self) -> PushModelRequestBuilder<'_> {
PushModelRequestBuilder {
client: self,
model: String::new(),
insecure: None,
}
}
pub fn create_model(&self) -> CreateModelRequestBuilder<'_> {
CreateModelRequestBuilder {
client: self,
model: String::new(),
from: None,
system: None,
template: None,
parameters: None,
quantize: None,
license: None,
messages: None,
}
}
pub async fn list_running(&self) -> Result<ListRunningResponse> {
let response = self.http.get(self.url("/api/ps")).send().await?;
let response = check_api_error(response).await?;
Ok(response.json().await?)
}
pub async fn check_blob(&self, digest: &str) -> Result<bool> {
validate_digest(digest)?;
let response = self
.http
.head(self.url(&format!("/api/blobs/{digest}")))
.send()
.await?;
match response.status().as_u16() {
200 => Ok(true),
404 => Ok(false),
_ => {
check_api_error(response).await?;
Ok(false)
}
}
}
pub async fn upload_blob(&self, digest: &str, data: Vec<u8>) -> Result<()> {
validate_digest(digest)?;
let response = self
.http
.post(self.url(&format!("/api/blobs/{digest}")))
.body(data)
.send()
.await?;
check_api_error(response).await?;
Ok(())
}
pub async fn version(&self) -> Result<VersionResponse> {
let response = self.http.get(self.url("/api/version")).send().await?;
let response = check_api_error(response).await?;
Ok(response.json().await?)
}
}
impl Default for OllamaClient {
fn default() -> Self {
Self::new()
}
}
#[must_use]
#[derive(Debug)]
pub struct ChatRequestBuilder<'a> {
client: &'a OllamaClient,
model: String,
messages: Vec<Message>,
format: Option<serde_json::Value>,
tools: Option<Vec<Tool>>,
think: Option<serde_json::Value>,
options: Option<Options>,
keep_alive: Option<String>,
}
impl<'a> ChatRequestBuilder<'a> {
pub fn model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self
}
pub fn messages(mut self, messages: Vec<Message>) -> Self {
self.messages = messages;
self
}
pub fn format(mut self, format: serde_json::Value) -> Self {
self.format = Some(format);
self
}
pub fn tools(mut self, tools: Vec<Tool>) -> Self {
self.tools = Some(tools);
self
}
pub fn think(mut self, think: bool) -> Self {
self.think = Some(serde_json::Value::Bool(think));
self
}
pub fn options(mut self, options: Options) -> Self {
self.options = Some(options);
self
}
pub fn temperature(mut self, temp: f64) -> Self {
self.options
.get_or_insert_with(Options::default)
.temperature = Some(temp);
self
}
pub fn top_k(mut self, top_k: u32) -> Self {
self.options.get_or_insert_with(Options::default).top_k = Some(top_k);
self
}
pub fn top_p(mut self, top_p: f64) -> Self {
self.options.get_or_insert_with(Options::default).top_p = Some(top_p);
self
}
pub fn num_ctx(mut self, num_ctx: u32) -> Self {
self.options.get_or_insert_with(Options::default).num_ctx = Some(num_ctx);
self
}
pub fn keep_alive(mut self, keep_alive: impl Into<String>) -> Self {
self.keep_alive = Some(keep_alive.into());
self
}
fn build_request(&self, stream: bool) -> ChatRequest {
ChatRequest {
model: self.model.clone(),
messages: self.messages.clone(),
stream: Some(stream),
format: self.format.clone(),
tools: self.tools.clone(),
think: self.think.clone(),
options: self.options.clone(),
keep_alive: self.keep_alive.clone(),
}
}
pub async fn send(self) -> Result<ChatResponse> {
let request = self.build_request(false);
let response = self
.client
.http()
.post(self.client.url("/api/chat"))
.json(&request)
.send()
.await?;
let response = check_api_error(response).await?;
let body = response.text().await?;
Ok(serde_json::from_str(&body)?)
}
pub async fn send_stream(self) -> Result<impl Stream<Item = Result<ChatStreamChunk>>> {
let request = self.build_request(true);
let response = self
.client
.http()
.post(self.client.url("/api/chat"))
.json(&request)
.send()
.await?;
let response = check_api_error(response).await?;
Ok(ndjson_stream(response))
}
}
#[must_use]
#[derive(Debug)]
pub struct GenerateRequestBuilder<'a> {
client: &'a OllamaClient,
model: String,
prompt: Option<String>,
suffix: Option<String>,
images: Option<Vec<String>>,
format: Option<serde_json::Value>,
system: Option<String>,
think: Option<serde_json::Value>,
raw: Option<bool>,
keep_alive: Option<String>,
options: Option<Options>,
context: Option<Vec<i64>>,
}
impl<'a> GenerateRequestBuilder<'a> {
pub fn model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self
}
pub fn prompt(mut self, prompt: impl Into<String>) -> Self {
self.prompt = Some(prompt.into());
self
}
pub fn suffix(mut self, suffix: impl Into<String>) -> Self {
self.suffix = Some(suffix.into());
self
}
pub fn images(mut self, images: Vec<String>) -> Self {
self.images = Some(images);
self
}
pub fn format(mut self, format: serde_json::Value) -> Self {
self.format = Some(format);
self
}
pub fn system(mut self, system: impl Into<String>) -> Self {
self.system = Some(system.into());
self
}
pub fn think(mut self, think: bool) -> Self {
self.think = Some(serde_json::Value::Bool(think));
self
}
pub fn raw(mut self, raw: bool) -> Self {
self.raw = Some(raw);
self
}
pub fn keep_alive(mut self, keep_alive: impl Into<String>) -> Self {
self.keep_alive = Some(keep_alive.into());
self
}
pub fn options(mut self, options: Options) -> Self {
self.options = Some(options);
self
}
pub fn temperature(mut self, temp: f64) -> Self {
self.options
.get_or_insert_with(Options::default)
.temperature = Some(temp);
self
}
pub fn context(mut self, context: Vec<i64>) -> Self {
self.context = Some(context);
self
}
fn build_request(&self, stream: bool) -> GenerateRequest {
GenerateRequest {
model: self.model.clone(),
prompt: self.prompt.clone(),
suffix: self.suffix.clone(),
images: self.images.clone(),
format: self.format.clone(),
system: self.system.clone(),
stream: Some(stream),
think: self.think.clone(),
raw: self.raw,
keep_alive: self.keep_alive.clone(),
options: self.options.clone(),
context: self.context.clone(),
}
}
pub async fn send(self) -> Result<GenerateResponse> {
let request = self.build_request(false);
let response = self
.client
.http()
.post(self.client.url("/api/generate"))
.json(&request)
.send()
.await?;
let response = check_api_error(response).await?;
let body = response.text().await?;
Ok(serde_json::from_str(&body)?)
}
pub async fn send_stream(self) -> Result<impl Stream<Item = Result<GenerateStreamChunk>>> {
let request = self.build_request(true);
let response = self
.client
.http()
.post(self.client.url("/api/generate"))
.json(&request)
.send()
.await?;
let response = check_api_error(response).await?;
Ok(ndjson_stream(response))
}
}
#[must_use]
#[derive(Debug)]
pub struct EmbedRequestBuilder<'a> {
client: &'a OllamaClient,
model: String,
input: EmbedInput,
truncate: Option<bool>,
dimensions: Option<u32>,
keep_alive: Option<String>,
options: Option<Options>,
}
impl<'a> EmbedRequestBuilder<'a> {
pub fn model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self
}
pub fn input(mut self, input: impl Into<EmbedInput>) -> Self {
self.input = input.into();
self
}
pub fn truncate(mut self, truncate: bool) -> Self {
self.truncate = Some(truncate);
self
}
pub fn dimensions(mut self, dimensions: u32) -> Self {
self.dimensions = Some(dimensions);
self
}
pub fn keep_alive(mut self, keep_alive: impl Into<String>) -> Self {
self.keep_alive = Some(keep_alive.into());
self
}
pub fn options(mut self, options: Options) -> Self {
self.options = Some(options);
self
}
pub async fn send(self) -> Result<EmbedResponse> {
let request = EmbedRequest {
model: self.model,
input: self.input,
truncate: self.truncate,
dimensions: self.dimensions,
keep_alive: self.keep_alive,
options: self.options,
};
let response = self
.client
.http()
.post(self.client.url("/api/embed"))
.json(&request)
.send()
.await?;
let response = check_api_error(response).await?;
let body = response.text().await?;
Ok(serde_json::from_str(&body)?)
}
}
#[must_use]
#[derive(Debug)]
pub struct ShowModelRequestBuilder<'a> {
client: &'a OllamaClient,
model: String,
verbose: Option<bool>,
}
impl<'a> ShowModelRequestBuilder<'a> {
pub fn model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self
}
pub fn verbose(mut self, verbose: bool) -> Self {
self.verbose = Some(verbose);
self
}
pub async fn send(self) -> Result<ShowModelResponse> {
let request = ShowModelRequest {
model: self.model,
verbose: self.verbose,
};
let response = self
.client
.http()
.post(self.client.url("/api/show"))
.json(&request)
.send()
.await?;
let response = check_api_error(response).await?;
let body = response.text().await?;
Ok(serde_json::from_str(&body)?)
}
}
#[must_use]
#[derive(Debug)]
pub struct PullModelRequestBuilder<'a> {
client: &'a OllamaClient,
model: String,
insecure: Option<bool>,
}
impl<'a> PullModelRequestBuilder<'a> {
pub fn model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self
}
pub fn insecure(mut self, insecure: bool) -> Self {
self.insecure = Some(insecure);
self
}
pub async fn send(self) -> Result<PullModelStatus> {
let request = PullModelRequest {
model: self.model,
insecure: self.insecure,
stream: Some(false),
};
let response = self
.client
.http()
.post(self.client.url("/api/pull"))
.json(&request)
.send()
.await?;
let response = check_api_error(response).await?;
let body = response.text().await?;
Ok(serde_json::from_str(&body)?)
}
pub async fn send_stream(self) -> Result<impl Stream<Item = Result<PullModelStatus>>> {
let request = PullModelRequest {
model: self.model,
insecure: self.insecure,
stream: Some(true),
};
let response = self
.client
.http()
.post(self.client.url("/api/pull"))
.json(&request)
.send()
.await?;
let response = check_api_error(response).await?;
Ok(ndjson_stream(response))
}
}
#[must_use]
#[derive(Debug)]
pub struct PushModelRequestBuilder<'a> {
client: &'a OllamaClient,
model: String,
insecure: Option<bool>,
}
impl<'a> PushModelRequestBuilder<'a> {
pub fn model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self
}
pub fn insecure(mut self, insecure: bool) -> Self {
self.insecure = Some(insecure);
self
}
pub async fn send(self) -> Result<PushModelStatus> {
let request = PushModelRequest {
model: self.model,
insecure: self.insecure,
stream: Some(false),
};
let response = self
.client
.http()
.post(self.client.url("/api/push"))
.json(&request)
.send()
.await?;
let response = check_api_error(response).await?;
let body = response.text().await?;
Ok(serde_json::from_str(&body)?)
}
pub async fn send_stream(self) -> Result<impl Stream<Item = Result<PushModelStatus>>> {
let request = PushModelRequest {
model: self.model,
insecure: self.insecure,
stream: Some(true),
};
let response = self
.client
.http()
.post(self.client.url("/api/push"))
.json(&request)
.send()
.await?;
let response = check_api_error(response).await?;
Ok(ndjson_stream(response))
}
}
#[must_use]
#[derive(Debug)]
pub struct CreateModelRequestBuilder<'a> {
client: &'a OllamaClient,
model: String,
from: Option<String>,
system: Option<String>,
template: Option<String>,
parameters: Option<serde_json::Value>,
quantize: Option<String>,
license: Option<serde_json::Value>,
messages: Option<Vec<Message>>,
}
impl<'a> CreateModelRequestBuilder<'a> {
pub fn model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self
}
pub fn from_model(mut self, from: impl Into<String>) -> Self {
self.from = Some(from.into());
self
}
pub fn system(mut self, system: impl Into<String>) -> Self {
self.system = Some(system.into());
self
}
pub fn template(mut self, template: impl Into<String>) -> Self {
self.template = Some(template.into());
self
}
pub fn parameters(mut self, parameters: serde_json::Value) -> Self {
self.parameters = Some(parameters);
self
}
pub fn quantize(mut self, quantize: impl Into<String>) -> Self {
self.quantize = Some(quantize.into());
self
}
pub fn license(mut self, license: serde_json::Value) -> Self {
self.license = Some(license);
self
}
pub fn messages(mut self, messages: Vec<Message>) -> Self {
self.messages = Some(messages);
self
}
pub async fn send(self) -> Result<CreateModelStatus> {
let request = self.build_request(false);
let response = self
.client
.http()
.post(self.client.url("/api/create"))
.json(&request)
.send()
.await?;
let response = check_api_error(response).await?;
let body = response.text().await?;
Ok(serde_json::from_str(&body)?)
}
pub async fn send_stream(self) -> Result<impl Stream<Item = Result<CreateModelStatus>>> {
let request = self.build_request(true);
let response = self
.client
.http()
.post(self.client.url("/api/create"))
.json(&request)
.send()
.await?;
let response = check_api_error(response).await?;
Ok(ndjson_stream(response))
}
fn build_request(&self, stream: bool) -> CreateModelRequest {
CreateModelRequest {
model: self.model.clone(),
from: self.from.clone(),
system: self.system.clone(),
template: self.template.clone(),
parameters: self.parameters.clone(),
quantize: self.quantize.clone(),
license: self.license.clone(),
messages: self.messages.clone(),
stream: Some(stream),
}
}
}
pub(crate) async fn check_api_error(response: reqwest::Response) -> Result<reqwest::Response> {
let status = response.status();
if status.is_client_error() || status.is_server_error() {
let raw = response
.text()
.await
.unwrap_or_else(|_| "unknown error".to_string());
let message = sanitize_error_message(&raw);
return Err(OllamaError::Api {
status: status.as_u16(),
message,
});
}
Ok(response)
}
fn sanitize_error_message(raw: &str) -> String {
let truncated = if raw.len() > MAX_ERROR_MESSAGE_LEN {
let mut end = MAX_ERROR_MESSAGE_LEN;
while !raw.is_char_boundary(end) && end > 0 {
end -= 1;
}
format!("{}...(truncated)", &raw[..end])
} else {
raw.to_string()
};
truncated
.chars()
.filter(|c| !c.is_control() || *c == '\n')
.collect()
}
fn parse_base_url(raw: String) -> Result<Url> {
let trimmed = raw.trim_end_matches('/');
let url =
Url::parse(trimmed).map_err(|e| OllamaError::InvalidBaseUrl(format!("{trimmed}: {e}")))?;
match url.scheme() {
"http" | "https" => Ok(url),
scheme => Err(OllamaError::InvalidBaseUrl(format!(
"unsupported scheme '{scheme}', expected 'http' or 'https'"
))),
}
}
fn validate_digest(digest: &str) -> Result<()> {
if digest.is_empty() {
return Err(OllamaError::InvalidDigest("digest is empty".into()));
}
if !digest.contains(':') {
return Err(OllamaError::InvalidDigest(format!(
"expected '<algorithm>:<hex>' format, got '{digest}'"
)));
}
let bad_chars = ['/', '\\', '?', '#'];
if digest
.chars()
.any(|c| bad_chars.contains(&c) || c.is_control())
{
return Err(OllamaError::InvalidDigest(format!(
"digest contains invalid characters: '{digest}'"
)));
}
Ok(())
}
#[derive(Debug)]
pub struct OllamaClientBuilder {
base_url: String,
timeout: Duration,
connect_timeout: Duration,
}
impl Default for OllamaClientBuilder {
fn default() -> Self {
Self {
base_url: DEFAULT_BASE_URL.to_string(),
timeout: Duration::from_secs(DEFAULT_TIMEOUT_SECS),
connect_timeout: Duration::from_secs(DEFAULT_CONNECT_TIMEOUT_SECS),
}
}
}
impl OllamaClientBuilder {
pub fn base_url(mut self, base_url: impl Into<String>) -> Self {
self.base_url = base_url.into();
self
}
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub fn connect_timeout(mut self, timeout: Duration) -> Self {
self.connect_timeout = timeout;
self
}
pub fn build(self) -> Result<OllamaClient> {
let base_url = parse_base_url(self.base_url)?;
let mut builder = reqwest::Client::builder().connect_timeout(self.connect_timeout);
if self.timeout != Duration::ZERO {
builder = builder.timeout(self.timeout);
}
let http = builder.build().map_err(OllamaError::Http)?;
Ok(OllamaClient { http, base_url })
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_base_url_valid_http() {
assert!(parse_base_url("http://localhost:11434".into()).is_ok());
}
#[test]
fn parse_base_url_valid_https() {
assert!(parse_base_url("https://ollama.example.com".into()).is_ok());
}
#[test]
fn parse_base_url_rejects_ftp() {
assert!(parse_base_url("ftp://evil.com".into()).is_err());
}
#[test]
fn parse_base_url_rejects_garbage() {
assert!(parse_base_url("not a url".into()).is_err());
}
#[test]
fn validate_digest_valid() {
assert!(validate_digest("sha256:abcdef0123456789").is_ok());
}
#[test]
fn validate_digest_empty() {
assert!(validate_digest("").is_err());
}
#[test]
fn validate_digest_no_colon() {
assert!(validate_digest("sha256abcdef").is_err());
}
#[test]
fn validate_digest_with_slash() {
assert!(validate_digest("sha256:abc/def").is_err());
}
#[test]
fn validate_digest_with_query() {
assert!(validate_digest("sha256:abc?x=1").is_err());
}
#[test]
fn sanitize_error_truncation() {
let long = "a".repeat(2000);
let result = sanitize_error_message(&long);
assert!(result.len() < 2000);
assert!(result.ends_with("...(truncated)"));
}
#[test]
fn sanitize_error_strips_control_chars() {
let msg = "error\x00with\x01control\nchars";
let result = sanitize_error_message(msg);
assert_eq!(result, "errorwithcontrol\nchars");
}
#[test]
fn builder_default() {
let client = OllamaClient::builder().build();
assert!(client.is_ok());
}
#[test]
fn builder_custom_timeout() {
let client = OllamaClient::builder()
.timeout(Duration::from_secs(300))
.connect_timeout(Duration::from_secs(10))
.build();
assert!(client.is_ok());
}
#[test]
fn builder_invalid_url() {
let client = OllamaClient::builder().base_url("ftp://bad").build();
assert!(client.is_err());
}
}