Skip to main content

ollama_kit/
guard.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use ollama_rs::error::Result as OllamaResult;
5use tokio::sync::Semaphore;
6
7#[cfg(feature = "stream")]
8use std::pin::Pin;
9#[cfg(feature = "stream")]
10use std::task::{Context, Poll};
11
12#[cfg(feature = "stream")]
13use pin_project_lite::pin_project;
14#[cfg(feature = "stream")]
15use tokio::sync::OwnedSemaphorePermit;
16#[cfg(feature = "stream")]
17use tokio_stream::Stream;
18
19use crate::error::{
20    map_ollama_error, ollama_error_is_retryable, runtime_error_is_retryable, Result, RuntimeError,
21};
22
23/// Per-attempt timeouts, concurrency cap, retries on transient errors.
24pub struct ExecutionGuard {
25    semaphore: Arc<Semaphore>,
26    timeout: Duration,
27    max_retries: usize,
28}
29
30impl ExecutionGuard {
31    pub(crate) fn new(max_concurrent: usize, timeout: Duration, max_retries: usize) -> Self {
32        Self {
33            semaphore: Arc::new(Semaphore::new(max_concurrent)),
34            timeout,
35            max_retries,
36        }
37    }
38
39    /// Uses `max_retries + 1` attempts total.
40    pub fn max_retries(&self) -> usize {
41        self.max_retries
42    }
43
44    /// Acquires semaphore permit and timeout-wraps each try; retries only transient failures.
45    pub async fn run<F, Fut, T>(&self, f: F) -> Result<T>
46    where
47        F: Fn() -> Fut,
48        Fut: std::future::Future<Output = OllamaResult<T>>,
49    {
50        let attempts = self.max_retries.saturating_add(1).max(1);
51
52        for attempt in 0..attempts {
53            let _permit = self
54                .semaphore
55                .acquire()
56                .await
57                .map_err(|_| RuntimeError::Other("execution semaphore closed".into()))?;
58
59            match tokio::time::timeout(self.timeout, f()).await {
60                Ok(Ok(value)) => return Ok(value),
61                Ok(Err(e)) => {
62                    if ollama_error_is_retryable(&e) && attempt + 1 < attempts {
63                        continue;
64                    }
65                    return Err(map_ollama_error(e));
66                }
67                Err(_elapsed) => {
68                    if runtime_error_is_retryable(&RuntimeError::Timeout) && attempt + 1 < attempts
69                    {
70                        continue;
71                    }
72                    return Err(RuntimeError::Timeout);
73                }
74            }
75        }
76
77        Err(RuntimeError::Other("exhausted retries".into()))
78    }
79}
80
81#[cfg(feature = "stream")]
82pin_project! {
83    pub struct GuardedStream<S> {
84        #[pin]
85        stream: S,
86        _permit: OwnedSemaphorePermit,
87    }
88}
89
90#[cfg(feature = "stream")]
91impl<S> Stream for GuardedStream<S>
92where
93    S: Stream + Unpin,
94{
95    type Item = S::Item;
96
97    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
98        self.project().stream.poll_next(cx)
99    }
100}
101
102#[cfg(feature = "stream")]
103impl ExecutionGuard {
104    /// Timeout on stream setup only; permit held until drop.
105    pub async fn run_stream<F, Fut, S>(&self, f: F) -> Result<GuardedStream<S>>
106    where
107        F: Fn() -> Fut,
108        Fut: std::future::Future<Output = OllamaResult<S>>,
109        S: Stream + Unpin,
110    {
111        let permit = self
112            .semaphore
113            .clone()
114            .acquire_owned()
115            .await
116            .map_err(|_| RuntimeError::Other("execution semaphore closed".into()))?;
117
118        let stream = tokio::time::timeout(self.timeout, f())
119            .await
120            .map_err(|_| RuntimeError::Timeout)?
121            .map_err(map_ollama_error)?;
122
123        Ok(GuardedStream {
124            stream,
125            _permit: permit,
126        })
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133    use std::sync::atomic::{AtomicUsize, Ordering};
134
135    use ollama_rs::error::OllamaError;
136
137    #[tokio::test]
138    async fn run_retries_transient_reqwest_error() {
139        let guard = ExecutionGuard::new(1, Duration::from_secs(5), 2);
140        let n = AtomicUsize::new(0);
141        let out = guard
142            .run(|| async {
143                let i = n.fetch_add(1, Ordering::SeqCst);
144                if i == 0 {
145                    let e = reqwest::get("http://127.0.0.1:1/")
146                        .await
147                        .expect_err("expected connection failure");
148                    Err::<(), _>(OllamaError::from(e))
149                } else {
150                    Ok(())
151                }
152            })
153            .await
154            .unwrap();
155        assert_eq!(out, ());
156        assert_eq!(n.load(Ordering::SeqCst), 2);
157    }
158
159    #[tokio::test]
160    async fn run_does_not_retry_other_error() {
161        let guard = ExecutionGuard::new(1, Duration::from_secs(5), 3);
162        let n = AtomicUsize::new(0);
163        let err = guard
164            .run(|| async {
165                n.fetch_add(1, Ordering::SeqCst);
166                Err::<(), _>(OllamaError::Other("client error shape".into()))
167            })
168            .await
169            .expect_err("expected failure");
170        assert!(matches!(err, RuntimeError::Other(_)));
171        assert_eq!(n.load(Ordering::SeqCst), 1);
172    }
173}