use std::sync::LazyLock;
use qubit_cas::{
FastCasPolicy,
FastCasState,
};
use qubit_state_machine::FastStateMachine;
use super::task_status::{
TASK_STATUS_COUNT,
TaskStatus,
};
const TASK_STATUS_EVENT_COUNT: usize = 6;
static TASK_STATUS_MACHINE: LazyLock<FastStateMachine> = LazyLock::new(build_task_status_machine);
#[repr(usize)]
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
enum TaskStatusEvent {
Start = 0,
CancelPending = 1,
CompleteSucceeded = 2,
CompleteFailed = 3,
CompletePanicked = 4,
DropUnfinished = 5,
}
impl TaskStatusEvent {
#[inline]
const fn as_usize(self) -> usize {
self as usize
}
const fn from_completion_status(status: TaskStatus) -> Option<Self> {
match status {
TaskStatus::Succeeded => Some(Self::CompleteSucceeded),
TaskStatus::Failed => Some(Self::CompleteFailed),
TaskStatus::Panicked => Some(Self::CompletePanicked),
TaskStatus::Pending
| TaskStatus::Running
| TaskStatus::Cancelled
| TaskStatus::Dropped => None,
}
}
}
fn build_task_status_machine() -> FastStateMachine {
let pending = TaskStatus::Pending.as_usize();
let running = TaskStatus::Running.as_usize();
let succeeded = TaskStatus::Succeeded.as_usize();
let failed = TaskStatus::Failed.as_usize();
let panicked = TaskStatus::Panicked.as_usize();
let cancelled = TaskStatus::Cancelled.as_usize();
let dropped = TaskStatus::Dropped.as_usize();
FastStateMachine::builder()
.state_count(TASK_STATUS_COUNT)
.event_count(TASK_STATUS_EVENT_COUNT)
.initial_state(pending)
.final_states(&[succeeded, failed, panicked, cancelled, dropped])
.cas_policy(FastCasPolicy::spin(16))
.transition(pending, TaskStatusEvent::Start.as_usize(), running)
.transition(
pending,
TaskStatusEvent::CancelPending.as_usize(),
cancelled,
)
.transition(
running,
TaskStatusEvent::CompleteSucceeded.as_usize(),
succeeded,
)
.transition(running, TaskStatusEvent::CompleteFailed.as_usize(), failed)
.transition(
running,
TaskStatusEvent::CompletePanicked.as_usize(),
panicked,
)
.transition(pending, TaskStatusEvent::DropUnfinished.as_usize(), dropped)
.transition(running, TaskStatusEvent::DropUnfinished.as_usize(), dropped)
.build()
.expect("task status state machine must be valid")
}
pub(crate) struct AtomicTaskStatus {
value: FastCasState,
}
impl AtomicTaskStatus {
#[inline]
pub(crate) fn new(status: TaskStatus) -> Self {
Self {
value: FastCasState::new(status.as_usize()),
}
}
#[inline]
pub(crate) fn load(&self) -> TaskStatus {
TaskStatus::from_usize(self.value.load())
}
#[inline]
pub(crate) fn try_start(&self) -> bool {
self.try_transition(TaskStatusEvent::Start)
}
#[inline]
pub(crate) fn try_cancel_pending(&self) -> bool {
self.try_transition(TaskStatusEvent::CancelPending)
}
#[inline]
pub(crate) fn try_complete(&self, status: TaskStatus) -> bool {
let Some(event) = TaskStatusEvent::from_completion_status(status) else {
return false;
};
self.try_transition(event)
}
#[inline]
pub(crate) fn try_drop_unfinished(&self) -> bool {
self.try_transition(TaskStatusEvent::DropUnfinished)
}
#[inline]
fn try_transition(&self, event: TaskStatusEvent) -> bool {
TASK_STATUS_MACHINE.try_trigger(&self.value, event.as_usize())
}
}
#[cfg(test)]
mod task_status_event_encoding_tests {
use super::TaskStatusEvent;
#[test]
fn task_status_event_as_usize_matches_stable_discriminants() {
assert_eq!(TaskStatusEvent::Start.as_usize(), 0);
assert_eq!(TaskStatusEvent::CancelPending.as_usize(), 1);
assert_eq!(TaskStatusEvent::CompleteSucceeded.as_usize(), 2);
assert_eq!(TaskStatusEvent::CompleteFailed.as_usize(), 3);
assert_eq!(TaskStatusEvent::CompletePanicked.as_usize(), 4);
assert_eq!(TaskStatusEvent::DropUnfinished.as_usize(), 5);
}
#[test]
fn task_status_event_codes_are_zero_through_seven_in_declaration_order() {
let events = [
TaskStatusEvent::Start,
TaskStatusEvent::CancelPending,
TaskStatusEvent::CompleteSucceeded,
TaskStatusEvent::CompleteFailed,
TaskStatusEvent::CompletePanicked,
TaskStatusEvent::DropUnfinished,
];
for (i, event) in events.iter().enumerate() {
assert_eq!(event.as_usize(), i, "event index {i}");
}
}
#[test]
fn task_status_event_count_matches_variants() {
assert_eq!(
TaskStatusEvent::DropUnfinished as usize + 1,
super::TASK_STATUS_EVENT_COUNT,
"last event discriminant + 1 must equal TASK_STATUS_EVENT_COUNT"
);
}
}