use crate::events::SharedEventBus;
use crate::expression::ExpressionEngineRegistry;
use crate::handler::HandlerRegistry;
use crate::listener::{WorkflowEvent, WorkflowExecutionListener};
use crate::secret::SecretManager;
use crate::status::{StatusPhase, StatusPhaseLog};
use serde_json::Value;
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use swf_core::models::task::TaskDefinition;
use swf_core::models::workflow::WorkflowDefinition;
use tokio::sync::Notify;
macro_rules! arc_accessors {
($field:ident, $setter:ident, $getter:ident, $clone:ident, $ty:ty) => {
pub fn $setter(&mut self, value: Arc<$ty>) {
self.$field = Some(value);
}
pub fn $getter(&self) -> Option<&$ty> {
self.$field.as_deref()
}
pub fn $clone(&self) -> Option<Arc<$ty>> {
self.$field.clone()
}
};
}
macro_rules! option_accessors {
($field:ident, $setter:ident, $getter:ident, $clone:ident, $ty:ty) => {
pub fn $setter(&mut self, value: $ty) {
self.$field = Some(value);
}
pub fn $getter(&self) -> Option<&$ty> {
self.$field.as_ref()
}
pub fn $clone(&self) -> Option<$ty> {
self.$field.clone()
}
};
}
#[derive(Clone)]
pub(crate) struct SuspendState {
suspended: Arc<AtomicBool>,
resume_notify: Arc<Notify>,
}
impl SuspendState {
pub(crate) fn new() -> Self {
Self {
suspended: Arc::new(AtomicBool::new(false)),
resume_notify: Arc::new(Notify::new()),
}
}
pub fn suspend(&self) -> bool {
self.suspended
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
.is_ok()
}
pub fn resume(&self) -> bool {
if self
.suspended
.compare_exchange(true, false, Ordering::SeqCst, Ordering::SeqCst)
.is_ok()
{
self.resume_notify.notify_waiters();
true
} else {
false
}
}
pub fn is_suspended(&self) -> bool {
self.suspended.load(Ordering::SeqCst)
}
pub(crate) fn resume_notify(&self) -> &Arc<Notify> {
&self.resume_notify
}
}
use tokio_util::sync::CancellationToken;
pub mod vars {
pub const CONTEXT: &str = "$context";
pub const INPUT: &str = "$input";
pub const OUTPUT: &str = "$output";
pub const WORKFLOW: &str = "$workflow";
pub const RUNTIME: &str = "$runtime";
pub const TASK: &str = "$task";
pub const SECRET: &str = "$secret";
pub const AUTHORIZATION: &str = "$authorization";
}
pub mod runtime_info {
pub const NAME: &str = "CNCF Serverless Workflow Specification Rust SDK";
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
static RUNTIME_INFO: std::sync::LazyLock<serde_json::Value> = std::sync::LazyLock::new(|| {
serde_json::json!({
"name": NAME,
"version": VERSION,
})
});
pub fn runtime_info_value() -> &'static serde_json::Value {
&RUNTIME_INFO
}
}
pub struct WorkflowContext {
input: Option<Value>,
output: Option<Value>,
instance_ctx: Option<Value>,
workflow_descriptor: Arc<Value>,
task_descriptor: Value,
local_expr_vars: HashMap<String, Value>,
authorization: Option<Value>,
secret_manager: Option<Arc<dyn SecretManager>>,
listener: Option<Arc<dyn WorkflowExecutionListener>>,
event_bus: Option<SharedEventBus>,
sub_workflows: HashMap<String, WorkflowDefinition>,
cancellation_token: CancellationToken,
suspend_state: SuspendState,
handler_registry: HandlerRegistry,
expression_engines: ExpressionEngineRegistry,
functions: HashMap<String, TaskDefinition>,
status_log: Vec<StatusPhaseLog>,
task_status: HashMap<String, Vec<StatusPhaseLog>>,
iterations: HashMap<String, u32>,
vars_cache: Mutex<Option<HashMap<String, Value>>>,
vars_dirty: AtomicBool,
}
impl Clone for WorkflowContext {
fn clone(&self) -> Self {
Self {
input: self.input.clone(),
output: self.output.clone(),
instance_ctx: self.instance_ctx.clone(),
workflow_descriptor: Arc::clone(&self.workflow_descriptor),
task_descriptor: self.task_descriptor.clone(),
local_expr_vars: self.local_expr_vars.clone(),
authorization: self.authorization.clone(),
secret_manager: self.secret_manager.clone(),
listener: self.listener.clone(),
event_bus: self.event_bus.clone(),
sub_workflows: self.sub_workflows.clone(),
cancellation_token: self.cancellation_token.clone(),
suspend_state: self.suspend_state.clone(),
handler_registry: self.handler_registry.clone(),
expression_engines: self.expression_engines.clone(),
functions: self.functions.clone(),
status_log: self.status_log.clone(),
task_status: self.task_status.clone(),
iterations: self.iterations.clone(),
vars_cache: Mutex::new(self.vars_cache.lock().unwrap().clone()),
vars_dirty: AtomicBool::new(self.vars_dirty.load(Ordering::Acquire)),
}
}
}
impl std::fmt::Debug for WorkflowContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WorkflowContext")
.field("input", &self.input)
.field("output", &self.output)
.field("instance_ctx", &self.instance_ctx)
.field("workflow_descriptor", &self.workflow_descriptor)
.field("task_descriptor", &self.task_descriptor)
.field("local_expr_vars", &self.local_expr_vars)
.field(
"secret_manager",
&self.secret_manager.as_ref().map(|_| "..."),
)
.field("listener", &self.listener.as_ref().map(|_| "..."))
.field("event_bus", &self.event_bus.as_ref().map(|_| "..."))
.field("status_log", &self.status_log)
.field("task_status", &self.task_status)
.field("iterations", &self.iterations)
.finish()
}
}
impl WorkflowContext {
pub fn new(
workflow: &swf_core::models::workflow::WorkflowDefinition,
) -> crate::error::WorkflowResult<Self> {
let workflow_json = serde_json::to_value(workflow).map_err(|e| {
crate::error::WorkflowError::runtime(
format!("failed to serialize workflow definition: {}", e),
"/",
"/",
)
})?;
let workflow_descriptor = Arc::new(serde_json::json!({
"id": uuid::Uuid::new_v4().to_string(),
"definition": workflow_json,
}));
let mut ctx = Self {
input: None,
output: None,
instance_ctx: None,
workflow_descriptor,
task_descriptor: Value::Object(Default::default()),
local_expr_vars: HashMap::new(),
authorization: None,
secret_manager: None,
listener: None,
event_bus: None,
sub_workflows: HashMap::new(),
cancellation_token: CancellationToken::new(),
suspend_state: SuspendState::new(),
handler_registry: HandlerRegistry::new(),
expression_engines: ExpressionEngineRegistry::new(),
functions: HashMap::new(),
status_log: Vec::new(),
task_status: HashMap::new(),
iterations: HashMap::new(),
vars_cache: Mutex::new(None),
vars_dirty: AtomicBool::new(true),
};
ctx.set_status(StatusPhase::Pending);
Ok(ctx)
}
pub fn set_status(&mut self, status: StatusPhase) {
self.status_log.push(StatusPhaseLog::new(status));
}
pub fn instance_id(&self) -> &str {
self.workflow_descriptor
.as_object()
.and_then(|obj| obj.get("id"))
.and_then(|id| id.as_str())
.unwrap_or("unknown")
}
pub fn get_status(&self) -> StatusPhase {
self.status_log
.last()
.map(|log| log.status)
.unwrap_or(StatusPhase::Pending)
}
pub fn set_task_status(&mut self, task: &str, status: StatusPhase) {
self.task_status
.entry(task.to_string())
.or_default()
.push(StatusPhaseLog::new(status));
}
pub fn get_task_status(&self, task: &str) -> Option<StatusPhase> {
self.task_status
.get(task)
.and_then(|logs| logs.last())
.map(|log| log.status)
}
pub fn set_input(&mut self, value: Value) {
self.input = Some(value);
self.invalidate_vars_cache();
}
pub fn get_input(&self) -> Option<&Value> {
self.input.as_ref()
}
pub fn set_output(&mut self, value: Value) {
self.output = Some(value);
self.invalidate_vars_cache();
}
pub fn get_output(&self) -> Option<&Value> {
self.output.as_ref()
}
pub fn set_instance_ctx(&mut self, value: Value) {
self.instance_ctx = Some(value);
self.invalidate_vars_cache();
}
pub fn get_instance_ctx(&self) -> Option<&Value> {
self.instance_ctx.as_ref()
}
pub fn set_raw_input(&mut self, input: &Value) {
let mut desc = (*self.workflow_descriptor).clone();
if let Some(obj) = desc.as_object_mut() {
obj.insert("input".to_string(), input.clone());
}
self.workflow_descriptor = Arc::new(desc);
self.invalidate_vars_cache();
}
fn task_descriptor_insert(&mut self, key: &str, value: Value) {
if let Some(obj) = self.task_descriptor.as_object_mut() {
obj.insert(key.to_string(), value);
}
self.invalidate_vars_cache();
}
pub fn set_task_name(&mut self, name: &str) {
self.task_descriptor_insert("name", Value::String(name.to_string()));
}
pub fn set_task_raw_input(&mut self, input: &Value) {
self.task_descriptor_insert("input", input.clone());
}
pub fn set_task_raw_output(&mut self, output: &Value) {
self.task_descriptor_insert("output", output.clone());
}
pub fn set_task_started_at(&mut self) {
let now = chrono::Utc::now();
let iso8601 = now.to_rfc3339();
let epoch_seconds = now.timestamp();
let epoch_millis = now.timestamp_millis();
self.task_descriptor_insert(
"startedAt",
serde_json::json!({
"iso8601": iso8601,
"epoch": {
"seconds": epoch_seconds,
"milliseconds": epoch_millis,
}
}),
);
}
pub fn set_task_reference(&mut self, reference: &str) {
self.task_descriptor_insert("reference", Value::String(reference.to_string()));
}
pub fn get_task_reference(&self) -> Option<&str> {
self.task_descriptor
.as_object()
.and_then(|obj| obj.get("reference"))
.and_then(|v| v.as_str())
}
pub fn get_workflow_json(&self) -> Option<&Value> {
self.workflow_descriptor
.as_object()
.and_then(|obj| obj.get("definition"))
}
pub fn set_task_def(&mut self, task: &Value) {
self.task_descriptor_insert("definition", task.clone());
}
pub fn inc_iteration(&mut self, position: &str) -> u32 {
let count = self.iterations.entry(position.to_string()).or_insert(0);
*count += 1;
let value = *count;
self.task_descriptor_insert("iteration", serde_json::json!(value));
value
}
pub fn set_retry_attempt(&mut self, attempt: u32) {
self.task_descriptor_insert("retryAttempt", serde_json::json!(attempt));
}
pub fn clear_task_context(&mut self) {
self.task_descriptor = Value::Object(Default::default());
}
arc_accessors!(
secret_manager,
set_secret_manager,
get_secret_manager,
clone_secret_manager,
dyn SecretManager
);
arc_accessors!(
listener,
set_listener,
get_listener,
clone_listener,
dyn WorkflowExecutionListener
);
pub fn emit_event(&self, event: WorkflowEvent) {
if let Some(ref listener) = self.listener {
listener.on_event(&event);
}
if let Some(ref event_bus) = self.event_bus {
let cloud_event = event.to_cloud_event();
let bus = event_bus.clone();
tokio::spawn(async move {
bus.publish(cloud_event).await;
});
}
}
option_accessors!(
event_bus,
set_event_bus,
get_event_bus,
clone_event_bus,
SharedEventBus
);
pub fn set_sub_workflows(&mut self, sub_workflows: HashMap<String, WorkflowDefinition>) {
self.sub_workflows = sub_workflows;
}
pub fn get_sub_workflow(
&self,
namespace: &str,
name: &str,
version: &str,
) -> Option<&WorkflowDefinition> {
let key = format!("{}/{}/{}", namespace, name, version);
self.sub_workflows.get(&key)
}
pub fn clone_sub_workflows(&self) -> HashMap<String, WorkflowDefinition> {
self.sub_workflows.clone()
}
pub fn set_handler_registry(&mut self, registry: HandlerRegistry) {
self.handler_registry = registry;
}
pub fn get_handler_registry(&self) -> &HandlerRegistry {
&self.handler_registry
}
pub fn clone_handler_registry(&self) -> HandlerRegistry {
self.handler_registry.clone()
}
pub(crate) fn set_expression_engines(&mut self, engines: ExpressionEngineRegistry) {
self.expression_engines = engines;
}
pub(crate) fn get_expression_engines(&self) -> &ExpressionEngineRegistry {
&self.expression_engines
}
pub(crate) fn clone_expression_engines(&self) -> ExpressionEngineRegistry {
self.expression_engines.clone()
}
pub fn set_functions(&mut self, functions: HashMap<String, TaskDefinition>) {
self.functions = functions;
}
pub fn get_function(&self, name: &str) -> Option<&TaskDefinition> {
self.functions.get(name)
}
pub fn cancellation_token(&self) -> CancellationToken {
self.cancellation_token.clone()
}
pub fn cancel(&self) {
self.cancellation_token.cancel();
}
pub fn is_cancelled(&self) -> bool {
self.cancellation_token.is_cancelled()
}
pub fn suspend(&self) -> bool {
self.suspend_state.suspend()
}
pub fn resume(&self) -> bool {
self.suspend_state.resume()
}
pub fn is_suspended(&self) -> bool {
self.suspend_state.is_suspended()
}
pub async fn wait_for_resume(&self) {
if self.is_suspended() {
tokio::select! {
_ = self.suspend_state.resume_notify().notified() => {}
_ = self.cancellation_token.cancelled() => {}
}
}
}
pub(crate) fn set_suspend_state(&mut self, state: SuspendState) {
self.suspend_state = state;
}
pub fn set_authorization(&mut self, scheme: &str, parameter: &str) {
self.authorization = Some(serde_json::json!({
"scheme": scheme,
"parameter": parameter,
}));
self.invalidate_vars_cache();
}
pub fn clear_authorization(&mut self) {
self.authorization = None;
self.invalidate_vars_cache();
}
pub fn set_local_expr_vars(&mut self, vars: HashMap<String, Value>) {
self.local_expr_vars = vars;
self.invalidate_vars_cache();
}
pub fn add_local_expr_vars(&mut self, vars: HashMap<String, Value>) {
for (k, v) in vars {
self.local_expr_vars.entry(k).or_insert(v);
}
self.invalidate_vars_cache();
}
pub fn remove_local_expr_vars(&mut self, keys: &[&str]) {
for key in keys {
self.local_expr_vars.remove(*key);
}
self.invalidate_vars_cache();
}
fn invalidate_vars_cache(&self) {
self.vars_dirty.store(true, Ordering::Release);
}
pub fn get_vars(&self) -> HashMap<String, Value> {
if self.vars_dirty.load(Ordering::Acquire) {
let mut vars = HashMap::new();
vars.insert(
vars::INPUT.to_string(),
self.input.clone().unwrap_or(Value::Null),
);
vars.insert(
vars::OUTPUT.to_string(),
self.output.clone().unwrap_or(Value::Null),
);
vars.insert(
vars::CONTEXT.to_string(),
self.instance_ctx.clone().unwrap_or(Value::Null),
);
vars.insert(vars::TASK.to_string(), self.task_descriptor.clone());
vars.insert(
vars::WORKFLOW.to_string(),
(*self.workflow_descriptor).clone(),
);
vars.insert(
vars::RUNTIME.to_string(),
runtime_info::runtime_info_value().clone(),
);
if let Some(ref mgr) = self.secret_manager {
vars.insert(vars::SECRET.to_string(), mgr.get_all_secrets());
}
if let Some(ref auth) = self.authorization {
vars.insert(vars::AUTHORIZATION.to_string(), auth.clone());
}
for (k, v) in &self.local_expr_vars {
vars.insert(k.clone(), v.clone());
}
*self.vars_cache.lock().unwrap() = Some(vars);
self.vars_dirty.store(false, Ordering::Release);
}
self.vars_cache.lock().unwrap().as_ref().unwrap().clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use swf_core::models::workflow::WorkflowDefinition;
fn new_context() -> WorkflowContext {
let workflow = WorkflowDefinition::default();
WorkflowContext::new(&workflow).unwrap()
}
#[test]
fn test_context_new() {
let ctx = new_context();
assert!(ctx.get_input().is_none());
assert!(ctx.get_output().is_none());
assert_eq!(ctx.get_status(), StatusPhase::Pending);
}
#[test]
fn test_context_set_input_output() {
let mut ctx = new_context();
ctx.set_input(json!({"key": "value"}));
assert_eq!(ctx.get_input(), Some(&json!({"key": "value"})));
ctx.set_output(json!(42));
assert_eq!(ctx.get_output(), Some(&json!(42)));
}
#[test]
fn test_context_status_transitions() {
let mut ctx = new_context();
assert_eq!(ctx.get_status(), StatusPhase::Pending);
ctx.set_status(StatusPhase::Running);
assert_eq!(ctx.get_status(), StatusPhase::Running);
ctx.set_status(StatusPhase::Completed);
assert_eq!(ctx.get_status(), StatusPhase::Completed);
}
#[test]
fn test_context_instance_ctx() {
let mut ctx = new_context();
assert!(ctx.get_instance_ctx().is_none());
ctx.set_instance_ctx(json!({"exported": "data"}));
assert_eq!(ctx.get_instance_ctx(), Some(&json!({"exported": "data"})));
}
#[test]
fn test_context_local_expr_vars() {
let mut ctx = new_context();
let mut vars = HashMap::new();
vars.insert("$item".to_string(), json!("hello"));
vars.insert("$index".to_string(), json!(0));
ctx.add_local_expr_vars(vars);
let all_vars = ctx.get_vars();
assert_eq!(all_vars.get("$item"), Some(&json!("hello")));
assert_eq!(all_vars.get("$index"), Some(&json!(0)));
ctx.remove_local_expr_vars(&["$item", "$index"]);
let all_vars = ctx.get_vars();
assert!(!all_vars.contains_key("$item"));
assert!(!all_vars.contains_key("$index"));
}
#[test]
fn test_context_get_vars_includes_runtime() {
let ctx = new_context();
let vars = ctx.get_vars();
assert!(vars.contains_key(vars::RUNTIME));
assert!(vars.contains_key(vars::WORKFLOW));
assert!(vars.contains_key(vars::TASK));
}
#[test]
fn test_context_task_status() {
let mut ctx = new_context();
ctx.set_task_status("task1", StatusPhase::Running);
ctx.set_task_status("task1", StatusPhase::Completed);
ctx.set_task_status("task2", StatusPhase::Pending);
let task1_status = ctx.get_task_status("task1");
assert_eq!(task1_status, Some(StatusPhase::Completed));
}
#[test]
fn test_context_authorization() {
let mut ctx = new_context();
let vars = ctx.get_vars();
assert!(!vars.contains_key("$authorization"));
ctx.set_authorization("Bearer", "my-token-123");
let vars = ctx.get_vars();
let auth = vars
.get("$authorization")
.expect("$authorization should be set");
assert_eq!(auth["scheme"], "Bearer");
assert_eq!(auth["parameter"], "my-token-123");
ctx.clear_authorization();
let vars = ctx.get_vars();
assert!(!vars.contains_key("$authorization"));
}
}