1use std::future::Future;
21use std::sync::Arc;
22
23use serde::de::DeserializeOwned;
24use serde::Serialize;
25
26use crate::client::{LambdaDurableServiceClient, SharedDurableServiceClient};
27use crate::context::DurableContext;
28use crate::error::{DurableError, ErrorObject};
29use crate::lambda::{DurableExecutionInvocationInput, DurableExecutionInvocationOutput};
30use crate::operation::OperationType;
31use crate::state::{CheckpointBatcherConfig, ExecutionState};
32use crate::termination::TerminationManager;
33
34const SDK_NAME: &str = "durable-execution-sdk-rust";
36
37const SDK_VERSION: &str = env!("CARGO_PKG_VERSION");
39
40const MAX_RESPONSE_SIZE: usize = 6 * 1024 * 1024;
42
43const CHECKPOINT_QUEUE_BUFFER: usize = 100;
45
46const BATCHER_DRAIN_TIMEOUT_SECS: u64 = 5;
48
49pub fn extract_event<E: DeserializeOwned>(
61 input: &DurableExecutionInvocationInput,
62) -> Result<E, DurableExecutionInvocationOutput> {
63 if let Some(value) = &input.input {
65 return serde_json::from_value(value.clone()).map_err(|e| {
66 DurableExecutionInvocationOutput::failed(ErrorObject::new(
67 "DeserializationError",
68 format!("Failed to deserialize event from Input: {}", e),
69 ))
70 });
71 }
72
73 let execution_op = input
75 .initial_execution_state
76 .operations
77 .iter()
78 .find(|op| op.operation_type == OperationType::Execution);
79
80 if let Some(op) = execution_op {
81 if let Some(details) = &op.execution_details {
82 if let Some(payload) = &details.input_payload {
83 return serde_json::from_str::<E>(payload).map_err(|e| {
84 DurableExecutionInvocationOutput::failed(ErrorObject::new(
85 "DeserializationError",
86 format!(
87 "Failed to deserialize event from ExecutionDetails.InputPayload: {}",
88 e
89 ),
90 ))
91 });
92 }
93 }
94 }
95
96 serde_json::from_value(serde_json::Value::Null).map_err(|_| {
98 DurableExecutionInvocationOutput::failed(ErrorObject::new(
99 "DeserializationError",
100 "No input provided and event type does not support default",
101 ))
102 })
103}
104
105async fn process_result<R: Serialize>(
112 result: Result<R, DurableError>,
113 state: &Arc<ExecutionState>,
114 durable_execution_arn: &str,
115) -> DurableExecutionInvocationOutput {
116 match result {
117 Ok(value) => match serde_json::to_string(&value) {
118 Ok(json) => {
119 if json.len() > MAX_RESPONSE_SIZE {
120 checkpoint_large_result(&json, state, durable_execution_arn).await
121 } else {
122 DurableExecutionInvocationOutput::succeeded(Some(json))
123 }
124 }
125 Err(e) => DurableExecutionInvocationOutput::failed(ErrorObject::new(
126 "SerializationError",
127 format!("Failed to serialize result: {}", e),
128 )),
129 },
130 Err(DurableError::Suspend { .. }) => DurableExecutionInvocationOutput::pending(),
131 Err(error) => DurableExecutionInvocationOutput::failed(ErrorObject::from(&error)),
132 }
133}
134
135async fn checkpoint_large_result(
137 json: &str,
138 state: &Arc<ExecutionState>,
139 durable_execution_arn: &str,
140) -> DurableExecutionInvocationOutput {
141 let result_op_id = format!(
142 "__result__{}",
143 crate::replay_safe::uuid_string_from_operation(durable_execution_arn, 0)
144 );
145
146 let update = crate::operation::OperationUpdate::succeed(
147 &result_op_id,
148 OperationType::Execution,
149 Some(json.to_string()),
150 );
151
152 match state.create_checkpoint(update, true).await {
153 Ok(()) => DurableExecutionInvocationOutput::checkpointed_result(&result_op_id, json.len()),
154 Err(e) => DurableExecutionInvocationOutput::failed(ErrorObject::new(
155 "CheckpointError",
156 format!("Failed to checkpoint large result: {}", e),
157 )),
158 }
159}
160
161pub async fn run_durable_handler<E, R, Fut, F>(
195 lambda_event: lambda_runtime::LambdaEvent<DurableExecutionInvocationInput>,
196 handler: F,
197) -> Result<DurableExecutionInvocationOutput, lambda_runtime::Error>
198where
199 E: DeserializeOwned,
200 R: Serialize,
201 Fut: Future<Output = Result<R, DurableError>>,
202 F: FnOnce(E, DurableContext) -> Fut,
203{
204 let (durable_input, lambda_context) = lambda_event.into_parts();
205
206 let user_event: E = match extract_event(&durable_input) {
208 Ok(event) => event,
209 Err(output) => return Ok(output),
210 };
211
212 let termination_mgr = TerminationManager::from_lambda_context(&lambda_context);
214
215 let aws_config = aws_config::load_defaults(aws_config::BehaviorVersion::latest()).await;
217 let service_client: SharedDurableServiceClient =
218 Arc::new(LambdaDurableServiceClient::from_aws_config_with_user_agent(
219 &aws_config,
220 SDK_NAME,
221 SDK_VERSION,
222 )?);
223
224 let batcher_config = CheckpointBatcherConfig::default();
226 let (state, mut batcher) = ExecutionState::with_batcher(
227 &durable_input.durable_execution_arn,
228 &durable_input.checkpoint_token,
229 durable_input.initial_execution_state,
230 service_client,
231 batcher_config,
232 CHECKPOINT_QUEUE_BUFFER,
233 );
234 let state = Arc::new(state);
235
236 let batcher_handle = tokio::spawn(async move {
238 batcher.run().await;
239 });
240
241 let durable_ctx = DurableContext::from_lambda_context(state.clone(), lambda_context);
243
244 let output = tokio::select! {
245 result = handler(user_event, durable_ctx) => {
246 process_result(result, &state, &durable_input.durable_execution_arn).await
248 }
249 _ = termination_mgr.wait_for_timeout() => {
250 DurableExecutionInvocationOutput::pending()
252 }
253 };
254
255 drop(state);
257
258 let _ = tokio::time::timeout(
260 std::time::Duration::from_secs(BATCHER_DRAIN_TIMEOUT_SECS),
261 batcher_handle,
262 )
263 .await;
264
265 Ok(output)
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271 use crate::lambda::InitialExecutionState;
272 use crate::operation::{ExecutionDetails, Operation};
273 use serde::Deserialize;
274
275 #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
276 struct TestEvent {
277 order_id: String,
278 amount: f64,
279 }
280
281 fn make_input(
282 input: Option<serde_json::Value>,
283 operations: Vec<Operation>,
284 ) -> DurableExecutionInvocationInput {
285 DurableExecutionInvocationInput {
286 durable_execution_arn:
287 "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc".to_string(),
288 checkpoint_token: "token".to_string(),
289 initial_execution_state: InitialExecutionState {
290 operations,
291 next_marker: None,
292 },
293 input,
294 }
295 }
296
297 #[test]
302 fn test_extract_event_from_top_level_input() {
303 let input = make_input(
304 Some(serde_json::json!({"order_id": "ORD-1", "amount": 99.99})),
305 vec![],
306 );
307 let event: TestEvent = extract_event(&input).unwrap();
308 assert_eq!(event.order_id, "ORD-1");
309 assert_eq!(event.amount, 99.99);
310 }
311
312 #[test]
313 fn test_extract_event_from_execution_details_payload() {
314 let mut op = Operation::new("exec-1", OperationType::Execution);
315 op.execution_details = Some(ExecutionDetails {
316 input_payload: Some(r#"{"order_id":"ORD-2","amount":50.0}"#.to_string()),
317 });
318 let input = make_input(None, vec![op]);
319 let event: TestEvent = extract_event(&input).unwrap();
320 assert_eq!(event.order_id, "ORD-2");
321 assert_eq!(event.amount, 50.0);
322 }
323
324 #[test]
325 fn test_extract_event_falls_back_to_null_for_option() {
326 let input = make_input(None, vec![]);
327 let event: Option<TestEvent> = extract_event(&input).unwrap();
328 assert!(event.is_none());
329 }
330
331 #[test]
332 fn test_extract_event_fails_when_no_input_and_type_requires_fields() {
333 let input = make_input(None, vec![]);
334 let result: Result<TestEvent, _> = extract_event(&input);
335 assert!(result.is_err());
336 let output = result.unwrap_err();
337 assert!(output.is_failed());
338 assert!(output
339 .error
340 .unwrap()
341 .error_message
342 .contains("does not support default"));
343 }
344
345 #[test]
346 fn test_extract_event_top_level_input_takes_priority() {
347 let mut op = Operation::new("exec-1", OperationType::Execution);
348 op.execution_details = Some(ExecutionDetails {
349 input_payload: Some(r#"{"order_id":"FROM-PAYLOAD","amount":1.0}"#.to_string()),
350 });
351 let input = make_input(
352 Some(serde_json::json!({"order_id": "FROM-INPUT", "amount": 2.0})),
353 vec![op],
354 );
355 let event: TestEvent = extract_event(&input).unwrap();
356 assert_eq!(event.order_id, "FROM-INPUT");
357 }
358
359 #[test]
360 fn test_extract_event_bad_top_level_input_returns_error() {
361 let input = make_input(Some(serde_json::json!({"wrong_field": true})), vec![]);
362 let result: Result<TestEvent, _> = extract_event(&input);
363 assert!(result.is_err());
364 let output = result.unwrap_err();
365 assert!(output.is_failed());
366 assert!(output
367 .error
368 .unwrap()
369 .error_message
370 .contains("Failed to deserialize event from Input"));
371 }
372
373 #[test]
374 fn test_extract_event_bad_payload_returns_error() {
375 let mut op = Operation::new("exec-1", OperationType::Execution);
376 op.execution_details = Some(ExecutionDetails {
377 input_payload: Some("not valid json".to_string()),
378 });
379 let input = make_input(None, vec![op]);
380 let result: Result<TestEvent, _> = extract_event(&input);
381 assert!(result.is_err());
382 let output = result.unwrap_err();
383 assert!(output
384 .error
385 .unwrap()
386 .error_message
387 .contains("ExecutionDetails.InputPayload"));
388 }
389
390 #[test]
391 fn test_extract_event_execution_op_without_details_falls_back() {
392 let op = Operation::new("exec-1", OperationType::Execution);
393 let input = make_input(None, vec![op]);
395 let event: Option<TestEvent> = extract_event(&input).unwrap();
396 assert!(event.is_none());
397 }
398
399 #[test]
400 fn test_extract_event_execution_op_without_payload_falls_back() {
401 let mut op = Operation::new("exec-1", OperationType::Execution);
402 op.execution_details = Some(ExecutionDetails {
403 input_payload: None,
404 });
405 let input = make_input(None, vec![op]);
406 let event: Option<TestEvent> = extract_event(&input).unwrap();
407 assert!(event.is_none());
408 }
409
410 #[tokio::test]
415 async fn test_process_result_success() {
416 let client = Arc::new(crate::client::MockDurableServiceClient::new());
417 let state = Arc::new(ExecutionState::new(
418 "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc",
419 "token",
420 InitialExecutionState::new(),
421 client,
422 ));
423 let output = process_result(Ok("hello"), &state, "test-arn").await;
424 assert!(output.is_succeeded());
425 assert_eq!(output.result.unwrap(), "\"hello\"");
426 }
427
428 #[tokio::test]
429 async fn test_process_result_suspend() {
430 let client = Arc::new(crate::client::MockDurableServiceClient::new());
431 let state = Arc::new(ExecutionState::new(
432 "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc",
433 "token",
434 InitialExecutionState::new(),
435 client,
436 ));
437 let result: Result<String, DurableError> = Err(DurableError::suspend());
438 let output = process_result(result, &state, "test-arn").await;
439 assert!(output.is_pending());
440 }
441
442 #[tokio::test]
443 async fn test_process_result_error() {
444 let client = Arc::new(crate::client::MockDurableServiceClient::new());
445 let state = Arc::new(ExecutionState::new(
446 "arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc",
447 "token",
448 InitialExecutionState::new(),
449 client,
450 ));
451 let result: Result<String, DurableError> = Err(DurableError::execution("something broke"));
452 let output = process_result(result, &state, "test-arn").await;
453 assert!(output.is_failed());
454 assert!(output
455 .error
456 .unwrap()
457 .error_message
458 .contains("something broke"));
459 }
460}