use std::sync::Arc;
use std::time::Duration;
use ollama_rs::error::Result as OllamaResult;
use tokio::sync::Semaphore;
#[cfg(feature = "stream")]
use std::pin::Pin;
#[cfg(feature = "stream")]
use std::task::{Context, Poll};
#[cfg(feature = "stream")]
use pin_project_lite::pin_project;
#[cfg(feature = "stream")]
use tokio::sync::OwnedSemaphorePermit;
#[cfg(feature = "stream")]
use tokio_stream::Stream;
use crate::error::{
map_ollama_error, ollama_error_is_retryable, runtime_error_is_retryable, Result, RuntimeError,
};
pub struct ExecutionGuard {
semaphore: Arc<Semaphore>,
timeout: Duration,
max_retries: usize,
}
impl ExecutionGuard {
pub(crate) fn new(max_concurrent: usize, timeout: Duration, max_retries: usize) -> Self {
Self {
semaphore: Arc::new(Semaphore::new(max_concurrent)),
timeout,
max_retries,
}
}
pub fn max_retries(&self) -> usize {
self.max_retries
}
pub async fn run<F, Fut, T>(&self, f: F) -> Result<T>
where
F: Fn() -> Fut,
Fut: std::future::Future<Output = OllamaResult<T>>,
{
let attempts = self.max_retries.saturating_add(1).max(1);
for attempt in 0..attempts {
let _permit = self
.semaphore
.acquire()
.await
.map_err(|_| RuntimeError::Other("execution semaphore closed".into()))?;
match tokio::time::timeout(self.timeout, f()).await {
Ok(Ok(value)) => return Ok(value),
Ok(Err(e)) => {
if ollama_error_is_retryable(&e) && attempt + 1 < attempts {
continue;
}
return Err(map_ollama_error(e));
}
Err(_elapsed) => {
if runtime_error_is_retryable(&RuntimeError::Timeout) && attempt + 1 < attempts
{
continue;
}
return Err(RuntimeError::Timeout);
}
}
}
Err(RuntimeError::Other("exhausted retries".into()))
}
}
#[cfg(feature = "stream")]
pin_project! {
pub struct GuardedStream<S> {
#[pin]
stream: S,
_permit: OwnedSemaphorePermit,
}
}
#[cfg(feature = "stream")]
impl<S> Stream for GuardedStream<S>
where
S: Stream + Unpin,
{
type Item = S::Item;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project().stream.poll_next(cx)
}
}
#[cfg(feature = "stream")]
impl ExecutionGuard {
pub async fn run_stream<F, Fut, S>(&self, f: F) -> Result<GuardedStream<S>>
where
F: Fn() -> Fut,
Fut: std::future::Future<Output = OllamaResult<S>>,
S: Stream + Unpin,
{
let permit = self
.semaphore
.clone()
.acquire_owned()
.await
.map_err(|_| RuntimeError::Other("execution semaphore closed".into()))?;
let stream = tokio::time::timeout(self.timeout, f())
.await
.map_err(|_| RuntimeError::Timeout)?
.map_err(map_ollama_error)?;
Ok(GuardedStream {
stream,
_permit: permit,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use ollama_rs::error::OllamaError;
#[tokio::test]
async fn run_retries_transient_reqwest_error() {
let guard = ExecutionGuard::new(1, Duration::from_secs(5), 2);
let n = AtomicUsize::new(0);
let out = guard
.run(|| async {
let i = n.fetch_add(1, Ordering::SeqCst);
if i == 0 {
let e = reqwest::get("http://127.0.0.1:1/")
.await
.expect_err("expected connection failure");
Err::<(), _>(OllamaError::from(e))
} else {
Ok(())
}
})
.await
.unwrap();
assert_eq!(out, ());
assert_eq!(n.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn run_does_not_retry_other_error() {
let guard = ExecutionGuard::new(1, Duration::from_secs(5), 3);
let n = AtomicUsize::new(0);
let err = guard
.run(|| async {
n.fetch_add(1, Ordering::SeqCst);
Err::<(), _>(OllamaError::Other("client error shape".into()))
})
.await
.expect_err("expected failure");
assert!(matches!(err, RuntimeError::Other(_)));
assert_eq!(n.load(Ordering::SeqCst), 1);
}
}