durable_lambda_core/
backend.rs1use std::time::Duration;
11
12use aws_sdk_lambda::operation::checkpoint_durable_execution::CheckpointDurableExecutionOutput;
13use aws_sdk_lambda::operation::get_durable_execution_state::GetDurableExecutionStateOutput;
14use aws_sdk_lambda::types::OperationUpdate;
15
16use crate::error::DurableError;
17
18#[async_trait::async_trait]
38pub trait DurableBackend: Send + Sync {
39 async fn checkpoint(
49 &self,
50 arn: &str,
51 checkpoint_token: &str,
52 updates: Vec<OperationUpdate>,
53 client_token: Option<&str>,
54 ) -> Result<CheckpointDurableExecutionOutput, DurableError>;
55
56 async fn get_execution_state(
66 &self,
67 arn: &str,
68 checkpoint_token: &str,
69 next_marker: &str,
70 max_items: i32,
71 ) -> Result<GetDurableExecutionStateOutput, DurableError>;
72
73 async fn batch_checkpoint(
83 &self,
84 arn: &str,
85 checkpoint_token: &str,
86 updates: Vec<OperationUpdate>,
87 client_token: Option<&str>,
88 ) -> Result<CheckpointDurableExecutionOutput, DurableError> {
89 self.checkpoint(arn, checkpoint_token, updates, client_token)
90 .await
91 }
92}
93
94pub struct RealBackend {
112 client: aws_sdk_lambda::Client,
113}
114
115impl RealBackend {
116 pub fn new(client: aws_sdk_lambda::Client) -> Self {
131 Self { client }
132 }
133}
134
135const MAX_RETRIES: u32 = 3;
137const BASE_DELAY_MS: u64 = 100;
139const MAX_DELAY_MS: u64 = 2000;
141
142fn backoff_delay(attempt: u32) -> Duration {
148 let base = BASE_DELAY_MS.saturating_mul(1u64 << attempt);
149 let capped = base.min(MAX_DELAY_MS);
150 let nanos = std::time::SystemTime::now()
152 .duration_since(std::time::UNIX_EPOCH)
153 .unwrap_or_default()
154 .subsec_nanos() as u64;
155 let jittered = if capped > 0 { nanos % capped } else { 0 };
156 Duration::from_millis(jittered)
157}
158
159fn is_retryable_error(err: &DurableError) -> bool {
165 match err {
166 DurableError::AwsSdkOperation(source) => {
167 let msg = source.to_string().to_lowercase();
168 msg.contains("throttl")
169 || msg.contains("rate exceeded")
170 || msg.contains("too many requests")
171 || msg.contains("service unavailable")
172 || msg.contains("internal server error")
173 || msg.contains("timed out")
174 || msg.contains("timeout")
175 }
176 DurableError::AwsSdk(sdk_err) => {
177 let msg = sdk_err.to_string().to_lowercase();
178 msg.contains("throttl")
179 || msg.contains("service unavailable")
180 || msg.contains("timed out")
181 }
182 _ => false,
184 }
185}
186
187#[async_trait::async_trait]
188impl DurableBackend for RealBackend {
189 async fn checkpoint(
190 &self,
191 arn: &str,
192 checkpoint_token: &str,
193 updates: Vec<OperationUpdate>,
194 client_token: Option<&str>,
195 ) -> Result<CheckpointDurableExecutionOutput, DurableError> {
196 let mut last_err = None;
197
198 for attempt in 0..=MAX_RETRIES {
199 let mut builder = self
200 .client
201 .checkpoint_durable_execution()
202 .durable_execution_arn(arn)
203 .checkpoint_token(checkpoint_token)
204 .set_updates(Some(updates.clone()));
205
206 if let Some(token) = client_token {
207 builder = builder.client_token(token);
208 }
209
210 match builder.send().await {
211 Ok(output) => return Ok(output),
212 Err(e) => {
213 let durable_err = DurableError::aws_sdk_operation(e);
214 if attempt < MAX_RETRIES && is_retryable_error(&durable_err) {
215 tokio::time::sleep(backoff_delay(attempt)).await;
216 last_err = Some(durable_err);
217 continue;
218 }
219 return Err(durable_err);
220 }
221 }
222 }
223
224 Err(last_err.unwrap())
225 }
226
227 async fn get_execution_state(
228 &self,
229 arn: &str,
230 checkpoint_token: &str,
231 next_marker: &str,
232 max_items: i32,
233 ) -> Result<GetDurableExecutionStateOutput, DurableError> {
234 let mut last_err = None;
235
236 for attempt in 0..=MAX_RETRIES {
237 let result = self
238 .client
239 .get_durable_execution_state()
240 .durable_execution_arn(arn)
241 .checkpoint_token(checkpoint_token)
242 .marker(next_marker)
243 .max_items(max_items)
244 .send()
245 .await;
246
247 match result {
248 Ok(output) => return Ok(output),
249 Err(e) => {
250 let durable_err = DurableError::aws_sdk_operation(e);
251 if attempt < MAX_RETRIES && is_retryable_error(&durable_err) {
252 tokio::time::sleep(backoff_delay(attempt)).await;
253 last_err = Some(durable_err);
254 continue;
255 }
256 return Err(durable_err);
257 }
258 }
259 }
260
261 Err(last_err.unwrap())
262 }
263}
264
265#[cfg(test)]
266mod tests {
267 use super::*;
268 use std::io;
269 use std::sync::Arc;
270
271 #[test]
272 fn durable_backend_is_object_safe() {
273 fn _accepts_dyn(_b: Arc<dyn DurableBackend>) {}
275 }
276
277 #[test]
278 fn real_backend_is_send_sync() {
279 fn _assert_send_sync<T: Send + Sync>() {}
280 _assert_send_sync::<RealBackend>();
281 }
282
283 #[test]
284 fn backoff_delay_within_bounds() {
285 for attempt in 0..=MAX_RETRIES {
287 let d = backoff_delay(attempt);
288 let base = BASE_DELAY_MS.saturating_mul(1u64 << attempt);
289 let capped = base.min(MAX_DELAY_MS);
290 assert!(
291 d.as_millis() <= capped as u128,
292 "attempt {attempt}: delay {}ms exceeds cap {capped}ms",
293 d.as_millis()
294 );
295 }
296 }
297
298 #[test]
301 fn is_retryable_detects_throttling() {
302 let err = DurableError::aws_sdk_operation(io::Error::new(
303 io::ErrorKind::Other,
304 "Throttling: Rate exceeded",
305 ));
306 assert!(is_retryable_error(&err));
307 }
308
309 #[test]
310 fn is_retryable_detects_timeout() {
311 let err = DurableError::aws_sdk_operation(io::Error::new(
312 io::ErrorKind::TimedOut,
313 "connection timed out",
314 ));
315 assert!(is_retryable_error(&err));
316 }
317
318 #[test]
319 fn is_retryable_rejects_non_transient() {
320 let err = DurableError::replay_mismatch("Step", "Wait", 0);
321 assert!(!is_retryable_error(&err));
322 }
323
324 #[test]
325 fn is_retryable_ignores_checkpoint_failed_with_throttle() {
326 let err = DurableError::checkpoint_failed(
329 "test",
330 io::Error::new(io::ErrorKind::Other, "Throttling: Rate exceeded"),
331 );
332 assert!(!is_retryable_error(&err));
333 }
334
335 #[test]
336 fn is_retryable_ignores_serialization_errors() {
337 let serde_err = serde_json::from_str::<i32>("bad").unwrap_err();
338 let err = DurableError::serialization("MyType", serde_err);
339 assert!(!is_retryable_error(&err));
340 }
341
342 #[test]
343 fn is_retryable_detects_service_unavailable() {
344 let err = DurableError::aws_sdk_operation(io::Error::new(
345 io::ErrorKind::Other,
346 "service unavailable",
347 ));
348 assert!(is_retryable_error(&err));
349 }
350
351 #[test]
352 fn is_retryable_detects_rate_exceeded() {
353 let err =
354 DurableError::aws_sdk_operation(io::Error::new(io::ErrorKind::Other, "rate exceeded"));
355 assert!(is_retryable_error(&err));
356 }
357
358 #[test]
359 fn is_retryable_detects_internal_server_error() {
360 let err = DurableError::aws_sdk_operation(io::Error::new(
361 io::ErrorKind::Other,
362 "internal server error",
363 ));
364 assert!(is_retryable_error(&err));
365 }
366
367 #[test]
368 fn is_retryable_rejects_callback_failed() {
369 let err = DurableError::callback_failed("op", "cb-1", "external system rejected");
370 assert!(!is_retryable_error(&err));
371 }
372}