use std::num::NonZeroU32;
use std::sync::{Arc, Mutex, MutexGuard, PoisonError};
use pureflow_types::{ExecutionId, NodeId, WorkflowId};
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct ExecutionAttempt(NonZeroU32);
impl ExecutionAttempt {
#[must_use]
pub const fn new(value: NonZeroU32) -> Self {
Self(value)
}
#[must_use]
pub const fn first() -> Self {
Self(NonZeroU32::MIN)
}
#[must_use]
pub const fn get(self) -> u32 {
self.0.get()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ExecutionMetadata {
execution_id: ExecutionId,
attempt: ExecutionAttempt,
}
impl ExecutionMetadata {
#[must_use]
pub const fn new(execution_id: ExecutionId, attempt: ExecutionAttempt) -> Self {
Self {
execution_id,
attempt,
}
}
#[must_use]
pub const fn first_attempt(execution_id: ExecutionId) -> Self {
Self::new(execution_id, ExecutionAttempt::first())
}
#[must_use]
pub const fn execution_id(&self) -> &ExecutionId {
&self.execution_id
}
#[must_use]
pub const fn attempt(&self) -> ExecutionAttempt {
self.attempt
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CancellationRequest {
reason: String,
}
impl CancellationRequest {
#[must_use]
pub fn new(reason: impl Into<String>) -> Self {
Self {
reason: reason.into(),
}
}
#[must_use]
pub fn reason(&self) -> &str {
&self.reason
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CancellationState {
Active,
Requested(CancellationRequest),
}
impl CancellationState {
#[must_use]
pub const fn is_requested(&self) -> bool {
matches!(self, Self::Requested(_))
}
}
#[derive(Debug, Default)]
struct CancellationSignal {
request: Mutex<Option<CancellationRequest>>,
}
#[derive(Debug, Clone)]
pub struct CancellationToken {
signal: Arc<CancellationSignal>,
}
impl CancellationToken {
#[must_use]
pub fn active() -> Self {
Self {
signal: Arc::new(CancellationSignal::default()),
}
}
#[must_use]
pub fn cancelled(request: CancellationRequest) -> Self {
let token: Self = Self::active();
let _first_request: bool = token.request_cancellation(request);
token
}
#[must_use]
pub fn request(&self) -> Option<CancellationRequest> {
self.signal
.request
.lock()
.unwrap_or_else(PoisonError::into_inner)
.clone()
}
#[must_use]
pub fn state(&self) -> CancellationState {
self.request()
.map_or(CancellationState::Active, |request: CancellationRequest| {
CancellationState::Requested(request)
})
}
#[must_use]
pub fn is_cancelled(&self) -> bool {
self.request().is_some()
}
fn request_cancellation(&self, request: CancellationRequest) -> bool {
let mut guard: MutexGuard<'_, Option<CancellationRequest>> = self
.signal
.request
.lock()
.unwrap_or_else(PoisonError::into_inner);
if guard.is_some() {
return false;
}
*guard = Some(request);
true
}
}
impl Default for CancellationToken {
fn default() -> Self {
Self::active()
}
}
impl PartialEq for CancellationToken {
fn eq(&self, other: &Self) -> bool {
self.request() == other.request()
}
}
impl Eq for CancellationToken {}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CancellationHandle {
token: CancellationToken,
}
impl CancellationHandle {
#[must_use]
pub fn new() -> Self {
Self {
token: CancellationToken::active(),
}
}
#[must_use]
pub fn token(&self) -> CancellationToken {
self.token.clone()
}
#[must_use]
pub fn cancel(&self, request: CancellationRequest) -> bool {
self.token.request_cancellation(request)
}
#[must_use]
pub fn is_cancelled(&self) -> bool {
self.token.is_cancelled()
}
#[must_use]
pub fn request(&self) -> Option<CancellationRequest> {
self.token.request()
}
}
impl Default for CancellationHandle {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct NodeContext {
workflow_id: WorkflowId,
node_id: NodeId,
execution: ExecutionMetadata,
cancellation: CancellationToken,
}
impl NodeContext {
#[must_use]
pub fn new(workflow_id: WorkflowId, node_id: NodeId, execution: ExecutionMetadata) -> Self {
Self {
workflow_id,
node_id,
execution,
cancellation: CancellationToken::active(),
}
}
#[must_use]
pub fn with_cancellation(mut self, request: CancellationRequest) -> Self {
self.cancellation = CancellationToken::cancelled(request);
self
}
#[must_use]
pub fn with_cancellation_token(mut self, token: CancellationToken) -> Self {
self.cancellation = token;
self
}
#[must_use]
pub const fn workflow_id(&self) -> &WorkflowId {
&self.workflow_id
}
#[must_use]
pub const fn node_id(&self) -> &NodeId {
&self.node_id
}
#[must_use]
pub const fn execution(&self) -> &ExecutionMetadata {
&self.execution
}
#[must_use]
pub fn cancellation(&self) -> CancellationState {
self.cancellation.state()
}
#[must_use]
pub fn cancellation_token(&self) -> CancellationToken {
self.cancellation.clone()
}
#[must_use]
pub fn is_cancelled(&self) -> bool {
self.cancellation.is_cancelled()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn execution_id(value: &str) -> ExecutionId {
ExecutionId::new(value).expect("valid execution id")
}
fn node_id(value: &str) -> NodeId {
NodeId::new(value).expect("valid node id")
}
fn workflow_id(value: &str) -> WorkflowId {
WorkflowId::new(value).expect("valid workflow id")
}
fn execution() -> ExecutionMetadata {
ExecutionMetadata::first_attempt(execution_id("run-1"))
}
#[test]
fn first_execution_attempt_is_one_based() {
assert_eq!(ExecutionAttempt::first().get(), 1);
}
#[test]
fn node_context_starts_active_and_can_carry_cancellation() {
let ctx: NodeContext = NodeContext::new(workflow_id("flow"), node_id("node"), execution());
assert!(!ctx.is_cancelled());
assert!(matches!(ctx.cancellation(), CancellationState::Active));
let cancelled: NodeContext =
ctx.with_cancellation(CancellationRequest::new("shutdown requested"));
assert!(cancelled.is_cancelled());
assert!(matches!(
cancelled.cancellation(),
CancellationState::Requested(request) if request.reason() == "shutdown requested"
));
}
#[test]
fn shared_cancellation_handle_reaches_parent_and_child_contexts() {
let handle: CancellationHandle = CancellationHandle::new();
let parent: NodeContext =
NodeContext::new(workflow_id("flow"), node_id("parent"), execution())
.with_cancellation_token(handle.token());
let child: NodeContext =
NodeContext::new(workflow_id("flow"), node_id("child"), execution())
.with_cancellation_token(parent.cancellation_token());
assert!(!parent.is_cancelled());
assert!(!child.is_cancelled());
assert!(handle.cancel(CancellationRequest::new("supervisor shutdown")));
assert!(!handle.cancel(CancellationRequest::new("ignored duplicate")));
assert!(parent.is_cancelled());
assert!(child.is_cancelled());
assert!(matches!(
child.cancellation(),
CancellationState::Requested(request) if request.reason() == "supervisor shutdown"
));
}
}