Skip to main content

durable_lambda_core/
context.rs

1//! DurableContext — the main context struct passed to handler functions.
2//!
3//! Own the replay state machine, backend connection, and execution metadata.
4//! Provide methods for all durable operations to interact with the replay engine.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use aws_sdk_lambda::types::Operation;
10use aws_sdk_lambda::types::OperationUpdate;
11
12use crate::backend::DurableBackend;
13use crate::error::DurableError;
14use crate::replay::ReplayEngine;
15use crate::types::{CompensationRecord, ExecutionMode};
16
17/// Main context for a durable execution invocation.
18///
19/// `DurableContext` is created at the start of each Lambda invocation. It loads
20/// the complete operation state from AWS (paginating if necessary), initializes
21/// the replay engine, and provides the interface for durable operations.
22///
23/// # Construction
24///
25/// Use [`DurableContext::new`] to create a context from the invocation payload.
26/// The constructor paginates through all remaining operations automatically.
27///
28/// # Examples
29///
30/// ```no_run
31/// use durable_lambda_core::context::DurableContext;
32/// use durable_lambda_core::backend::RealBackend;
33/// use durable_lambda_core::types::ExecutionMode;
34/// use std::sync::Arc;
35/// use std::collections::HashMap;
36///
37/// # async fn example() -> Result<(), durable_lambda_core::error::DurableError> {
38/// let config = aws_config::load_defaults(aws_config::BehaviorVersion::latest()).await;
39/// let client = aws_sdk_lambda::Client::new(&config);
40/// let backend = Arc::new(RealBackend::new(client));
41///
42/// let ctx = DurableContext::new(
43///     backend,
44///     "arn:aws:lambda:us-east-1:123456789:durable-execution/my-exec".to_string(),
45///     "initial-token".to_string(),
46///     vec![],       // initial operations from invocation payload
47///     None,         // no more pages
48/// ).await?;
49///
50/// match ctx.execution_mode() {
51///     ExecutionMode::Replaying => println!("Replaying from history"),
52///     ExecutionMode::Executing => println!("Executing new operations"),
53/// }
54/// # Ok(())
55/// # }
56/// ```
57pub struct DurableContext {
58    backend: Arc<dyn DurableBackend>,
59    replay_engine: ReplayEngine,
60    durable_execution_arn: String,
61    checkpoint_token: String,
62    parent_op_id: Option<String>,
63    batch_mode: bool,
64    pending_updates: Vec<OperationUpdate>,
65    /// Registered compensation closures for the saga/compensation pattern.
66    /// Populated by `step_with_compensation` on forward step success.
67    /// Drained and executed in reverse order by `run_compensations`.
68    compensations: Vec<CompensationRecord>,
69}
70
71/// Maximum items per page when paginating execution state.
72const PAGE_SIZE: i32 = 1000;
73
74impl DurableContext {
75    /// Create a new `DurableContext` from invocation parameters.
76    ///
77    /// Loads the complete operation state by paginating through
78    /// `get_execution_state` until all pages are fetched. Initializes the
79    /// replay engine with the full operations map.
80    ///
81    /// # Arguments
82    ///
83    /// * `backend` — The durable execution backend (real or mock).
84    /// * `arn` — The durable execution ARN.
85    /// * `checkpoint_token` — The initial checkpoint token from the invocation payload.
86    /// * `initial_operations` — First page of operations from the invocation payload.
87    /// * `next_marker` — Pagination marker for additional pages (`None` if complete).
88    ///
89    /// # Errors
90    ///
91    /// Returns [`DurableError`] if paginating the execution state fails.
92    pub async fn new(
93        backend: Arc<dyn DurableBackend>,
94        arn: String,
95        checkpoint_token: String,
96        initial_operations: Vec<Operation>,
97        next_marker: Option<String>,
98    ) -> Result<Self, DurableError> {
99        let mut operations: HashMap<String, Operation> = initial_operations
100            .into_iter()
101            .map(|op| (op.id().to_string(), op))
102            .collect();
103
104        // Paginate remaining operations.
105        let mut marker = next_marker;
106        while let Some(ref m) = marker {
107            if m.is_empty() {
108                break;
109            }
110            let response = backend
111                .get_execution_state(&arn, &checkpoint_token, m, PAGE_SIZE)
112                .await?;
113
114            for op in response.operations() {
115                operations.insert(op.id().to_string(), op.clone());
116            }
117
118            marker = response.next_marker().map(|s| s.to_string());
119        }
120
121        let replay_engine = ReplayEngine::new(operations, None);
122
123        Ok(Self {
124            backend,
125            replay_engine,
126            durable_execution_arn: arn,
127            checkpoint_token,
128            parent_op_id: None,
129            batch_mode: false,
130            pending_updates: Vec::new(),
131            compensations: Vec::new(),
132        })
133    }
134
135    /// Return the current execution mode (Replaying or Executing).
136    ///
137    /// # Examples
138    ///
139    /// ```no_run
140    /// # async fn example(ctx: durable_lambda_core::context::DurableContext) {
141    /// use durable_lambda_core::types::ExecutionMode;
142    /// match ctx.execution_mode() {
143    ///     ExecutionMode::Replaying => { /* returning cached results */ }
144    ///     ExecutionMode::Executing => { /* running new operations */ }
145    /// }
146    /// # }
147    /// ```
148    pub fn execution_mode(&self) -> ExecutionMode {
149        self.replay_engine.execution_mode()
150    }
151
152    /// Return whether the context is currently replaying from history.
153    ///
154    /// # Examples
155    ///
156    /// ```no_run
157    /// # async fn example(ctx: durable_lambda_core::context::DurableContext) {
158    /// if ctx.is_replaying() {
159    ///     println!("Replaying cached operations");
160    /// }
161    /// # }
162    /// ```
163    pub fn is_replaying(&self) -> bool {
164        self.replay_engine.is_replaying()
165    }
166
167    /// Return a reference to the durable execution ARN.
168    ///
169    /// # Examples
170    ///
171    /// ```no_run
172    /// # async fn example(ctx: durable_lambda_core::context::DurableContext) {
173    /// println!("Execution ARN: {}", ctx.arn());
174    /// # }
175    /// ```
176    pub fn arn(&self) -> &str {
177        &self.durable_execution_arn
178    }
179
180    /// Return the current checkpoint token.
181    ///
182    /// # Examples
183    ///
184    /// ```no_run
185    /// # async fn example(ctx: durable_lambda_core::context::DurableContext) {
186    /// let token = ctx.checkpoint_token();
187    /// # }
188    /// ```
189    pub fn checkpoint_token(&self) -> &str {
190        &self.checkpoint_token
191    }
192
193    /// Update the checkpoint token (called after a successful checkpoint).
194    ///
195    /// # Examples
196    ///
197    /// ```no_run
198    /// # async fn example(mut ctx: durable_lambda_core::context::DurableContext) {
199    /// ctx.set_checkpoint_token("new-token-from-aws".to_string());
200    /// # }
201    /// ```
202    pub fn set_checkpoint_token(&mut self, token: String) {
203        self.checkpoint_token = token;
204    }
205
206    /// Return a reference to the backend.
207    ///
208    /// # Examples
209    ///
210    /// ```no_run
211    /// # async fn example(ctx: durable_lambda_core::context::DurableContext) {
212    /// let _backend = ctx.backend();
213    /// # }
214    /// ```
215    pub fn backend(&self) -> &Arc<dyn DurableBackend> {
216        &self.backend
217    }
218
219    /// Return a mutable reference to the replay engine.
220    ///
221    /// # Examples
222    ///
223    /// ```no_run
224    /// # async fn example(mut ctx: durable_lambda_core::context::DurableContext) {
225    /// let engine = ctx.replay_engine_mut();
226    /// # }
227    /// ```
228    pub fn replay_engine_mut(&mut self) -> &mut ReplayEngine {
229        &mut self.replay_engine
230    }
231
232    /// Create a child context for isolated operation ID namespacing.
233    ///
234    /// The child context shares the same backend and ARN but gets its own
235    /// `ReplayEngine` with a parent-scoped `OperationIdGenerator`. Operations
236    /// within the child context produce deterministic IDs scoped under
237    /// `parent_op_id`, preventing collisions with the parent or sibling contexts.
238    ///
239    /// Used internally by parallel and child_context operations.
240    ///
241    /// # Arguments
242    ///
243    /// * `parent_op_id` — The operation ID that scopes this child context
244    ///
245    /// # Examples
246    ///
247    /// ```no_run
248    /// # async fn example(ctx: &durable_lambda_core::context::DurableContext) {
249    /// let child = ctx.create_child_context("branch-op-id");
250    /// // child operations will have IDs scoped under "branch-op-id"
251    /// # }
252    /// ```
253    pub fn create_child_context(&self, parent_op_id: &str) -> DurableContext {
254        let operations = self.replay_engine.operations().clone();
255        let replay_engine = ReplayEngine::new(operations, Some(parent_op_id.to_string()));
256
257        DurableContext {
258            backend: self.backend.clone(),
259            replay_engine,
260            durable_execution_arn: self.durable_execution_arn.clone(),
261            checkpoint_token: self.checkpoint_token.clone(),
262            parent_op_id: Some(parent_op_id.to_string()),
263            batch_mode: false,
264            pending_updates: Vec::new(),
265            compensations: Vec::new(), // NOT inherited from parent (isolated per context)
266        }
267    }
268
269    /// Return a reference to the replay engine.
270    ///
271    /// # Examples
272    ///
273    /// ```no_run
274    /// # async fn example(ctx: durable_lambda_core::context::DurableContext) {
275    /// let engine = ctx.replay_engine();
276    /// assert!(!engine.operations().is_empty() || true);
277    /// # }
278    /// ```
279    pub fn replay_engine(&self) -> &ReplayEngine {
280        &self.replay_engine
281    }
282
283    /// Return the parent operation ID, if this is a child context.
284    ///
285    /// Returns `None` for the root context. Returns the parent's operation ID
286    /// for child contexts created via [`create_child_context`](Self::create_child_context).
287    /// Used by replay-safe logging for hierarchical tracing.
288    ///
289    /// # Examples
290    ///
291    /// ```no_run
292    /// # async fn example(ctx: &durable_lambda_core::context::DurableContext) {
293    /// if let Some(parent_id) = ctx.parent_op_id() {
294    ///     println!("Child context under parent: {parent_id}");
295    /// }
296    /// # }
297    /// ```
298    pub fn parent_op_id(&self) -> Option<&str> {
299        self.parent_op_id.as_deref()
300    }
301
302    /// Enable batch checkpoint mode.
303    ///
304    /// When enabled, step operation checkpoints (START and SUCCEED/FAIL)
305    /// are accumulated in memory instead of being sent immediately.
306    /// Call [`flush_batch`](Self::flush_batch) to send all accumulated
307    /// updates in a single AWS API call.
308    ///
309    /// Batch mode applies only to `step` operations. `wait`, `invoke`,
310    /// and `callback` always send checkpoints immediately because they
311    /// produce suspension errors that require the checkpoint to be
312    /// persisted before the function exits.
313    ///
314    /// # Examples
315    ///
316    /// ```no_run
317    /// # async fn example(mut ctx: durable_lambda_core::context::DurableContext) -> Result<(), durable_lambda_core::error::DurableError> {
318    /// ctx.enable_batch_mode();
319    /// let _: Result<i32, String> = ctx.step("step1", || async { Ok(1) }).await?;
320    /// let _: Result<i32, String> = ctx.step("step2", || async { Ok(2) }).await?;
321    /// ctx.flush_batch().await?;  // sends all updates in one call
322    /// # Ok(())
323    /// # }
324    /// ```
325    pub fn enable_batch_mode(&mut self) {
326        self.batch_mode = true;
327    }
328
329    /// Return whether batch checkpoint mode is active.
330    pub fn is_batch_mode(&self) -> bool {
331        self.batch_mode
332    }
333
334    /// Accumulate an operation update for later batch flush.
335    ///
336    /// Called internally by step operations when batch mode is active.
337    pub(crate) fn push_pending_update(&mut self, update: OperationUpdate) {
338        self.pending_updates.push(update);
339    }
340
341    /// Return the number of pending (unflushed) updates.
342    pub fn pending_update_count(&self) -> usize {
343        self.pending_updates.len()
344    }
345
346    /// Return the number of registered compensation closures.
347    ///
348    /// Useful for asserting compensation registration in tests.
349    ///
350    /// # Examples
351    ///
352    /// ```no_run
353    /// # async fn example(ctx: durable_lambda_core::context::DurableContext) {
354    /// assert_eq!(ctx.compensation_count(), 0);
355    /// # }
356    /// ```
357    pub fn compensation_count(&self) -> usize {
358        self.compensations.len()
359    }
360
361    /// Register a compensation closure after a successful forward step.
362    ///
363    /// Called by `step_with_compensation` when the forward step succeeds.
364    pub(crate) fn push_compensation(&mut self, record: CompensationRecord) {
365        self.compensations.push(record);
366    }
367
368    /// Drain all registered compensations for execution.
369    ///
370    /// Returns the compensations vec (emptying the field) so `run_compensations`
371    /// can execute them. Reversing the returned vec gives LIFO order.
372    pub(crate) fn take_compensations(&mut self) -> Vec<CompensationRecord> {
373        std::mem::take(&mut self.compensations)
374    }
375
376    /// Flush all accumulated checkpoint updates in a single AWS API call.
377    ///
378    /// No-op if no updates are pending. After flushing, the checkpoint
379    /// token is updated from the response.
380    ///
381    /// # Errors
382    ///
383    /// Returns [`DurableError`] if the batch checkpoint call fails.
384    ///
385    /// # Examples
386    ///
387    /// ```no_run
388    /// # async fn example(mut ctx: durable_lambda_core::context::DurableContext) -> Result<(), durable_lambda_core::error::DurableError> {
389    /// ctx.enable_batch_mode();
390    /// // ... run several steps ...
391    /// ctx.flush_batch().await?;
392    /// # Ok(())
393    /// # }
394    /// ```
395    pub async fn flush_batch(&mut self) -> Result<(), DurableError> {
396        if self.pending_updates.is_empty() {
397            return Ok(());
398        }
399        let updates = std::mem::take(&mut self.pending_updates);
400        let response = self
401            .backend()
402            .batch_checkpoint(self.arn(), self.checkpoint_token(), updates, None)
403            .await?;
404        let new_token = response.checkpoint_token().ok_or_else(|| {
405            DurableError::checkpoint_failed(
406                "batch",
407                std::io::Error::new(
408                    std::io::ErrorKind::InvalidData,
409                    "batch checkpoint response missing checkpoint_token",
410                ),
411            )
412        })?;
413        self.set_checkpoint_token(new_token.to_string());
414        Ok(())
415    }
416}
417
418#[cfg(test)]
419mod tests {
420    use super::*;
421    use aws_sdk_lambda::operation::checkpoint_durable_execution::CheckpointDurableExecutionOutput;
422    use aws_sdk_lambda::operation::get_durable_execution_state::GetDurableExecutionStateOutput;
423    use aws_sdk_lambda::types::{OperationStatus, OperationType, OperationUpdate};
424    /// A simple mock backend for testing context construction.
425    struct TestBackend {
426        pages: Vec<(Vec<Operation>, Option<String>)>,
427    }
428
429    #[async_trait::async_trait]
430    impl DurableBackend for TestBackend {
431        async fn checkpoint(
432            &self,
433            _arn: &str,
434            _checkpoint_token: &str,
435            _updates: Vec<OperationUpdate>,
436            _client_token: Option<&str>,
437        ) -> Result<CheckpointDurableExecutionOutput, DurableError> {
438            unimplemented!("not needed for context tests")
439        }
440
441        async fn get_execution_state(
442            &self,
443            _arn: &str,
444            _checkpoint_token: &str,
445            next_marker: &str,
446            _max_items: i32,
447        ) -> Result<GetDurableExecutionStateOutput, DurableError> {
448            let page_idx: usize = next_marker.parse().unwrap_or(0);
449            if page_idx >= self.pages.len() {
450                return Ok(GetDurableExecutionStateOutput::builder().build().unwrap());
451            }
452            let (ops, marker) = &self.pages[page_idx];
453            let mut builder = GetDurableExecutionStateOutput::builder();
454            for op in ops {
455                builder = builder.operations(op.clone());
456            }
457            if let Some(m) = marker {
458                builder = builder.next_marker(m);
459            }
460            Ok(builder.build().unwrap())
461        }
462    }
463
464    fn make_op(id: &str, status: OperationStatus) -> Operation {
465        Operation::builder()
466            .id(id)
467            .r#type(OperationType::Step)
468            .status(status)
469            .start_timestamp(aws_smithy_types::DateTime::from_secs(0))
470            .build()
471            .unwrap()
472    }
473
474    #[tokio::test]
475    async fn empty_history_creates_executing_context() {
476        let backend = Arc::new(TestBackend { pages: vec![] });
477        let ctx = DurableContext::new(backend, "arn:test".into(), "tok".into(), vec![], None)
478            .await
479            .unwrap();
480
481        assert_eq!(ctx.execution_mode(), ExecutionMode::Executing);
482        assert!(!ctx.is_replaying());
483        assert_eq!(ctx.arn(), "arn:test");
484        assert_eq!(ctx.checkpoint_token(), "tok");
485    }
486
487    #[tokio::test]
488    async fn initial_operations_loaded() {
489        let backend = Arc::new(TestBackend { pages: vec![] });
490        let ops = vec![make_op("op1", OperationStatus::Succeeded)];
491        let ctx = DurableContext::new(backend, "arn:test".into(), "tok".into(), ops, None)
492            .await
493            .unwrap();
494
495        assert!(ctx.is_replaying());
496        assert!(ctx.replay_engine().check_result("op1").is_some());
497    }
498
499    #[tokio::test]
500    async fn pagination_loads_all_pages() {
501        let backend = Arc::new(TestBackend {
502            pages: vec![
503                (
504                    vec![make_op("op2", OperationStatus::Succeeded)],
505                    Some("1".to_string()),
506                ),
507                (vec![make_op("op3", OperationStatus::Succeeded)], None),
508            ],
509        });
510
511        let initial = vec![make_op("op1", OperationStatus::Succeeded)];
512        let ctx = DurableContext::new(
513            backend,
514            "arn:test".into(),
515            "tok".into(),
516            initial,
517            Some("0".to_string()),
518        )
519        .await
520        .unwrap();
521
522        assert!(ctx.replay_engine().check_result("op1").is_some());
523        assert!(ctx.replay_engine().check_result("op2").is_some());
524        assert!(ctx.replay_engine().check_result("op3").is_some());
525    }
526
527    #[tokio::test]
528    async fn set_checkpoint_token_updates() {
529        let backend = Arc::new(TestBackend { pages: vec![] });
530        let mut ctx = DurableContext::new(backend, "arn:test".into(), "tok1".into(), vec![], None)
531            .await
532            .unwrap();
533
534        assert_eq!(ctx.checkpoint_token(), "tok1");
535        ctx.set_checkpoint_token("tok2".to_string());
536        assert_eq!(ctx.checkpoint_token(), "tok2");
537    }
538
539    // --- compensation field tests ---
540
541    #[tokio::test]
542    async fn new_context_has_empty_compensations() {
543        let backend = Arc::new(TestBackend { pages: vec![] });
544        let ctx = DurableContext::new(backend, "arn:test".into(), "tok".into(), vec![], None)
545            .await
546            .unwrap();
547
548        assert_eq!(ctx.compensation_count(), 0);
549    }
550
551    #[tokio::test]
552    async fn create_child_context_has_empty_compensations_not_inherited() {
553        use crate::types::CompensationRecord;
554
555        let backend = Arc::new(TestBackend { pages: vec![] });
556        let mut ctx = DurableContext::new(backend, "arn:test".into(), "tok".into(), vec![], None)
557            .await
558            .unwrap();
559
560        // Register a compensation on the parent
561        let record = CompensationRecord {
562            name: "parent_comp".to_string(),
563            forward_result_json: serde_json::Value::Null,
564            compensate_fn: Box::new(|_| Box::pin(async { Ok(()) })),
565        };
566        ctx.push_compensation(record);
567        assert_eq!(ctx.compensation_count(), 1);
568
569        // Child context should start with 0 compensations
570        let child = ctx.create_child_context("some-op-id");
571        assert_eq!(
572            child.compensation_count(),
573            0,
574            "child context must NOT inherit parent compensations"
575        );
576    }
577}