retry_block/persist/mod.rs
1//! Tools for persistent retries that save the retry status to be continued on a restart
2//!
3//! # Usage
4//!
5//! To use this persistent retry module, you need to create a `RetryHandle` associated to your
6//! implementation of the `RetryInjector` trait.
7//!
8//! ```
9//! # use retry_block::persist::{RetryHandle, RetryInjector, Status};
10//! # use retry_block::RetryConfig;
11//! # use async_trait::async_trait;
12//! # use std::collections::HashMap;
13//! # use std::sync::Arc;
14//! # use tokio::sync::Mutex;
15//!
16//! struct Injector {
17//! ops: HashMap<u64, (Status<i64, ()>, i64)>,
18//! }
19//!
20//! #[async_trait]
21//! impl<'a> RetryInjector<'a> for Injector {
22//! type Input = i64;
23//! type Output = i64;
24//! type Error = ();
25//! type Id = u64;
26//! type Res = Result<i64, ()>;
27//! async fn load_pending(&mut self) -> Vec<(u64, i64)> {
28//! self.ops
29//! .iter()
30//! .filter(|(_, (state, _))| matches!(state, Status::Pending))
31//! .map(|(id, (_, val))| (id.clone(), val.clone()))
32//! .collect()
33//! }
34//! async fn save_status(&mut self, id: u64, input: i64, status: Status<i64, ()>) {
35//! self.ops.insert(id, (status, input));
36//! }
37//! }
38//!
39//! #[tokio::main]
40//! async fn main() {
41//! let counter = Arc::new(Mutex::new(0));
42//!
43//! let increment = |input| {
44//! let counter = counter.clone();
45//! async move {
46//! let ref mut counter = *counter.lock().await;
47//! *counter += input;
48//! Ok(*counter)
49//! }
50//! };
51//!
52//! let mut handle = RetryHandle::new(
53//! Injector {
54//! ops: HashMap::from([(0u64, (Status::Pending, 3))]),
55//! },
56//! RetryConfig {
57//! count: 10,
58//! min_backoff: 500,
59//! max_backoff: 1000,
60//! },
61//! );
62//! assert_eq!(*counter.lock().await, 0);
63//!
64//! handle.retry_pending(1, &increment).await;
65//! assert_eq!(*counter.lock().await, 3);
66//!
67//! handle.retry(1u64, 6, &increment).await;
68//! assert_eq!(*counter.lock().await, 9);
69//!
70//! let multiply = |input| {
71//! let counter = counter.clone();
72//! async move {
73//! let ref mut counter = *counter.lock().await;
74//! *counter *= input;
75//! Ok(*counter)
76//! }
77//! };
78//! handle.retry(2u64, 2, &multiply).await;
79//! assert_eq!(*counter.lock().await, 18);
80//! }
81//! ```
82//!
83use crate::OperationResult;
84use async_trait::async_trait;
85use futures_util::{Stream, StreamExt};
86use serde::{Deserialize, Serialize};
87use std::future::Future;
88use std::sync::Arc;
89use tokio::sync::Mutex;
90
91#[cfg(test)]
92mod test;
93
94/// Status of a persistent retry
95pub enum Status<O, E> {
96 Pending,
97 Success(O),
98 Failure(E),
99}
100
101impl<O, E> std::fmt::Debug for Status<O, E>
102where
103 O: std::fmt::Debug,
104 E: std::fmt::Debug,
105{
106 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107 match self {
108 Self::Pending => write!(f, "Pending"),
109 Self::Success(o) => write!(f, "Success({:?})", o),
110 Self::Failure(e) => write!(f, "Failure({:?})", e),
111 }
112 }
113}
114
115/// A trait to specify how to save and retrieve the status of a retried operation
116#[async_trait]
117pub trait RetryInjector<'a>: Sized {
118 /// The input value of a retry operation
119 ///
120 /// Will be saved to repeat the operation
121 type Input: Serialize + Deserialize<'a> + Clone;
122 /// The positive output value of a retry operation
123 ///
124 /// Will be saved if the operation succeeds
125 type Output;
126 /// The negative output value of a retry operation
127 ///
128 /// Will be saved if the operation fails permanently
129 type Error;
130 /// An identifier for a given input
131 ///
132 /// Will be saved to repeat the operation
133 type Id: Clone;
134 /// A `Result` type for the output of the retry operation
135 ///
136 /// typically either:
137 /// * `OperationResult<Self::Ouput, Self::Error>`
138 /// * `Result<Self::Output, Self::Error>`
139 type Res: Into<OperationResult<Self::Output, Self::Error>>;
140
141 /// Return the stored inputs with a status of `Status::Pending`
142 async fn load_pending(&mut self) -> Vec<(Self::Id, Self::Input)>;
143
144 /// Save the status of a given operation
145 async fn save_status(
146 &mut self,
147 id: Self::Id,
148 input: Self::Input,
149 status: Status<Self::Output, Self::Error>,
150 );
151}
152
153/// Persistent retry handle
154pub struct RetryHandle<Inj, Dur> {
155 injector: Inj,
156 durations: Dur,
157}
158
159impl<'a, Inj, Dur> RetryHandle<Inj, Dur>
160where
161 Inj: RetryInjector<'a>,
162 Dur: IntoIterator<Item = std::time::Duration> + Clone,
163{
164 /// Create a new persistent retry handle from an injector and a cloneable delay iterator
165 pub fn new(injector: Inj, durations: Dur) -> Self {
166 Self {
167 injector,
168 durations,
169 }
170 }
171
172 /// Start concurrent persistent retry of pending input loaded from the injector using the given
173 /// operation and concurrency limit
174 pub async fn retry_pending<F>(
175 &mut self,
176 concurrency_limit: usize,
177 operation: &dyn Fn(Inj::Input) -> F,
178 ) where
179 F: Future<Output = Inj::Res>,
180 {
181 let pending = self.injector.load_pending().await;
182 self.retry_stream(tokio_stream::iter(pending), concurrency_limit, operation)
183 .await;
184 }
185
186 /// Start concurrent persistent retry of input loaded from the given stream using the given
187 /// operation and concurrency limit
188 pub async fn retry_stream<F, S>(
189 &mut self,
190 stream: S,
191 concurrency_limit: usize,
192 operation: &dyn Fn(Inj::Input) -> F,
193 ) where
194 F: Future<Output = Inj::Res>,
195 S: Stream<Item = (Inj::Id, Inj::Input)>,
196 {
197 let handle = Arc::new(Mutex::new(self));
198 stream
199 .for_each_concurrent(concurrency_limit, |(id, input)| async {
200 handle.lock().await.retry(id, input, operation).await;
201 })
202 .await;
203 }
204
205 /// Persistently retry a given input (uniquely identified by the given id) using the given
206 /// operation
207 pub async fn retry<F>(
208 &mut self,
209 id: Inj::Id,
210 input: Inj::Input,
211 operation: &dyn Fn(Inj::Input) -> F,
212 ) where
213 F: Future<Output = Inj::Res>,
214 {
215 self.injector
216 .save_status(id.clone(), input.clone(), Status::Pending)
217 .await;
218 let mut it = self.durations.clone().into_iter();
219 let res = loop {
220 match operation(input.clone()).await.into() {
221 OperationResult::Ok(res) => break Ok(res),
222 OperationResult::Err(e) => break Err(e),
223 OperationResult::Retry(e) => {
224 if let Some(duration) = it.next() {
225 tokio::time::sleep(duration).await;
226 } else {
227 break Err(e);
228 }
229 }
230 }
231 };
232
233 let status = match res {
234 Ok(ok) => Status::Success(ok),
235 Err(err) => Status::Failure(err),
236 };
237 self.injector
238 .save_status(id.clone(), input.clone(), status)
239 .await
240 }
241}