Skip to main content

durable_lambda_testing/
mock_backend.rs

1//! MockBackend — implements DurableBackend without AWS dependency.
2//!
3//! Returns pre-loaded data for testing. No network calls, no credentials needed.
4//! Records all checkpoint calls for test assertions.
5
6use std::sync::Arc;
7
8use aws_sdk_lambda::operation::checkpoint_durable_execution::CheckpointDurableExecutionOutput;
9use aws_sdk_lambda::operation::get_durable_execution_state::GetDurableExecutionStateOutput;
10use aws_sdk_lambda::types::OperationUpdate;
11use durable_lambda_core::backend::DurableBackend;
12use durable_lambda_core::error::DurableError;
13use tokio::sync::Mutex;
14
15/// A recorded durable operation for sequence verification (FR39).
16///
17/// Each time the handler starts a new operation (step, wait, callback, etc.),
18/// an `OperationRecord` is captured with the operation name and type.
19///
20/// # Examples
21///
22/// ```no_run
23/// # async fn example() {
24/// use durable_lambda_testing::prelude::*;
25///
26/// let (mut ctx, calls, ops) = MockDurableContext::new()
27///     .build()
28///     .await;
29///
30/// let _: Result<i32, String> = ctx.step("validate", || async { Ok(42) }).await.unwrap();
31///
32/// let recorded = ops.lock().await;
33/// assert_eq!(recorded[0].name, "validate");
34/// assert_eq!(recorded[0].operation_type, "step");
35/// # }
36/// ```
37#[derive(Debug, Clone, PartialEq, Eq)]
38pub struct OperationRecord {
39    /// The user-provided operation name (e.g., "validate", "cooldown").
40    pub name: String,
41    /// The operation type as a lowercase string (e.g., "step", "wait", "callback").
42    pub operation_type: String,
43}
44
45impl OperationRecord {
46    /// Format as `"type:name"` for use with assertion helpers.
47    ///
48    /// # Examples
49    ///
50    /// ```
51    /// use durable_lambda_testing::mock_backend::OperationRecord;
52    ///
53    /// let record = OperationRecord {
54    ///     name: "validate".to_string(),
55    ///     operation_type: "step".to_string(),
56    /// };
57    /// assert_eq!(record.to_type_name(), "step:validate");
58    /// ```
59    pub fn to_type_name(&self) -> String {
60        format!("{}:{}", self.operation_type, self.name)
61    }
62}
63
64impl std::fmt::Display for OperationRecord {
65    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66        write!(f, "{}:{}", self.operation_type, self.name)
67    }
68}
69
70/// A captured checkpoint call for test assertions.
71///
72/// Each time the handler checkpoints an operation (START, SUCCEED, FAIL, RETRY),
73/// a `CheckpointCall` is recorded with the full details.
74///
75/// # Examples
76///
77/// ```no_run
78/// # async fn example() {
79/// use durable_lambda_testing::prelude::*;
80///
81/// let (mut ctx, calls, _ops) = MockDurableContext::new()
82///     .build()
83///     .await;
84///
85/// // ... run handler ...
86///
87/// let captured = calls.lock().await;
88/// assert_eq!(captured.len(), 2); // START + SUCCEED
89/// # }
90/// ```
91#[derive(Debug, Clone)]
92pub struct CheckpointCall {
93    /// The durable execution ARN passed to checkpoint.
94    pub arn: String,
95    /// The checkpoint token passed to checkpoint.
96    pub checkpoint_token: String,
97    /// The operation updates sent in this checkpoint.
98    pub updates: Vec<OperationUpdate>,
99}
100
101/// Mock implementation of [`DurableBackend`] for testing.
102///
103/// Records all checkpoint calls and returns configurable responses.
104/// Never makes AWS API calls — pure in-memory mock.
105///
106/// Typically created via [`MockDurableContext`](crate::mock_context::MockDurableContext)
107/// rather than directly.
108///
109/// # Examples
110///
111/// ```no_run
112/// # async fn example() {
113/// use durable_lambda_testing::mock_backend::MockBackend;
114/// use std::sync::Arc;
115///
116/// let (backend, calls, _ops) = MockBackend::new("mock-token");
117/// let backend = Arc::new(backend);
118/// // Use with DurableContext::new(backend, ...)
119/// # }
120/// ```
121/// Shared recorder for checkpoint calls.
122///
123/// Use this handle to inspect all checkpoint API calls made during a test.
124///
125/// # Examples
126///
127/// ```no_run
128/// # async fn example() {
129/// use durable_lambda_testing::prelude::*;
130///
131/// let (mut ctx, calls, _ops) = MockDurableContext::new().build().await;
132/// // ... run handler operations ...
133/// let captured: Vec<_> = calls.lock().await.clone();
134/// assert!(!captured.is_empty());
135/// # }
136/// ```
137pub type CheckpointRecorder = Arc<Mutex<Vec<CheckpointCall>>>;
138
139/// Shared recorder for operation sequence tracking.
140///
141/// Use this handle to inspect the sequence of durable operations
142/// started during a test.
143///
144/// # Examples
145///
146/// ```no_run
147/// # async fn example() {
148/// use durable_lambda_testing::prelude::*;
149///
150/// let (mut ctx, _calls, ops) = MockDurableContext::new().build().await;
151/// // ... run handler operations ...
152/// let recorded: Vec<_> = ops.lock().await.clone();
153/// assert!(!recorded.is_empty());
154/// # }
155/// ```
156pub type OperationRecorder = Arc<Mutex<Vec<OperationRecord>>>;
157
158/// Shared counter for batch checkpoint calls.
159pub type BatchCallCounter = Arc<Mutex<usize>>;
160
161pub struct MockBackend {
162    calls: CheckpointRecorder,
163    operations: OperationRecorder,
164    checkpoint_token: String,
165    batch_call_count: BatchCallCounter,
166}
167
168impl MockBackend {
169    /// Create a new `MockBackend` with the given checkpoint token.
170    ///
171    /// Returns the backend, a checkpoint call recorder, and an operation
172    /// sequence recorder for test assertions.
173    ///
174    /// # Examples
175    ///
176    /// ```no_run
177    /// # async fn example() {
178    /// use durable_lambda_testing::mock_backend::MockBackend;
179    ///
180    /// let (backend, calls, _ops) = MockBackend::new("token-123");
181    /// // calls can be inspected after running the handler
182    /// # }
183    /// ```
184    pub fn new(checkpoint_token: &str) -> (Self, CheckpointRecorder, OperationRecorder) {
185        let calls = Arc::new(Mutex::new(Vec::new()));
186        let operations = Arc::new(Mutex::new(Vec::new()));
187        let backend = Self {
188            calls: calls.clone(),
189            operations: operations.clone(),
190            checkpoint_token: checkpoint_token.to_string(),
191            batch_call_count: Arc::new(Mutex::new(0)),
192        };
193        (backend, calls, operations)
194    }
195
196    /// Return the batch checkpoint call counter for test assertions.
197    pub fn batch_call_counter(&self) -> BatchCallCounter {
198        self.batch_call_count.clone()
199    }
200}
201
202#[async_trait::async_trait]
203impl DurableBackend for MockBackend {
204    async fn checkpoint(
205        &self,
206        arn: &str,
207        checkpoint_token: &str,
208        updates: Vec<OperationUpdate>,
209        _client_token: Option<&str>,
210    ) -> Result<CheckpointDurableExecutionOutput, DurableError> {
211        // Record operation sequence from START actions (one per logical operation).
212        for update in &updates {
213            if update.action() == &aws_sdk_lambda::types::OperationAction::Start {
214                let op_type = match update.r#type() {
215                    aws_sdk_lambda::types::OperationType::Step => "step",
216                    aws_sdk_lambda::types::OperationType::Wait => "wait",
217                    aws_sdk_lambda::types::OperationType::Callback => "callback",
218                    aws_sdk_lambda::types::OperationType::ChainedInvoke => "invoke",
219                    _ => "unknown",
220                };
221                let name = update.name().unwrap_or("").to_string();
222                self.operations.lock().await.push(OperationRecord {
223                    name,
224                    operation_type: op_type.to_string(),
225                });
226            }
227        }
228
229        self.calls.lock().await.push(CheckpointCall {
230            arn: arn.to_string(),
231            checkpoint_token: checkpoint_token.to_string(),
232            updates,
233        });
234        Ok(CheckpointDurableExecutionOutput::builder()
235            .checkpoint_token(&self.checkpoint_token)
236            .build())
237    }
238
239    async fn batch_checkpoint(
240        &self,
241        arn: &str,
242        checkpoint_token: &str,
243        updates: Vec<OperationUpdate>,
244        _client_token: Option<&str>,
245    ) -> Result<
246        aws_sdk_lambda::operation::checkpoint_durable_execution::CheckpointDurableExecutionOutput,
247        DurableError,
248    > {
249        *self.batch_call_count.lock().await += 1;
250        // Record individual operations for sequence tracking (same as checkpoint).
251        for update in &updates {
252            if update.action() == &aws_sdk_lambda::types::OperationAction::Start {
253                let op_type = match update.r#type() {
254                    aws_sdk_lambda::types::OperationType::Step => "step",
255                    aws_sdk_lambda::types::OperationType::Wait => "wait",
256                    aws_sdk_lambda::types::OperationType::Callback => "callback",
257                    aws_sdk_lambda::types::OperationType::ChainedInvoke => "invoke",
258                    _ => "unknown",
259                };
260                let name = update.name().unwrap_or("").to_string();
261                self.operations.lock().await.push(OperationRecord {
262                    name,
263                    operation_type: op_type.to_string(),
264                });
265            }
266        }
267        self.calls.lock().await.push(CheckpointCall {
268            arn: arn.to_string(),
269            checkpoint_token: checkpoint_token.to_string(),
270            updates,
271        });
272        Ok(
273            aws_sdk_lambda::operation::checkpoint_durable_execution::CheckpointDurableExecutionOutput::builder()
274                .checkpoint_token(&self.checkpoint_token)
275                .build(),
276        )
277    }
278
279    async fn get_execution_state(
280        &self,
281        _arn: &str,
282        _checkpoint_token: &str,
283        _next_marker: &str,
284        _max_items: i32,
285    ) -> Result<GetDurableExecutionStateOutput, DurableError> {
286        Ok(GetDurableExecutionStateOutput::builder()
287            .build()
288            .expect("empty execution state"))
289    }
290}