use crate::{Error, Result};
use reqwest::Client;
use std::sync::Arc;
use std::time::Duration;
use super::ClientConfig;
use super::streaming::ChatStreamBlocking;
const STREAMING_TIMEOUT: Duration = Duration::from_secs(300);
fn trim_bytes(mut s: &[u8]) -> &[u8] {
while s.first().is_some_and(|b| b.is_ascii_whitespace()) {
s = &s[1..];
}
while s.last().is_some_and(|b| b.is_ascii_whitespace()) {
s = &s[..s.len() - 1];
}
s
}
#[derive(Clone, Debug)]
pub struct OllamaClient {
pub(super) config: ClientConfig,
pub(super) client: Arc<Client>,
}
impl OllamaClient {
pub fn new(config: ClientConfig) -> Result<Self> {
let client = Client::builder().timeout(config.timeout()).build()?;
Ok(Self {
config,
client: Arc::new(client),
})
}
pub fn with_base_url(base_url: impl Into<String>) -> Result<Self> {
let config = ClientConfig::with_base_url(base_url.into())?;
Self::new(config)
}
pub fn with_base_url_and_timeout(
base_url: impl Into<String>,
timeout: Duration,
) -> Result<Self> {
let config = ClientConfig::with_base_url_and_timeout(base_url.into(), timeout)?;
Self::new(config)
}
#[allow(clippy::should_implement_trait)]
pub fn default() -> Result<Self> {
Self::new(ClientConfig::default())
}
pub(super) async fn get_with_retry<T>(&self, url: &str) -> Result<T>
where
T: serde::de::DeserializeOwned,
{
for attempt in 0..=self.config.max_retries() {
match self.client.get(url).send().await {
Ok(response) => {
if response.status().is_server_error() && attempt < self.config.max_retries() {
tokio::time::sleep(Duration::from_millis(100 * (attempt as u64 + 1))).await;
continue;
}
let result = response.json::<T>().await?;
return Ok(result);
}
Err(_e) => {
if attempt < self.config.max_retries() {
tokio::time::sleep(Duration::from_millis(100 * (attempt as u64 + 1))).await;
}
}
}
}
Err(Error::MaxRetriesExceededError(self.config.max_retries()))
}
pub(super) fn get_blocking_with_retry<T>(&self, url: &str) -> Result<T>
where
T: serde::de::DeserializeOwned,
{
let blocking_client = reqwest::blocking::Client::builder()
.timeout(self.config.timeout())
.build()?;
for attempt in 0..=self.config.max_retries() {
match blocking_client.get(url).send() {
Ok(response) => {
if response.status().is_server_error() && attempt < self.config.max_retries() {
std::thread::sleep(Duration::from_millis(100 * (attempt as u64 + 1)));
continue;
}
let result = response.json::<T>()?;
return Ok(result);
}
Err(_e) => {
if attempt < self.config.max_retries() {
std::thread::sleep(Duration::from_millis(100 * (attempt as u64 + 1)));
}
}
}
}
Err(Error::MaxRetriesExceededError(self.config.max_retries()))
}
pub(super) async fn post_with_retry<R, T>(&self, url: &str, body: &R) -> Result<T>
where
R: serde::Serialize,
T: serde::de::DeserializeOwned,
{
for attempt in 0..=self.config.max_retries() {
match self.client.post(url).json(body).send().await {
Ok(response) => {
if response.status().is_server_error() && attempt < self.config.max_retries() {
tokio::time::sleep(Duration::from_millis(100 * (attempt as u64 + 1))).await;
continue;
}
if response.status().is_client_error() {
return Err(Error::HttpStatusError(response.status().as_u16()));
}
let result = response.json::<T>().await?;
return Ok(result);
}
Err(_e) => {
if attempt < self.config.max_retries() {
tokio::time::sleep(Duration::from_millis(100 * (attempt as u64 + 1))).await;
}
}
}
}
Err(Error::MaxRetriesExceededError(self.config.max_retries()))
}
pub(super) fn post_blocking_with_retry<R, T>(&self, url: &str, body: &R) -> Result<T>
where
R: serde::Serialize,
T: serde::de::DeserializeOwned,
{
let blocking_client = reqwest::blocking::Client::builder()
.timeout(self.config.timeout())
.build()?;
for attempt in 0..=self.config.max_retries() {
match blocking_client.post(url).json(body).send() {
Ok(response) => {
if response.status().is_server_error() && attempt < self.config.max_retries() {
std::thread::sleep(Duration::from_millis(100 * (attempt as u64 + 1)));
continue;
}
if response.status().is_client_error() {
return Err(Error::HttpStatusError(response.status().as_u16()));
}
let result = response.json::<T>()?;
return Ok(result);
}
Err(_e) => {
if attempt < self.config.max_retries() {
std::thread::sleep(Duration::from_millis(100 * (attempt as u64 + 1)));
}
}
}
}
Err(Error::MaxRetriesExceededError(self.config.max_retries()))
}
#[cfg(feature = "model")]
pub(super) async fn post_empty_with_retry<R>(&self, url: &str, body: &R) -> Result<()>
where
R: serde::Serialize,
{
for attempt in 0..=self.config.max_retries() {
match self.client.post(url).json(body).send().await {
Ok(response) => {
if response.status().is_server_error() && attempt < self.config.max_retries() {
tokio::time::sleep(Duration::from_millis(100 * (attempt as u64 + 1))).await;
continue;
}
if response.status().is_success() {
return Ok(());
}
return Err(Error::HttpStatusError(response.status().as_u16()));
}
Err(_e) => {
if attempt < self.config.max_retries() {
tokio::time::sleep(Duration::from_millis(100 * (attempt as u64 + 1))).await;
}
}
}
}
Err(Error::MaxRetriesExceededError(self.config.max_retries()))
}
#[cfg(feature = "model")]
pub(super) fn post_empty_blocking_with_retry<R>(&self, url: &str, body: &R) -> Result<()>
where
R: serde::Serialize,
{
let blocking_client = reqwest::blocking::Client::builder()
.timeout(self.config.timeout())
.build()?;
for attempt in 0..=self.config.max_retries() {
match blocking_client.post(url).json(body).send() {
Ok(response) => {
if response.status().is_server_error() && attempt < self.config.max_retries() {
std::thread::sleep(Duration::from_millis(100 * (attempt as u64 + 1)));
continue;
}
if response.status().is_success() {
return Ok(());
}
return Err(Error::HttpStatusError(response.status().as_u16()));
}
Err(_e) => {
if attempt < self.config.max_retries() {
std::thread::sleep(Duration::from_millis(100 * (attempt as u64 + 1)));
}
}
}
}
Err(Error::MaxRetriesExceededError(self.config.max_retries()))
}
#[cfg(feature = "model")]
pub(super) async fn delete_empty_with_retry<R>(&self, url: &str, body: &R) -> Result<()>
where
R: serde::Serialize,
{
for attempt in 0..=self.config.max_retries() {
match self.client.delete(url).json(body).send().await {
Ok(response) => {
if response.status().is_server_error() && attempt < self.config.max_retries() {
tokio::time::sleep(Duration::from_millis(100 * (attempt as u64 + 1))).await;
continue;
}
if response.status().is_success() {
return Ok(());
}
return Err(Error::HttpStatusError(response.status().as_u16()));
}
Err(_e) => {
if attempt < self.config.max_retries() {
tokio::time::sleep(Duration::from_millis(100 * (attempt as u64 + 1))).await;
}
}
}
}
Err(Error::MaxRetriesExceededError(self.config.max_retries()))
}
#[cfg(feature = "model")]
pub(super) fn delete_empty_blocking_with_retry<R>(&self, url: &str, body: &R) -> Result<()>
where
R: serde::Serialize,
{
let blocking_client = reqwest::blocking::Client::builder()
.timeout(self.config.timeout())
.build()?;
for attempt in 0..=self.config.max_retries() {
match blocking_client.delete(url).json(body).send() {
Ok(response) => {
if response.status().is_server_error() && attempt < self.config.max_retries() {
std::thread::sleep(Duration::from_millis(100 * (attempt as u64 + 1)));
continue;
}
if response.status().is_success() {
return Ok(());
}
return Err(Error::HttpStatusError(response.status().as_u16()));
}
Err(_e) => {
if attempt < self.config.max_retries() {
std::thread::sleep(Duration::from_millis(100 * (attempt as u64 + 1)));
}
}
}
}
Err(Error::MaxRetriesExceededError(self.config.max_retries()))
}
pub(super) async fn post_ndjson_stream<R, T>(
&self,
url: &str,
body: &R,
) -> Result<tokio::sync::mpsc::Receiver<Result<T>>>
where
R: serde::Serialize + ?Sized,
T: serde::de::DeserializeOwned + Send + 'static,
{
let response = self
.client
.post(url)
.json(body)
.timeout(STREAMING_TIMEOUT)
.send()
.await?;
if !response.status().is_success() {
return Err(Error::HttpStatusError(response.status().as_u16()));
}
let (tx, rx) = tokio::sync::mpsc::channel::<Result<T>>(32);
tokio::spawn(async move {
let mut response = response;
let mut buf: Vec<u8> = Vec::new();
loop {
match response.chunk().await {
Ok(Some(chunk)) => {
buf.extend_from_slice(&chunk);
while let Some(idx) = buf.iter().position(|&b| b == b'\n') {
let mut line: Vec<u8> = buf.drain(..=idx).collect();
if line.last() == Some(&b'\n') {
line.pop();
}
if line.last() == Some(&b'\r') {
line.pop();
}
if line.is_empty() {
continue;
}
match serde_json::from_slice::<T>(&line) {
Ok(v) => {
if tx.send(Ok(v)).await.is_err() {
return;
}
}
Err(e) => {
let _ = tx.send(Err(Error::StreamError(e.to_string()))).await;
return;
}
}
}
}
Ok(None) => {
let trimmed = trim_bytes(&buf);
if !trimmed.is_empty() {
match serde_json::from_slice::<T>(trimmed) {
Ok(v) => {
let _ = tx.send(Ok(v)).await;
}
Err(e) => {
let _ = tx.send(Err(Error::StreamError(e.to_string()))).await;
}
}
}
break;
}
Err(e) => {
let _ = tx.send(Err(Error::StreamError(e.to_string()))).await;
break;
}
}
}
});
Ok(rx)
}
pub(super) fn post_ndjson_stream_blocking<R>(
&self,
url: &str,
body: &R,
) -> Result<ChatStreamBlocking>
where
R: serde::Serialize + ?Sized,
{
let blocking_client = reqwest::blocking::Client::builder()
.timeout(STREAMING_TIMEOUT)
.build()?;
let response = blocking_client.post(url).json(body).send()?;
if !response.status().is_success() {
return Err(Error::HttpStatusError(response.status().as_u16()));
}
Ok(ChatStreamBlocking::new(response))
}
}