Skip to main content

durable_lambda_core/
backend.rs

1//! DurableBackend trait and RealBackend implementation.
2//!
3//! The [`DurableBackend`] trait is the I/O boundary between the replay engine
4//! and AWS. It covers the 2 AWS durable execution API operations used internally
5//! by the SDK: `checkpoint_durable_execution` and `get_durable_execution_state`.
6//!
7//! [`RealBackend`] calls AWS via `aws-sdk-lambda`; `MockBackend` (in
8//! `durable-lambda-testing`) returns pre-loaded data for credential-free testing.
9
10use std::time::Duration;
11
12use aws_sdk_lambda::operation::checkpoint_durable_execution::CheckpointDurableExecutionOutput;
13use aws_sdk_lambda::operation::get_durable_execution_state::GetDurableExecutionStateOutput;
14use aws_sdk_lambda::types::OperationUpdate;
15
16use crate::error::DurableError;
17
18/// Define the I/O boundary between the replay engine and the durable execution backend.
19///
20/// This trait abstracts the 2 AWS Lambda durable execution API operations that
21/// the SDK uses internally. Implement this trait for real AWS calls
22/// ([`RealBackend`]) or for testing ([`MockBackend`] in `durable-lambda-testing`).
23///
24/// # Object Safety
25///
26/// This trait is object-safe and designed to be used as
27/// `Arc<dyn DurableBackend + Send + Sync>`.
28///
29/// # Examples
30///
31/// ```
32/// use durable_lambda_core::backend::{DurableBackend, RealBackend};
33///
34/// // RealBackend implements DurableBackend.
35/// fn accepts_backend(_b: &dyn DurableBackend) {}
36/// ```
37#[async_trait::async_trait]
38pub trait DurableBackend: Send + Sync {
39    /// Persist checkpoint updates for a durable execution.
40    ///
41    /// Wraps the `checkpoint_durable_execution` AWS API. Sends a batch of
42    /// [`OperationUpdate`] items and receives a new checkpoint token plus
43    /// any updated execution state.
44    ///
45    /// # Errors
46    ///
47    /// Returns [`DurableError`] if the AWS API call fails after retries.
48    async fn checkpoint(
49        &self,
50        arn: &str,
51        checkpoint_token: &str,
52        updates: Vec<OperationUpdate>,
53        client_token: Option<&str>,
54    ) -> Result<CheckpointDurableExecutionOutput, DurableError>;
55
56    /// Get the current operation state of a durable execution (paginated).
57    ///
58    /// Wraps the `get_durable_execution_state` AWS API. Returns a page of
59    /// [`Operation`](aws_sdk_lambda::types::Operation) items and an optional
60    /// `next_marker` for pagination.
61    ///
62    /// # Errors
63    ///
64    /// Returns [`DurableError`] if the AWS API call fails after retries.
65    async fn get_execution_state(
66        &self,
67        arn: &str,
68        checkpoint_token: &str,
69        next_marker: &str,
70        max_items: i32,
71    ) -> Result<GetDurableExecutionStateOutput, DurableError>;
72
73    /// Persist multiple checkpoint updates in a single AWS API call.
74    ///
75    /// Default implementation delegates to [`checkpoint`](Self::checkpoint),
76    /// making it backward-compatible for existing implementors. Override
77    /// in test mocks to track batch-specific call counts.
78    ///
79    /// # Errors
80    ///
81    /// Returns [`DurableError`] if the underlying AWS API call fails.
82    async fn batch_checkpoint(
83        &self,
84        arn: &str,
85        checkpoint_token: &str,
86        updates: Vec<OperationUpdate>,
87        client_token: Option<&str>,
88    ) -> Result<CheckpointDurableExecutionOutput, DurableError> {
89        self.checkpoint(arn, checkpoint_token, updates, client_token)
90            .await
91    }
92}
93
94/// Real AWS backend that calls `aws-sdk-lambda` durable execution APIs.
95///
96/// Implements [`DurableBackend`] with exponential backoff retry for transient
97/// AWS failures (throttling, server errors, timeouts).
98///
99/// # Examples
100///
101/// ```no_run
102/// use aws_sdk_lambda::Client;
103/// use durable_lambda_core::backend::RealBackend;
104///
105/// # async fn example() {
106/// let config = aws_config::load_defaults(aws_config::BehaviorVersion::latest()).await;
107/// let client = Client::new(&config);
108/// let backend = RealBackend::new(client);
109/// # }
110/// ```
111pub struct RealBackend {
112    client: aws_sdk_lambda::Client,
113}
114
115impl RealBackend {
116    /// Create a new `RealBackend` wrapping an `aws-sdk-lambda` client.
117    ///
118    /// # Examples
119    ///
120    /// ```no_run
121    /// use aws_sdk_lambda::Client;
122    /// use durable_lambda_core::backend::RealBackend;
123    ///
124    /// # async fn example() {
125    /// let config = aws_config::load_defaults(aws_config::BehaviorVersion::latest()).await;
126    /// let client = Client::new(&config);
127    /// let backend = RealBackend::new(client);
128    /// # }
129    /// ```
130    pub fn new(client: aws_sdk_lambda::Client) -> Self {
131        Self { client }
132    }
133}
134
135/// Maximum number of retry attempts for transient AWS failures.
136const MAX_RETRIES: u32 = 3;
137/// Base delay for exponential backoff (100ms).
138const BASE_DELAY_MS: u64 = 100;
139/// Maximum delay cap for backoff (2s).
140const MAX_DELAY_MS: u64 = 2000;
141
142/// Compute backoff delay with jitter for a given attempt (0-indexed).
143///
144/// Uses "full jitter" strategy: uniform random delay in `[0, min(cap, base * 2^attempt)]`.
145/// Entropy comes from `SystemTime` nanoseconds — sufficient for retry decorrelation,
146/// not cryptographic.
147fn backoff_delay(attempt: u32) -> Duration {
148    let base = BASE_DELAY_MS.saturating_mul(1u64 << attempt);
149    let capped = base.min(MAX_DELAY_MS);
150    // Use system time nanoseconds as cheap entropy source for jitter.
151    let nanos = std::time::SystemTime::now()
152        .duration_since(std::time::UNIX_EPOCH)
153        .unwrap_or_default()
154        .subsec_nanos() as u64;
155    let jittered = if capped > 0 { nanos % capped } else { 0 };
156    Duration::from_millis(jittered)
157}
158
159/// Check if an error is retryable (only AWS transient errors qualify).
160///
161/// Only `AwsSdkOperation` and `AwsSdk` errors can represent transient AWS
162/// failures. All other `DurableError` variants (replay mismatches,
163/// serialization errors, etc.) are deterministic and must not be retried.
164fn is_retryable_error(err: &DurableError) -> bool {
165    match err {
166        DurableError::AwsSdkOperation(source) => {
167            let msg = source.to_string().to_lowercase();
168            msg.contains("throttl")
169                || msg.contains("rate exceeded")
170                || msg.contains("too many requests")
171                || msg.contains("service unavailable")
172                || msg.contains("internal server error")
173                || msg.contains("timed out")
174                || msg.contains("timeout")
175        }
176        DurableError::AwsSdk(sdk_err) => {
177            let msg = sdk_err.to_string().to_lowercase();
178            msg.contains("throttl")
179                || msg.contains("service unavailable")
180                || msg.contains("timed out")
181        }
182        // All other variants are deterministic errors -- never retry.
183        _ => false,
184    }
185}
186
187#[async_trait::async_trait]
188impl DurableBackend for RealBackend {
189    async fn checkpoint(
190        &self,
191        arn: &str,
192        checkpoint_token: &str,
193        updates: Vec<OperationUpdate>,
194        client_token: Option<&str>,
195    ) -> Result<CheckpointDurableExecutionOutput, DurableError> {
196        let mut last_err = None;
197
198        for attempt in 0..=MAX_RETRIES {
199            let mut builder = self
200                .client
201                .checkpoint_durable_execution()
202                .durable_execution_arn(arn)
203                .checkpoint_token(checkpoint_token)
204                .set_updates(Some(updates.clone()));
205
206            if let Some(token) = client_token {
207                builder = builder.client_token(token);
208            }
209
210            match builder.send().await {
211                Ok(output) => return Ok(output),
212                Err(e) => {
213                    let durable_err = DurableError::aws_sdk_operation(e);
214                    if attempt < MAX_RETRIES && is_retryable_error(&durable_err) {
215                        tokio::time::sleep(backoff_delay(attempt)).await;
216                        last_err = Some(durable_err);
217                        continue;
218                    }
219                    return Err(durable_err);
220                }
221            }
222        }
223
224        Err(last_err.unwrap())
225    }
226
227    async fn get_execution_state(
228        &self,
229        arn: &str,
230        checkpoint_token: &str,
231        next_marker: &str,
232        max_items: i32,
233    ) -> Result<GetDurableExecutionStateOutput, DurableError> {
234        let mut last_err = None;
235
236        for attempt in 0..=MAX_RETRIES {
237            let result = self
238                .client
239                .get_durable_execution_state()
240                .durable_execution_arn(arn)
241                .checkpoint_token(checkpoint_token)
242                .marker(next_marker)
243                .max_items(max_items)
244                .send()
245                .await;
246
247            match result {
248                Ok(output) => return Ok(output),
249                Err(e) => {
250                    let durable_err = DurableError::aws_sdk_operation(e);
251                    if attempt < MAX_RETRIES && is_retryable_error(&durable_err) {
252                        tokio::time::sleep(backoff_delay(attempt)).await;
253                        last_err = Some(durable_err);
254                        continue;
255                    }
256                    return Err(durable_err);
257                }
258            }
259        }
260
261        Err(last_err.unwrap())
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268    use std::io;
269    use std::sync::Arc;
270
271    #[test]
272    fn durable_backend_is_object_safe() {
273        // Verify DurableBackend can be used as a trait object.
274        fn _accepts_dyn(_b: Arc<dyn DurableBackend>) {}
275    }
276
277    #[test]
278    fn real_backend_is_send_sync() {
279        fn _assert_send_sync<T: Send + Sync>() {}
280        _assert_send_sync::<RealBackend>();
281    }
282
283    #[test]
284    fn backoff_delay_within_bounds() {
285        // Each attempt's delay must be in [0, min(cap, base * 2^attempt)].
286        for attempt in 0..=MAX_RETRIES {
287            let d = backoff_delay(attempt);
288            let base = BASE_DELAY_MS.saturating_mul(1u64 << attempt);
289            let capped = base.min(MAX_DELAY_MS);
290            assert!(
291                d.as_millis() <= capped as u128,
292                "attempt {attempt}: delay {}ms exceeds cap {capped}ms",
293                d.as_millis()
294            );
295        }
296    }
297
298    // --- TDD RED: new behavior tests for variant-based retry detection ---
299
300    #[test]
301    fn is_retryable_detects_throttling() {
302        let err = DurableError::aws_sdk_operation(io::Error::new(
303            io::ErrorKind::Other,
304            "Throttling: Rate exceeded",
305        ));
306        assert!(is_retryable_error(&err));
307    }
308
309    #[test]
310    fn is_retryable_detects_timeout() {
311        let err = DurableError::aws_sdk_operation(io::Error::new(
312            io::ErrorKind::TimedOut,
313            "connection timed out",
314        ));
315        assert!(is_retryable_error(&err));
316    }
317
318    #[test]
319    fn is_retryable_rejects_non_transient() {
320        let err = DurableError::replay_mismatch("Step", "Wait", 0);
321        assert!(!is_retryable_error(&err));
322    }
323
324    #[test]
325    fn is_retryable_ignores_checkpoint_failed_with_throttle() {
326        // KEY behavior change: CheckpointFailed with "Throttling" must NOT be retried.
327        // Previously the string-scanning impl would retry this incorrectly.
328        let err = DurableError::checkpoint_failed(
329            "test",
330            io::Error::new(io::ErrorKind::Other, "Throttling: Rate exceeded"),
331        );
332        assert!(!is_retryable_error(&err));
333    }
334
335    #[test]
336    fn is_retryable_ignores_serialization_errors() {
337        let serde_err = serde_json::from_str::<i32>("bad").unwrap_err();
338        let err = DurableError::serialization("MyType", serde_err);
339        assert!(!is_retryable_error(&err));
340    }
341
342    #[test]
343    fn is_retryable_detects_service_unavailable() {
344        let err = DurableError::aws_sdk_operation(io::Error::new(
345            io::ErrorKind::Other,
346            "service unavailable",
347        ));
348        assert!(is_retryable_error(&err));
349    }
350
351    #[test]
352    fn is_retryable_detects_rate_exceeded() {
353        let err =
354            DurableError::aws_sdk_operation(io::Error::new(io::ErrorKind::Other, "rate exceeded"));
355        assert!(is_retryable_error(&err));
356    }
357
358    #[test]
359    fn is_retryable_detects_internal_server_error() {
360        let err = DurableError::aws_sdk_operation(io::Error::new(
361            io::ErrorKind::Other,
362            "internal server error",
363        ));
364        assert!(is_retryable_error(&err));
365    }
366
367    #[test]
368    fn is_retryable_rejects_callback_failed() {
369        let err = DurableError::callback_failed("op", "cb-1", "external system rejected");
370        assert!(!is_retryable_error(&err));
371    }
372}