use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use crabdance_activity::{
ActivityContext, ActivityInfo, WorkflowExecution as ActivityWorkflowExecution,
};
use crabdance_core::{
ActivityOptions, ChildWorkflowOptions, WorkflowExecution, WorkflowInfo, WorkflowType,
};
use crabdance_workflow::context::WorkflowError;
use crabdance_workflow::future::{ActivityFailureInfo, ActivityFailureType};
use serde::{Deserialize, Serialize};
type WorkflowFn = Box<
dyn Fn(
TestWorkflowContext,
Vec<u8>,
) -> Pin<
Box<dyn Future<Output = Result<(TestWorkflowContext, Vec<u8>), WorkflowError>> + Send>,
> + Send
+ Sync,
>;
type ActivityFn = Arc<
dyn Fn(
&ActivityContext,
Vec<u8>,
) -> Pin<Box<dyn Future<Output = Result<Vec<u8>, ActivityError>> + Send>>
+ Send
+ Sync,
>;
#[derive(Debug, thiserror::Error)]
pub enum ActivityError {
#[error("Activity failed: {0}")]
Generic(String),
}
impl ActivityError {
pub fn execution_failed(message: impl Into<String>) -> Self {
Self::Generic(message.into())
}
pub fn execution_failed_error<E: std::error::Error>(error: E) -> Self {
Self::Generic(error.to_string())
}
}
pub struct TestSuite;
impl TestSuite {
pub fn new() -> Self {
Self
}
}
impl Default for TestSuite {
fn default() -> Self {
Self::new()
}
}
pub struct TestWorkflowEnvironment {
workflow_id: String,
run_id: String,
test_time: Arc<Mutex<TestTime>>,
registered_workflows: HashMap<String, WorkflowFn>,
registered_activities: HashMap<String, ActivityFn>,
pending_signals: HashMap<String, Vec<Vec<u8>>>,
executed_activities: Vec<String>,
}
impl TestWorkflowEnvironment {
pub fn new() -> Self {
Self {
workflow_id: format!("test-workflow-{}", uuid::Uuid::new_v4()),
run_id: format!("test-run-{}", uuid::Uuid::new_v4()),
test_time: Arc::new(Mutex::new(TestTime::new())),
registered_workflows: HashMap::new(),
registered_activities: HashMap::new(),
pending_signals: HashMap::new(),
executed_activities: Vec::new(),
}
}
pub fn register_workflow<F, Fut, I, O>(&mut self, name: &str, workflow: F)
where
F: Fn(TestWorkflowContext, I) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<(TestWorkflowContext, O), WorkflowError>> + Send + 'static,
I: for<'de> Deserialize<'de> + Send + 'static,
O: Serialize + Send + 'static,
{
#[expect(clippy::type_complexity)]
let boxed = Box::new(
move |ctx: TestWorkflowContext,
input_bytes: Vec<u8>|
-> Pin<
Box<
dyn Future<Output = Result<(TestWorkflowContext, Vec<u8>), WorkflowError>>
+ Send,
>,
> {
let input: I = match serde_json::from_slice(&input_bytes) {
Ok(i) => i,
Err(e) => {
return Box::pin(async move {
Err(WorkflowError::execution_failed(format!(
"Input deserialization failed: {}",
e
)))
})
}
};
let future = workflow(ctx, input);
Box::pin(async move {
let (ctx, output) = future.await?;
let output_bytes = serde_json::to_vec(&output)
.map_err(WorkflowError::execution_failed_error)?;
Ok((ctx, output_bytes))
})
},
)
as Box<
dyn Fn(
TestWorkflowContext,
Vec<u8>,
) -> Pin<
Box<
dyn Future<
Output = Result<(TestWorkflowContext, Vec<u8>), WorkflowError>,
> + Send,
>,
> + Send
+ Sync,
>;
self.registered_workflows.insert(name.to_string(), boxed);
}
pub fn register_activity<F, Fut, I, O>(&mut self, name: &str, activity: F)
where
F: Fn(&ActivityContext, I) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<O, ActivityError>> + Send + 'static,
I: for<'de> Deserialize<'de> + Send + 'static,
O: Serialize + Send + 'static,
{
let arc = Arc::new(move |ctx: &ActivityContext, input_bytes: Vec<u8>| {
let input: I = match serde_json::from_slice(&input_bytes) {
Ok(i) => i,
Err(e) => {
return Box::pin(async move {
Err(ActivityError::execution_failed(format!(
"Input deserialization failed: {}",
e
)))
}) as Pin<Box<dyn Future<Output = _> + Send>>
}
};
let result = activity(ctx, input);
Box::pin(async move {
let output = result.await?;
serde_json::to_vec(&output).map_err(ActivityError::execution_failed_error)
}) as Pin<Box<dyn Future<Output = _> + Send>>
});
self.registered_activities.insert(name.to_string(), arc);
}
pub async fn execute_workflow<I, O>(&mut self, name: &str, input: I) -> Result<O, WorkflowError>
where
I: Serialize,
O: for<'de> Deserialize<'de>,
{
let input_bytes =
serde_json::to_vec(&input).map_err(WorkflowError::execution_failed_error)?;
let workflow = self.registered_workflows.get(name).ok_or_else(|| {
WorkflowError::execution_failed(format!("Workflow '{}' not registered", name))
})?;
let activities = self.registered_activities.clone();
let signals = self.pending_signals.clone();
let test_time = self.test_time.clone();
let ctx = TestWorkflowContext {
workflow_id: self.workflow_id.clone(),
run_id: self.run_id.clone(),
workflow_type: name.to_string(),
task_list: "test-task-list".to_string(),
activities,
signals,
queries: HashMap::new(),
test_time,
is_cancelled: false,
};
self.pending_signals.clear();
let (_ctx, result_bytes) = workflow(ctx, input_bytes).await?;
serde_json::from_slice(&result_bytes).map_err(WorkflowError::execution_failed_error)
}
pub async fn execute_activity<I, O>(&self, name: &str, input: I) -> Result<O, ActivityError>
where
I: Serialize,
O: for<'de> Deserialize<'de>,
{
let input_bytes =
serde_json::to_vec(&input).map_err(ActivityError::execution_failed_error)?;
let activity = self.registered_activities.get(name).ok_or_else(|| {
ActivityError::execution_failed(format!("Activity '{}' not registered", name))
})?;
let activity_info = ActivityInfo {
activity_id: format!("test-activity-{}", uuid::Uuid::new_v4()),
activity_type: name.to_string(),
task_token: vec![],
workflow_execution: ActivityWorkflowExecution::new(&self.workflow_id, &self.run_id),
attempt: 1,
scheduled_time: chrono::Utc::now(),
started_time: chrono::Utc::now(),
deadline: None,
heartbeat_timeout: Duration::from_secs(0),
heartbeat_details: None,
};
let ctx = ActivityContext::new(activity_info, None);
let result_bytes = activity(&ctx, input_bytes).await?;
serde_json::from_slice(&result_bytes).map_err(ActivityError::execution_failed_error)
}
pub fn signal_workflow(&mut self, signal_name: &str, data: Vec<u8>) {
self.pending_signals
.entry(signal_name.to_string())
.or_default()
.push(data);
}
pub fn set_workflow_time(&mut self, time: chrono::DateTime<chrono::Utc>) {
if let Ok(mut test_time) = self.test_time.lock() {
test_time.set_time(time);
}
}
pub fn advance_workflow_time(&mut self, duration: Duration) {
if let Ok(mut test_time) = self.test_time.lock() {
test_time.advance(duration);
}
}
pub fn get_executed_activities(&self) -> &[String] {
&self.executed_activities
}
pub fn was_activity_executed(&self, name: &str) -> bool {
self.executed_activities.contains(&name.to_string())
}
#[expect(dead_code)]
fn track_activity_execution(&mut self, name: &str) {
self.executed_activities.push(name.to_string());
}
}
impl Default for TestWorkflowEnvironment {
fn default() -> Self {
Self::new()
}
}
pub struct TestWorkflowContext {
workflow_id: String,
run_id: String,
workflow_type: String,
task_list: String,
activities: HashMap<String, ActivityFn>,
signals: HashMap<String, Vec<Vec<u8>>>,
#[expect(clippy::type_complexity)]
queries: HashMap<String, Box<dyn Fn(Vec<u8>) -> Vec<u8> + Send + Sync>>,
test_time: Arc<Mutex<TestTime>>,
is_cancelled: bool,
}
impl TestWorkflowContext {
pub fn workflow_info(&self) -> WorkflowInfo {
WorkflowInfo {
workflow_execution: WorkflowExecution::new(&self.workflow_id, &self.run_id),
workflow_type: WorkflowType {
name: self.workflow_type.clone(),
},
task_list: self.task_list.clone(),
start_time: chrono::Utc::now(),
execution_start_to_close_timeout: Duration::from_secs(60),
task_start_to_close_timeout: Duration::from_secs(10),
attempt: 1,
cron_schedule: None,
continued_execution_run_id: None,
parent_workflow_execution: None,
memo: None,
search_attributes: None,
}
}
pub async fn execute_activity(
&mut self,
activity_type: &str,
args: Option<Vec<u8>>,
_options: ActivityOptions,
) -> Result<Vec<u8>, WorkflowError> {
let activity = self.activities.get(activity_type).ok_or_else(|| {
WorkflowError::ActivityFailed(ActivityFailureInfo {
failure_type: ActivityFailureType::ExecutionFailed,
message: format!("Activity '{}' not registered", activity_type),
details: None,
retryable: false,
})
})?;
let activity_info = ActivityInfo {
activity_id: format!("activity-{}", uuid::Uuid::new_v4()),
activity_type: activity_type.to_string(),
task_token: vec![],
workflow_execution: ActivityWorkflowExecution::new(&self.workflow_id, &self.run_id),
attempt: 1,
scheduled_time: chrono::Utc::now(),
started_time: chrono::Utc::now(),
deadline: None,
heartbeat_timeout: Duration::from_secs(0),
heartbeat_details: None,
};
let ctx = ActivityContext::new(activity_info, None);
let input = args.unwrap_or_default();
activity(&ctx, input).await.map_err(|e| {
WorkflowError::ActivityFailed(ActivityFailureInfo {
failure_type: ActivityFailureType::ExecutionFailed,
message: format!("{:?}", e),
details: None,
retryable: false,
})
})
}
pub fn get_signal_channel(&self, signal_name: &str) -> TestSignalChannel {
let signals = self.signals.get(signal_name).cloned().unwrap_or_default();
TestSignalChannel::new(signals)
}
pub async fn sleep(&self, duration: Duration) {
if let Ok(mut time) = self.test_time.lock() {
time.advance(duration);
}
}
pub fn now(&self) -> chrono::DateTime<chrono::Utc> {
self.test_time
.lock()
.map(|t| t.current_time)
.unwrap_or_else(|_| chrono::Utc::now())
}
pub fn is_cancelled(&self) -> bool {
self.is_cancelled
}
pub async fn execute_child_workflow(
&mut self,
_workflow_type: &str,
_args: Option<Vec<u8>>,
_options: ChildWorkflowOptions,
) -> Result<Vec<u8>, WorkflowError> {
Err(WorkflowError::execution_failed(
"Child workflow execution not yet implemented in test environment",
))
}
pub async fn signal_external_workflow(
&self,
_workflow_id: &str,
_run_id: Option<&str>,
_signal_name: &str,
_args: Option<Vec<u8>>,
) -> Result<(), WorkflowError> {
Ok(())
}
pub async fn request_cancel_external_workflow(
&self,
_workflow_id: &str,
_run_id: Option<&str>,
) -> Result<(), WorkflowError> {
Ok(())
}
pub async fn side_effect<F, R>(&self, f: F) -> R
where
F: FnOnce() -> R,
{
f()
}
pub async fn mutable_side_effect<F, R>(&self, _id: &str, f: F) -> R
where
F: FnOnce() -> R,
R: Clone,
{
f()
}
pub fn get_version(&self, _change_id: &str, min_supported: i32, _max_supported: i32) -> i32 {
min_supported
}
pub fn set_query_handler<F>(&mut self, query_type: &str, handler: F)
where
F: Fn(Vec<u8>) -> Vec<u8> + Send + Sync + 'static,
{
self.queries
.insert(query_type.to_string(), Box::new(handler));
}
pub fn upsert_search_attributes(&self, _search_attributes: Vec<(String, Vec<u8>)>) {
}
pub fn get_cancellation_channel(&self) -> TestCancellationChannel {
TestCancellationChannel::new()
}
}
pub struct TestSignalChannel {
signals: Vec<Vec<u8>>,
current_index: usize,
}
impl TestSignalChannel {
fn new(signals: Vec<Vec<u8>>) -> Self {
Self {
signals,
current_index: 0,
}
}
pub async fn recv(&mut self) -> Option<Vec<u8>> {
if self.current_index < self.signals.len() {
let signal = self.signals[self.current_index].clone();
self.current_index += 1;
Some(signal)
} else {
None
}
}
pub fn try_recv(&mut self) -> Option<Vec<u8>> {
if self.current_index < self.signals.len() {
let signal = self.signals[self.current_index].clone();
self.current_index += 1;
Some(signal)
} else {
None
}
}
}
pub struct TestCancellationChannel;
impl TestCancellationChannel {
fn new() -> Self {
Self
}
pub async fn recv(&mut self) {
std::future::pending().await
}
}
pub struct TestActivityEnvironment;
impl TestActivityEnvironment {
pub fn new() -> Self {
Self
}
pub async fn execute_activity<F, R>(&self, activity: F) -> R
where
F: FnOnce(&mut TestActivityContext) -> R,
{
let mut ctx = TestActivityContext::new(ActivityInfo {
activity_id: format!("test-activity-{}", uuid::Uuid::new_v4()),
activity_type: "test".to_string(),
task_token: vec![],
workflow_execution: ActivityWorkflowExecution::new("test-workflow", "test-run"),
attempt: 1,
scheduled_time: chrono::Utc::now(),
started_time: chrono::Utc::now(),
deadline: None,
heartbeat_timeout: Duration::from_secs(0),
heartbeat_details: None,
});
activity(&mut ctx)
}
}
impl Default for TestActivityEnvironment {
fn default() -> Self {
Self::new()
}
}
pub struct TestActivityContext {
activity_info: ActivityInfo,
recorded_heartbeats: Vec<Option<Vec<u8>>>,
}
impl TestActivityContext {
fn new(activity_info: ActivityInfo) -> Self {
Self {
activity_info,
recorded_heartbeats: Vec::new(),
}
}
pub fn get_info(&self) -> &ActivityInfo {
&self.activity_info
}
pub fn record_heartbeat(&mut self, details: Option<&[u8]>) {
self.recorded_heartbeats.push(details.map(|d| d.to_vec()));
}
pub fn has_heartbeat_details(&self) -> bool {
self.activity_info.heartbeat_details.is_some()
}
pub fn get_heartbeat_details(&self) -> Option<&[u8]> {
self.activity_info.heartbeat_details.as_deref()
}
pub fn is_cancelled(&self) -> bool {
false
}
pub fn get_deadline(&self) -> Option<Instant> {
self.activity_info.deadline
}
pub fn get_remaining_time(&self) -> Option<Duration> {
self.activity_info.deadline.map(|d| {
let now = Instant::now();
if d > now {
d - now
} else {
Duration::from_secs(0)
}
})
}
pub fn get_recorded_heartbeats(&self) -> &[Option<Vec<u8>>] {
&self.recorded_heartbeats
}
}
#[derive(Clone)]
pub struct TestTime {
current_time: chrono::DateTime<chrono::Utc>,
}
impl TestTime {
fn new() -> Self {
Self {
current_time: chrono::Utc::now(),
}
}
fn set_time(&mut self, time: chrono::DateTime<chrono::Utc>) {
self.current_time = time;
}
fn advance(&mut self, duration: Duration) {
self.current_time += chrono::Duration::from_std(duration).unwrap();
}
}
pub struct WorkflowReplayer;
impl WorkflowReplayer {
pub fn new() -> Self {
Self
}
pub async fn replay_workflow_history(
&self,
_history: WorkflowHistory,
) -> Result<(), ReplayError> {
Ok(())
}
pub async fn replay_workflow_history_from_json(&self, _json: &str) -> Result<(), ReplayError> {
Ok(())
}
pub async fn replay_partial_workflow_history(
&self,
_history: WorkflowHistory,
_last_event_id: i64,
) -> Result<(), ReplayError> {
Ok(())
}
}
impl Default for WorkflowReplayer {
fn default() -> Self {
Self::new()
}
}
pub struct WorkflowHistory {
pub events: Vec<HistoryEvent>,
}
pub struct HistoryEvent {
pub event_id: i64,
pub event_type: String,
pub timestamp: i64,
}
#[derive(Debug, thiserror::Error)]
pub enum ReplayError {
#[error("Non-deterministic workflow detected: {0}")]
NonDeterministic(String),
#[error("Invalid history: {0}")]
InvalidHistory(String),
#[error("Replay failed: {0}")]
ReplayFailed(String),
}
pub struct WorkflowShadower;
impl WorkflowShadower {
pub fn new() -> Self {
Self
}
pub async fn run(&self) -> Result<(), ShadowerError> {
Ok(())
}
}
impl Default for WorkflowShadower {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, thiserror::Error)]
pub enum ShadowerError {
#[error("Shadowing failed: {0}")]
ShadowingFailed(String),
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_register_and_execute_workflow() {
let mut env = TestWorkflowEnvironment::new();
env.register_workflow("test_workflow", |ctx, input: String| async move {
Ok((ctx, format!("Hello, {}!", input)))
});
let result: Result<String, _> = env
.execute_workflow("test_workflow", "World".to_string())
.await;
assert_eq!(result.unwrap(), "Hello, World!");
}
#[tokio::test]
async fn test_register_and_execute_activity() {
let mut env = TestWorkflowEnvironment::new();
env.register_activity(
"test_activity",
|_ctx, input: i32| async move { Ok(input * 2) },
);
let result: Result<i32, _> = env.execute_activity("test_activity", 21).await;
assert_eq!(result.unwrap(), 42);
}
#[tokio::test]
async fn test_signal_workflow() {
let mut env = TestWorkflowEnvironment::new();
env.register_workflow("signal_workflow", |ctx, _input: ()| async move {
let mut channel = ctx.get_signal_channel("test_signal");
let signal = channel.recv().await;
assert!(signal.is_some());
Ok((ctx, ()))
});
env.signal_workflow("test_signal", vec![1, 2, 3]);
let result: Result<(), _> = env.execute_workflow("signal_workflow", ()).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_workflow_executes_activity() {
let mut env = TestWorkflowEnvironment::new();
env.register_activity("double", |_ctx, n: i32| async move { Ok(n * 2) });
env.register_workflow("calc_workflow", |mut ctx, input: i32| async move {
let result = ctx
.execute_activity(
"double",
Some(serde_json::to_vec(&input).unwrap()),
ActivityOptions::default(),
)
.await?;
let output: i32 = serde_json::from_slice(&result).unwrap();
Ok((ctx, output))
});
let result: Result<i32, _> = env.execute_workflow("calc_workflow", 21).await;
assert_eq!(result.unwrap(), 42);
}
#[test]
fn test_workflow_not_registered() {
let mut env = TestWorkflowEnvironment::new();
let result: Result<String, _> =
futures::executor::block_on(env.execute_workflow("nonexistent", "input".to_string()));
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("not registered"));
}
#[tokio::test]
async fn test_test_time() {
let mut env = TestWorkflowEnvironment::new();
let start_time = chrono::Utc::now();
env.set_workflow_time(start_time);
env.advance_workflow_time(Duration::from_secs(60));
}
}