durable_lambda_core/context.rs
1//! DurableContext — the main context struct passed to handler functions.
2//!
3//! Own the replay state machine, backend connection, and execution metadata.
4//! Provide methods for all durable operations to interact with the replay engine.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use aws_sdk_lambda::types::Operation;
10use aws_sdk_lambda::types::OperationUpdate;
11
12use crate::backend::DurableBackend;
13use crate::error::DurableError;
14use crate::replay::ReplayEngine;
15use crate::types::{CompensationRecord, ExecutionMode};
16
17/// Main context for a durable execution invocation.
18///
19/// `DurableContext` is created at the start of each Lambda invocation. It loads
20/// the complete operation state from AWS (paginating if necessary), initializes
21/// the replay engine, and provides the interface for durable operations.
22///
23/// # Construction
24///
25/// Use [`DurableContext::new`] to create a context from the invocation payload.
26/// The constructor paginates through all remaining operations automatically.
27///
28/// # Examples
29///
30/// ```no_run
31/// use durable_lambda_core::context::DurableContext;
32/// use durable_lambda_core::backend::RealBackend;
33/// use durable_lambda_core::types::ExecutionMode;
34/// use std::sync::Arc;
35/// use std::collections::HashMap;
36///
37/// # async fn example() -> Result<(), durable_lambda_core::error::DurableError> {
38/// let config = aws_config::load_defaults(aws_config::BehaviorVersion::latest()).await;
39/// let client = aws_sdk_lambda::Client::new(&config);
40/// let backend = Arc::new(RealBackend::new(client));
41///
42/// let ctx = DurableContext::new(
43/// backend,
44/// "arn:aws:lambda:us-east-1:123456789:durable-execution/my-exec".to_string(),
45/// "initial-token".to_string(),
46/// vec![], // initial operations from invocation payload
47/// None, // no more pages
48/// ).await?;
49///
50/// match ctx.execution_mode() {
51/// ExecutionMode::Replaying => println!("Replaying from history"),
52/// ExecutionMode::Executing => println!("Executing new operations"),
53/// }
54/// # Ok(())
55/// # }
56/// ```
57pub struct DurableContext {
58 backend: Arc<dyn DurableBackend>,
59 replay_engine: ReplayEngine,
60 durable_execution_arn: String,
61 checkpoint_token: String,
62 parent_op_id: Option<String>,
63 batch_mode: bool,
64 pending_updates: Vec<OperationUpdate>,
65 /// Registered compensation closures for the saga/compensation pattern.
66 /// Populated by `step_with_compensation` on forward step success.
67 /// Drained and executed in reverse order by `run_compensations`.
68 compensations: Vec<CompensationRecord>,
69}
70
71/// Maximum items per page when paginating execution state.
72const PAGE_SIZE: i32 = 1000;
73
74impl DurableContext {
75 /// Create a new `DurableContext` from invocation parameters.
76 ///
77 /// Loads the complete operation state by paginating through
78 /// `get_execution_state` until all pages are fetched. Initializes the
79 /// replay engine with the full operations map.
80 ///
81 /// # Arguments
82 ///
83 /// * `backend` — The durable execution backend (real or mock).
84 /// * `arn` — The durable execution ARN.
85 /// * `checkpoint_token` — The initial checkpoint token from the invocation payload.
86 /// * `initial_operations` — First page of operations from the invocation payload.
87 /// * `next_marker` — Pagination marker for additional pages (`None` if complete).
88 ///
89 /// # Errors
90 ///
91 /// Returns [`DurableError`] if paginating the execution state fails.
92 pub async fn new(
93 backend: Arc<dyn DurableBackend>,
94 arn: String,
95 checkpoint_token: String,
96 initial_operations: Vec<Operation>,
97 next_marker: Option<String>,
98 ) -> Result<Self, DurableError> {
99 let mut operations: HashMap<String, Operation> = initial_operations
100 .into_iter()
101 .map(|op| (op.id().to_string(), op))
102 .collect();
103
104 // Paginate remaining operations.
105 let mut marker = next_marker;
106 while let Some(ref m) = marker {
107 if m.is_empty() {
108 break;
109 }
110 let response = backend
111 .get_execution_state(&arn, &checkpoint_token, m, PAGE_SIZE)
112 .await?;
113
114 for op in response.operations() {
115 operations.insert(op.id().to_string(), op.clone());
116 }
117
118 marker = response.next_marker().map(|s| s.to_string());
119 }
120
121 let replay_engine = ReplayEngine::new(operations, None);
122
123 Ok(Self {
124 backend,
125 replay_engine,
126 durable_execution_arn: arn,
127 checkpoint_token,
128 parent_op_id: None,
129 batch_mode: false,
130 pending_updates: Vec::new(),
131 compensations: Vec::new(),
132 })
133 }
134
135 /// Return the current execution mode (Replaying or Executing).
136 ///
137 /// # Examples
138 ///
139 /// ```no_run
140 /// # async fn example(ctx: durable_lambda_core::context::DurableContext) {
141 /// use durable_lambda_core::types::ExecutionMode;
142 /// match ctx.execution_mode() {
143 /// ExecutionMode::Replaying => { /* returning cached results */ }
144 /// ExecutionMode::Executing => { /* running new operations */ }
145 /// }
146 /// # }
147 /// ```
148 pub fn execution_mode(&self) -> ExecutionMode {
149 self.replay_engine.execution_mode()
150 }
151
152 /// Return whether the context is currently replaying from history.
153 ///
154 /// # Examples
155 ///
156 /// ```no_run
157 /// # async fn example(ctx: durable_lambda_core::context::DurableContext) {
158 /// if ctx.is_replaying() {
159 /// println!("Replaying cached operations");
160 /// }
161 /// # }
162 /// ```
163 pub fn is_replaying(&self) -> bool {
164 self.replay_engine.is_replaying()
165 }
166
167 /// Return a reference to the durable execution ARN.
168 ///
169 /// # Examples
170 ///
171 /// ```no_run
172 /// # async fn example(ctx: durable_lambda_core::context::DurableContext) {
173 /// println!("Execution ARN: {}", ctx.arn());
174 /// # }
175 /// ```
176 pub fn arn(&self) -> &str {
177 &self.durable_execution_arn
178 }
179
180 /// Return the current checkpoint token.
181 ///
182 /// # Examples
183 ///
184 /// ```no_run
185 /// # async fn example(ctx: durable_lambda_core::context::DurableContext) {
186 /// let token = ctx.checkpoint_token();
187 /// # }
188 /// ```
189 pub fn checkpoint_token(&self) -> &str {
190 &self.checkpoint_token
191 }
192
193 /// Update the checkpoint token (called after a successful checkpoint).
194 ///
195 /// # Examples
196 ///
197 /// ```no_run
198 /// # async fn example(mut ctx: durable_lambda_core::context::DurableContext) {
199 /// ctx.set_checkpoint_token("new-token-from-aws".to_string());
200 /// # }
201 /// ```
202 pub fn set_checkpoint_token(&mut self, token: String) {
203 self.checkpoint_token = token;
204 }
205
206 /// Return a reference to the backend.
207 ///
208 /// # Examples
209 ///
210 /// ```no_run
211 /// # async fn example(ctx: durable_lambda_core::context::DurableContext) {
212 /// let _backend = ctx.backend();
213 /// # }
214 /// ```
215 pub fn backend(&self) -> &Arc<dyn DurableBackend> {
216 &self.backend
217 }
218
219 /// Return a mutable reference to the replay engine.
220 ///
221 /// # Examples
222 ///
223 /// ```no_run
224 /// # async fn example(mut ctx: durable_lambda_core::context::DurableContext) {
225 /// let engine = ctx.replay_engine_mut();
226 /// # }
227 /// ```
228 pub fn replay_engine_mut(&mut self) -> &mut ReplayEngine {
229 &mut self.replay_engine
230 }
231
232 /// Create a child context for isolated operation ID namespacing.
233 ///
234 /// The child context shares the same backend and ARN but gets its own
235 /// `ReplayEngine` with a parent-scoped `OperationIdGenerator`. Operations
236 /// within the child context produce deterministic IDs scoped under
237 /// `parent_op_id`, preventing collisions with the parent or sibling contexts.
238 ///
239 /// Used internally by parallel and child_context operations.
240 ///
241 /// # Arguments
242 ///
243 /// * `parent_op_id` — The operation ID that scopes this child context
244 ///
245 /// # Examples
246 ///
247 /// ```no_run
248 /// # async fn example(ctx: &durable_lambda_core::context::DurableContext) {
249 /// let child = ctx.create_child_context("branch-op-id");
250 /// // child operations will have IDs scoped under "branch-op-id"
251 /// # }
252 /// ```
253 pub fn create_child_context(&self, parent_op_id: &str) -> DurableContext {
254 let operations = self.replay_engine.operations().clone();
255 let replay_engine = ReplayEngine::new(operations, Some(parent_op_id.to_string()));
256
257 DurableContext {
258 backend: self.backend.clone(),
259 replay_engine,
260 durable_execution_arn: self.durable_execution_arn.clone(),
261 checkpoint_token: self.checkpoint_token.clone(),
262 parent_op_id: Some(parent_op_id.to_string()),
263 batch_mode: false,
264 pending_updates: Vec::new(),
265 compensations: Vec::new(), // NOT inherited from parent (isolated per context)
266 }
267 }
268
269 /// Return a reference to the replay engine.
270 ///
271 /// # Examples
272 ///
273 /// ```no_run
274 /// # async fn example(ctx: durable_lambda_core::context::DurableContext) {
275 /// let engine = ctx.replay_engine();
276 /// assert!(!engine.operations().is_empty() || true);
277 /// # }
278 /// ```
279 pub fn replay_engine(&self) -> &ReplayEngine {
280 &self.replay_engine
281 }
282
283 /// Return the parent operation ID, if this is a child context.
284 ///
285 /// Returns `None` for the root context. Returns the parent's operation ID
286 /// for child contexts created via [`create_child_context`](Self::create_child_context).
287 /// Used by replay-safe logging for hierarchical tracing.
288 ///
289 /// # Examples
290 ///
291 /// ```no_run
292 /// # async fn example(ctx: &durable_lambda_core::context::DurableContext) {
293 /// if let Some(parent_id) = ctx.parent_op_id() {
294 /// println!("Child context under parent: {parent_id}");
295 /// }
296 /// # }
297 /// ```
298 pub fn parent_op_id(&self) -> Option<&str> {
299 self.parent_op_id.as_deref()
300 }
301
302 /// Enable batch checkpoint mode.
303 ///
304 /// When enabled, step operation checkpoints (START and SUCCEED/FAIL)
305 /// are accumulated in memory instead of being sent immediately.
306 /// Call [`flush_batch`](Self::flush_batch) to send all accumulated
307 /// updates in a single AWS API call.
308 ///
309 /// Batch mode applies only to `step` operations. `wait`, `invoke`,
310 /// and `callback` always send checkpoints immediately because they
311 /// produce suspension errors that require the checkpoint to be
312 /// persisted before the function exits.
313 ///
314 /// # Examples
315 ///
316 /// ```no_run
317 /// # async fn example(mut ctx: durable_lambda_core::context::DurableContext) -> Result<(), durable_lambda_core::error::DurableError> {
318 /// ctx.enable_batch_mode();
319 /// let _: Result<i32, String> = ctx.step("step1", || async { Ok(1) }).await?;
320 /// let _: Result<i32, String> = ctx.step("step2", || async { Ok(2) }).await?;
321 /// ctx.flush_batch().await?; // sends all updates in one call
322 /// # Ok(())
323 /// # }
324 /// ```
325 pub fn enable_batch_mode(&mut self) {
326 self.batch_mode = true;
327 }
328
329 /// Return whether batch checkpoint mode is active.
330 pub fn is_batch_mode(&self) -> bool {
331 self.batch_mode
332 }
333
334 /// Accumulate an operation update for later batch flush.
335 ///
336 /// Called internally by step operations when batch mode is active.
337 pub(crate) fn push_pending_update(&mut self, update: OperationUpdate) {
338 self.pending_updates.push(update);
339 }
340
341 /// Return the number of pending (unflushed) updates.
342 pub fn pending_update_count(&self) -> usize {
343 self.pending_updates.len()
344 }
345
346 /// Return the number of registered compensation closures.
347 ///
348 /// Useful for asserting compensation registration in tests.
349 ///
350 /// # Examples
351 ///
352 /// ```no_run
353 /// # async fn example(ctx: durable_lambda_core::context::DurableContext) {
354 /// assert_eq!(ctx.compensation_count(), 0);
355 /// # }
356 /// ```
357 pub fn compensation_count(&self) -> usize {
358 self.compensations.len()
359 }
360
361 /// Register a compensation closure after a successful forward step.
362 ///
363 /// Called by `step_with_compensation` when the forward step succeeds.
364 pub(crate) fn push_compensation(&mut self, record: CompensationRecord) {
365 self.compensations.push(record);
366 }
367
368 /// Drain all registered compensations for execution.
369 ///
370 /// Returns the compensations vec (emptying the field) so `run_compensations`
371 /// can execute them. Reversing the returned vec gives LIFO order.
372 pub(crate) fn take_compensations(&mut self) -> Vec<CompensationRecord> {
373 std::mem::take(&mut self.compensations)
374 }
375
376 /// Flush all accumulated checkpoint updates in a single AWS API call.
377 ///
378 /// No-op if no updates are pending. After flushing, the checkpoint
379 /// token is updated from the response.
380 ///
381 /// # Errors
382 ///
383 /// Returns [`DurableError`] if the batch checkpoint call fails.
384 ///
385 /// # Examples
386 ///
387 /// ```no_run
388 /// # async fn example(mut ctx: durable_lambda_core::context::DurableContext) -> Result<(), durable_lambda_core::error::DurableError> {
389 /// ctx.enable_batch_mode();
390 /// // ... run several steps ...
391 /// ctx.flush_batch().await?;
392 /// # Ok(())
393 /// # }
394 /// ```
395 pub async fn flush_batch(&mut self) -> Result<(), DurableError> {
396 if self.pending_updates.is_empty() {
397 return Ok(());
398 }
399 let updates = std::mem::take(&mut self.pending_updates);
400 let response = self
401 .backend()
402 .batch_checkpoint(self.arn(), self.checkpoint_token(), updates, None)
403 .await?;
404 let new_token = response.checkpoint_token().ok_or_else(|| {
405 DurableError::checkpoint_failed(
406 "batch",
407 std::io::Error::new(
408 std::io::ErrorKind::InvalidData,
409 "batch checkpoint response missing checkpoint_token",
410 ),
411 )
412 })?;
413 self.set_checkpoint_token(new_token.to_string());
414 Ok(())
415 }
416}
417
418#[cfg(test)]
419mod tests {
420 use super::*;
421 use aws_sdk_lambda::operation::checkpoint_durable_execution::CheckpointDurableExecutionOutput;
422 use aws_sdk_lambda::operation::get_durable_execution_state::GetDurableExecutionStateOutput;
423 use aws_sdk_lambda::types::{OperationStatus, OperationType, OperationUpdate};
424 /// A simple mock backend for testing context construction.
425 struct TestBackend {
426 pages: Vec<(Vec<Operation>, Option<String>)>,
427 }
428
429 #[async_trait::async_trait]
430 impl DurableBackend for TestBackend {
431 async fn checkpoint(
432 &self,
433 _arn: &str,
434 _checkpoint_token: &str,
435 _updates: Vec<OperationUpdate>,
436 _client_token: Option<&str>,
437 ) -> Result<CheckpointDurableExecutionOutput, DurableError> {
438 unimplemented!("not needed for context tests")
439 }
440
441 async fn get_execution_state(
442 &self,
443 _arn: &str,
444 _checkpoint_token: &str,
445 next_marker: &str,
446 _max_items: i32,
447 ) -> Result<GetDurableExecutionStateOutput, DurableError> {
448 let page_idx: usize = next_marker.parse().unwrap_or(0);
449 if page_idx >= self.pages.len() {
450 return Ok(GetDurableExecutionStateOutput::builder().build().unwrap());
451 }
452 let (ops, marker) = &self.pages[page_idx];
453 let mut builder = GetDurableExecutionStateOutput::builder();
454 for op in ops {
455 builder = builder.operations(op.clone());
456 }
457 if let Some(m) = marker {
458 builder = builder.next_marker(m);
459 }
460 Ok(builder.build().unwrap())
461 }
462 }
463
464 fn make_op(id: &str, status: OperationStatus) -> Operation {
465 Operation::builder()
466 .id(id)
467 .r#type(OperationType::Step)
468 .status(status)
469 .start_timestamp(aws_smithy_types::DateTime::from_secs(0))
470 .build()
471 .unwrap()
472 }
473
474 #[tokio::test]
475 async fn empty_history_creates_executing_context() {
476 let backend = Arc::new(TestBackend { pages: vec![] });
477 let ctx = DurableContext::new(backend, "arn:test".into(), "tok".into(), vec![], None)
478 .await
479 .unwrap();
480
481 assert_eq!(ctx.execution_mode(), ExecutionMode::Executing);
482 assert!(!ctx.is_replaying());
483 assert_eq!(ctx.arn(), "arn:test");
484 assert_eq!(ctx.checkpoint_token(), "tok");
485 }
486
487 #[tokio::test]
488 async fn initial_operations_loaded() {
489 let backend = Arc::new(TestBackend { pages: vec![] });
490 let ops = vec![make_op("op1", OperationStatus::Succeeded)];
491 let ctx = DurableContext::new(backend, "arn:test".into(), "tok".into(), ops, None)
492 .await
493 .unwrap();
494
495 assert!(ctx.is_replaying());
496 assert!(ctx.replay_engine().check_result("op1").is_some());
497 }
498
499 #[tokio::test]
500 async fn pagination_loads_all_pages() {
501 let backend = Arc::new(TestBackend {
502 pages: vec![
503 (
504 vec![make_op("op2", OperationStatus::Succeeded)],
505 Some("1".to_string()),
506 ),
507 (vec![make_op("op3", OperationStatus::Succeeded)], None),
508 ],
509 });
510
511 let initial = vec![make_op("op1", OperationStatus::Succeeded)];
512 let ctx = DurableContext::new(
513 backend,
514 "arn:test".into(),
515 "tok".into(),
516 initial,
517 Some("0".to_string()),
518 )
519 .await
520 .unwrap();
521
522 assert!(ctx.replay_engine().check_result("op1").is_some());
523 assert!(ctx.replay_engine().check_result("op2").is_some());
524 assert!(ctx.replay_engine().check_result("op3").is_some());
525 }
526
527 #[tokio::test]
528 async fn set_checkpoint_token_updates() {
529 let backend = Arc::new(TestBackend { pages: vec![] });
530 let mut ctx = DurableContext::new(backend, "arn:test".into(), "tok1".into(), vec![], None)
531 .await
532 .unwrap();
533
534 assert_eq!(ctx.checkpoint_token(), "tok1");
535 ctx.set_checkpoint_token("tok2".to_string());
536 assert_eq!(ctx.checkpoint_token(), "tok2");
537 }
538
539 // --- compensation field tests ---
540
541 #[tokio::test]
542 async fn new_context_has_empty_compensations() {
543 let backend = Arc::new(TestBackend { pages: vec![] });
544 let ctx = DurableContext::new(backend, "arn:test".into(), "tok".into(), vec![], None)
545 .await
546 .unwrap();
547
548 assert_eq!(ctx.compensation_count(), 0);
549 }
550
551 #[tokio::test]
552 async fn create_child_context_has_empty_compensations_not_inherited() {
553 use crate::types::CompensationRecord;
554
555 let backend = Arc::new(TestBackend { pages: vec![] });
556 let mut ctx = DurableContext::new(backend, "arn:test".into(), "tok".into(), vec![], None)
557 .await
558 .unwrap();
559
560 // Register a compensation on the parent
561 let record = CompensationRecord {
562 name: "parent_comp".to_string(),
563 forward_result_json: serde_json::Value::Null,
564 compensate_fn: Box::new(|_| Box::pin(async { Ok(()) })),
565 };
566 ctx.push_compensation(record);
567 assert_eq!(ctx.compensation_count(), 1);
568
569 // Child context should start with 0 compensations
570 let child = ctx.create_child_context("some-op-id");
571 assert_eq!(
572 child.compensation_count(),
573 0,
574 "child context must NOT inherit parent compensations"
575 );
576 }
577}