1use std::collections::VecDeque;
26use std::sync::Mutex;
27
28use async_trait::async_trait;
29
30use crate::{
31 CheckpointResponse, DurableError, DurableServiceClient, GetOperationsResponse, Operation,
32 OperationUpdate,
33};
34
35#[derive(Debug, Clone)]
64pub struct CheckpointCall {
65 pub durable_execution_arn: String,
67 pub checkpoint_token: String,
69 pub operations: Vec<OperationUpdate>,
71}
72
73impl CheckpointCall {
74 pub fn new(
76 durable_execution_arn: impl Into<String>,
77 checkpoint_token: impl Into<String>,
78 operations: Vec<OperationUpdate>,
79 ) -> Self {
80 Self {
81 durable_execution_arn: durable_execution_arn.into(),
82 checkpoint_token: checkpoint_token.into(),
83 operations,
84 }
85 }
86}
87
88#[derive(Debug, Clone)]
93pub struct GetOperationsCall {
94 pub durable_execution_arn: String,
96 pub next_marker: String,
98}
99
100impl GetOperationsCall {
101 pub fn new(durable_execution_arn: impl Into<String>, next_marker: impl Into<String>) -> Self {
103 Self {
104 durable_execution_arn: durable_execution_arn.into(),
105 next_marker: next_marker.into(),
106 }
107 }
108}
109
110pub struct MockDurableServiceClient {
159 checkpoint_responses: Mutex<VecDeque<Result<CheckpointResponse, DurableError>>>,
161 get_operations_responses: Mutex<VecDeque<Result<GetOperationsResponse, DurableError>>>,
163 checkpoint_calls: Mutex<Vec<CheckpointCall>>,
165 get_operations_calls: Mutex<Vec<GetOperationsCall>>,
167}
168
169impl MockDurableServiceClient {
170 pub fn new() -> Self {
175 Self {
176 checkpoint_responses: Mutex::new(VecDeque::new()),
177 get_operations_responses: Mutex::new(VecDeque::new()),
178 checkpoint_calls: Mutex::new(Vec::new()),
179 get_operations_calls: Mutex::new(Vec::new()),
180 }
181 }
182
183 pub fn with_checkpoint_response(
199 self,
200 response: Result<CheckpointResponse, DurableError>,
201 ) -> Self {
202 self.checkpoint_responses
203 .lock()
204 .unwrap()
205 .push_back(response);
206 self
207 }
208
209 pub fn with_checkpoint_responses(self, count: usize) -> Self {
223 let mut responses = self.checkpoint_responses.lock().unwrap();
224 for i in 0..count {
225 responses.push_back(Ok(CheckpointResponse::new(format!("token-{}", i))));
226 }
227 drop(responses);
228 self
229 }
230
231 pub fn with_checkpoint_response_with_operations(
247 self,
248 token: impl Into<String>,
249 operations: Vec<Operation>,
250 ) -> Self {
251 use durable_execution_sdk::client::NewExecutionState;
252
253 let response = CheckpointResponse {
254 checkpoint_token: token.into(),
255 new_execution_state: Some(NewExecutionState {
256 operations,
257 next_marker: None,
258 }),
259 };
260 self.checkpoint_responses
261 .lock()
262 .unwrap()
263 .push_back(Ok(response));
264 self
265 }
266
267 pub fn with_get_operations_response(
286 self,
287 response: Result<GetOperationsResponse, DurableError>,
288 ) -> Self {
289 self.get_operations_responses
290 .lock()
291 .unwrap()
292 .push_back(response);
293 self
294 }
295
296 pub fn get_checkpoint_calls(&self) -> Vec<CheckpointCall> {
316 self.checkpoint_calls.lock().unwrap().clone()
317 }
318
319 pub fn get_get_operations_calls(&self) -> Vec<GetOperationsCall> {
324 self.get_operations_calls.lock().unwrap().clone()
325 }
326
327 pub fn clear_checkpoint_calls(&self) {
331 self.checkpoint_calls.lock().unwrap().clear();
332 }
333
334 pub fn clear_get_operations_calls(&self) {
336 self.get_operations_calls.lock().unwrap().clear();
337 }
338
339 pub fn clear_all_calls(&self) {
341 self.clear_checkpoint_calls();
342 self.clear_get_operations_calls();
343 }
344
345 pub fn checkpoint_call_count(&self) -> usize {
347 self.checkpoint_calls.lock().unwrap().len()
348 }
349
350 pub fn get_operations_call_count(&self) -> usize {
352 self.get_operations_calls.lock().unwrap().len()
353 }
354}
355
356impl Default for MockDurableServiceClient {
357 fn default() -> Self {
358 Self::new()
359 }
360}
361
362impl std::fmt::Debug for MockDurableServiceClient {
363 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
364 f.debug_struct("MockDurableServiceClient")
365 .field(
366 "checkpoint_responses_remaining",
367 &self.checkpoint_responses.lock().unwrap().len(),
368 )
369 .field(
370 "get_operations_responses_remaining",
371 &self.get_operations_responses.lock().unwrap().len(),
372 )
373 .field(
374 "checkpoint_calls_count",
375 &self.checkpoint_calls.lock().unwrap().len(),
376 )
377 .field(
378 "get_operations_calls_count",
379 &self.get_operations_calls.lock().unwrap().len(),
380 )
381 .finish()
382 }
383}
384
385#[async_trait]
386impl DurableServiceClient for MockDurableServiceClient {
387 async fn checkpoint(
388 &self,
389 durable_execution_arn: &str,
390 checkpoint_token: &str,
391 operations: Vec<OperationUpdate>,
392 ) -> Result<CheckpointResponse, DurableError> {
393 self.checkpoint_calls
395 .lock()
396 .unwrap()
397 .push(CheckpointCall::new(
398 durable_execution_arn,
399 checkpoint_token,
400 operations,
401 ));
402
403 let mut responses = self.checkpoint_responses.lock().unwrap();
405 if let Some(response) = responses.pop_front() {
406 response
407 } else {
408 Ok(CheckpointResponse::new("mock-token"))
410 }
411 }
412
413 async fn get_operations(
414 &self,
415 durable_execution_arn: &str,
416 next_marker: &str,
417 ) -> Result<GetOperationsResponse, DurableError> {
418 self.get_operations_calls
420 .lock()
421 .unwrap()
422 .push(GetOperationsCall::new(durable_execution_arn, next_marker));
423
424 let mut responses = self.get_operations_responses.lock().unwrap();
426 if let Some(response) = responses.pop_front() {
427 response
428 } else {
429 Ok(GetOperationsResponse {
431 operations: Vec::new(),
432 next_marker: None,
433 })
434 }
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441 use crate::OperationType;
442
443 #[tokio::test]
444 async fn test_mock_client_default_checkpoint_response() {
445 let client = MockDurableServiceClient::new();
446 let result = client
447 .checkpoint("arn:test", "token-123", vec![])
448 .await
449 .unwrap();
450 assert_eq!(result.checkpoint_token, "mock-token");
451 }
452
453 #[tokio::test]
454 async fn test_mock_client_custom_checkpoint_response() {
455 let client = MockDurableServiceClient::new()
456 .with_checkpoint_response(Ok(CheckpointResponse::new("custom-token")));
457
458 let result = client
459 .checkpoint("arn:test", "token-123", vec![])
460 .await
461 .unwrap();
462 assert_eq!(result.checkpoint_token, "custom-token");
463 }
464
465 #[tokio::test]
466 async fn test_mock_client_checkpoint_response_order() {
467 let client = MockDurableServiceClient::new()
468 .with_checkpoint_response(Ok(CheckpointResponse::new("token-1")))
469 .with_checkpoint_response(Ok(CheckpointResponse::new("token-2")))
470 .with_checkpoint_response(Ok(CheckpointResponse::new("token-3")));
471
472 let r1 = client.checkpoint("arn:test", "t", vec![]).await.unwrap();
473 let r2 = client.checkpoint("arn:test", "t", vec![]).await.unwrap();
474 let r3 = client.checkpoint("arn:test", "t", vec![]).await.unwrap();
475
476 assert_eq!(r1.checkpoint_token, "token-1");
477 assert_eq!(r2.checkpoint_token, "token-2");
478 assert_eq!(r3.checkpoint_token, "token-3");
479 }
480
481 #[tokio::test]
482 async fn test_mock_client_checkpoint_error_response() {
483 let client = MockDurableServiceClient::new()
484 .with_checkpoint_response(Err(DurableError::checkpoint_retriable("Test error")));
485
486 let result = client.checkpoint("arn:test", "token-123", vec![]).await;
487 assert!(result.is_err());
488 assert!(result.unwrap_err().is_retriable());
489 }
490
491 #[tokio::test]
492 async fn test_mock_client_records_checkpoint_calls() {
493 let client = MockDurableServiceClient::new();
494
495 client
496 .checkpoint(
497 "arn:aws:lambda:us-east-1:123456789012:function:test",
498 "token-123",
499 vec![OperationUpdate::start("op-1", OperationType::Step)],
500 )
501 .await
502 .unwrap();
503
504 let calls = client.get_checkpoint_calls();
505 assert_eq!(calls.len(), 1);
506 assert_eq!(
507 calls[0].durable_execution_arn,
508 "arn:aws:lambda:us-east-1:123456789012:function:test"
509 );
510 assert_eq!(calls[0].checkpoint_token, "token-123");
511 assert_eq!(calls[0].operations.len(), 1);
512 assert_eq!(calls[0].operations[0].operation_id, "op-1");
513 }
514
515 #[tokio::test]
516 async fn test_mock_client_records_multiple_checkpoint_calls() {
517 let client = MockDurableServiceClient::new();
518
519 client
520 .checkpoint("arn:test-1", "token-1", vec![])
521 .await
522 .unwrap();
523 client
524 .checkpoint("arn:test-2", "token-2", vec![])
525 .await
526 .unwrap();
527 client
528 .checkpoint("arn:test-3", "token-3", vec![])
529 .await
530 .unwrap();
531
532 let calls = client.get_checkpoint_calls();
533 assert_eq!(calls.len(), 3);
534 assert_eq!(calls[0].checkpoint_token, "token-1");
535 assert_eq!(calls[1].checkpoint_token, "token-2");
536 assert_eq!(calls[2].checkpoint_token, "token-3");
537 }
538
539 #[tokio::test]
540 async fn test_mock_client_clear_checkpoint_calls() {
541 let client = MockDurableServiceClient::new();
542
543 client
544 .checkpoint("arn:test", "token", vec![])
545 .await
546 .unwrap();
547 assert_eq!(client.checkpoint_call_count(), 1);
548
549 client.clear_checkpoint_calls();
550 assert_eq!(client.checkpoint_call_count(), 0);
551 }
552
553 #[tokio::test]
554 async fn test_mock_client_default_get_operations_response() {
555 let client = MockDurableServiceClient::new();
556 let result = client
557 .get_operations("arn:test", "marker-123")
558 .await
559 .unwrap();
560 assert!(result.operations.is_empty());
561 assert!(result.next_marker.is_none());
562 }
563
564 #[tokio::test]
565 async fn test_mock_client_custom_get_operations_response() {
566 let client = MockDurableServiceClient::new().with_get_operations_response(Ok(
567 GetOperationsResponse {
568 operations: vec![Operation::new("op-1", OperationType::Step)],
569 next_marker: Some("next-marker".to_string()),
570 },
571 ));
572
573 let result = client
574 .get_operations("arn:test", "marker-123")
575 .await
576 .unwrap();
577 assert_eq!(result.operations.len(), 1);
578 assert_eq!(result.operations[0].operation_id, "op-1");
579 assert_eq!(result.next_marker, Some("next-marker".to_string()));
580 }
581
582 #[tokio::test]
583 async fn test_mock_client_records_get_operations_calls() {
584 let client = MockDurableServiceClient::new();
585
586 client
587 .get_operations("arn:test", "marker-123")
588 .await
589 .unwrap();
590
591 let calls = client.get_get_operations_calls();
592 assert_eq!(calls.len(), 1);
593 assert_eq!(calls[0].durable_execution_arn, "arn:test");
594 assert_eq!(calls[0].next_marker, "marker-123");
595 }
596
597 #[tokio::test]
598 async fn test_mock_client_with_checkpoint_responses() {
599 let client = MockDurableServiceClient::new().with_checkpoint_responses(3);
600
601 let r1 = client.checkpoint("arn:test", "t", vec![]).await.unwrap();
602 let r2 = client.checkpoint("arn:test", "t", vec![]).await.unwrap();
603 let r3 = client.checkpoint("arn:test", "t", vec![]).await.unwrap();
604
605 assert_eq!(r1.checkpoint_token, "token-0");
606 assert_eq!(r2.checkpoint_token, "token-1");
607 assert_eq!(r3.checkpoint_token, "token-2");
608 }
609
610 #[tokio::test]
611 async fn test_mock_client_with_checkpoint_response_with_operations() {
612 let mut op = Operation::new("callback-1", OperationType::Callback);
613 op.callback_details = Some(durable_execution_sdk::CallbackDetails {
614 callback_id: Some("cb-123".to_string()),
615 result: None,
616 error: None,
617 });
618
619 let client = MockDurableServiceClient::new()
620 .with_checkpoint_response_with_operations("token-1", vec![op]);
621
622 let result = client.checkpoint("arn:test", "t", vec![]).await.unwrap();
623
624 assert_eq!(result.checkpoint_token, "token-1");
625 let state = result.new_execution_state.unwrap();
626 assert_eq!(state.operations.len(), 1);
627 assert_eq!(state.operations[0].operation_id, "callback-1");
628 assert_eq!(
629 state.operations[0]
630 .callback_details
631 .as_ref()
632 .unwrap()
633 .callback_id,
634 Some("cb-123".to_string())
635 );
636 }
637
638 #[tokio::test]
639 async fn test_mock_client_falls_back_to_default_after_configured_responses() {
640 let client = MockDurableServiceClient::new()
641 .with_checkpoint_response(Ok(CheckpointResponse::new("configured-token")));
642
643 let r1 = client.checkpoint("arn:test", "t", vec![]).await.unwrap();
645 assert_eq!(r1.checkpoint_token, "configured-token");
646
647 let r2 = client.checkpoint("arn:test", "t", vec![]).await.unwrap();
649 assert_eq!(r2.checkpoint_token, "mock-token");
650 }
651
652 #[test]
653 fn test_mock_client_debug() {
654 let client = MockDurableServiceClient::new()
655 .with_checkpoint_response(Ok(CheckpointResponse::new("token")));
656
657 let debug_str = format!("{:?}", client);
658 assert!(debug_str.contains("MockDurableServiceClient"));
659 assert!(debug_str.contains("checkpoint_responses_remaining"));
660 }
661
662 #[test]
663 fn test_mock_client_default() {
664 let client = MockDurableServiceClient::default();
665 assert_eq!(client.checkpoint_call_count(), 0);
666 }
667}
668
669#[cfg(test)]
673mod property_tests {
674 use super::*;
675 use crate::OperationType;
676 use proptest::prelude::*;
677
678 fn token_sequence_strategy() -> impl Strategy<Value = Vec<String>> {
680 prop::collection::vec("[a-zA-Z0-9_-]{1,20}", 1..=10)
681 }
682
683 fn checkpoint_call_strategy() -> impl Strategy<Value = (String, String, Vec<OperationUpdate>)> {
685 (
686 "[a-zA-Z0-9:/_-]{10,50}", "[a-zA-Z0-9_-]{1,20}", prop::collection::vec(
689 (
690 "[a-zA-Z0-9_-]{1,20}",
691 prop_oneof![
692 Just(OperationType::Step),
693 Just(OperationType::Wait),
694 Just(OperationType::Callback),
695 Just(OperationType::Invoke),
696 Just(OperationType::Context),
697 ],
698 )
699 .prop_map(|(id, op_type)| OperationUpdate::start(id, op_type)),
700 0..=5,
701 ),
702 )
703 }
704
705 proptest! {
706 #[test]
713 fn prop_mock_client_response_order(tokens in token_sequence_strategy()) {
714 let rt = tokio::runtime::Runtime::new().unwrap();
715 rt.block_on(async {
716 let mut client = MockDurableServiceClient::new();
718 for token in &tokens {
719 client = client.with_checkpoint_response(Ok(CheckpointResponse::new(token.clone())));
720 }
721
722 let mut received_tokens = Vec::new();
724 for _ in 0..tokens.len() {
725 let response = client.checkpoint("arn:test", "t", vec![]).await.unwrap();
726 received_tokens.push(response.checkpoint_token);
727 }
728
729 prop_assert_eq!(received_tokens, tokens);
731 Ok(())
732 })?;
733 }
734
735 #[test]
742 fn prop_mock_client_call_recording(
743 calls in prop::collection::vec(checkpoint_call_strategy(), 1..=10)
744 ) {
745 let rt = tokio::runtime::Runtime::new().unwrap();
746 rt.block_on(async {
747 let client = MockDurableServiceClient::new();
748
749 for (arn, token, ops) in &calls {
751 let _ = client.checkpoint(arn, token, ops.clone()).await;
752 }
753
754 let recorded_calls = client.get_checkpoint_calls();
756
757 prop_assert_eq!(recorded_calls.len(), calls.len());
759
760 for (i, ((expected_arn, expected_token, expected_ops), recorded)) in
762 calls.iter().zip(recorded_calls.iter()).enumerate()
763 {
764 prop_assert_eq!(
765 &recorded.durable_execution_arn,
766 expected_arn,
767 "Call {} ARN mismatch",
768 i
769 );
770 prop_assert_eq!(
771 &recorded.checkpoint_token,
772 expected_token,
773 "Call {} token mismatch",
774 i
775 );
776 prop_assert_eq!(
777 recorded.operations.len(),
778 expected_ops.len(),
779 "Call {} operations count mismatch",
780 i
781 );
782
783 for (j, (expected_op, recorded_op)) in
785 expected_ops.iter().zip(recorded.operations.iter()).enumerate()
786 {
787 prop_assert_eq!(
788 &recorded_op.operation_id,
789 &expected_op.operation_id,
790 "Call {} operation {} ID mismatch",
791 i,
792 j
793 );
794 prop_assert_eq!(
795 recorded_op.operation_type,
796 expected_op.operation_type,
797 "Call {} operation {} type mismatch",
798 i,
799 j
800 );
801 }
802 }
803
804 Ok(())
805 })?;
806 }
807 }
808}