durable_execution_sdk_testing/
run_future.rs1use std::future::Future;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11
12use crate::error::TestError;
13use crate::test_result::TestResult;
14
15enum RunFutureInner<O> {
17 Spawned(tokio::task::JoinHandle<Result<TestResult<O>, TestError>>),
19 Inline(Pin<Box<dyn Future<Output = Result<TestResult<O>, TestError>>>>),
21}
22
23pub struct RunFuture<O> {
47 inner: RunFutureInner<O>,
48}
49
50impl<O> RunFuture<O> {
51 pub fn new(handle: tokio::task::JoinHandle<Result<TestResult<O>, TestError>>) -> Self {
53 Self {
54 inner: RunFutureInner::Spawned(handle),
55 }
56 }
57
58 pub fn from_future(
63 future: Pin<Box<dyn Future<Output = Result<TestResult<O>, TestError>>>>,
64 ) -> Self {
65 Self {
66 inner: RunFutureInner::Inline(future),
67 }
68 }
69}
70
71impl<O> Future for RunFuture<O> {
72 type Output = Result<TestResult<O>, TestError>;
73
74 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
75 match &mut self.inner {
76 RunFutureInner::Spawned(handle) => {
77 let handle = Pin::new(handle);
79 match handle.poll(cx) {
80 Poll::Ready(Ok(result)) => Poll::Ready(result),
81 Poll::Ready(Err(join_error)) => Poll::Ready(Err(
82 TestError::CheckpointServerError(format!("Task failed: {}", join_error)),
83 )),
84 Poll::Pending => Poll::Pending,
85 }
86 }
87 RunFutureInner::Inline(future) => future.as_mut().poll(cx),
88 }
89 }
90}
91
92#[cfg(test)]
93mod tests {
94 use super::*;
95 use crate::test_result::TestResult;
96 use crate::types::ExecutionStatus;
97
98 #[tokio::test]
99 async fn test_run_future_success() {
100 let handle =
101 tokio::spawn(async { Ok(TestResult::<String>::success("hello".to_string(), vec![])) });
102 let future = RunFuture::new(handle);
103 let result = future.await.unwrap();
104 assert_eq!(result.get_status(), ExecutionStatus::Succeeded);
105 assert_eq!(result.get_result().unwrap(), "hello");
106 }
107
108 #[tokio::test]
109 async fn test_run_future_error() {
110 let handle =
111 tokio::spawn(async { Err(TestError::CheckpointServerError("test error".to_string())) });
112 let future: RunFuture<String> = RunFuture::new(handle);
113 let result = future.await;
114 assert!(result.is_err());
115 assert!(matches!(
116 result.unwrap_err(),
117 TestError::CheckpointServerError(_)
118 ));
119 }
120
121 #[tokio::test]
122 async fn test_run_future_join_error() {
123 let handle: tokio::task::JoinHandle<Result<TestResult<String>, TestError>> =
124 tokio::spawn(async {
125 panic!("task panicked");
126 });
127 let future = RunFuture::new(handle);
128 let result = future.await;
129 assert!(result.is_err());
130 match result.unwrap_err() {
131 TestError::CheckpointServerError(msg) => {
132 assert!(msg.contains("Task failed"));
133 }
134 other => panic!("Expected CheckpointServerError, got: {:?}", other),
135 }
136 }
137
138 #[tokio::test]
139 async fn test_run_future_from_future_success() {
140 let future = RunFuture::<String>::from_future(Box::pin(async {
141 Ok(TestResult::success("inline".to_string(), vec![]))
142 }));
143 let result = future.await.unwrap();
144 assert_eq!(result.get_status(), ExecutionStatus::Succeeded);
145 assert_eq!(result.get_result().unwrap(), "inline");
146 }
147
148 #[tokio::test]
149 async fn test_run_future_from_future_error() {
150 let future = RunFuture::<String>::from_future(Box::pin(async {
151 Err(TestError::CheckpointServerError("inline error".to_string()))
152 }));
153 let result = future.await;
154 assert!(result.is_err());
155 }
156}