dapr_durabletask/task/
completable_task.rs1use std::future::Future;
2use std::pin::Pin;
3use std::sync::atomic::{AtomicBool, Ordering};
4use std::sync::{Arc, Mutex};
5use std::task::{Context, Poll, Waker};
6
7use crate::api::{DurableTaskError, FailureDetails};
8
9#[derive(Debug, Clone)]
11pub enum TaskResult {
12 Completed(Option<String>),
14 Failed(FailureDetails),
16}
17
18struct CompletableTaskInner {
19 result: Option<TaskResult>,
20 waker: Option<Waker>,
21 completed_during_replay: bool,
25 replay_handle: Option<Arc<AtomicBool>>,
27}
28
29#[derive(Clone)]
35pub struct CompletableTask {
36 inner: Arc<Mutex<CompletableTaskInner>>,
37}
38
39impl CompletableTask {
40 pub fn new() -> Self {
41 Self {
42 inner: Arc::new(Mutex::new(CompletableTaskInner {
43 result: None,
44 waker: None,
45 completed_during_replay: true,
46 replay_handle: None,
47 })),
48 }
49 }
50
51 pub(crate) fn set_replay_handle(&self, handle: Arc<AtomicBool>) {
54 let mut inner = self.inner.lock().unwrap_or_else(|e| e.into_inner());
55 inner.replay_handle = Some(handle);
56 }
57
58 pub fn complete(&self, result: Option<String>) {
60 self.complete_with_phase(result, true);
61 }
62
63 pub(crate) fn complete_with_phase(&self, result: Option<String>, during_replay: bool) {
66 let mut inner = self.inner.lock().unwrap_or_else(|e| e.into_inner());
67 inner.result = Some(TaskResult::Completed(result));
68 inner.completed_during_replay = during_replay;
69 if let Some(waker) = inner.waker.take() {
70 waker.wake();
71 }
72 }
73
74 pub fn fail(&self, details: FailureDetails) {
76 self.fail_with_phase(details, true);
77 }
78
79 pub(crate) fn fail_with_phase(&self, details: FailureDetails, during_replay: bool) {
82 let mut inner = self.inner.lock().unwrap_or_else(|e| e.into_inner());
83 inner.result = Some(TaskResult::Failed(details));
84 inner.completed_during_replay = during_replay;
85 if let Some(waker) = inner.waker.take() {
86 waker.wake();
87 }
88 }
89
90 pub fn is_complete(&self) -> bool {
92 let inner = self.inner.lock().unwrap_or_else(|e| e.into_inner());
93 inner.result.is_some()
94 }
95
96 pub fn is_failed(&self) -> bool {
98 let inner = self.inner.lock().unwrap_or_else(|e| e.into_inner());
99 matches!(inner.result, Some(TaskResult::Failed(_)))
100 }
101
102 pub fn get_result(&self) -> Option<TaskResult> {
104 let inner = self.inner.lock().unwrap_or_else(|e| e.into_inner());
105 inner.result.clone()
106 }
107
108 pub(crate) fn ptr_eq(&self, other: &Self) -> bool {
110 Arc::ptr_eq(&self.inner, &other.inner)
111 }
112}
113
114impl Default for CompletableTask {
115 fn default() -> Self {
116 Self::new()
117 }
118}
119
120impl Future for CompletableTask {
121 type Output = crate::api::Result<Option<String>>;
122
123 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
124 let mut inner = self.inner.lock().unwrap_or_else(|e| e.into_inner());
125 match &inner.result {
126 Some(TaskResult::Completed(value)) => {
127 let value = value.clone();
128 if !inner.completed_during_replay
129 && let Some(handle) = inner.replay_handle.as_ref()
130 {
131 handle.store(false, Ordering::Release);
132 }
133 Poll::Ready(Ok(value))
134 }
135 Some(TaskResult::Failed(details)) => {
136 let details = details.clone();
137 if !inner.completed_during_replay
138 && let Some(handle) = inner.replay_handle.as_ref()
139 {
140 handle.store(false, Ordering::Release);
141 }
142 Poll::Ready(Err(DurableTaskError::TaskFailed {
143 message: details.message.clone(),
144 failure_details: Some(details),
145 }))
146 }
147 None => {
148 inner.waker = Some(cx.waker().clone());
149 Poll::Pending
150 }
151 }
152 }
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158 use std::task::Waker;
159
160 fn noop_waker() -> Waker {
161 Waker::noop().clone()
162 }
163
164 #[test]
165 fn test_new_task_is_not_complete() {
166 let task = CompletableTask::new();
167 assert!(!task.is_complete());
168 assert!(!task.is_failed());
169 assert!(task.get_result().is_none());
170 }
171
172 #[test]
173 fn test_complete_task() {
174 let task = CompletableTask::new();
175 task.complete(Some("42".to_string()));
176 assert!(task.is_complete());
177 assert!(!task.is_failed());
178 match task.get_result() {
179 Some(TaskResult::Completed(v)) => assert_eq!(v, Some("42".to_string())),
180 _ => panic!("expected Completed"),
181 }
182 }
183
184 #[test]
185 fn test_fail_task() {
186 let task = CompletableTask::new();
187 let details = FailureDetails {
188 message: "boom".to_string(),
189 error_type: "Error".to_string(),
190 stack_trace: None,
191 };
192 task.fail(details);
193 assert!(task.is_complete());
194 assert!(task.is_failed());
195 }
196
197 #[test]
198 fn test_poll_pending_then_ready() {
199 let task = CompletableTask::new();
200 let waker = noop_waker();
201 let mut cx = Context::from_waker(&waker);
202
203 let mut t = task.clone();
204 assert!(Pin::new(&mut t).poll(&mut cx).is_pending());
205
206 task.complete(Some("\"hello\"".to_string()));
207
208 let mut t2 = task.clone();
209 match Pin::new(&mut t2).poll(&mut cx) {
210 Poll::Ready(Ok(v)) => assert_eq!(v, Some("\"hello\"".to_string())),
211 other => panic!("expected Ready(Ok), got {other:?}"),
212 }
213 }
214
215 #[test]
216 fn test_poll_failed() {
217 let task = CompletableTask::new();
218 let details = FailureDetails {
219 message: "oops".to_string(),
220 error_type: "TestError".to_string(),
221 stack_trace: None,
222 };
223 task.fail(details);
224
225 let waker = noop_waker();
226 let mut cx = Context::from_waker(&waker);
227 let mut t = task.clone();
228 match Pin::new(&mut t).poll(&mut cx) {
229 Poll::Ready(Err(DurableTaskError::TaskFailed { message, .. })) => {
230 assert_eq!(message, "oops");
231 }
232 other => panic!("expected Ready(Err(TaskFailed)), got {other:?}"),
233 }
234 }
235
236 #[test]
237 fn test_clone_shares_state() {
238 let task = CompletableTask::new();
239 let clone = task.clone();
240 task.complete(Some("shared".to_string()));
241 assert!(clone.is_complete());
242 }
243}