use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll, Waker};
use crate::api::{DurableTaskError, FailureDetails};
#[derive(Debug, Clone)]
pub enum TaskResult {
Completed(Option<String>),
Failed(FailureDetails),
}
struct CompletableTaskInner {
result: Option<TaskResult>,
waker: Option<Waker>,
completed_during_replay: bool,
replay_handle: Option<Arc<AtomicBool>>,
}
#[derive(Clone)]
pub struct CompletableTask {
inner: Arc<Mutex<CompletableTaskInner>>,
}
impl CompletableTask {
pub fn new() -> Self {
Self {
inner: Arc::new(Mutex::new(CompletableTaskInner {
result: None,
waker: None,
completed_during_replay: true,
replay_handle: None,
})),
}
}
pub(crate) fn set_replay_handle(&self, handle: Arc<AtomicBool>) {
let mut inner = self.inner.lock().unwrap_or_else(|e| e.into_inner());
inner.replay_handle = Some(handle);
}
pub fn complete(&self, result: Option<String>) {
self.complete_with_phase(result, true);
}
pub(crate) fn complete_with_phase(&self, result: Option<String>, during_replay: bool) {
let mut inner = self.inner.lock().unwrap_or_else(|e| e.into_inner());
inner.result = Some(TaskResult::Completed(result));
inner.completed_during_replay = during_replay;
if let Some(waker) = inner.waker.take() {
waker.wake();
}
}
pub fn fail(&self, details: FailureDetails) {
self.fail_with_phase(details, true);
}
pub(crate) fn fail_with_phase(&self, details: FailureDetails, during_replay: bool) {
let mut inner = self.inner.lock().unwrap_or_else(|e| e.into_inner());
inner.result = Some(TaskResult::Failed(details));
inner.completed_during_replay = during_replay;
if let Some(waker) = inner.waker.take() {
waker.wake();
}
}
pub fn is_complete(&self) -> bool {
let inner = self.inner.lock().unwrap_or_else(|e| e.into_inner());
inner.result.is_some()
}
pub fn is_failed(&self) -> bool {
let inner = self.inner.lock().unwrap_or_else(|e| e.into_inner());
matches!(inner.result, Some(TaskResult::Failed(_)))
}
pub fn get_result(&self) -> Option<TaskResult> {
let inner = self.inner.lock().unwrap_or_else(|e| e.into_inner());
inner.result.clone()
}
pub(crate) fn ptr_eq(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.inner, &other.inner)
}
}
impl Default for CompletableTask {
fn default() -> Self {
Self::new()
}
}
impl Future for CompletableTask {
type Output = crate::api::Result<Option<String>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut inner = self.inner.lock().unwrap_or_else(|e| e.into_inner());
match &inner.result {
Some(TaskResult::Completed(value)) => {
let value = value.clone();
if !inner.completed_during_replay
&& let Some(handle) = inner.replay_handle.as_ref()
{
handle.store(false, Ordering::Release);
}
Poll::Ready(Ok(value))
}
Some(TaskResult::Failed(details)) => {
let details = details.clone();
if !inner.completed_during_replay
&& let Some(handle) = inner.replay_handle.as_ref()
{
handle.store(false, Ordering::Release);
}
Poll::Ready(Err(DurableTaskError::TaskFailed {
message: details.message.clone(),
failure_details: Some(details),
}))
}
None => {
inner.waker = Some(cx.waker().clone());
Poll::Pending
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::task::Waker;
fn noop_waker() -> Waker {
Waker::noop().clone()
}
#[test]
fn test_new_task_is_not_complete() {
let task = CompletableTask::new();
assert!(!task.is_complete());
assert!(!task.is_failed());
assert!(task.get_result().is_none());
}
#[test]
fn test_complete_task() {
let task = CompletableTask::new();
task.complete(Some("42".to_string()));
assert!(task.is_complete());
assert!(!task.is_failed());
match task.get_result() {
Some(TaskResult::Completed(v)) => assert_eq!(v, Some("42".to_string())),
_ => panic!("expected Completed"),
}
}
#[test]
fn test_fail_task() {
let task = CompletableTask::new();
let details = FailureDetails {
message: "boom".to_string(),
error_type: "Error".to_string(),
stack_trace: None,
};
task.fail(details);
assert!(task.is_complete());
assert!(task.is_failed());
}
#[test]
fn test_poll_pending_then_ready() {
let task = CompletableTask::new();
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let mut t = task.clone();
assert!(Pin::new(&mut t).poll(&mut cx).is_pending());
task.complete(Some("\"hello\"".to_string()));
let mut t2 = task.clone();
match Pin::new(&mut t2).poll(&mut cx) {
Poll::Ready(Ok(v)) => assert_eq!(v, Some("\"hello\"".to_string())),
other => panic!("expected Ready(Ok), got {other:?}"),
}
}
#[test]
fn test_poll_failed() {
let task = CompletableTask::new();
let details = FailureDetails {
message: "oops".to_string(),
error_type: "TestError".to_string(),
stack_trace: None,
};
task.fail(details);
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let mut t = task.clone();
match Pin::new(&mut t).poll(&mut cx) {
Poll::Ready(Err(DurableTaskError::TaskFailed { message, .. })) => {
assert_eq!(message, "oops");
}
other => panic!("expected Ready(Err(TaskFailed)), got {other:?}"),
}
}
#[test]
fn test_clone_shares_state() {
let task = CompletableTask::new();
let clone = task.clone();
task.complete(Some("shared".to_string()));
assert!(clone.is_complete());
}
}