Skip to main content

alembic_engine/
apply_retry.rs

1use crate::journal::Journal;
2use crate::{AppliedOp, Op};
3use anyhow::{anyhow, Result};
4use async_trait::async_trait;
5
6#[derive(Debug)]
7pub struct RetryApplyResult {
8    pub applied: Vec<AppliedOp>,
9    pub pending: Vec<Op>,
10}
11
12#[async_trait]
13pub trait RetryApplyDriver {
14    async fn apply_non_delete(&mut self, op: &Op) -> Result<AppliedOp>;
15    fn is_retryable(&self, err: &anyhow::Error) -> bool;
16}
17
18pub async fn apply_non_delete_with_retries(
19    ops: &[Op],
20    mut journal: Option<&mut Journal>,
21    driver: &mut impl RetryApplyDriver,
22) -> Result<RetryApplyResult> {
23    let mut applied = Vec::new();
24    let mut pending: Vec<Op> = ops
25        .iter()
26        .filter(|op| !matches!(op, Op::Delete { .. }))
27        .cloned()
28        .collect();
29
30    if let Some(journal) = journal.as_mut() {
31        let done_ops = journal.done_ops();
32        let done_ops_len = done_ops.len();
33
34        let mut done = done_ops
35            .into_iter()
36            .collect::<std::collections::HashSet<_>>();
37
38        if done.len() != done_ops_len {
39            // the use of a hash set here is an optimization, but it rules out ops with
40            // exactly the same uid, typename and hash.
41            // if there's a need to support such a thing in the future, it can be done by
42            // switching the container for `done` into a type that supports duplicates.
43            return Err(anyhow!("journal contained duplicated ops (same uid, typename and hash) which is not supported"));
44        }
45
46        pending.retain(|op| !done.remove(&(op.uid(), op.type_name().clone(), op.hashed())));
47
48        if !done.is_empty() {
49            return Err(anyhow!(
50                "journal contains done ops that are not present in the provided ops"
51            ));
52        }
53    }
54
55    while !pending.is_empty() {
56        let current = std::mem::take(&mut pending);
57        let applied_before = applied.len();
58
59        for op in current {
60            match driver.apply_non_delete(&op).await {
61                Ok(applied_op) => {
62                    // marked in memory only; the journal is flushed to disk at the exit
63                    // points below, not once per op (per-op saving was a ~100x regression).
64                    if let Some(journal) = journal.as_mut() {
65                        journal.mark_op_as_done(&op)?;
66                    }
67                    applied.push(applied_op);
68                }
69                Err(err) if driver.is_retryable(&err) => pending.push(op),
70                Err(err) => {
71                    // a fatal error is a clean unwind: persist progress before surfacing it
72                    // so the next run can resume from here. don't mask the original error if
73                    // the save itself fails.
74                    if let Some(journal) = journal.as_mut() {
75                        if let Err(save_err) = journal.save() {
76                            tracing::warn!(
77                                error = %save_err,
78                                "failed to persist journal after apply error"
79                            );
80                        }
81                    }
82                    return Err(err);
83                }
84            }
85        }
86
87        // only break if no progress was made (no items applied in this iteration)
88        if applied.len() == applied_before {
89            break;
90        }
91    }
92
93    if let Some(journal) = journal.as_mut() {
94        if journal.is_completed() {
95            journal.delete_backing_file()?;
96        } else {
97            // ops remain pending (stuck with no progress): persist so a re-run can resume
98            journal.save()?;
99        }
100    }
101
102    Ok(RetryApplyResult { applied, pending })
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108    use crate::BackendId;
109    use alembic_core::{JsonMap, Key, Object, TypeName, Uid};
110    use anyhow::anyhow;
111    use rand::rng;
112    use rand::seq::SliceRandom;
113    use tempfile::tempdir;
114
115    fn create_op(uid: Uid) -> Op {
116        Op::Create {
117            uid,
118            type_name: TypeName::new("test.item"),
119            desired: Object {
120                uid,
121                type_name: TypeName::new("test.item"),
122                key: Key::default(),
123                attrs: JsonMap::default(),
124                source: None,
125            },
126        }
127    }
128
129    #[derive(Clone, Copy)]
130    enum Mode {
131        RetryThenOk,
132        AlwaysRetry,
133        Fatal,
134    }
135
136    struct TestDriver {
137        attempts: usize,
138        mode: Mode,
139    }
140
141    #[async_trait]
142    impl RetryApplyDriver for TestDriver {
143        async fn apply_non_delete(&mut self, op: &Op) -> Result<AppliedOp> {
144            self.attempts += 1;
145            match self.mode {
146                Mode::RetryThenOk if self.attempts == 1 => {
147                    Err(anyhow!("missing referenced uid {}", op.uid()))
148                }
149                Mode::AlwaysRetry => Err(anyhow!("missing referenced uid {}", op.uid())),
150                Mode::Fatal => Err(anyhow!("boom")),
151                Mode::RetryThenOk => Ok(AppliedOp {
152                    uid: op.uid(),
153                    type_name: op.type_name().clone(),
154                    backend_id: Some(BackendId::Int(1)),
155                }),
156            }
157        }
158
159        fn is_retryable(&self, err: &anyhow::Error) -> bool {
160            err.to_string().contains("missing referenced uid")
161        }
162    }
163
164    #[tokio::test]
165    async fn retries_then_applies() {
166        let uid1 = Uid::from_u128(1);
167        let uid2 = Uid::from_u128(2);
168        let ops = vec![create_op(uid1), create_op(uid2)];
169        let mut driver = TestDriver {
170            attempts: 0,
171            mode: Mode::RetryThenOk,
172        };
173
174        let result = apply_non_delete_with_retries(&ops, None, &mut driver)
175            .await
176            .unwrap();
177
178        assert_eq!(driver.attempts, 3);
179        assert_eq!(result.applied.len(), 2);
180        assert!(result.pending.is_empty());
181    }
182
183    #[tokio::test]
184    async fn returns_pending_when_stuck() {
185        let uid = Uid::from_u128(1);
186        let ops = vec![create_op(uid)];
187        let mut driver = TestDriver {
188            attempts: 0,
189            mode: Mode::AlwaysRetry,
190        };
191
192        let result = apply_non_delete_with_retries(&ops, None, &mut driver)
193            .await
194            .unwrap();
195
196        assert!(result.applied.is_empty());
197        assert_eq!(result.pending.len(), 1);
198    }
199
200    #[tokio::test]
201    async fn returns_non_retryable_error() {
202        let uid = Uid::from_u128(1);
203        let ops = vec![create_op(uid)];
204        let mut driver = TestDriver {
205            attempts: 0,
206            mode: Mode::Fatal,
207        };
208
209        let err = apply_non_delete_with_retries(&ops, None, &mut driver)
210            .await
211            .unwrap_err();
212
213        assert!(err.to_string().contains("boom"));
214    }
215
216    #[tokio::test]
217    async fn ignores_delete_ops() {
218        let uid = Uid::from_u128(1);
219        let ops = vec![Op::Delete {
220            uid,
221            type_name: TypeName::new("test.item"),
222            key: Key::default(),
223            backend_id: None,
224        }];
225        let mut driver = TestDriver {
226            attempts: 0,
227            mode: Mode::Fatal,
228        };
229
230        let result = apply_non_delete_with_retries(&ops, None, &mut driver)
231            .await
232            .unwrap();
233
234        assert_eq!(driver.attempts, 0);
235        assert!(result.pending.is_empty());
236        assert!(result.applied.is_empty());
237    }
238
239    struct ErraticDriver {
240        countdown_to_crash: u32,
241        applied_ops: Vec<AppliedOp>,
242    }
243
244    #[async_trait]
245    impl RetryApplyDriver for ErraticDriver {
246        async fn apply_non_delete(&mut self, op: &Op) -> Result<AppliedOp> {
247            self.countdown_to_crash -= 1;
248
249            if self.countdown_to_crash == 0 {
250                return Err(anyhow!("planned error"));
251            }
252
253            let applied_op = AppliedOp {
254                uid: op.uid(),
255                type_name: op.type_name().clone(),
256                backend_id: None,
257            };
258            self.applied_ops.push(applied_op.clone());
259
260            Ok(applied_op)
261        }
262
263        fn is_retryable(&self, _err: &anyhow::Error) -> bool {
264            false
265        }
266    }
267    #[tokio::test]
268    async fn erratic_driver_first_fails_then_succeeds() {
269        let uid1 = Uid::from_u128(1);
270        let uid2 = Uid::from_u128(2);
271        let ops = vec![create_op(uid1), create_op(uid2)];
272        let mut driver = ErraticDriver {
273            countdown_to_crash: 2,
274            applied_ops: vec![],
275        };
276        let dir = tempdir().unwrap();
277        let mut journal = Journal::load_or_create(dir.path(), "erratic_driver", &ops).unwrap();
278
279        apply_non_delete_with_retries(&ops, Some(&mut journal), &mut driver)
280            .await
281            .expect_err("should fail (on second op applied this run)");
282        assert_eq!(driver.applied_ops.len(), 1);
283        assert!(!journal.is_completed());
284
285        // turn off crashing
286        driver.countdown_to_crash = 99999;
287        _ = apply_non_delete_with_retries(&ops, Some(&mut journal), &mut driver)
288            .await
289            .unwrap();
290        assert_eq!(
291            driver.applied_ops.iter().map(|a| a.uid).collect::<Vec<_>>(),
292            vec![uid1, uid2]
293        );
294        assert!(journal.is_completed());
295    }
296
297    #[tokio::test]
298    async fn resumes_from_disk_after_error() {
299        let uid1 = Uid::from_u128(1);
300        let uid2 = Uid::from_u128(2);
301        let uid3 = Uid::from_u128(3);
302        let ops = vec![create_op(uid1), create_op(uid2), create_op(uid3)];
303        let dir = tempdir().unwrap();
304
305        // first run crashes after applying the first op; the journal is dropped to
306        // simulate the process exiting, so resume must rely on what was flushed to disk.
307        {
308            let mut driver = ErraticDriver {
309                countdown_to_crash: 2,
310                applied_ops: vec![],
311            };
312            let mut journal = Journal::load_or_create(dir.path(), "resume_test", &ops).unwrap();
313            apply_non_delete_with_retries(&ops, Some(&mut journal), &mut driver)
314                .await
315                .expect_err("should fail on the second op");
316            assert_eq!(driver.applied_ops.len(), 1);
317        }
318
319        // second run reloads the journal from disk and applies only the remaining ops.
320        {
321            let mut driver = ErraticDriver {
322                countdown_to_crash: 99999,
323                applied_ops: vec![],
324            };
325            let mut journal = Journal::load_or_create(dir.path(), "resume_test", &ops).unwrap();
326            let result = apply_non_delete_with_retries(&ops, Some(&mut journal), &mut driver)
327                .await
328                .unwrap();
329            assert_eq!(
330                driver.applied_ops.iter().map(|a| a.uid).collect::<Vec<_>>(),
331                vec![uid2, uid3]
332            );
333            assert_eq!(result.applied.len(), 2);
334            assert!(result.pending.is_empty());
335            assert!(journal.is_completed());
336        }
337    }
338
339    #[tokio::test]
340    async fn erratic_driver_with_shuffled_ops() {
341        let mut ops = Vec::new();
342        for i in 1..10 {
343            ops.push(create_op(Uid::from_u128(i)));
344        }
345
346        let mut rng = rng();
347        ops.shuffle(&mut rng);
348
349        let mut driver = ErraticDriver {
350            countdown_to_crash: 5,
351            applied_ops: vec![],
352        };
353        let dir = tempdir().unwrap();
354        let mut journal = Journal::load_or_create(dir.path(), "erratic_driver", &ops).unwrap();
355
356        apply_non_delete_with_retries(&ops, Some(&mut journal), &mut driver)
357            .await
358            .expect_err("should fail (on fifth op applied this run)");
359        assert_eq!(driver.applied_ops.len(), 4);
360        assert!(!journal.is_completed());
361
362        ops.shuffle(&mut rng);
363
364        // turn off crashing
365        driver.countdown_to_crash = 99999;
366        _ = apply_non_delete_with_retries(&ops, Some(&mut journal), &mut driver)
367            .await
368            .unwrap();
369
370        let mut applied_uids = driver.applied_ops.iter().map(|a| a.uid).collect::<Vec<_>>();
371        applied_uids.sort();
372        let mut op_uids = ops.iter().map(|op| op.uid()).collect::<Vec<_>>();
373        op_uids.sort();
374        assert_eq!(applied_uids, op_uids,);
375        assert!(journal.is_completed());
376    }
377}