use std::sync::Arc;
use std::sync::atomic::AtomicU64;
use aion_core::{Payload, RunId, WorkflowError, WorkflowId, WorkflowStatus};
use aion_package::ContentHash;
use tokio::sync::{Mutex, watch};
use crate::durability::Recorder;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Residency {
Resident,
Suspended,
}
pub type HandleResidency = Residency;
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum TerminalOutcome {
Completed(Payload),
Failed(WorkflowError),
Cancelled(String),
TimedOut(String),
ContinuedAsNew {
input: Payload,
workflow_type: Option<String>,
parent_run_id: RunId,
},
}
#[derive(Clone, Debug)]
pub struct CompletionNotifier {
sender: watch::Sender<Option<TerminalOutcome>>,
}
impl CompletionNotifier {
#[must_use]
pub fn new() -> Self {
let (sender, _receiver) = watch::channel(None);
Self { sender }
}
#[must_use]
pub fn subscribe(&self) -> watch::Receiver<Option<TerminalOutcome>> {
self.sender.subscribe()
}
pub fn notify(&self, outcome: TerminalOutcome) {
drop(self.sender.send_replace(Some(outcome)));
}
#[must_use]
pub fn is_completed(&self) -> bool {
self.sender.borrow().is_some()
}
}
impl Default for CompletionNotifier {
fn default() -> Self {
Self::new()
}
}
impl PartialEq for CompletionNotifier {
fn eq(&self, other: &Self) -> bool {
self.sender.same_channel(&other.sender)
}
}
impl Eq for CompletionNotifier {}
pub struct WorkflowHandleParts {
pub workflow_id: WorkflowId,
pub run_id: RunId,
pub pid: u64,
pub workflow_type: String,
pub loaded_version: ContentHash,
pub cached_status: WorkflowStatus,
pub residency: Residency,
pub recorder: Recorder,
pub completion: CompletionNotifier,
}
#[derive(Clone)]
pub struct WorkflowHandle {
workflow_id: WorkflowId,
run_id: RunId,
pid: u64,
workflow_type: String,
loaded_version: ContentHash,
cached_status: WorkflowStatus,
residency: Residency,
recorder: Arc<Mutex<Recorder>>,
completion: CompletionNotifier,
deterministic_nif_sequence: Arc<AtomicU64>,
activity_ordinal_sequence: Arc<AtomicU64>,
timer_ordinal_sequence: Arc<AtomicU64>,
child_ordinal_sequence: Arc<AtomicU64>,
signal_receive_counts: Arc<dashmap::DashMap<String, u64>>,
signal_send_counts: Arc<dashmap::DashMap<String, u64>>,
}
impl WorkflowHandle {
#[must_use]
pub fn new(parts: WorkflowHandleParts) -> Self {
Self {
workflow_id: parts.workflow_id,
run_id: parts.run_id,
pid: parts.pid,
workflow_type: parts.workflow_type,
loaded_version: parts.loaded_version,
cached_status: parts.cached_status,
residency: parts.residency,
recorder: Arc::new(Mutex::new(parts.recorder)),
completion: parts.completion,
deterministic_nif_sequence: Arc::new(AtomicU64::new(0)),
activity_ordinal_sequence: Arc::new(AtomicU64::new(0)),
timer_ordinal_sequence: Arc::new(AtomicU64::new(0)),
child_ordinal_sequence: Arc::new(AtomicU64::new(0)),
signal_receive_counts: Arc::new(dashmap::DashMap::new()),
signal_send_counts: Arc::new(dashmap::DashMap::new()),
}
}
#[must_use]
pub fn allocate_activity_ordinals(&self, count: u64) -> u64 {
self.activity_ordinal_sequence
.fetch_add(count, std::sync::atomic::Ordering::SeqCst)
}
#[must_use]
pub fn allocate_child_ordinals(&self, count: u64) -> u64 {
self.child_ordinal_sequence
.fetch_add(count, std::sync::atomic::Ordering::SeqCst)
}
#[must_use]
pub fn allocate_timer_ordinals(&self, count: u64) -> u64 {
self.timer_ordinal_sequence
.fetch_add(count, std::sync::atomic::Ordering::SeqCst)
}
#[must_use]
pub fn activity_ordinals_allocated(&self) -> u64 {
self.activity_ordinal_sequence
.load(std::sync::atomic::Ordering::SeqCst)
}
#[must_use]
pub fn timer_ordinals_allocated(&self) -> u64 {
self.timer_ordinal_sequence
.load(std::sync::atomic::Ordering::SeqCst)
}
#[must_use]
pub fn child_ordinals_allocated(&self) -> u64 {
self.child_ordinal_sequence
.load(std::sync::atomic::Ordering::SeqCst)
}
#[must_use]
pub fn signal_receives_consumed(&self, name: &str) -> u64 {
self.signal_receive_counts
.get(name)
.map_or(0, |entry| *entry)
}
pub fn mark_signal_receive_consumed(&self, name: &str) {
*self
.signal_receive_counts
.entry(name.to_owned())
.or_insert(0) += 1;
}
#[must_use]
pub fn signal_sends_completed(&self, name: &str) -> u64 {
self.signal_send_counts.get(name).map_or(0, |entry| *entry)
}
pub fn mark_signal_send_completed(&self, name: &str) {
*self.signal_send_counts.entry(name.to_owned()).or_insert(0) += 1;
}
#[must_use]
pub const fn workflow_id(&self) -> &WorkflowId {
&self.workflow_id
}
#[must_use]
pub const fn run_id(&self) -> &RunId {
&self.run_id
}
#[must_use]
pub const fn pid(&self) -> u64 {
self.pid
}
#[must_use]
pub fn workflow_type(&self) -> &str {
&self.workflow_type
}
#[must_use]
pub const fn loaded_version(&self) -> &ContentHash {
&self.loaded_version
}
#[must_use]
pub const fn cached_status(&self) -> WorkflowStatus {
self.cached_status
}
#[must_use]
pub const fn residency(&self) -> Residency {
self.residency
}
#[must_use]
pub fn recorder(&self) -> Arc<Mutex<Recorder>> {
Arc::clone(&self.recorder)
}
#[must_use]
pub const fn completion(&self) -> &CompletionNotifier {
&self.completion
}
#[must_use]
pub fn next_deterministic_nif_sequence(&self) -> u64 {
self.deterministic_nif_sequence
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
}
pub(in crate::registry) const fn replace_projected_status(&mut self, status: WorkflowStatus) {
self.cached_status = status;
}
pub(in crate::registry) const fn replace_residency(&mut self, residency: Residency) {
self.residency = residency;
}
}
impl std::fmt::Debug for WorkflowHandle {
fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
formatter
.debug_struct("WorkflowHandle")
.field("workflow_id", &self.workflow_id)
.field("run_id", &self.run_id)
.field("pid", &self.pid)
.field("workflow_type", &self.workflow_type)
.field("loaded_version", &self.loaded_version)
.field("cached_status", &self.cached_status)
.field("residency", &self.residency)
.field("completion", &self.completion)
.finish_non_exhaustive()
}
}
impl PartialEq for WorkflowHandle {
fn eq(&self, other: &Self) -> bool {
self.workflow_id == other.workflow_id
&& self.run_id == other.run_id
&& self.pid == other.pid
&& self.workflow_type == other.workflow_type
&& self.loaded_version == other.loaded_version
&& self.cached_status == other.cached_status
&& self.residency == other.residency
&& Arc::ptr_eq(&self.recorder, &other.recorder)
&& self.completion == other.completion
&& Arc::ptr_eq(
&self.deterministic_nif_sequence,
&other.deterministic_nif_sequence,
)
}
}
impl Eq for WorkflowHandle {}
#[cfg(test)]
mod tests {
use serde_json::json;
use super::{CompletionNotifier, TerminalOutcome};
fn payload(label: &str) -> Result<aion_core::Payload, aion_core::PayloadError> {
aion_core::Payload::from_json(&json!({ "label": label }))
}
#[test]
fn completion_notifier_stores_outcome_without_active_receiver()
-> Result<(), aion_core::PayloadError> {
let notifier = CompletionNotifier::new();
let receiver = notifier.subscribe();
drop(receiver);
let result = payload("completed")?;
notifier.notify(TerminalOutcome::Completed(result.clone()));
let late_receiver = notifier.subscribe();
assert_eq!(
late_receiver.borrow().clone(),
Some(TerminalOutcome::Completed(result))
);
assert!(notifier.is_completed());
Ok(())
}
}