1use std::sync::Arc;
2use std::time::Duration;
3
4use ollama_rs::error::Result as OllamaResult;
5use tokio::sync::Semaphore;
6
7#[cfg(feature = "stream")]
8use std::pin::Pin;
9#[cfg(feature = "stream")]
10use std::task::{Context, Poll};
11
12#[cfg(feature = "stream")]
13use pin_project_lite::pin_project;
14#[cfg(feature = "stream")]
15use tokio::sync::OwnedSemaphorePermit;
16#[cfg(feature = "stream")]
17use tokio_stream::Stream;
18
19use crate::error::{
20 map_ollama_error, ollama_error_is_retryable, runtime_error_is_retryable, Result, RuntimeError,
21};
22
23pub struct ExecutionGuard {
25 semaphore: Arc<Semaphore>,
26 timeout: Duration,
27 max_retries: usize,
28}
29
30impl ExecutionGuard {
31 pub(crate) fn new(max_concurrent: usize, timeout: Duration, max_retries: usize) -> Self {
32 Self {
33 semaphore: Arc::new(Semaphore::new(max_concurrent)),
34 timeout,
35 max_retries,
36 }
37 }
38
39 pub fn max_retries(&self) -> usize {
41 self.max_retries
42 }
43
44 pub async fn run<F, Fut, T>(&self, f: F) -> Result<T>
47 where
48 F: Fn() -> Fut,
49 Fut: std::future::Future<Output = OllamaResult<T>>,
50 {
51 let attempts = self.max_retries.saturating_add(1).max(1);
52
53 for attempt in 0..attempts {
54 let _permit = self
55 .semaphore
56 .acquire()
57 .await
58 .map_err(|_| RuntimeError::Other("execution semaphore closed".into()))?;
59
60 match tokio::time::timeout(self.timeout, f()).await {
61 Ok(Ok(value)) => return Ok(value),
62 Ok(Err(e)) => {
63 if ollama_error_is_retryable(&e) && attempt + 1 < attempts {
64 continue;
65 }
66 return Err(map_ollama_error(e));
67 }
68 Err(_elapsed) => {
69 if runtime_error_is_retryable(&RuntimeError::Timeout) && attempt + 1 < attempts
70 {
71 continue;
72 }
73 return Err(RuntimeError::Timeout);
74 }
75 }
76 }
77
78 Err(RuntimeError::Other("exhausted retries".into()))
79 }
80}
81
82#[cfg(feature = "stream")]
83pin_project! {
84 pub struct GuardedStream<S> {
86 #[pin]
87 stream: S,
88 _permit: OwnedSemaphorePermit,
89 }
90}
91
92#[cfg(feature = "stream")]
93impl<S> Stream for GuardedStream<S>
94where
95 S: Stream + Unpin,
96{
97 type Item = S::Item;
98
99 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
100 self.project().stream.poll_next(cx)
101 }
102}
103
104#[cfg(feature = "stream")]
105impl ExecutionGuard {
106 pub async fn run_stream<F, Fut, S>(&self, f: F) -> Result<GuardedStream<S>>
109 where
110 F: Fn() -> Fut,
111 Fut: std::future::Future<Output = OllamaResult<S>>,
112 S: Stream + Unpin,
113 {
114 let permit = self
115 .semaphore
116 .clone()
117 .acquire_owned()
118 .await
119 .map_err(|_| RuntimeError::Other("execution semaphore closed".into()))?;
120
121 let stream = tokio::time::timeout(self.timeout, f())
122 .await
123 .map_err(|_| RuntimeError::Timeout)?
124 .map_err(map_ollama_error)?;
125
126 Ok(GuardedStream {
127 stream,
128 _permit: permit,
129 })
130 }
131}
132
133#[cfg(test)]
134mod tests {
135 use super::*;
136 use std::sync::atomic::{AtomicUsize, Ordering};
137
138 use ollama_rs::error::OllamaError;
139
140 #[tokio::test]
141 async fn run_retries_transient_reqwest_error() {
142 let guard = ExecutionGuard::new(1, Duration::from_secs(5), 2);
143 let n = AtomicUsize::new(0);
144 let out = guard
145 .run(|| async {
146 let i = n.fetch_add(1, Ordering::SeqCst);
147 if i == 0 {
148 let e = reqwest::get("http://127.0.0.1:1/")
149 .await
150 .expect_err("expected connection failure");
151 Err::<(), _>(OllamaError::from(e))
152 } else {
153 Ok(())
154 }
155 })
156 .await
157 .unwrap();
158 assert_eq!(out, ());
159 assert_eq!(n.load(Ordering::SeqCst), 2);
160 }
161
162 #[tokio::test]
163 async fn run_does_not_retry_other_error() {
164 let guard = ExecutionGuard::new(1, Duration::from_secs(5), 3);
165 let n = AtomicUsize::new(0);
166 let err = guard
167 .run(|| async {
168 n.fetch_add(1, Ordering::SeqCst);
169 Err::<(), _>(OllamaError::Other("client error shape".into()))
170 })
171 .await
172 .expect_err("expected failure");
173 assert!(matches!(err, RuntimeError::Other(_)));
174 assert_eq!(n.load(Ordering::SeqCst), 1);
175 }
176}