ollama-kit 0.2.1

Runtime control (lifecycle + execution guards) for ollama-rs without wrapping its API.
Documentation
use std::sync::Arc;
use std::time::Duration;

use ollama_rs::error::Result as OllamaResult;
use tokio::sync::Semaphore;

#[cfg(feature = "stream")]
use std::pin::Pin;
#[cfg(feature = "stream")]
use std::task::{Context, Poll};

#[cfg(feature = "stream")]
use pin_project_lite::pin_project;
#[cfg(feature = "stream")]
use tokio::sync::OwnedSemaphorePermit;
#[cfg(feature = "stream")]
use tokio_stream::Stream;

use crate::error::{
    map_ollama_error, ollama_error_is_retryable, runtime_error_is_retryable, Result, RuntimeError,
};

/// Per-attempt timeouts, concurrency cap, retries on transient errors.
pub struct ExecutionGuard {
    semaphore: Arc<Semaphore>,
    timeout: Duration,
    max_retries: usize,
}

impl ExecutionGuard {
    pub(crate) fn new(max_concurrent: usize, timeout: Duration, max_retries: usize) -> Self {
        Self {
            semaphore: Arc::new(Semaphore::new(max_concurrent)),
            timeout,
            max_retries,
        }
    }

    /// Uses `max_retries + 1` attempts total.
    pub fn max_retries(&self) -> usize {
        self.max_retries
    }

    /// Acquires semaphore permit and timeout-wraps each try; retries only transient failures.
    pub async fn run<F, Fut, T>(&self, f: F) -> Result<T>
    where
        F: Fn() -> Fut,
        Fut: std::future::Future<Output = OllamaResult<T>>,
    {
        let attempts = self.max_retries.saturating_add(1).max(1);

        for attempt in 0..attempts {
            let _permit = self
                .semaphore
                .acquire()
                .await
                .map_err(|_| RuntimeError::Other("execution semaphore closed".into()))?;

            match tokio::time::timeout(self.timeout, f()).await {
                Ok(Ok(value)) => return Ok(value),
                Ok(Err(e)) => {
                    if ollama_error_is_retryable(&e) && attempt + 1 < attempts {
                        continue;
                    }
                    return Err(map_ollama_error(e));
                }
                Err(_elapsed) => {
                    if runtime_error_is_retryable(&RuntimeError::Timeout) && attempt + 1 < attempts
                    {
                        continue;
                    }
                    return Err(RuntimeError::Timeout);
                }
            }
        }

        Err(RuntimeError::Other("exhausted retries".into()))
    }
}

#[cfg(feature = "stream")]
pin_project! {
    pub struct GuardedStream<S> {
        #[pin]
        stream: S,
        _permit: OwnedSemaphorePermit,
    }
}

#[cfg(feature = "stream")]
impl<S> Stream for GuardedStream<S>
where
    S: Stream + Unpin,
{
    type Item = S::Item;

    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        self.project().stream.poll_next(cx)
    }
}

#[cfg(feature = "stream")]
impl ExecutionGuard {
    /// Timeout on stream setup only; permit held until drop.
    pub async fn run_stream<F, Fut, S>(&self, f: F) -> Result<GuardedStream<S>>
    where
        F: Fn() -> Fut,
        Fut: std::future::Future<Output = OllamaResult<S>>,
        S: Stream + Unpin,
    {
        let permit = self
            .semaphore
            .clone()
            .acquire_owned()
            .await
            .map_err(|_| RuntimeError::Other("execution semaphore closed".into()))?;

        let stream = tokio::time::timeout(self.timeout, f())
            .await
            .map_err(|_| RuntimeError::Timeout)?
            .map_err(map_ollama_error)?;

        Ok(GuardedStream {
            stream,
            _permit: permit,
        })
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::sync::atomic::{AtomicUsize, Ordering};

    use ollama_rs::error::OllamaError;

    #[tokio::test]
    async fn run_retries_transient_reqwest_error() {
        let guard = ExecutionGuard::new(1, Duration::from_secs(5), 2);
        let n = AtomicUsize::new(0);
        let out = guard
            .run(|| async {
                let i = n.fetch_add(1, Ordering::SeqCst);
                if i == 0 {
                    let e = reqwest::get("http://127.0.0.1:1/")
                        .await
                        .expect_err("expected connection failure");
                    Err::<(), _>(OllamaError::from(e))
                } else {
                    Ok(())
                }
            })
            .await
            .unwrap();
        assert_eq!(out, ());
        assert_eq!(n.load(Ordering::SeqCst), 2);
    }

    #[tokio::test]
    async fn run_does_not_retry_other_error() {
        let guard = ExecutionGuard::new(1, Duration::from_secs(5), 3);
        let n = AtomicUsize::new(0);
        let err = guard
            .run(|| async {
                n.fetch_add(1, Ordering::SeqCst);
                Err::<(), _>(OllamaError::Other("client error shape".into()))
            })
            .await
            .expect_err("expected failure");
        assert!(matches!(err, RuntimeError::Other(_)));
        assert_eq!(n.load(Ordering::SeqCst), 1);
    }
}