Skip to main content

durable_execution_sdk/
runtime.rs

1//! Runtime support for the `#[durable_execution]` macro.
2//!
3//! This module contains the runtime logic that was previously generated inline
4//! by the proc macro. By extracting it into library code, we get:
5//!
6//! - Unit-testable runtime logic
7//! - A single copy in the binary (no per-handler duplication)
8//! - Readable error messages pointing to real source files
9//! - Bug fixes ship in the SDK crate, not the macro crate
10//!
11//! Users typically don't interact with this module directly — the
12//! `#[durable_execution]` macro calls [`run_durable_handler`] automatically.
13//! However, advanced users can call it directly to skip the macro.
14//!
15//! # Requirements
16//!
17//! - 15.1: THE Lambda_Integration SHALL provide a `#[durable_execution]` attribute macro for handler functions
18//! - 15.3: THE Lambda_Integration SHALL create ExecutionState and DurableContext for the handler
19
20use std::future::Future;
21use std::sync::Arc;
22
23use serde::de::DeserializeOwned;
24use serde::Serialize;
25
26use crate::client::{LambdaDurableServiceClient, SharedDurableServiceClient};
27use crate::context::DurableContext;
28use crate::error::{DurableError, ErrorObject};
29use crate::lambda::{DurableExecutionInvocationInput, DurableExecutionInvocationOutput};
30use crate::operation::OperationType;
31use crate::state::{CheckpointBatcherConfig, ExecutionState};
32use crate::termination::TerminationManager;
33
34/// SDK name for user-agent identification.
35const SDK_NAME: &str = "durable-execution-sdk-rust";
36
37/// SDK version for user-agent identification (from Cargo.toml).
38const SDK_VERSION: &str = env!("CARGO_PKG_VERSION");
39
40/// Maximum response payload size (6MB Lambda limit).
41const MAX_RESPONSE_SIZE: usize = 6 * 1024 * 1024;
42
43/// Queue buffer size for the checkpoint batcher.
44const CHECKPOINT_QUEUE_BUFFER: usize = 100;
45
46/// Timeout in seconds for waiting on the batcher to drain during cleanup.
47const BATCHER_DRAIN_TIMEOUT_SECS: u64 = 5;
48
49/// Extracts the user's event from a [`DurableExecutionInvocationInput`].
50///
51/// Tries these sources in order:
52/// 1. Top-level `Input` field (JSON value)
53/// 2. `ExecutionDetails.InputPayload` from the EXECUTION operation (JSON string)
54/// 3. `null` deserialization (for types with defaults, e.g. `Option<T>` or `()`)
55///
56/// # Errors
57///
58/// Returns a [`DurableExecutionInvocationOutput`] with `FAILED` status if
59/// deserialization fails from all sources.
60pub fn extract_event<E: DeserializeOwned>(
61    input: &DurableExecutionInvocationInput,
62) -> Result<E, DurableExecutionInvocationOutput> {
63    // Try top-level Input first
64    if let Some(value) = &input.input {
65        return serde_json::from_value(value.clone()).map_err(|e| {
66            DurableExecutionInvocationOutput::failed(ErrorObject::new(
67                "DeserializationError",
68                format!("Failed to deserialize event from Input: {}", e),
69            ))
70        });
71    }
72
73    // Try ExecutionDetails.InputPayload from the EXECUTION operation
74    let execution_op = input
75        .initial_execution_state
76        .operations
77        .iter()
78        .find(|op| op.operation_type == OperationType::Execution);
79
80    if let Some(op) = execution_op {
81        if let Some(details) = &op.execution_details {
82            if let Some(payload) = &details.input_payload {
83                return serde_json::from_str::<E>(payload).map_err(|e| {
84                    DurableExecutionInvocationOutput::failed(ErrorObject::new(
85                        "DeserializationError",
86                        format!(
87                            "Failed to deserialize event from ExecutionDetails.InputPayload: {}",
88                            e
89                        ),
90                    ))
91                });
92            }
93        }
94    }
95
96    // Fall back to null deserialization (supports Option<T>, (), etc.)
97    serde_json::from_value(serde_json::Value::Null).map_err(|_| {
98        DurableExecutionInvocationOutput::failed(ErrorObject::new(
99            "DeserializationError",
100            "No input provided and event type does not support default",
101        ))
102    })
103}
104
105/// Processes the handler result into a [`DurableExecutionInvocationOutput`].
106///
107/// Handles three cases:
108/// - `Ok(value)` → serialize, checkpoint if >6MB, return `SUCCEEDED`
109/// - `Err(Suspend)` → return `PENDING`
110/// - `Err(other)` → return `FAILED` with error details
111async fn process_result<R: Serialize>(
112    result: Result<R, DurableError>,
113    state: &Arc<ExecutionState>,
114    durable_execution_arn: &str,
115) -> DurableExecutionInvocationOutput {
116    match result {
117        Ok(value) => match serde_json::to_string(&value) {
118            Ok(json) => {
119                if json.len() > MAX_RESPONSE_SIZE {
120                    checkpoint_large_result(&json, state, durable_execution_arn).await
121                } else {
122                    DurableExecutionInvocationOutput::succeeded(Some(json))
123                }
124            }
125            Err(e) => DurableExecutionInvocationOutput::failed(ErrorObject::new(
126                "SerializationError",
127                format!("Failed to serialize result: {}", e),
128            )),
129        },
130        Err(DurableError::Suspend { .. }) => DurableExecutionInvocationOutput::pending(),
131        Err(error) => DurableExecutionInvocationOutput::failed(ErrorObject::from(&error)),
132    }
133}
134
135/// Checkpoints a large result (>6MB) and returns a reference to it.
136async fn checkpoint_large_result(
137    json: &str,
138    state: &Arc<ExecutionState>,
139    durable_execution_arn: &str,
140) -> DurableExecutionInvocationOutput {
141    let result_op_id = format!(
142        "__result__{}",
143        crate::replay_safe::uuid_string_from_operation(durable_execution_arn, 0)
144    );
145
146    let update = crate::operation::OperationUpdate::succeed(
147        &result_op_id,
148        OperationType::Execution,
149        Some(json.to_string()),
150    );
151
152    match state.create_checkpoint(update, true).await {
153        Ok(()) => DurableExecutionInvocationOutput::checkpointed_result(&result_op_id, json.len()),
154        Err(e) => DurableExecutionInvocationOutput::failed(ErrorObject::new(
155            "CheckpointError",
156            format!("Failed to checkpoint large result: {}", e),
157        )),
158    }
159}
160
161/// Runs a durable execution handler within the Lambda runtime.
162///
163/// This is the core runtime function that the `#[durable_execution]` macro delegates to.
164/// It handles the full lifecycle:
165///
166/// 1. Extract the user's event from the Lambda input
167/// 2. Set up `ExecutionState`, checkpoint batcher, and `DurableContext`
168/// 3. Call the user's handler
169/// 4. Process the result (serialize, checkpoint large results, map errors)
170/// 5. Clean up (drain batcher, drop state)
171///
172/// # Type Parameters
173///
174/// - `E`: The user's event type (must implement `DeserializeOwned`)
175/// - `R`: The user's result type (must implement `Serialize`)
176/// - `Fut`: The future returned by the handler
177/// - `F`: The handler function
178///
179/// # Example
180///
181/// ```rust,ignore
182/// use durable_execution_sdk::runtime::run_durable_handler;
183///
184/// // Called automatically by #[durable_execution], but can be used directly:
185/// pub async fn my_handler(
186///     event: LambdaEvent<DurableExecutionInvocationInput>,
187/// ) -> Result<DurableExecutionInvocationOutput, lambda_runtime::Error> {
188///     run_durable_handler(event, |event: MyEvent, ctx| async move {
189///         let result = ctx.step(|_| Ok(42), None).await?;
190///         Ok(MyResult { value: result })
191///     }).await
192/// }
193/// ```
194pub async fn run_durable_handler<E, R, Fut, F>(
195    lambda_event: lambda_runtime::LambdaEvent<DurableExecutionInvocationInput>,
196    handler: F,
197) -> Result<DurableExecutionInvocationOutput, lambda_runtime::Error>
198where
199    E: DeserializeOwned,
200    R: Serialize,
201    Fut: Future<Output = Result<R, DurableError>>,
202    F: FnOnce(E, DurableContext) -> Fut,
203{
204    let (durable_input, lambda_context) = lambda_event.into_parts();
205
206    // Extract the user's event
207    let user_event: E = match extract_event(&durable_input) {
208        Ok(event) => event,
209        Err(output) => return Ok(output),
210    };
211
212    // Create termination manager from Lambda context
213    let termination_mgr = TerminationManager::from_lambda_context(&lambda_context);
214
215    // Create the service client
216    let aws_config = aws_config::load_defaults(aws_config::BehaviorVersion::latest()).await;
217    let service_client: SharedDurableServiceClient =
218        Arc::new(LambdaDurableServiceClient::from_aws_config_with_user_agent(
219            &aws_config,
220            SDK_NAME,
221            SDK_VERSION,
222        )?);
223
224    // Create ExecutionState with batcher
225    let batcher_config = CheckpointBatcherConfig::default();
226    let (state, mut batcher) = ExecutionState::with_batcher(
227        &durable_input.durable_execution_arn,
228        &durable_input.checkpoint_token,
229        durable_input.initial_execution_state,
230        service_client,
231        batcher_config,
232        CHECKPOINT_QUEUE_BUFFER,
233    );
234    let state = Arc::new(state);
235
236    // Spawn the checkpoint batcher task
237    let batcher_handle = tokio::spawn(async move {
238        batcher.run().await;
239    });
240
241    // Create DurableContext and call the handler, racing against timeout
242    let durable_ctx = DurableContext::from_lambda_context(state.clone(), lambda_context);
243
244    let output = tokio::select! {
245        result = handler(user_event, durable_ctx) => {
246            // Handler completed normally (Req 5.3)
247            process_result(result, &state, &durable_input.durable_execution_arn).await
248        }
249        _ = termination_mgr.wait_for_timeout() => {
250            // Timeout approaching — flush pending checkpoints and return PENDING (Req 5.2)
251            DurableExecutionInvocationOutput::pending()
252        }
253    };
254
255    // Drop the state to close the checkpoint queue and stop the batcher
256    drop(state);
257
258    // Wait for batcher to finish (with timeout)
259    let _ = tokio::time::timeout(
260        std::time::Duration::from_secs(BATCHER_DRAIN_TIMEOUT_SECS),
261        batcher_handle,
262    )
263    .await;
264
265    Ok(output)
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271    use crate::lambda::InitialExecutionState;
272    use crate::operation::{ExecutionDetails, Operation};
273    use serde::Deserialize;
274
275    #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
276    struct TestEvent {
277        order_id: String,
278        amount: f64,
279    }
280
281    fn make_input(
282        input: Option<serde_json::Value>,
283        operations: Vec<Operation>,
284    ) -> DurableExecutionInvocationInput {
285        DurableExecutionInvocationInput {
286            durable_execution_arn:
287                "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc".to_string(),
288            checkpoint_token: "token".to_string(),
289            initial_execution_state: InitialExecutionState {
290                operations,
291                next_marker: None,
292            },
293            input,
294        }
295    }
296
297    // ========================================================================
298    // extract_event tests
299    // ========================================================================
300
301    #[test]
302    fn test_extract_event_from_top_level_input() {
303        let input = make_input(
304            Some(serde_json::json!({"order_id": "ORD-1", "amount": 99.99})),
305            vec![],
306        );
307        let event: TestEvent = extract_event(&input).unwrap();
308        assert_eq!(event.order_id, "ORD-1");
309        assert_eq!(event.amount, 99.99);
310    }
311
312    #[test]
313    fn test_extract_event_from_execution_details_payload() {
314        let mut op = Operation::new("exec-1", OperationType::Execution);
315        op.execution_details = Some(ExecutionDetails {
316            input_payload: Some(r#"{"order_id":"ORD-2","amount":50.0}"#.to_string()),
317        });
318        let input = make_input(None, vec![op]);
319        let event: TestEvent = extract_event(&input).unwrap();
320        assert_eq!(event.order_id, "ORD-2");
321        assert_eq!(event.amount, 50.0);
322    }
323
324    #[test]
325    fn test_extract_event_falls_back_to_null_for_option() {
326        let input = make_input(None, vec![]);
327        let event: Option<TestEvent> = extract_event(&input).unwrap();
328        assert!(event.is_none());
329    }
330
331    #[test]
332    fn test_extract_event_fails_when_no_input_and_type_requires_fields() {
333        let input = make_input(None, vec![]);
334        let result: Result<TestEvent, _> = extract_event(&input);
335        assert!(result.is_err());
336        let output = result.unwrap_err();
337        assert!(output.is_failed());
338        assert!(output
339            .error
340            .unwrap()
341            .error_message
342            .contains("does not support default"));
343    }
344
345    #[test]
346    fn test_extract_event_top_level_input_takes_priority() {
347        let mut op = Operation::new("exec-1", OperationType::Execution);
348        op.execution_details = Some(ExecutionDetails {
349            input_payload: Some(r#"{"order_id":"FROM-PAYLOAD","amount":1.0}"#.to_string()),
350        });
351        let input = make_input(
352            Some(serde_json::json!({"order_id": "FROM-INPUT", "amount": 2.0})),
353            vec![op],
354        );
355        let event: TestEvent = extract_event(&input).unwrap();
356        assert_eq!(event.order_id, "FROM-INPUT");
357    }
358
359    #[test]
360    fn test_extract_event_bad_top_level_input_returns_error() {
361        let input = make_input(Some(serde_json::json!({"wrong_field": true})), vec![]);
362        let result: Result<TestEvent, _> = extract_event(&input);
363        assert!(result.is_err());
364        let output = result.unwrap_err();
365        assert!(output.is_failed());
366        assert!(output
367            .error
368            .unwrap()
369            .error_message
370            .contains("Failed to deserialize event from Input"));
371    }
372
373    #[test]
374    fn test_extract_event_bad_payload_returns_error() {
375        let mut op = Operation::new("exec-1", OperationType::Execution);
376        op.execution_details = Some(ExecutionDetails {
377            input_payload: Some("not valid json".to_string()),
378        });
379        let input = make_input(None, vec![op]);
380        let result: Result<TestEvent, _> = extract_event(&input);
381        assert!(result.is_err());
382        let output = result.unwrap_err();
383        assert!(output
384            .error
385            .unwrap()
386            .error_message
387            .contains("ExecutionDetails.InputPayload"));
388    }
389
390    #[test]
391    fn test_extract_event_execution_op_without_details_falls_back() {
392        let op = Operation::new("exec-1", OperationType::Execution);
393        // No execution_details set
394        let input = make_input(None, vec![op]);
395        let event: Option<TestEvent> = extract_event(&input).unwrap();
396        assert!(event.is_none());
397    }
398
399    #[test]
400    fn test_extract_event_execution_op_without_payload_falls_back() {
401        let mut op = Operation::new("exec-1", OperationType::Execution);
402        op.execution_details = Some(ExecutionDetails {
403            input_payload: None,
404        });
405        let input = make_input(None, vec![op]);
406        let event: Option<TestEvent> = extract_event(&input).unwrap();
407        assert!(event.is_none());
408    }
409
410    // ========================================================================
411    // process_result tests
412    // ========================================================================
413
414    #[tokio::test]
415    async fn test_process_result_success() {
416        let client = Arc::new(crate::client::MockDurableServiceClient::new());
417        let state = Arc::new(ExecutionState::new(
418            "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc",
419            "token",
420            InitialExecutionState::new(),
421            client,
422        ));
423        let output = process_result(Ok("hello"), &state, "test-arn").await;
424        assert!(output.is_succeeded());
425        assert_eq!(output.result.unwrap(), "\"hello\"");
426    }
427
428    #[tokio::test]
429    async fn test_process_result_suspend() {
430        let client = Arc::new(crate::client::MockDurableServiceClient::new());
431        let state = Arc::new(ExecutionState::new(
432            "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc",
433            "token",
434            InitialExecutionState::new(),
435            client,
436        ));
437        let result: Result<String, DurableError> = Err(DurableError::suspend());
438        let output = process_result(result, &state, "test-arn").await;
439        assert!(output.is_pending());
440    }
441
442    #[tokio::test]
443    async fn test_process_result_error() {
444        let client = Arc::new(crate::client::MockDurableServiceClient::new());
445        let state = Arc::new(ExecutionState::new(
446            "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc",
447            "token",
448            InitialExecutionState::new(),
449            client,
450        ));
451        let result: Result<String, DurableError> = Err(DurableError::execution("something broke"));
452        let output = process_result(result, &state, "test-arn").await;
453        assert!(output.is_failed());
454        assert!(output
455            .error
456            .unwrap()
457            .error_message
458            .contains("something broke"));
459    }
460}