1use std::sync::Arc;
2use std::time::Duration;
3
4use ollama_rs::error::Result as OllamaResult;
5use tokio::sync::Semaphore;
6
7#[cfg(feature = "stream")]
8use std::pin::Pin;
9#[cfg(feature = "stream")]
10use std::task::{Context, Poll};
11
12#[cfg(feature = "stream")]
13use pin_project_lite::pin_project;
14#[cfg(feature = "stream")]
15use tokio::sync::OwnedSemaphorePermit;
16#[cfg(feature = "stream")]
17use tokio_stream::Stream;
18
19use crate::error::{
20 map_ollama_error, ollama_error_is_retryable, runtime_error_is_retryable, Result, RuntimeError,
21};
22
23pub struct ExecutionGuard {
25 semaphore: Arc<Semaphore>,
26 timeout: Duration,
27 max_retries: usize,
28}
29
30impl ExecutionGuard {
31 pub(crate) fn new(max_concurrent: usize, timeout: Duration, max_retries: usize) -> Self {
32 Self {
33 semaphore: Arc::new(Semaphore::new(max_concurrent)),
34 timeout,
35 max_retries,
36 }
37 }
38
39 pub fn max_retries(&self) -> usize {
41 self.max_retries
42 }
43
44 pub async fn run<F, Fut, T>(&self, f: F) -> Result<T>
46 where
47 F: Fn() -> Fut,
48 Fut: std::future::Future<Output = OllamaResult<T>>,
49 {
50 let attempts = self.max_retries.saturating_add(1).max(1);
51
52 for attempt in 0..attempts {
53 let _permit = self
54 .semaphore
55 .acquire()
56 .await
57 .map_err(|_| RuntimeError::Other("execution semaphore closed".into()))?;
58
59 match tokio::time::timeout(self.timeout, f()).await {
60 Ok(Ok(value)) => return Ok(value),
61 Ok(Err(e)) => {
62 if ollama_error_is_retryable(&e) && attempt + 1 < attempts {
63 continue;
64 }
65 return Err(map_ollama_error(e));
66 }
67 Err(_elapsed) => {
68 if runtime_error_is_retryable(&RuntimeError::Timeout) && attempt + 1 < attempts
69 {
70 continue;
71 }
72 return Err(RuntimeError::Timeout);
73 }
74 }
75 }
76
77 Err(RuntimeError::Other("exhausted retries".into()))
78 }
79}
80
81#[cfg(feature = "stream")]
82pin_project! {
83 pub struct GuardedStream<S> {
84 #[pin]
85 stream: S,
86 _permit: OwnedSemaphorePermit,
87 }
88}
89
90#[cfg(feature = "stream")]
91impl<S> Stream for GuardedStream<S>
92where
93 S: Stream + Unpin,
94{
95 type Item = S::Item;
96
97 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
98 self.project().stream.poll_next(cx)
99 }
100}
101
102#[cfg(feature = "stream")]
103impl ExecutionGuard {
104 pub async fn run_stream<F, Fut, S>(&self, f: F) -> Result<GuardedStream<S>>
106 where
107 F: Fn() -> Fut,
108 Fut: std::future::Future<Output = OllamaResult<S>>,
109 S: Stream + Unpin,
110 {
111 let permit = self
112 .semaphore
113 .clone()
114 .acquire_owned()
115 .await
116 .map_err(|_| RuntimeError::Other("execution semaphore closed".into()))?;
117
118 let stream = tokio::time::timeout(self.timeout, f())
119 .await
120 .map_err(|_| RuntimeError::Timeout)?
121 .map_err(map_ollama_error)?;
122
123 Ok(GuardedStream {
124 stream,
125 _permit: permit,
126 })
127 }
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133 use std::sync::atomic::{AtomicUsize, Ordering};
134
135 use ollama_rs::error::OllamaError;
136
137 #[tokio::test]
138 async fn run_retries_transient_reqwest_error() {
139 let guard = ExecutionGuard::new(1, Duration::from_secs(5), 2);
140 let n = AtomicUsize::new(0);
141 let out = guard
142 .run(|| async {
143 let i = n.fetch_add(1, Ordering::SeqCst);
144 if i == 0 {
145 let e = reqwest::get("http://127.0.0.1:1/")
146 .await
147 .expect_err("expected connection failure");
148 Err::<(), _>(OllamaError::from(e))
149 } else {
150 Ok(())
151 }
152 })
153 .await
154 .unwrap();
155 assert_eq!(out, ());
156 assert_eq!(n.load(Ordering::SeqCst), 2);
157 }
158
159 #[tokio::test]
160 async fn run_does_not_retry_other_error() {
161 let guard = ExecutionGuard::new(1, Duration::from_secs(5), 3);
162 let n = AtomicUsize::new(0);
163 let err = guard
164 .run(|| async {
165 n.fetch_add(1, Ordering::SeqCst);
166 Err::<(), _>(OllamaError::Other("client error shape".into()))
167 })
168 .await
169 .expect_err("expected failure");
170 assert!(matches!(err, RuntimeError::Other(_)));
171 assert_eq!(n.load(Ordering::SeqCst), 1);
172 }
173}