Skip to main content

ollama_kit/
guard.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use ollama_rs::error::Result as OllamaResult;
5use tokio::sync::Semaphore;
6
7#[cfg(feature = "stream")]
8use std::pin::Pin;
9#[cfg(feature = "stream")]
10use std::task::{Context, Poll};
11
12#[cfg(feature = "stream")]
13use pin_project_lite::pin_project;
14#[cfg(feature = "stream")]
15use tokio::sync::OwnedSemaphorePermit;
16#[cfg(feature = "stream")]
17use tokio_stream::Stream;
18
19use crate::error::{
20    map_ollama_error, ollama_error_is_retryable, runtime_error_is_retryable, Result, RuntimeError,
21};
22
23/// Limits concurrent Ollama calls, applies per-attempt timeouts, and retries transient failures.
24pub struct ExecutionGuard {
25    semaphore: Arc<Semaphore>,
26    timeout: Duration,
27    max_retries: usize,
28}
29
30impl ExecutionGuard {
31    pub(crate) fn new(max_concurrent: usize, timeout: Duration, max_retries: usize) -> Self {
32        Self {
33            semaphore: Arc::new(Semaphore::new(max_concurrent)),
34            timeout,
35            max_retries,
36        }
37    }
38
39    /// Total attempts = `max_retries + 1` (the first try plus `max_retries` retries).
40    pub fn max_retries(&self) -> usize {
41        self.max_retries
42    }
43
44    /// Runs a closure returning an ollama-rs future: acquires a permit, applies timeout, maps
45    /// errors to [`RuntimeError`], and retries only when the failure is classified as transient.
46    pub async fn run<F, Fut, T>(&self, f: F) -> Result<T>
47    where
48        F: Fn() -> Fut,
49        Fut: std::future::Future<Output = OllamaResult<T>>,
50    {
51        let attempts = self.max_retries.saturating_add(1).max(1);
52
53        for attempt in 0..attempts {
54            let _permit = self
55                .semaphore
56                .acquire()
57                .await
58                .map_err(|_| RuntimeError::Other("execution semaphore closed".into()))?;
59
60            match tokio::time::timeout(self.timeout, f()).await {
61                Ok(Ok(value)) => return Ok(value),
62                Ok(Err(e)) => {
63                    if ollama_error_is_retryable(&e) && attempt + 1 < attempts {
64                        continue;
65                    }
66                    return Err(map_ollama_error(e));
67                }
68                Err(_elapsed) => {
69                    if runtime_error_is_retryable(&RuntimeError::Timeout) && attempt + 1 < attempts
70                    {
71                        continue;
72                    }
73                    return Err(RuntimeError::Timeout);
74                }
75            }
76        }
77
78        Err(RuntimeError::Other("exhausted retries".into()))
79    }
80}
81
82#[cfg(feature = "stream")]
83pin_project! {
84    /// Holds a concurrency permit for the lifetime of an ollama-rs streaming response.
85    pub struct GuardedStream<S> {
86        #[pin]
87        stream: S,
88        _permit: OwnedSemaphorePermit,
89    }
90}
91
92#[cfg(feature = "stream")]
93impl<S> Stream for GuardedStream<S>
94where
95    S: Stream + Unpin,
96{
97    type Item = S::Item;
98
99    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
100        self.project().stream.poll_next(cx)
101    }
102}
103
104#[cfg(feature = "stream")]
105impl ExecutionGuard {
106    /// Acquires an owned permit for the whole stream, applies a timeout only to stream
107    /// establishment (`f()`), and returns the underlying stream without buffering items.
108    pub async fn run_stream<F, Fut, S>(&self, f: F) -> Result<GuardedStream<S>>
109    where
110        F: Fn() -> Fut,
111        Fut: std::future::Future<Output = OllamaResult<S>>,
112        S: Stream + Unpin,
113    {
114        let permit = self
115            .semaphore
116            .clone()
117            .acquire_owned()
118            .await
119            .map_err(|_| RuntimeError::Other("execution semaphore closed".into()))?;
120
121        let stream = tokio::time::timeout(self.timeout, f())
122            .await
123            .map_err(|_| RuntimeError::Timeout)?
124            .map_err(map_ollama_error)?;
125
126        Ok(GuardedStream {
127            stream,
128            _permit: permit,
129        })
130    }
131}
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136    use std::sync::atomic::{AtomicUsize, Ordering};
137
138    use ollama_rs::error::OllamaError;
139
140    #[tokio::test]
141    async fn run_retries_transient_reqwest_error() {
142        let guard = ExecutionGuard::new(1, Duration::from_secs(5), 2);
143        let n = AtomicUsize::new(0);
144        let out = guard
145            .run(|| async {
146                let i = n.fetch_add(1, Ordering::SeqCst);
147                if i == 0 {
148                    let e = reqwest::get("http://127.0.0.1:1/")
149                        .await
150                        .expect_err("expected connection failure");
151                    Err::<(), _>(OllamaError::from(e))
152                } else {
153                    Ok(())
154                }
155            })
156            .await
157            .unwrap();
158        assert_eq!(out, ());
159        assert_eq!(n.load(Ordering::SeqCst), 2);
160    }
161
162    #[tokio::test]
163    async fn run_does_not_retry_other_error() {
164        let guard = ExecutionGuard::new(1, Duration::from_secs(5), 3);
165        let n = AtomicUsize::new(0);
166        let err = guard
167            .run(|| async {
168                n.fetch_add(1, Ordering::SeqCst);
169                Err::<(), _>(OllamaError::Other("client error shape".into()))
170            })
171            .await
172            .expect_err("expected failure");
173        assert!(matches!(err, RuntimeError::Other(_)));
174        assert_eq!(n.load(Ordering::SeqCst), 1);
175    }
176}