use crate::{Error, Result};
use reqwest::Client;
use std::sync::Arc;
use std::time::Duration;
use super::ClientConfig;
#[derive(Clone, Debug)]
pub struct OllamaClient {
pub(super) config: ClientConfig,
pub(super) client: Arc<Client>,
}
impl OllamaClient {
pub fn new(config: ClientConfig) -> Result<Self> {
let client = Client::builder().timeout(config.timeout()).build()?;
Ok(Self {
config,
client: Arc::new(client),
})
}
pub fn with_base_url(base_url: impl Into<String>) -> Result<Self> {
let config = ClientConfig::with_base_url(base_url.into())?;
Self::new(config)
}
pub fn with_base_url_and_timeout(
base_url: impl Into<String>,
timeout: Duration,
) -> Result<Self> {
let config = ClientConfig::with_base_url_and_timeout(base_url.into(), timeout)?;
Self::new(config)
}
#[allow(clippy::should_implement_trait)]
pub fn default() -> Result<Self> {
Self::new(ClientConfig::default())
}
pub(super) async fn get_with_retry<T>(&self, url: &str) -> Result<T>
where
T: serde::de::DeserializeOwned,
{
for attempt in 0..=self.config.max_retries() {
match self.client.get(url).send().await {
Ok(response) => {
if response.status().is_server_error() && attempt < self.config.max_retries() {
tokio::time::sleep(Duration::from_millis(100 * (attempt as u64 + 1))).await;
continue;
}
let result = response.json::<T>().await?;
return Ok(result);
}
Err(_e) => {
if attempt < self.config.max_retries() {
tokio::time::sleep(Duration::from_millis(100 * (attempt as u64 + 1))).await;
}
}
}
}
Err(Error::MaxRetriesExceededError(self.config.max_retries()))
}
pub(super) fn get_blocking_with_retry<T>(&self, url: &str) -> Result<T>
where
T: serde::de::DeserializeOwned,
{
let blocking_client = reqwest::blocking::Client::builder()
.timeout(self.config.timeout())
.build()?;
for attempt in 0..=self.config.max_retries() {
match blocking_client.get(url).send() {
Ok(response) => {
if response.status().is_server_error() && attempt < self.config.max_retries() {
std::thread::sleep(Duration::from_millis(100 * (attempt as u64 + 1)));
continue;
}
let result = response.json::<T>()?;
return Ok(result);
}
Err(_e) => {
if attempt < self.config.max_retries() {
std::thread::sleep(Duration::from_millis(100 * (attempt as u64 + 1)));
}
}
}
}
Err(Error::MaxRetriesExceededError(self.config.max_retries()))
}
pub(super) async fn post_with_retry<R, T>(&self, url: &str, body: &R) -> Result<T>
where
R: serde::Serialize,
T: serde::de::DeserializeOwned,
{
for attempt in 0..=self.config.max_retries() {
match self.client.post(url).json(body).send().await {
Ok(response) => {
if response.status().is_server_error() && attempt < self.config.max_retries() {
tokio::time::sleep(Duration::from_millis(100 * (attempt as u64 + 1))).await;
continue;
}
if response.status().is_client_error() {
return Err(Error::HttpStatusError(response.status().as_u16()));
}
let result = response.json::<T>().await?;
return Ok(result);
}
Err(_e) => {
if attempt < self.config.max_retries() {
tokio::time::sleep(Duration::from_millis(100 * (attempt as u64 + 1))).await;
}
}
}
}
Err(Error::MaxRetriesExceededError(self.config.max_retries()))
}
pub(super) fn post_blocking_with_retry<R, T>(&self, url: &str, body: &R) -> Result<T>
where
R: serde::Serialize,
T: serde::de::DeserializeOwned,
{
let blocking_client = reqwest::blocking::Client::builder()
.timeout(self.config.timeout())
.build()?;
for attempt in 0..=self.config.max_retries() {
match blocking_client.post(url).json(body).send() {
Ok(response) => {
if response.status().is_server_error() && attempt < self.config.max_retries() {
std::thread::sleep(Duration::from_millis(100 * (attempt as u64 + 1)));
continue;
}
if response.status().is_client_error() {
return Err(Error::HttpStatusError(response.status().as_u16()));
}
let result = response.json::<T>()?;
return Ok(result);
}
Err(_e) => {
if attempt < self.config.max_retries() {
std::thread::sleep(Duration::from_millis(100 * (attempt as u64 + 1)));
}
}
}
}
Err(Error::MaxRetriesExceededError(self.config.max_retries()))
}
#[cfg(feature = "model")]
pub(super) async fn post_empty_with_retry<R>(&self, url: &str, body: &R) -> Result<()>
where
R: serde::Serialize,
{
for attempt in 0..=self.config.max_retries() {
match self.client.post(url).json(body).send().await {
Ok(response) => {
if response.status().is_server_error() && attempt < self.config.max_retries() {
tokio::time::sleep(Duration::from_millis(100 * (attempt as u64 + 1))).await;
continue;
}
if response.status().is_success() {
return Ok(());
}
return Err(Error::HttpStatusError(response.status().as_u16()));
}
Err(_e) => {
if attempt < self.config.max_retries() {
tokio::time::sleep(Duration::from_millis(100 * (attempt as u64 + 1))).await;
}
}
}
}
Err(Error::MaxRetriesExceededError(self.config.max_retries()))
}
#[cfg(feature = "model")]
pub(super) fn post_empty_blocking_with_retry<R>(&self, url: &str, body: &R) -> Result<()>
where
R: serde::Serialize,
{
let blocking_client = reqwest::blocking::Client::builder()
.timeout(self.config.timeout())
.build()?;
for attempt in 0..=self.config.max_retries() {
match blocking_client.post(url).json(body).send() {
Ok(response) => {
if response.status().is_server_error() && attempt < self.config.max_retries() {
std::thread::sleep(Duration::from_millis(100 * (attempt as u64 + 1)));
continue;
}
if response.status().is_success() {
return Ok(());
}
return Err(Error::HttpStatusError(response.status().as_u16()));
}
Err(_e) => {
if attempt < self.config.max_retries() {
std::thread::sleep(Duration::from_millis(100 * (attempt as u64 + 1)));
}
}
}
}
Err(Error::MaxRetriesExceededError(self.config.max_retries()))
}
#[cfg(feature = "model")]
pub(super) async fn delete_empty_with_retry<R>(&self, url: &str, body: &R) -> Result<()>
where
R: serde::Serialize,
{
for attempt in 0..=self.config.max_retries() {
match self.client.delete(url).json(body).send().await {
Ok(response) => {
if response.status().is_server_error() && attempt < self.config.max_retries() {
tokio::time::sleep(Duration::from_millis(100 * (attempt as u64 + 1))).await;
continue;
}
if response.status().is_success() {
return Ok(());
}
return Err(Error::HttpStatusError(response.status().as_u16()));
}
Err(_e) => {
if attempt < self.config.max_retries() {
tokio::time::sleep(Duration::from_millis(100 * (attempt as u64 + 1))).await;
}
}
}
}
Err(Error::MaxRetriesExceededError(self.config.max_retries()))
}
#[cfg(feature = "model")]
pub(super) fn delete_empty_blocking_with_retry<R>(&self, url: &str, body: &R) -> Result<()>
where
R: serde::Serialize,
{
let blocking_client = reqwest::blocking::Client::builder()
.timeout(self.config.timeout())
.build()?;
for attempt in 0..=self.config.max_retries() {
match blocking_client.delete(url).json(body).send() {
Ok(response) => {
if response.status().is_server_error() && attempt < self.config.max_retries() {
std::thread::sleep(Duration::from_millis(100 * (attempt as u64 + 1)));
continue;
}
if response.status().is_success() {
return Ok(());
}
return Err(Error::HttpStatusError(response.status().as_u16()));
}
Err(_e) => {
if attempt < self.config.max_retries() {
std::thread::sleep(Duration::from_millis(100 * (attempt as u64 + 1)));
}
}
}
}
Err(Error::MaxRetriesExceededError(self.config.max_retries()))
}
}