Skip to main content

retry_block/persist/
mod.rs

1//! Tools for persistent retries that save the retry status to be continued on a restart
2//!
3//! # Usage
4//!
5//! To use this persistent retry module, you need to create a `RetryHandle` associated to your
6//! implementation of the `RetryInjector` trait.
7//!
8//! ```
9//! # use retry_block::persist::{RetryHandle, RetryInjector, Status};
10//! # use retry_block::RetryConfig;
11//! # use async_trait::async_trait;
12//! # use std::collections::HashMap;
13//! # use std::sync::Arc;
14//! # use tokio::sync::Mutex;
15//!
16//! struct Injector {
17//!     ops: HashMap<u64, (Status<i64, ()>, i64)>,
18//! }
19//!
20//! #[async_trait]
21//! impl<'a> RetryInjector<'a> for Injector {
22//!     type Input = i64;
23//!     type Output = i64;
24//!     type Error = ();
25//!     type Id = u64;
26//!     type Res = Result<i64, ()>;
27//!     async fn load_pending(&mut self) -> Vec<(u64, i64)> {
28//!         self.ops
29//!             .iter()
30//!             .filter(|(_, (state, _))| matches!(state, Status::Pending))
31//!             .map(|(id, (_, val))| (id.clone(), val.clone()))
32//!             .collect()
33//!     }
34//!     async fn save_status(&mut self, id: u64, input: i64, status: Status<i64, ()>) {
35//!         self.ops.insert(id, (status, input));
36//!     }
37//! }
38//!
39//! #[tokio::main]
40//! async fn main() {
41//!     let counter = Arc::new(Mutex::new(0));
42//!
43//!     let increment = |input| {
44//!         let counter = counter.clone();
45//!         async move {
46//!             let ref mut counter = *counter.lock().await;
47//!             *counter += input;
48//!             Ok(*counter)
49//!         }
50//!     };
51//!
52//!     let mut handle = RetryHandle::new(
53//!         Injector {
54//!             ops: HashMap::from([(0u64, (Status::Pending, 3))]),
55//!         },
56//!         RetryConfig {
57//!             count: 10,
58//!             min_backoff: 500,
59//!             max_backoff: 1000,
60//!         },
61//!     );
62//!     assert_eq!(*counter.lock().await, 0);
63//!
64//!     handle.retry_pending(1, &increment).await;
65//!     assert_eq!(*counter.lock().await, 3);
66//!
67//!     handle.retry(1u64, 6, &increment).await;
68//!     assert_eq!(*counter.lock().await, 9);
69//!
70//!     let multiply = |input| {
71//!         let counter = counter.clone();
72//!         async move {
73//!             let ref mut counter = *counter.lock().await;
74//!             *counter *= input;
75//!             Ok(*counter)
76//!         }
77//!     };
78//!     handle.retry(2u64, 2, &multiply).await;
79//!     assert_eq!(*counter.lock().await, 18);
80//! }
81//! ```
82//!
83use crate::OperationResult;
84use async_trait::async_trait;
85use futures_util::{Stream, StreamExt};
86use serde::{Deserialize, Serialize};
87use std::future::Future;
88use std::sync::Arc;
89use tokio::sync::Mutex;
90
91#[cfg(test)]
92mod test;
93
94/// Status of a persistent retry
95pub enum Status<O, E> {
96    Pending,
97    Success(O),
98    Failure(E),
99}
100
101impl<O, E> std::fmt::Debug for Status<O, E>
102where
103    O: std::fmt::Debug,
104    E: std::fmt::Debug,
105{
106    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107        match self {
108            Self::Pending => write!(f, "Pending"),
109            Self::Success(o) => write!(f, "Success({:?})", o),
110            Self::Failure(e) => write!(f, "Failure({:?})", e),
111        }
112    }
113}
114
115/// A trait to specify how to save and retrieve the status of a retried operation
116#[async_trait]
117pub trait RetryInjector<'a>: Sized {
118    /// The input value of a retry operation
119    ///
120    /// Will be saved to repeat the operation
121    type Input: Serialize + Deserialize<'a> + Clone;
122    /// The positive output value of a retry operation
123    ///
124    /// Will be saved if the operation succeeds
125    type Output;
126    /// The negative output value of a retry operation
127    ///
128    /// Will be saved if the operation fails permanently
129    type Error;
130    /// An identifier for a given input
131    ///
132    /// Will be saved to repeat the operation
133    type Id: Clone;
134    /// A `Result` type for the output of the retry operation
135    ///
136    /// typically either:
137    /// * `OperationResult<Self::Ouput, Self::Error>`
138    /// * `Result<Self::Output, Self::Error>`
139    type Res: Into<OperationResult<Self::Output, Self::Error>>;
140
141    /// Return the stored inputs with a status of `Status::Pending`
142    async fn load_pending(&mut self) -> Vec<(Self::Id, Self::Input)>;
143
144    /// Save the status of a given operation
145    async fn save_status(
146        &mut self,
147        id: Self::Id,
148        input: Self::Input,
149        status: Status<Self::Output, Self::Error>,
150    );
151}
152
153/// Persistent retry handle
154pub struct RetryHandle<Inj, Dur> {
155    injector: Inj,
156    durations: Dur,
157}
158
159impl<'a, Inj, Dur> RetryHandle<Inj, Dur>
160where
161    Inj: RetryInjector<'a>,
162    Dur: IntoIterator<Item = std::time::Duration> + Clone,
163{
164    /// Create a new persistent retry handle from an injector and a cloneable delay iterator
165    pub fn new(injector: Inj, durations: Dur) -> Self {
166        Self {
167            injector,
168            durations,
169        }
170    }
171
172    /// Start concurrent persistent retry of pending input loaded from the injector using the given
173    /// operation and concurrency limit
174    pub async fn retry_pending<F>(
175        &mut self,
176        concurrency_limit: usize,
177        operation: &dyn Fn(Inj::Input) -> F,
178    ) where
179        F: Future<Output = Inj::Res>,
180    {
181        let pending = self.injector.load_pending().await;
182        self.retry_stream(tokio_stream::iter(pending), concurrency_limit, operation)
183            .await;
184    }
185
186    /// Start concurrent persistent retry of input loaded from the given stream using the given
187    /// operation and concurrency limit
188    pub async fn retry_stream<F, S>(
189        &mut self,
190        stream: S,
191        concurrency_limit: usize,
192        operation: &dyn Fn(Inj::Input) -> F,
193    ) where
194        F: Future<Output = Inj::Res>,
195        S: Stream<Item = (Inj::Id, Inj::Input)>,
196    {
197        let handle = Arc::new(Mutex::new(self));
198        stream
199            .for_each_concurrent(concurrency_limit, |(id, input)| async {
200                handle.lock().await.retry(id, input, operation).await;
201            })
202            .await;
203    }
204
205    /// Persistently retry a given input (uniquely identified by the given id) using the given
206    /// operation
207    pub async fn retry<F>(
208        &mut self,
209        id: Inj::Id,
210        input: Inj::Input,
211        operation: &dyn Fn(Inj::Input) -> F,
212    ) where
213        F: Future<Output = Inj::Res>,
214    {
215        self.injector
216            .save_status(id.clone(), input.clone(), Status::Pending)
217            .await;
218        let mut it = self.durations.clone().into_iter();
219        let res = loop {
220            match operation(input.clone()).await.into() {
221                OperationResult::Ok(res) => break Ok(res),
222                OperationResult::Err(e) => break Err(e),
223                OperationResult::Retry(e) => {
224                    if let Some(duration) = it.next() {
225                        tokio::time::sleep(duration).await;
226                    } else {
227                        break Err(e);
228                    }
229                }
230            }
231        };
232
233        let status = match res {
234            Ok(ok) => Status::Success(ok),
235            Err(err) => Status::Failure(err),
236        };
237        self.injector
238            .save_status(id.clone(), input.clone(), status)
239            .await
240    }
241}