use crate::backend::{BackendError, SignalStore, SnapshotStore, TaskClaimStore};
use chrono::{Duration, Utc};
use sayiir_core::snapshot::{
ExecutionPosition, PauseRequest, SignalKind, SignalRequest, WorkflowSnapshot,
WorkflowSnapshotState,
};
use sayiir_core::task_claim::{AvailableTask, TaskClaim};
use std::collections::{HashMap, VecDeque};
use std::sync::{Arc, RwLock};
#[derive(Clone, Default)]
pub struct InMemoryBackend {
snapshots: Arc<RwLock<HashMap<String, WorkflowSnapshot>>>,
claims: Arc<RwLock<HashMap<String, TaskClaim>>>, signals: Arc<RwLock<HashMap<String, HashMap<SignalKind, SignalRequest>>>>,
#[allow(clippy::type_complexity)]
events: Arc<RwLock<HashMap<(String, String), VecDeque<bytes::Bytes>>>>,
}
impl InMemoryBackend {
pub fn new() -> Self {
Default::default()
}
fn claim_key(instance_id: &str, task_id: &str) -> String {
format!("{}:{}", instance_id, task_id)
}
fn lock_error<E: std::fmt::Display>(e: E) -> BackendError {
BackendError::Backend(format!("Lock error: {e}"))
}
}
impl SnapshotStore for InMemoryBackend {
async fn save_snapshot(&self, snapshot: &WorkflowSnapshot) -> Result<(), BackendError> {
let mut snapshots = self.snapshots.write().map_err(Self::lock_error)?;
snapshots.insert(snapshot.instance_id.clone(), snapshot.clone());
Ok(())
}
async fn save_task_result(
&self,
instance_id: &str,
task_id: &str,
output: bytes::Bytes,
) -> Result<(), BackendError> {
let mut snapshots = self.snapshots.write().map_err(Self::lock_error)?;
let snapshot = snapshots
.get_mut(instance_id)
.ok_or_else(|| BackendError::NotFound(instance_id.to_string()))?;
snapshot.mark_task_completed(task_id.to_string(), output);
Ok(())
}
async fn load_snapshot(&self, instance_id: &str) -> Result<WorkflowSnapshot, BackendError> {
let snapshots = self.snapshots.read().map_err(Self::lock_error)?;
snapshots
.get(instance_id)
.cloned()
.ok_or_else(|| BackendError::NotFound(instance_id.to_string()))
}
async fn delete_snapshot(&self, instance_id: &str) -> Result<(), BackendError> {
let mut snapshots = self.snapshots.write().map_err(Self::lock_error)?;
snapshots
.remove(instance_id)
.ok_or_else(|| BackendError::NotFound(instance_id.to_string()))
.map(|_| ())
}
async fn list_snapshots(&self) -> Result<Vec<String>, BackendError> {
let snapshots = self.snapshots.read().map_err(Self::lock_error)?;
Ok(snapshots.keys().cloned().collect())
}
}
impl SignalStore for InMemoryBackend {
async fn store_signal(
&self,
instance_id: &str,
kind: SignalKind,
request: SignalRequest,
) -> Result<(), BackendError> {
{
let snapshots = self.snapshots.read().map_err(Self::lock_error)?;
let snapshot = snapshots
.get(instance_id)
.ok_or_else(|| BackendError::NotFound(instance_id.to_string()))?;
match kind {
SignalKind::Cancel => {
if snapshot.state.is_completed() {
return Err(BackendError::CannotCancel("Completed".to_string()));
}
if snapshot.state.is_failed() {
return Err(BackendError::CannotCancel("Failed".to_string()));
}
if snapshot.state.is_cancelled() {
return Ok(()); }
}
SignalKind::Pause => {
if snapshot.state.is_completed() {
return Err(BackendError::CannotPause("Completed".to_string()));
}
if snapshot.state.is_failed() {
return Err(BackendError::CannotPause("Failed".to_string()));
}
if snapshot.state.is_cancelled() {
return Err(BackendError::CannotPause("Cancelled".to_string()));
}
if snapshot.state.is_paused() {
return Ok(()); }
}
}
}
let mut signals = self.signals.write().map_err(Self::lock_error)?;
signals
.entry(instance_id.to_string())
.or_default()
.insert(kind, request);
Ok(())
}
async fn get_signal(
&self,
instance_id: &str,
kind: SignalKind,
) -> Result<Option<SignalRequest>, BackendError> {
let signals = self.signals.read().map_err(Self::lock_error)?;
Ok(signals.get(instance_id).and_then(|m| m.get(&kind)).cloned())
}
async fn clear_signal(&self, instance_id: &str, kind: SignalKind) -> Result<(), BackendError> {
let mut signals = self.signals.write().map_err(Self::lock_error)?;
if let Some(inner) = signals.get_mut(instance_id) {
inner.remove(&kind);
if inner.is_empty() {
signals.remove(instance_id);
}
}
Ok(())
}
async fn send_event(
&self,
instance_id: &str,
signal_name: &str,
payload: bytes::Bytes,
) -> Result<(), BackendError> {
let mut events = self.events.write().map_err(Self::lock_error)?;
events
.entry((instance_id.to_string(), signal_name.to_string()))
.or_default()
.push_back(payload);
Ok(())
}
async fn consume_event(
&self,
instance_id: &str,
signal_name: &str,
) -> Result<Option<bytes::Bytes>, BackendError> {
let mut events = self.events.write().map_err(Self::lock_error)?;
let key = (instance_id.to_string(), signal_name.to_string());
let payload = events.get_mut(&key).and_then(VecDeque::pop_front);
if events.get(&key).is_some_and(VecDeque::is_empty) {
events.remove(&key);
}
Ok(payload)
}
async fn check_and_cancel(
&self,
instance_id: &str,
interrupted_at_task: Option<&str>,
) -> Result<bool, BackendError> {
let request = {
let signals = self.signals.read().map_err(Self::lock_error)?;
match signals
.get(instance_id)
.and_then(|m| m.get(&SignalKind::Cancel))
{
Some(req) => req.clone(),
None => return Ok(false),
}
};
{
let mut snapshots = self.snapshots.write().map_err(Self::lock_error)?;
let snapshot = snapshots
.get_mut(instance_id)
.ok_or_else(|| BackendError::NotFound(instance_id.to_string()))?;
if !snapshot.state.is_in_progress() {
return Ok(false);
}
snapshot.mark_cancelled(
request.reason,
request.requested_by,
interrupted_at_task.map(String::from),
);
}
{
let mut signals = self.signals.write().map_err(Self::lock_error)?;
if let Some(inner) = signals.get_mut(instance_id) {
inner.remove(&SignalKind::Cancel);
if inner.is_empty() {
signals.remove(instance_id);
}
}
}
Ok(true)
}
async fn check_and_pause(&self, instance_id: &str) -> Result<bool, BackendError> {
let request = {
let signals = self.signals.read().map_err(Self::lock_error)?;
match signals
.get(instance_id)
.and_then(|m| m.get(&SignalKind::Pause))
{
Some(req) => req.clone(),
None => return Ok(false),
}
};
{
let mut snapshots = self.snapshots.write().map_err(Self::lock_error)?;
let snapshot = snapshots
.get_mut(instance_id)
.ok_or_else(|| BackendError::NotFound(instance_id.to_string()))?;
if !snapshot.state.is_in_progress() {
return Ok(false);
}
let pause_request: PauseRequest = request.into();
snapshot.mark_paused(&pause_request);
}
{
let mut signals = self.signals.write().map_err(Self::lock_error)?;
if let Some(inner) = signals.get_mut(instance_id) {
inner.remove(&SignalKind::Pause);
if inner.is_empty() {
signals.remove(instance_id);
}
}
}
Ok(true)
}
async fn unpause(&self, instance_id: &str) -> Result<WorkflowSnapshot, BackendError> {
let mut snapshots = self.snapshots.write().map_err(Self::lock_error)?;
let snapshot = snapshots
.get_mut(instance_id)
.ok_or_else(|| BackendError::NotFound(instance_id.to_string()))?;
if !snapshot.state.is_paused() {
return Err(BackendError::CannotPause(format!(
"Workflow is not paused (current state: {:?})",
if snapshot.state.is_in_progress() {
"InProgress"
} else if snapshot.state.is_completed() {
"Completed"
} else if snapshot.state.is_failed() {
"Failed"
} else if snapshot.state.is_cancelled() {
"Cancelled"
} else {
"Unknown"
}
)));
}
snapshot.mark_unpaused();
Ok(snapshot.clone())
}
}
impl TaskClaimStore for InMemoryBackend {
async fn claim_task(
&self,
instance_id: &str,
task_id: &str,
worker_id: &str,
ttl: Option<Duration>,
) -> Result<Option<TaskClaim>, BackendError> {
let key = Self::claim_key(instance_id, task_id);
let mut claims = self.claims.write().map_err(Self::lock_error)?;
if let Some(existing_claim) = claims.get(&key) {
if !existing_claim.is_expired() {
return Ok(None); }
claims.remove(&key);
}
let claim = TaskClaim::new(
instance_id.to_string(),
task_id.to_string(),
worker_id.to_string(),
ttl,
);
claims.insert(key, claim.clone());
Ok(Some(claim))
}
async fn release_task_claim(
&self,
instance_id: &str,
task_id: &str,
worker_id: &str,
) -> Result<(), BackendError> {
let key = Self::claim_key(instance_id, task_id);
let mut claims = self.claims.write().map_err(Self::lock_error)?;
if let Some(claim) = claims.get(&key) {
if claim.worker_id != worker_id {
return Err(BackendError::Backend(format!(
"Claim owned by different worker: {}",
claim.worker_id
)));
}
claims.remove(&key);
Ok(())
} else {
Err(BackendError::NotFound(format!(
"{}:{}",
instance_id, task_id
)))
}
}
async fn extend_task_claim(
&self,
instance_id: &str,
task_id: &str,
worker_id: &str,
additional_duration: Duration,
) -> Result<(), BackendError> {
let key = Self::claim_key(instance_id, task_id);
let mut claims = self.claims.write().map_err(Self::lock_error)?;
if let Some(claim) = claims.get_mut(&key) {
if claim.worker_id != worker_id {
return Err(BackendError::Backend(format!(
"Claim owned by different worker: {}",
claim.worker_id
)));
}
if let Some(expires_at) = claim.expires_at {
let expires_datetime = chrono::DateTime::from_timestamp(expires_at as i64, 0)
.ok_or_else(|| BackendError::Backend("Invalid timestamp".to_string()))?;
let new_expiry = expires_datetime
.checked_add_signed(additional_duration)
.ok_or_else(|| BackendError::Backend("Time overflow".to_string()))?;
claim.expires_at = Some(new_expiry.timestamp() as u64);
}
Ok(())
} else {
Err(BackendError::NotFound(format!(
"{}:{}",
instance_id, task_id
)))
}
}
async fn find_available_tasks(
&self,
worker_id: &str,
limit: usize,
) -> Result<Vec<AvailableTask>, BackendError> {
{
let mut claims = self.claims.write().map_err(Self::lock_error)?;
claims.retain(|_, claim| !claim.is_expired());
}
let mut delay_advances: Vec<(String, String)> = Vec::new();
let mut delay_completions: Vec<(String, String)> = Vec::new();
let mut signal_advances: Vec<(String, String, Option<String>)> = Vec::new();
let mut signal_timeout_advances: Vec<(String, String, Option<String>)> = Vec::new();
{
let snapshots = self.snapshots.read().map_err(Self::lock_error)?;
let signals = self.signals.read().map_err(Self::lock_error)?;
let events = self.events.read().map_err(Self::lock_error)?;
for (instance_id, snapshot) in snapshots.iter() {
if !snapshot.state.is_in_progress() {
continue;
}
if signals
.get(instance_id.as_str())
.is_some_and(|m| m.contains_key(&SignalKind::Cancel))
{
continue;
}
if signals
.get(instance_id.as_str())
.is_some_and(|m| m.contains_key(&SignalKind::Pause))
{
continue;
}
match &snapshot.state {
WorkflowSnapshotState::InProgress {
position:
ExecutionPosition::AtDelay {
wake_at,
delay_id,
next_task_id,
..
},
..
} if Utc::now() >= *wake_at => {
if let Some(next_id) = next_task_id {
delay_advances.push((instance_id.clone(), next_id.clone()));
} else {
delay_completions.push((instance_id.clone(), delay_id.clone()));
}
}
WorkflowSnapshotState::InProgress {
position:
ExecutionPosition::AtSignal {
signal_name,
wake_at,
next_task_id,
..
},
..
} => {
let key = (instance_id.clone(), signal_name.clone());
if events.get(&key).is_some_and(|q| !q.is_empty()) {
signal_advances.push((
instance_id.clone(),
signal_name.clone(),
next_task_id.clone(),
));
} else if wake_at.is_some_and(|wt| Utc::now() >= wt) {
signal_timeout_advances.push((
instance_id.clone(),
signal_name.clone(),
next_task_id.clone(),
));
}
}
_ => {}
}
}
}
if !delay_advances.is_empty()
|| !delay_completions.is_empty()
|| !signal_advances.is_empty()
|| !signal_timeout_advances.is_empty()
{
let mut snapshots = self.snapshots.write().map_err(Self::lock_error)?;
for (instance_id, next_task_id) in &delay_advances {
if let Some(snapshot) = snapshots.get_mut(instance_id) {
snapshot.update_position(ExecutionPosition::AtTask {
task_id: next_task_id.clone(),
});
}
}
for (instance_id, delay_id) in &delay_completions {
if let Some(snapshot) = snapshots.get_mut(instance_id) {
let output = snapshot.get_task_result_bytes(delay_id).unwrap_or_default();
snapshot.mark_completed(output);
}
}
{
let mut events = self.events.write().map_err(Self::lock_error)?;
for (instance_id, signal_name, next_task_id) in &signal_advances {
let key = (instance_id.clone(), signal_name.clone());
let payload = events
.get_mut(&key)
.and_then(VecDeque::pop_front)
.unwrap_or_default();
if events.get(&key).is_some_and(VecDeque::is_empty) {
events.remove(&key);
}
if let Some(snapshot) = snapshots.get_mut(instance_id) {
snapshot.mark_task_completed(signal_name.clone(), payload);
if let Some(next_id) = next_task_id {
snapshot.update_position(ExecutionPosition::AtTask {
task_id: next_id.clone(),
});
} else {
let output = snapshot
.get_task_result_bytes(signal_name)
.unwrap_or_default();
snapshot.mark_completed(output);
}
}
}
}
for (instance_id, signal_name, next_task_id) in &signal_timeout_advances {
if let Some(snapshot) = snapshots.get_mut(instance_id) {
snapshot.mark_task_completed(signal_name.clone(), bytes::Bytes::new());
if let Some(next_id) = next_task_id {
snapshot.update_position(ExecutionPosition::AtTask {
task_id: next_id.clone(),
});
} else {
snapshot.mark_completed(bytes::Bytes::new());
}
}
}
}
let snapshots = self.snapshots.read().map_err(Self::lock_error)?;
let claims = self.claims.read().map_err(Self::lock_error)?;
let signals = self.signals.read().map_err(Self::lock_error)?;
let mut available = Vec::new();
for (instance_id, snapshot) in snapshots.iter() {
if !snapshot.state.is_in_progress() {
continue;
}
if let Some(instance_signals) = signals.get(instance_id.as_str())
&& (instance_signals.contains_key(&SignalKind::Cancel)
|| instance_signals.contains_key(&SignalKind::Pause))
{
continue;
}
if let WorkflowSnapshotState::InProgress {
completed_tasks,
position: ExecutionPosition::AtTask { task_id },
..
} = &snapshot.state
{
let claim_key = Self::claim_key(instance_id, task_id);
let is_claimed = claims.contains_key(&claim_key);
let is_completed = completed_tasks.contains_key(task_id);
if !is_completed && !is_claimed {
if let Some(rs) = snapshot.task_retries.get(task_id)
&& Utc::now() < rs.next_retry_at
{
continue;
}
let input = if completed_tasks.is_empty() {
snapshot.initial_input_bytes()
} else {
snapshot.get_last_task_output()
};
if let Some(input_bytes) = input {
available.push(AvailableTask {
instance_id: instance_id.clone(),
task_id: task_id.clone(),
input: input_bytes,
workflow_definition_hash: snapshot.definition_hash.clone(),
});
if available.len() >= limit {
break;
}
}
}
}
}
available.sort_by_key(|t| {
snapshots
.get(&t.instance_id)
.and_then(|s| s.task_retries.get(&t.task_id))
.and_then(|rs| rs.last_failed_worker.as_deref())
.is_some_and(|w| w == worker_id)
});
Ok(available)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backend::SignalStore;
use sayiir_core::snapshot::SignalKind;
#[tokio::test]
async fn test_save_and_load() {
let backend = InMemoryBackend::new();
let snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
backend.save_snapshot(&snapshot).await.unwrap();
let loaded = backend.load_snapshot("test-123").await.unwrap();
assert_eq!(snapshot.instance_id, loaded.instance_id);
assert_eq!(snapshot.definition_hash, loaded.definition_hash);
}
#[tokio::test]
async fn test_not_found() {
let backend = InMemoryBackend::new();
let result = backend.load_snapshot("nonexistent").await;
assert!(matches!(result, Err(BackendError::NotFound(_))));
}
#[tokio::test]
async fn test_delete() {
let backend = InMemoryBackend::new();
let snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
backend.save_snapshot(&snapshot).await.unwrap();
backend.delete_snapshot("test-123").await.unwrap();
let result = backend.load_snapshot("test-123").await;
assert!(matches!(result, Err(BackendError::NotFound(_))));
}
#[tokio::test]
async fn test_list_snapshots() {
let backend = InMemoryBackend::new();
backend
.save_snapshot(&WorkflowSnapshot::new(
"test-1".to_string(),
"hash-1".to_string(),
))
.await
.unwrap();
backend
.save_snapshot(&WorkflowSnapshot::new(
"test-2".to_string(),
"hash-2".to_string(),
))
.await
.unwrap();
let list = backend.list_snapshots().await.unwrap();
assert_eq!(list.len(), 2);
assert!(list.contains(&"test-1".to_string()));
assert!(list.contains(&"test-2".to_string()));
}
#[tokio::test]
async fn test_claim_task_success() {
let backend = InMemoryBackend::new();
let claim = backend
.claim_task(
"workflow-1",
"task-1",
"worker-1",
Some(Duration::seconds(300)),
)
.await
.unwrap();
assert!(claim.is_some());
let claim = claim.unwrap();
assert_eq!(claim.instance_id, "workflow-1");
assert_eq!(claim.task_id, "task-1");
assert_eq!(claim.worker_id, "worker-1");
assert!(claim.expires_at.is_some());
}
#[tokio::test]
async fn test_claim_task_already_claimed() {
let backend = InMemoryBackend::new();
let claim1 = backend
.claim_task(
"workflow-1",
"task-1",
"worker-1",
Some(Duration::seconds(300)),
)
.await
.unwrap();
assert!(claim1.is_some());
let claim2 = backend
.claim_task(
"workflow-1",
"task-1",
"worker-2",
Some(Duration::seconds(300)),
)
.await
.unwrap();
assert!(claim2.is_none());
}
#[tokio::test]
async fn test_claim_task_expired_claim_replaced() {
let backend = InMemoryBackend::new();
let claim1 = backend
.claim_task(
"workflow-1",
"task-1",
"worker-1",
Some(Duration::seconds(0)),
)
.await
.unwrap();
assert!(claim1.is_some());
let claim2 = backend
.claim_task(
"workflow-1",
"task-1",
"worker-2",
Some(Duration::seconds(300)),
)
.await
.unwrap();
assert!(claim2.is_some());
let claim2 = claim2.unwrap();
assert_eq!(claim2.worker_id, "worker-2");
}
#[tokio::test]
async fn test_claim_task_no_ttl() {
let backend = InMemoryBackend::new();
let claim = backend
.claim_task("workflow-1", "task-1", "worker-1", None)
.await
.unwrap();
assert!(claim.is_some());
let claim = claim.unwrap();
assert!(claim.expires_at.is_none());
assert!(!claim.is_expired()); }
#[tokio::test]
async fn test_release_task_claim_success() {
let backend = InMemoryBackend::new();
backend
.claim_task(
"workflow-1",
"task-1",
"worker-1",
Some(Duration::seconds(300)),
)
.await
.unwrap();
let result = backend
.release_task_claim("workflow-1", "task-1", "worker-1")
.await;
assert!(result.is_ok());
let claim = backend
.claim_task(
"workflow-1",
"task-1",
"worker-2",
Some(Duration::seconds(300)),
)
.await
.unwrap();
assert!(claim.is_some());
}
#[tokio::test]
async fn test_release_task_claim_wrong_worker() {
let backend = InMemoryBackend::new();
backend
.claim_task(
"workflow-1",
"task-1",
"worker-1",
Some(Duration::seconds(300)),
)
.await
.unwrap();
let result = backend
.release_task_claim("workflow-1", "task-1", "worker-2")
.await;
assert!(matches!(result, Err(BackendError::Backend(_))));
}
#[tokio::test]
async fn test_release_task_claim_not_found() {
let backend = InMemoryBackend::new();
let result = backend
.release_task_claim("workflow-1", "task-1", "worker-1")
.await;
assert!(matches!(result, Err(BackendError::NotFound(_))));
}
#[tokio::test]
async fn test_extend_task_claim_success() {
let backend = InMemoryBackend::new();
let claim = backend
.claim_task(
"workflow-1",
"task-1",
"worker-1",
Some(Duration::seconds(10)),
)
.await
.unwrap()
.unwrap();
let original_expiry = claim.expires_at.unwrap();
backend
.extend_task_claim("workflow-1", "task-1", "worker-1", Duration::seconds(300))
.await
.unwrap();
let claims = backend.claims.read().unwrap();
let key = InMemoryBackend::claim_key("workflow-1", "task-1");
let extended_claim = claims.get(&key).unwrap();
assert!(extended_claim.expires_at.unwrap() > original_expiry);
}
#[tokio::test]
async fn test_extend_task_claim_wrong_worker() {
let backend = InMemoryBackend::new();
backend
.claim_task(
"workflow-1",
"task-1",
"worker-1",
Some(Duration::seconds(300)),
)
.await
.unwrap();
let result = backend
.extend_task_claim("workflow-1", "task-1", "worker-2", Duration::seconds(300))
.await;
assert!(matches!(result, Err(BackendError::Backend(_))));
}
#[tokio::test]
async fn test_extend_task_claim_not_found() {
let backend = InMemoryBackend::new();
let result = backend
.extend_task_claim("workflow-1", "task-1", "worker-1", Duration::seconds(300))
.await;
assert!(matches!(result, Err(BackendError::NotFound(_))));
}
#[tokio::test]
async fn test_extend_task_claim_no_expiry() {
let backend = InMemoryBackend::new();
backend
.claim_task("workflow-1", "task-1", "worker-1", None)
.await
.unwrap();
backend
.extend_task_claim("workflow-1", "task-1", "worker-1", Duration::seconds(300))
.await
.unwrap();
let claims = backend.claims.read().unwrap();
let key = InMemoryBackend::claim_key("workflow-1", "task-1");
let claim = claims.get(&key).unwrap();
assert!(claim.expires_at.is_none());
}
#[tokio::test]
async fn test_store_signal_cancel_success() {
let backend = InMemoryBackend::new();
let snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
backend.save_snapshot(&snapshot).await.unwrap();
let result = backend
.store_signal(
"test-123",
SignalKind::Cancel,
SignalRequest::new(
Some("User requested".to_string()),
Some("admin".to_string()),
),
)
.await;
assert!(result.is_ok(), "store_signal should succeed");
let stored = backend
.get_signal("test-123", SignalKind::Cancel)
.await
.unwrap();
assert!(stored.is_some(), "cancel signal should be stored");
let stored = stored.unwrap();
assert_eq!(stored.reason, Some("User requested".to_string()));
assert_eq!(stored.requested_by, Some("admin".to_string()));
}
#[tokio::test]
async fn test_store_signal_cancel_not_found() {
let backend = InMemoryBackend::new();
let result = backend
.store_signal(
"nonexistent",
SignalKind::Cancel,
SignalRequest::new(None, None),
)
.await;
assert!(
matches!(result, Err(BackendError::NotFound(_))),
"should return NotFound for non-existent workflow"
);
}
#[tokio::test]
async fn test_store_signal_cancel_completed_workflow() {
let backend = InMemoryBackend::new();
let mut snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
snapshot.mark_completed(bytes::Bytes::from("result"));
backend.save_snapshot(&snapshot).await.unwrap();
let result = backend
.store_signal(
"test-123",
SignalKind::Cancel,
SignalRequest::new(None, None),
)
.await;
assert!(
matches!(result, Err(BackendError::CannotCancel(_))),
"should return CannotCancel for completed workflow"
);
}
#[tokio::test]
async fn test_store_signal_cancel_failed_workflow() {
let backend = InMemoryBackend::new();
let mut snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
snapshot.mark_failed("Some error".to_string());
backend.save_snapshot(&snapshot).await.unwrap();
let result = backend
.store_signal(
"test-123",
SignalKind::Cancel,
SignalRequest::new(None, None),
)
.await;
assert!(
matches!(result, Err(BackendError::CannotCancel(_))),
"should return CannotCancel for failed workflow"
);
}
#[tokio::test]
async fn test_store_signal_cancel_already_cancelled_idempotent() {
let backend = InMemoryBackend::new();
let mut snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
snapshot.mark_cancelled(Some("First cancel".to_string()), None, None);
backend.save_snapshot(&snapshot).await.unwrap();
let result = backend
.store_signal(
"test-123",
SignalKind::Cancel,
SignalRequest::new(Some("Second cancel".to_string()), None),
)
.await;
assert!(
result.is_ok(),
"cancelling already-cancelled workflow should be idempotent"
);
}
#[tokio::test]
async fn test_get_signal_cancel_none() {
let backend = InMemoryBackend::new();
let result = backend
.get_signal("test-123", SignalKind::Cancel)
.await
.unwrap();
assert!(
result.is_none(),
"should return None when no cancellation signal exists"
);
}
#[tokio::test]
async fn test_clear_signal_cancel() {
let backend = InMemoryBackend::new();
let snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
backend.save_snapshot(&snapshot).await.unwrap();
backend
.store_signal(
"test-123",
SignalKind::Cancel,
SignalRequest::new(Some("Test".to_string()), None),
)
.await
.unwrap();
assert!(
backend
.get_signal("test-123", SignalKind::Cancel)
.await
.unwrap()
.is_some(),
"cancel signal should exist before clearing"
);
backend
.clear_signal("test-123", SignalKind::Cancel)
.await
.unwrap();
assert!(
backend
.get_signal("test-123", SignalKind::Cancel)
.await
.unwrap()
.is_none(),
"cancel signal should be gone after clearing"
);
}
#[tokio::test]
async fn test_store_signal_pause_completed_workflow() {
let backend = InMemoryBackend::new();
let mut snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
snapshot.mark_completed(bytes::Bytes::from("result"));
backend.save_snapshot(&snapshot).await.unwrap();
let result = backend
.store_signal(
"test-123",
SignalKind::Pause,
SignalRequest::new(None, None),
)
.await;
assert!(
matches!(result, Err(BackendError::CannotPause(_))),
"should return CannotPause for completed workflow"
);
}
#[tokio::test]
async fn test_store_signal_pause_failed_workflow() {
let backend = InMemoryBackend::new();
let mut snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
snapshot.mark_failed("Some error".to_string());
backend.save_snapshot(&snapshot).await.unwrap();
let result = backend
.store_signal(
"test-123",
SignalKind::Pause,
SignalRequest::new(None, None),
)
.await;
assert!(
matches!(result, Err(BackendError::CannotPause(_))),
"should return CannotPause for failed workflow"
);
}
#[tokio::test]
async fn test_store_signal_pause_cancelled_workflow() {
let backend = InMemoryBackend::new();
let mut snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
snapshot.mark_cancelled(Some("done".to_string()), None, None);
backend.save_snapshot(&snapshot).await.unwrap();
let result = backend
.store_signal(
"test-123",
SignalKind::Pause,
SignalRequest::new(None, None),
)
.await;
assert!(
matches!(result, Err(BackendError::CannotPause(_))),
"should return CannotPause for cancelled workflow"
);
}
#[tokio::test]
async fn test_store_signal_pause_already_paused_idempotent() {
let backend = InMemoryBackend::new();
let mut snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
snapshot.mark_paused(&PauseRequest::new(Some("first".to_string()), None));
backend.save_snapshot(&snapshot).await.unwrap();
let result = backend
.store_signal(
"test-123",
SignalKind::Pause,
SignalRequest::new(Some("second".to_string()), None),
)
.await;
assert!(
result.is_ok(),
"pausing already-paused workflow should be idempotent"
);
}
#[tokio::test]
async fn test_store_signal_pause_not_found() {
let backend = InMemoryBackend::new();
let result = backend
.store_signal(
"nonexistent",
SignalKind::Pause,
SignalRequest::new(None, None),
)
.await;
assert!(
matches!(result, Err(BackendError::NotFound(_))),
"should return NotFound for non-existent workflow"
);
}
#[tokio::test]
async fn test_check_and_cancel_success() {
let backend = InMemoryBackend::new();
let snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
backend.save_snapshot(&snapshot).await.unwrap();
backend
.store_signal(
"test-123",
SignalKind::Cancel,
SignalRequest::new(Some("Timeout".to_string()), Some("system".to_string())),
)
.await
.unwrap();
let result = backend
.check_and_cancel("test-123", Some("task-1"))
.await
.unwrap();
assert!(
result,
"check_and_cancel should return true when cancellation pending"
);
let snapshot = backend.load_snapshot("test-123").await.unwrap();
assert!(
snapshot.state.is_cancelled(),
"workflow should be in cancelled state"
);
let WorkflowSnapshotState::Cancelled {
reason,
cancelled_by,
interrupted_at_task,
..
} = &snapshot.state
else {
panic!("Expected Cancelled state");
};
assert_eq!(reason, &Some("Timeout".to_string()));
assert_eq!(cancelled_by, &Some("system".to_string()));
assert_eq!(interrupted_at_task, &Some("task-1".to_string()));
assert!(
backend
.get_signal("test-123", SignalKind::Cancel)
.await
.unwrap()
.is_none(),
"cancel signal should be cleared after check_and_cancel"
);
}
#[tokio::test]
async fn test_check_and_cancel_no_request() {
let backend = InMemoryBackend::new();
let snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
backend.save_snapshot(&snapshot).await.unwrap();
let result = backend.check_and_cancel("test-123", None).await.unwrap();
assert!(
!result,
"check_and_cancel should return false when no cancellation pending"
);
let snapshot = backend.load_snapshot("test-123").await.unwrap();
assert!(
snapshot.state.is_in_progress(),
"workflow should still be in progress"
);
}
#[tokio::test]
async fn test_check_and_cancel_not_in_progress() {
let backend = InMemoryBackend::new();
let mut snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
snapshot.mark_completed(bytes::Bytes::from("done"));
backend.save_snapshot(&snapshot).await.unwrap();
{
let mut signals = backend.signals.write().unwrap();
signals
.entry("test-123".to_string())
.or_default()
.insert(SignalKind::Cancel, SignalRequest::new(None, None));
}
let result = backend.check_and_cancel("test-123", None).await.unwrap();
assert!(
!result,
"check_and_cancel should return false for non-in-progress workflow"
);
let snapshot = backend.load_snapshot("test-123").await.unwrap();
assert!(
snapshot.state.is_completed(),
"workflow should still be completed"
);
}
#[tokio::test]
async fn test_find_available_tasks_skips_cancelled_workflows() {
let backend = InMemoryBackend::new();
let mut snapshot1 = WorkflowSnapshot::new("workflow-1".to_string(), "hash-abc".to_string());
snapshot1.update_position(ExecutionPosition::AtTask {
task_id: "task-1".to_string(),
});
backend.save_snapshot(&snapshot1).await.unwrap();
let mut snapshot2 = WorkflowSnapshot::new("workflow-2".to_string(), "hash-abc".to_string());
snapshot2.update_position(ExecutionPosition::AtTask {
task_id: "task-2".to_string(),
});
backend.save_snapshot(&snapshot2).await.unwrap();
backend
.store_signal(
"workflow-1",
SignalKind::Cancel,
SignalRequest::new(None, None),
)
.await
.unwrap();
let tasks = backend.find_available_tasks("worker-1", 10).await.unwrap();
assert!(
!tasks.iter().any(|t| t.instance_id == "workflow-1"),
"workflow with pending cancellation should be skipped"
);
}
#[tokio::test]
async fn test_check_and_pause_success() {
let backend = InMemoryBackend::new();
let snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
backend.save_snapshot(&snapshot).await.unwrap();
backend
.store_signal(
"test-123",
SignalKind::Pause,
SignalRequest::new(Some("maintenance".to_string()), Some("ops".to_string())),
)
.await
.unwrap();
let result = backend.check_and_pause("test-123").await.unwrap();
assert!(
result,
"check_and_pause should return true when pause pending"
);
let snapshot = backend.load_snapshot("test-123").await.unwrap();
assert!(snapshot.state.is_paused(), "workflow should be paused");
let WorkflowSnapshotState::Paused {
reason, paused_by, ..
} = &snapshot.state
else {
panic!("Expected Paused state");
};
assert_eq!(reason, &Some("maintenance".to_string()));
assert_eq!(paused_by, &Some("ops".to_string()));
assert!(
backend
.get_signal("test-123", SignalKind::Pause)
.await
.unwrap()
.is_none(),
"pause signal should be cleared after check_and_pause"
);
}
#[tokio::test]
async fn test_check_and_pause_no_request() {
let backend = InMemoryBackend::new();
let snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
backend.save_snapshot(&snapshot).await.unwrap();
let result = backend.check_and_pause("test-123").await.unwrap();
assert!(
!result,
"check_and_pause should return false when no pause pending"
);
let snapshot = backend.load_snapshot("test-123").await.unwrap();
assert!(
snapshot.state.is_in_progress(),
"workflow should still be in progress"
);
}
#[tokio::test]
async fn test_check_and_pause_not_in_progress() {
let backend = InMemoryBackend::new();
let mut snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
snapshot.mark_completed(bytes::Bytes::from("done"));
backend.save_snapshot(&snapshot).await.unwrap();
{
let mut signals = backend.signals.write().unwrap();
signals
.entry("test-123".to_string())
.or_default()
.insert(SignalKind::Pause, SignalRequest::new(None, None));
}
let result = backend.check_and_pause("test-123").await.unwrap();
assert!(
!result,
"check_and_pause should return false for non-in-progress workflow"
);
let snapshot = backend.load_snapshot("test-123").await.unwrap();
assert!(
snapshot.state.is_completed(),
"workflow should still be completed"
);
}
#[tokio::test]
async fn test_check_and_pause_preserves_position() {
let backend = InMemoryBackend::new();
let mut snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
snapshot.update_position(ExecutionPosition::AtTask {
task_id: "task-3".to_string(),
});
snapshot.mark_task_completed("task-1".to_string(), bytes::Bytes::from("out1"));
snapshot.mark_task_completed("task-2".to_string(), bytes::Bytes::from("out2"));
backend.save_snapshot(&snapshot).await.unwrap();
backend
.store_signal(
"test-123",
SignalKind::Pause,
SignalRequest::new(None, None),
)
.await
.unwrap();
backend.check_and_pause("test-123").await.unwrap();
let snapshot = backend.load_snapshot("test-123").await.unwrap();
let WorkflowSnapshotState::Paused {
completed_tasks,
position,
last_completed_task_id,
..
} = &snapshot.state
else {
panic!("Expected Paused state");
};
assert_eq!(completed_tasks.len(), 2);
assert!(completed_tasks.contains_key("task-1"));
assert!(completed_tasks.contains_key("task-2"));
assert!(matches!(
position,
ExecutionPosition::AtTask { task_id } if task_id == "task-3"
));
assert_eq!(last_completed_task_id, &Some("task-2".to_string()));
}
#[tokio::test]
async fn test_unpause_success() {
let backend = InMemoryBackend::new();
let mut snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
snapshot.update_position(ExecutionPosition::AtTask {
task_id: "task-2".to_string(),
});
snapshot.mark_task_completed("task-1".to_string(), bytes::Bytes::from("out1"));
snapshot.mark_paused(&PauseRequest::new(
Some("maintenance".to_string()),
Some("ops".to_string()),
));
backend.save_snapshot(&snapshot).await.unwrap();
let result = backend.unpause("test-123").await.unwrap();
assert!(
result.state.is_in_progress(),
"unpaused workflow should be in progress"
);
let WorkflowSnapshotState::InProgress {
position,
completed_tasks,
last_completed_task_id,
} = &result.state
else {
panic!("Expected InProgress state");
};
assert!(matches!(
position,
ExecutionPosition::AtTask { task_id } if task_id == "task-2"
));
assert!(completed_tasks.contains_key("task-1"));
assert_eq!(last_completed_task_id, &Some("task-1".to_string()));
let loaded = backend.load_snapshot("test-123").await.unwrap();
assert!(loaded.state.is_in_progress());
}
#[tokio::test]
async fn test_unpause_not_paused_errors() {
let backend = InMemoryBackend::new();
let snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
backend.save_snapshot(&snapshot).await.unwrap();
let result = backend.unpause("test-123").await;
assert!(
matches!(result, Err(BackendError::CannotPause(_))),
"unpause on in-progress workflow should error"
);
}
#[tokio::test]
async fn test_unpause_completed_errors() {
let backend = InMemoryBackend::new();
let mut snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
snapshot.mark_completed(bytes::Bytes::from("done"));
backend.save_snapshot(&snapshot).await.unwrap();
let result = backend.unpause("test-123").await;
assert!(
matches!(result, Err(BackendError::CannotPause(_))),
"unpause on completed workflow should error"
);
}
#[tokio::test]
async fn test_unpause_not_found() {
let backend = InMemoryBackend::new();
let result = backend.unpause("nonexistent").await;
assert!(matches!(result, Err(BackendError::NotFound(_))));
}
#[tokio::test]
async fn test_cancel_and_pause_simultaneously_cancel_wins() {
let backend = InMemoryBackend::new();
let snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
backend.save_snapshot(&snapshot).await.unwrap();
backend
.store_signal(
"test-123",
SignalKind::Cancel,
SignalRequest::new(Some("cancel reason".to_string()), None),
)
.await
.unwrap();
backend
.store_signal(
"test-123",
SignalKind::Pause,
SignalRequest::new(Some("pause reason".to_string()), None),
)
.await
.unwrap();
let cancelled = backend
.check_and_cancel("test-123", Some("task-1"))
.await
.unwrap();
assert!(cancelled, "cancel should succeed");
let paused = backend.check_and_pause("test-123").await.unwrap();
assert!(
!paused,
"pause should return false since workflow is already cancelled"
);
let snapshot = backend.load_snapshot("test-123").await.unwrap();
assert!(snapshot.state.is_cancelled());
}
#[tokio::test]
async fn test_cancel_signal_independent_of_pause_signal() {
let backend = InMemoryBackend::new();
let snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
backend.save_snapshot(&snapshot).await.unwrap();
backend
.store_signal(
"test-123",
SignalKind::Cancel,
SignalRequest::new(Some("cancel".to_string()), None),
)
.await
.unwrap();
backend
.store_signal(
"test-123",
SignalKind::Pause,
SignalRequest::new(Some("pause".to_string()), None),
)
.await
.unwrap();
backend
.clear_signal("test-123", SignalKind::Cancel)
.await
.unwrap();
assert!(
backend
.get_signal("test-123", SignalKind::Cancel)
.await
.unwrap()
.is_none()
);
assert!(
backend
.get_signal("test-123", SignalKind::Pause)
.await
.unwrap()
.is_some()
);
}
#[tokio::test]
async fn test_find_available_tasks_skips_paused_workflows() {
let backend = InMemoryBackend::new();
let mut snapshot1 = WorkflowSnapshot::with_initial_input(
"workflow-1".to_string(),
"hash-abc".to_string(),
bytes::Bytes::from(vec![1]),
);
snapshot1.update_position(ExecutionPosition::AtTask {
task_id: "task-1".to_string(),
});
backend.save_snapshot(&snapshot1).await.unwrap();
let mut snapshot2 = WorkflowSnapshot::with_initial_input(
"workflow-2".to_string(),
"hash-abc".to_string(),
bytes::Bytes::from(vec![2]),
);
snapshot2.update_position(ExecutionPosition::AtTask {
task_id: "task-2".to_string(),
});
backend.save_snapshot(&snapshot2).await.unwrap();
backend
.store_signal(
"workflow-1",
SignalKind::Pause,
SignalRequest::new(None, None),
)
.await
.unwrap();
let tasks = backend.find_available_tasks("worker-1", 10).await.unwrap();
assert!(
!tasks.iter().any(|t| t.instance_id == "workflow-1"),
"workflow with pending pause should be skipped"
);
assert!(
tasks.iter().any(|t| t.instance_id == "workflow-2"),
"workflow without signals should be available"
);
}
#[tokio::test]
async fn test_delete_snapshot_leaves_orphaned_signals() {
let backend = InMemoryBackend::new();
let snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
backend.save_snapshot(&snapshot).await.unwrap();
backend
.store_signal(
"test-123",
SignalKind::Cancel,
SignalRequest::new(Some("reason".to_string()), None),
)
.await
.unwrap();
backend.delete_snapshot("test-123").await.unwrap();
let signal = backend
.get_signal("test-123", SignalKind::Cancel)
.await
.unwrap();
assert!(
signal.is_some(),
"signal persists after snapshot deletion (orphaned)"
);
}
#[tokio::test]
async fn test_store_signal_overwrites_previous() {
let backend = InMemoryBackend::new();
let snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
backend.save_snapshot(&snapshot).await.unwrap();
backend
.store_signal(
"test-123",
SignalKind::Cancel,
SignalRequest::new(Some("first".to_string()), None),
)
.await
.unwrap();
backend
.store_signal(
"test-123",
SignalKind::Cancel,
SignalRequest::new(Some("second".to_string()), None),
)
.await
.unwrap();
let signal = backend
.get_signal("test-123", SignalKind::Cancel)
.await
.unwrap()
.unwrap();
assert_eq!(
signal.reason,
Some("second".to_string()),
"latest signal should overwrite previous"
);
}
#[tokio::test]
async fn test_find_available_tasks_skips_unexpired_delay() {
let backend = InMemoryBackend::new();
let mut snapshot = WorkflowSnapshot::with_initial_input(
"workflow-1".to_string(),
"hash-abc".to_string(),
bytes::Bytes::from(vec![42]),
);
let wake_at = Utc::now() + chrono::Duration::hours(1);
snapshot.update_position(ExecutionPosition::AtDelay {
delay_id: "wait_1h".to_string(),
entered_at: Utc::now(),
wake_at,
next_task_id: Some("next_step".to_string()),
});
snapshot.mark_task_completed("wait_1h".to_string(), bytes::Bytes::from(vec![42]));
backend.save_snapshot(&snapshot).await.unwrap();
let tasks = backend.find_available_tasks("worker-1", 10).await.unwrap();
assert!(
tasks.is_empty(),
"workflow at unexpired delay should not appear in available tasks"
);
}
#[tokio::test]
async fn test_find_available_tasks_advances_expired_delay() {
let backend = InMemoryBackend::new();
let mut snapshot = WorkflowSnapshot::with_initial_input(
"workflow-1".to_string(),
"hash-abc".to_string(),
bytes::Bytes::from(vec![42]),
);
let wake_at = Utc::now() - chrono::Duration::seconds(1);
snapshot.update_position(ExecutionPosition::AtDelay {
delay_id: "wait_done".to_string(),
entered_at: Utc::now() - chrono::Duration::seconds(2),
wake_at,
next_task_id: Some("process".to_string()),
});
snapshot.mark_task_completed("wait_done".to_string(), bytes::Bytes::from(vec![42]));
backend.save_snapshot(&snapshot).await.unwrap();
let tasks = backend.find_available_tasks("worker-1", 10).await.unwrap();
assert_eq!(tasks.len(), 1);
assert_eq!(tasks[0].instance_id, "workflow-1");
assert_eq!(tasks[0].task_id, "process");
let loaded = backend.load_snapshot("workflow-1").await.unwrap();
match &loaded.state {
WorkflowSnapshotState::InProgress {
position: ExecutionPosition::AtTask { task_id },
..
} => {
assert_eq!(task_id, "process");
}
other => panic!("Expected AtTask position, got {other:?}"),
}
}
#[tokio::test]
async fn test_find_available_tasks_completes_expired_delay_last_node() {
let backend = InMemoryBackend::new();
let mut snapshot = WorkflowSnapshot::with_initial_input(
"workflow-1".to_string(),
"hash-abc".to_string(),
bytes::Bytes::from(vec![42]),
);
let wake_at = Utc::now() - chrono::Duration::seconds(1);
snapshot.update_position(ExecutionPosition::AtDelay {
delay_id: "final_wait".to_string(),
entered_at: Utc::now() - chrono::Duration::seconds(2),
wake_at,
next_task_id: None,
});
snapshot.mark_task_completed("final_wait".to_string(), bytes::Bytes::from(vec![42]));
backend.save_snapshot(&snapshot).await.unwrap();
let tasks = backend.find_available_tasks("worker-1", 10).await.unwrap();
assert!(
tasks.is_empty(),
"completed workflow should not appear in available tasks"
);
let loaded = backend.load_snapshot("workflow-1").await.unwrap();
assert!(
loaded.state.is_completed(),
"workflow should be completed when delay is last node and expired"
);
}
#[tokio::test]
async fn test_send_and_consume_event_fifo() {
let backend = InMemoryBackend::new();
backend
.send_event("wf-1", "approval", bytes::Bytes::from("first"))
.await
.unwrap();
backend
.send_event("wf-1", "approval", bytes::Bytes::from("second"))
.await
.unwrap();
let first = backend.consume_event("wf-1", "approval").await.unwrap();
assert_eq!(first.as_deref(), Some(b"first".as_slice()));
let second = backend.consume_event("wf-1", "approval").await.unwrap();
assert_eq!(second.as_deref(), Some(b"second".as_slice()));
let none = backend.consume_event("wf-1", "approval").await.unwrap();
assert!(none.is_none());
}
#[tokio::test]
async fn test_consume_event_empty_returns_none() {
let backend = InMemoryBackend::new();
let result = backend.consume_event("wf-1", "nonexistent").await.unwrap();
assert!(result.is_none());
}
#[tokio::test]
async fn test_events_are_isolated_by_signal_name() {
let backend = InMemoryBackend::new();
backend
.send_event("wf-1", "sig_a", bytes::Bytes::from("a_payload"))
.await
.unwrap();
backend
.send_event("wf-1", "sig_b", bytes::Bytes::from("b_payload"))
.await
.unwrap();
let a = backend.consume_event("wf-1", "sig_a").await.unwrap();
assert_eq!(a.as_deref(), Some(b"a_payload".as_slice()));
let b = backend.consume_event("wf-1", "sig_b").await.unwrap();
assert_eq!(b.as_deref(), Some(b"b_payload".as_slice()));
}
}