1use std::error::Error as StdError;
45use std::fmt;
46use std::time::Instant;
47
48pub type DynError = Box<dyn StdError + Send + Sync>;
51
52#[derive(Debug)]
54pub struct Attempt {
55 pub name: String,
57 pub error: Option<DynError>,
59 pub duration_ms: f64,
61}
62
63#[derive(Debug)]
65pub struct ChainResult<O> {
66 pub value: O,
68 pub provider: String,
70 pub attempts: Vec<Attempt>,
73}
74
75#[derive(Debug)]
77pub struct AllProvidersFailed {
78 pub attempts: Vec<Attempt>,
80}
81
82impl fmt::Display for AllProvidersFailed {
83 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84 let names: Vec<&str> = self.attempts.iter().map(|a| a.name.as_str()).collect();
85 write!(f, "all providers failed: {}", names.join(", "))
86 }
87}
88
89impl StdError for AllProvidersFailed {}
90
91pub type SyncProvider<I, O> = Box<dyn Fn(&I) -> Result<O, DynError> + Send + Sync>;
93
94pub type ShouldFallBack = Box<dyn Fn(&DynError) -> bool + Send + Sync>;
97
98pub type OnFallback = Box<dyn Fn(&str, &DynError, &str) + Send + Sync>;
101
102fn default_should_fall_back(_err: &DynError) -> bool {
103 true
104}
105
106pub struct FallbackChain<I, O> {
113 providers: Vec<(String, SyncProvider<I, O>)>,
114 should_fall_back: ShouldFallBack,
115 on_fallback: Option<OnFallback>,
116}
117
118impl<I, O> FallbackChain<I, O> {
119 pub fn new<S: Into<String>>(
123 providers: Vec<(S, SyncProvider<I, O>)>,
124 ) -> Result<Self, &'static str> {
125 if providers.is_empty() {
126 return Err("providers must be a non-empty list");
127 }
128 let providers = providers
129 .into_iter()
130 .map(|(name, fn_)| (name.into(), fn_))
131 .collect();
132 Ok(Self {
133 providers,
134 should_fall_back: Box::new(default_should_fall_back),
135 on_fallback: None,
136 })
137 }
138
139 pub fn with_should_fall_back<F>(mut self, f: F) -> Self
142 where
143 F: Fn(&DynError) -> bool + Send + Sync + 'static,
144 {
145 self.should_fall_back = Box::new(f);
146 self
147 }
148
149 pub fn with_on_fallback<F>(mut self, f: F) -> Self
152 where
153 F: Fn(&str, &DynError, &str) + Send + Sync + 'static,
154 {
155 self.on_fallback = Some(Box::new(f));
156 self
157 }
158
159 pub fn with_skip<P>(mut self, predicate: P) -> Result<Self, &'static str>
163 where
164 P: Fn(&str) -> bool,
165 {
166 self.providers.retain(|(name, _)| !predicate(name));
167 if self.providers.is_empty() {
168 return Err("with_skip removed all providers");
169 }
170 Ok(self)
171 }
172
173 pub fn names(&self) -> Vec<&str> {
175 self.providers.iter().map(|(n, _)| n.as_str()).collect()
176 }
177
178 pub fn call(&self, input: &I) -> Result<ChainResult<O>, DynError> {
180 let mut failures: Vec<Attempt> = Vec::new();
181 let last = self.providers.len() - 1;
182 for (i, (name, fn_)) in self.providers.iter().enumerate() {
183 let start = Instant::now();
184 match fn_(input) {
185 Ok(value) => {
186 return Ok(ChainResult {
187 value,
188 provider: name.clone(),
189 attempts: failures,
190 });
191 }
192 Err(err) => {
193 let elapsed = start.elapsed().as_secs_f64() * 1000.0;
194 if !(self.should_fall_back)(&err) {
195 return Err(err);
196 }
197 if i < last {
198 if let Some(cb) = &self.on_fallback {
199 let next_name = &self.providers[i + 1].0;
200 cb(name, &err, next_name);
201 }
202 }
203 failures.push(Attempt {
204 name: name.clone(),
205 error: Some(err),
206 duration_ms: elapsed,
207 });
208 }
209 }
210 }
211 Err(Box::new(AllProvidersFailed { attempts: failures }))
212 }
213}
214
215#[cfg(feature = "tokio")]
216mod async_chain {
217 use super::{
218 default_should_fall_back, AllProvidersFailed, Attempt, ChainResult, DynError, OnFallback,
219 ShouldFallBack,
220 };
221 use futures::future::BoxFuture;
222 use std::time::Instant;
223
224 pub type AsyncProvider<I, O> =
227 Box<dyn for<'a> Fn(&'a I) -> BoxFuture<'a, Result<O, DynError>> + Send + Sync>;
228
229 pub fn async_provider<I, O, F, Fut>(f: F) -> AsyncProvider<I, O>
234 where
235 F: for<'a> Fn(&'a I) -> Fut + Send + Sync + 'static,
236 Fut: std::future::Future<Output = Result<O, DynError>> + Send + 'static,
237 I: 'static,
238 {
239 Box::new(move |i: &I| {
240 let fut = f(i);
241 Box::pin(fut) as BoxFuture<'_, _>
242 })
243 }
244
245 pub struct AsyncFallbackChain<I, O> {
248 providers: Vec<(String, AsyncProvider<I, O>)>,
249 should_fall_back: ShouldFallBack,
250 on_fallback: Option<OnFallback>,
251 }
252
253 impl<I: Send + Sync, O: Send> AsyncFallbackChain<I, O> {
254 pub fn new<S: Into<String>>(
255 providers: Vec<(S, AsyncProvider<I, O>)>,
256 ) -> Result<Self, &'static str> {
257 if providers.is_empty() {
258 return Err("providers must be a non-empty list");
259 }
260 let providers = providers
261 .into_iter()
262 .map(|(name, fn_)| (name.into(), fn_))
263 .collect();
264 Ok(Self {
265 providers,
266 should_fall_back: Box::new(default_should_fall_back),
267 on_fallback: None,
268 })
269 }
270
271 pub fn with_should_fall_back<F>(mut self, f: F) -> Self
272 where
273 F: Fn(&DynError) -> bool + Send + Sync + 'static,
274 {
275 self.should_fall_back = Box::new(f);
276 self
277 }
278
279 pub fn with_on_fallback<F>(mut self, f: F) -> Self
280 where
281 F: Fn(&str, &DynError, &str) + Send + Sync + 'static,
282 {
283 self.on_fallback = Some(Box::new(f));
284 self
285 }
286
287 pub fn with_skip<P>(mut self, predicate: P) -> Result<Self, &'static str>
288 where
289 P: Fn(&str) -> bool,
290 {
291 self.providers.retain(|(name, _)| !predicate(name));
292 if self.providers.is_empty() {
293 return Err("with_skip removed all providers");
294 }
295 Ok(self)
296 }
297
298 pub fn names(&self) -> Vec<&str> {
299 self.providers.iter().map(|(n, _)| n.as_str()).collect()
300 }
301
302 pub async fn call(&self, input: &I) -> Result<ChainResult<O>, DynError> {
303 let mut failures: Vec<Attempt> = Vec::new();
304 let last = self.providers.len() - 1;
305 for (i, (name, fn_)) in self.providers.iter().enumerate() {
306 let start = Instant::now();
307 match fn_(input).await {
308 Ok(value) => {
309 return Ok(ChainResult {
310 value,
311 provider: name.clone(),
312 attempts: failures,
313 });
314 }
315 Err(err) => {
316 let elapsed = start.elapsed().as_secs_f64() * 1000.0;
317 if !(self.should_fall_back)(&err) {
318 return Err(err);
319 }
320 if i < last {
321 if let Some(cb) = &self.on_fallback {
322 let next_name = &self.providers[i + 1].0;
323 cb(name, &err, next_name);
324 }
325 }
326 failures.push(Attempt {
327 name: name.clone(),
328 error: Some(err),
329 duration_ms: elapsed,
330 });
331 }
332 }
333 }
334 Err(Box::new(AllProvidersFailed { attempts: failures }))
335 }
336 }
337}
338
339#[cfg(feature = "tokio")]
340pub use async_chain::{async_provider, AsyncFallbackChain, AsyncProvider};
341
342#[cfg(feature = "serde")]
343mod serde_impls {
344 use super::Attempt;
345 use serde::Serialize;
346
347 #[derive(Debug, Serialize)]
350 pub struct AttemptView {
351 pub name: String,
352 pub error: Option<String>,
353 pub duration_ms: f64,
354 }
355
356 impl From<&Attempt> for AttemptView {
357 fn from(a: &Attempt) -> Self {
358 Self {
359 name: a.name.clone(),
360 error: a.error.as_ref().map(|e| e.to_string()),
361 duration_ms: a.duration_ms,
362 }
363 }
364 }
365}
366
367#[cfg(feature = "serde")]
368pub use serde_impls::AttemptView;