use crate::provider_validation::{ExecutionMetadata, start_item};
use crate::provider_validations::ProviderFactory;
use crate::providers::WorkItem;
use std::sync::Arc;
use std::time::Duration;
pub async fn test_exclusive_instance_lock<F: ProviderFactory>(factory: &F) {
tracing::info!("→ Testing instance locking: exclusive lock acquisition");
let provider = factory.create_provider().await;
let lock_timeout = factory.lock_timeout();
provider
.enqueue_for_orchestrator(start_item("instance-A"), None)
.await
.unwrap();
let (_item1, lock_token1, _attempt_count1) = provider
.fetch_orchestration_item(lock_timeout, Duration::ZERO, None)
.await
.unwrap()
.unwrap();
assert!(
provider
.fetch_orchestration_item(lock_timeout, Duration::ZERO, None)
.await
.unwrap()
.is_none()
);
tokio::time::sleep(lock_timeout + Duration::from_millis(100)).await;
let (_item2, lock_token2, _attempt_count2) = provider
.fetch_orchestration_item(lock_timeout, Duration::ZERO, None)
.await
.unwrap()
.unwrap();
assert_ne!(lock_token2, lock_token1);
tracing::info!("✓ Test passed: exclusive lock verified");
}
pub async fn test_lock_token_uniqueness<F: ProviderFactory>(factory: &F) {
tracing::info!("→ Testing instance locking: lock token uniqueness");
let provider = factory.create_provider().await;
for i in 0..5 {
provider
.enqueue_for_orchestrator(start_item(&format!("inst-{i}")), None)
.await
.unwrap();
}
let mut tokens = Vec::new();
for _ in 0..5 {
let (_item, lock_token, _attempt_count) = provider
.fetch_orchestration_item(Duration::from_secs(30), Duration::ZERO, None)
.await
.unwrap()
.unwrap();
tokens.push(lock_token);
}
let unique_tokens: std::collections::HashSet<_> = tokens.iter().collect();
assert_eq!(unique_tokens.len(), 5, "All lock tokens should be unique");
tracing::info!("✓ Test passed: lock token uniqueness verified");
}
pub async fn test_invalid_lock_token_rejection<F: ProviderFactory>(factory: &F) {
tracing::info!("→ Testing instance locking: invalid lock token rejection");
let provider = factory.create_provider().await;
provider
.enqueue_for_orchestrator(start_item("instance-A"), None)
.await
.unwrap();
let _item = provider
.fetch_orchestration_item(Duration::from_secs(30), Duration::ZERO, None)
.await
.unwrap()
.unwrap();
let result = provider
.ack_orchestration_item(
"invalid-token",
1,
vec![],
vec![],
vec![],
ExecutionMetadata::default(),
vec![],
)
.await;
assert!(result.is_err(), "Should reject invalid lock token");
let result = provider.abandon_orchestration_item("invalid-token", None, false).await;
assert!(result.is_err(), "Should reject invalid lock token for abandon");
assert!(
provider
.fetch_orchestration_item(Duration::from_secs(30), Duration::ZERO, None)
.await
.unwrap()
.is_none()
);
tracing::info!("✓ Test passed: invalid lock token rejection verified");
}
pub async fn test_concurrent_instance_fetching<F: ProviderFactory>(factory: &F) {
tracing::info!("→ Testing instance locking: concurrent fetch attempts");
let provider = Arc::new(factory.create_provider().await);
for i in 0..10 {
provider
.enqueue_for_orchestrator(start_item(&format!("inst-{i}")), None)
.await
.unwrap();
}
let handles: Vec<_> = (0..10)
.map(|i| {
let p = provider.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(i * 30)).await;
p.fetch_orchestration_item(Duration::from_secs(30), Duration::ZERO, None)
.await
.unwrap()
})
})
.collect();
let results: Vec<_> = futures::future::join_all(handles)
.await
.into_iter()
.map(|r| r.unwrap())
.collect();
let instances: std::collections::HashSet<_> = results
.iter()
.filter_map(|r| r.as_ref())
.map(|(item, _lock_token, _attempt_count)| item.instance.clone())
.collect();
assert_eq!(instances.len(), 10, "Each instance should be fetched exactly once");
tracing::info!("✓ Test passed: concurrent fetching verified");
}
pub async fn test_completions_arriving_during_lock_blocked<F: ProviderFactory>(factory: &F) {
tracing::info!("→ Testing instance locking: completions arriving during lock blocked");
let provider = Arc::new(factory.create_provider().await);
let lock_timeout = factory.lock_timeout();
provider
.enqueue_for_orchestrator(start_item("instance-A"), None)
.await
.unwrap();
let (item1, _lock_token, _attempt_count) = provider
.fetch_orchestration_item(lock_timeout, Duration::ZERO, None)
.await
.unwrap()
.unwrap();
assert_eq!(item1.instance, "instance-A");
for i in 1..=3 {
provider
.enqueue_for_orchestrator(
WorkItem::ActivityCompleted {
instance: "instance-A".to_string(),
execution_id: 1,
id: i,
result: format!("result-{i}"),
},
None,
)
.await
.unwrap();
}
let item2 = provider
.fetch_orchestration_item(lock_timeout, Duration::ZERO, None)
.await
.unwrap();
assert!(item2.is_none(), "Instance still locked, no fetch possible");
tokio::time::sleep(lock_timeout + Duration::from_millis(100)).await;
let (item3, _lock_token3, _attempt_count3) = provider
.fetch_orchestration_item(lock_timeout, Duration::ZERO, None)
.await
.unwrap()
.unwrap();
assert_eq!(item3.instance, "instance-A");
assert_eq!(
item3.messages.len(),
4,
"Should have StartOrchestration + 3 completions"
);
let activity_completions: Vec<_> = item3
.messages
.iter()
.filter_map(|msg| match msg {
WorkItem::ActivityCompleted { id, .. } => Some(*id),
_ => None,
})
.collect();
assert_eq!(
activity_completions.len(),
3,
"Should have 3 ActivityCompleted messages"
);
assert_eq!(activity_completions, vec![1, 2, 3]);
tracing::info!("✓ Test passed: completions during lock blocked verified");
}
pub async fn test_cross_instance_lock_isolation<F: ProviderFactory>(factory: &F) {
tracing::info!("→ Testing instance locking: cross-instance lock isolation");
let provider = Arc::new(factory.create_provider().await);
crate::provider_validation::create_instance((*provider).as_ref(), "instance-A")
.await
.unwrap();
crate::provider_validation::create_instance((*provider).as_ref(), "instance-B")
.await
.unwrap();
provider
.enqueue_for_orchestrator(start_item("instance-A"), None)
.await
.unwrap();
provider
.enqueue_for_orchestrator(start_item("instance-B"), None)
.await
.unwrap();
let (item_a, _lock_token_a, _attempt_count_a) = provider
.fetch_orchestration_item(Duration::from_secs(30), Duration::ZERO, None)
.await
.unwrap()
.unwrap();
assert_eq!(item_a.instance, "instance-A");
let (item_b, lock_token_b, _attempt_count_b) = provider
.fetch_orchestration_item(Duration::from_secs(30), Duration::ZERO, None)
.await
.unwrap()
.unwrap();
assert_eq!(item_b.instance, "instance-B");
provider
.ack_orchestration_item(
&lock_token_b,
1,
vec![],
vec![],
vec![],
ExecutionMetadata {
orchestration_name: Some("TestOrch".to_string()),
orchestration_version: Some("1.0.0".to_string()),
..Default::default()
},
vec![],
)
.await
.unwrap();
provider
.enqueue_for_orchestrator(
WorkItem::ActivityCompleted {
instance: "instance-B".to_string(),
execution_id: 1,
id: 1,
result: "done".to_string(),
},
None,
)
.await
.unwrap();
let (item_b2, _lock_token_b2, _attempt_count_b2) = provider
.fetch_orchestration_item(Duration::from_secs(30), Duration::ZERO, None)
.await
.unwrap()
.unwrap();
assert_eq!(item_b2.instance, "instance-B");
assert_ne!(item_a.instance, item_b.instance);
tracing::info!("✓ Test passed: cross-instance lock isolation verified");
}
pub async fn test_message_tagging_during_lock<F: ProviderFactory>(factory: &F) {
tracing::info!("→ Testing instance locking: message tagging during lock");
let provider = Arc::new(factory.create_provider().await);
crate::provider_validation::create_instance((*provider).as_ref(), "instance-A")
.await
.unwrap();
provider
.enqueue_for_orchestrator(
WorkItem::ActivityCompleted {
instance: "instance-A".to_string(),
execution_id: 1,
id: 1,
result: "msg1".to_string(),
},
None,
)
.await
.unwrap();
provider
.enqueue_for_orchestrator(
WorkItem::ActivityCompleted {
instance: "instance-A".to_string(),
execution_id: 1,
id: 2,
result: "msg2".to_string(),
},
None,
)
.await
.unwrap();
let (item, lock_token, _attempt_count) = provider
.fetch_orchestration_item(Duration::from_secs(30), Duration::ZERO, None)
.await
.unwrap()
.unwrap();
assert_eq!(item.instance, "instance-A");
assert_eq!(item.messages.len(), 2);
provider
.enqueue_for_orchestrator(
WorkItem::ActivityCompleted {
instance: "instance-A".to_string(),
execution_id: 1,
id: 3,
result: "msg3".to_string(),
},
None,
)
.await
.unwrap();
provider
.ack_orchestration_item(
&lock_token,
1,
vec![],
vec![],
vec![],
ExecutionMetadata::default(),
vec![],
)
.await
.unwrap();
let (item2, _lock_token2, _attempt_count2) = provider
.fetch_orchestration_item(Duration::from_secs(30), Duration::ZERO, None)
.await
.unwrap()
.unwrap();
assert_eq!(item2.instance, "instance-A");
assert_eq!(item2.messages.len(), 1);
assert!(matches!(&item2.messages[0], WorkItem::ActivityCompleted { id: 3, .. }));
tracing::info!("✓ Test passed: message tagging during lock verified");
}
pub async fn test_ack_only_affects_locked_messages<F: ProviderFactory>(factory: &F) {
tracing::info!("→ Testing instance locking: ack only affects locked messages");
let provider = Arc::new(factory.create_provider().await);
crate::provider_validation::create_instance((*provider).as_ref(), "instance-A")
.await
.unwrap();
provider
.enqueue_for_orchestrator(
WorkItem::ActivityCompleted {
instance: "instance-A".to_string(),
execution_id: 1,
id: 1,
result: "msg1".to_string(),
},
None,
)
.await
.unwrap();
let (item1, lock_token, _attempt_count) = provider
.fetch_orchestration_item(Duration::from_secs(30), Duration::ZERO, None)
.await
.unwrap()
.unwrap();
assert_eq!(item1.messages.len(), 1);
provider
.enqueue_for_orchestrator(
WorkItem::ActivityCompleted {
instance: "instance-A".to_string(),
execution_id: 1,
id: 2,
result: "msg2".to_string(),
},
None,
)
.await
.unwrap();
provider
.enqueue_for_orchestrator(
WorkItem::ActivityCompleted {
instance: "instance-A".to_string(),
execution_id: 1,
id: 3,
result: "msg3".to_string(),
},
None,
)
.await
.unwrap();
assert!(
provider
.fetch_orchestration_item(Duration::from_secs(30), Duration::ZERO, None)
.await
.unwrap()
.is_none()
);
provider
.ack_orchestration_item(
&lock_token,
1,
vec![],
vec![],
vec![],
ExecutionMetadata::default(),
vec![],
)
.await
.unwrap();
let (item2, _lock_token2, _attempt_count2) = provider
.fetch_orchestration_item(Duration::from_secs(30), Duration::ZERO, None)
.await
.unwrap()
.unwrap();
assert_eq!(item2.instance, "instance-A");
assert_eq!(item2.messages.len(), 2, "Should have messages 2 and 3");
let ids: Vec<u64> = item2
.messages
.iter()
.filter_map(|msg| match msg {
WorkItem::ActivityCompleted { id, .. } => Some(*id),
_ => None,
})
.collect();
assert_eq!(ids, vec![2, 3], "Should only have messages 2 and 3");
tracing::info!("✓ Test passed: ack only affects locked messages verified");
}
pub async fn test_multi_threaded_lock_contention<F: ProviderFactory>(factory: &F) {
tracing::info!("→ Testing instance locking: multi-threaded lock contention");
let provider = Arc::new(factory.create_provider().await);
provider
.enqueue_for_orchestrator(start_item("contention-instance"), None)
.await
.unwrap();
let num_threads = 10;
let handles: Vec<_> = (0..num_threads)
.map(|i| {
let p = provider.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(i * 5)).await;
let result = p
.fetch_orchestration_item(Duration::from_secs(30), Duration::ZERO, None)
.await
.unwrap();
(i, result)
})
})
.collect();
let results: Vec<_> = futures::future::join_all(handles)
.await
.into_iter()
.map(|r| r.unwrap())
.collect();
let successful_fetches: Vec<_> = results
.iter()
.filter_map(|(thread_id, result)| {
result
.as_ref()
.map(|(item, lock_token, _attempt_count)| (*thread_id, item.instance.clone(), lock_token.clone()))
})
.collect();
assert_eq!(
successful_fetches.len(),
1,
"Only one thread should successfully acquire lock for the same instance"
);
let (winner_thread, winner_instance, _winner_token) = &successful_fetches[0];
assert_eq!(winner_instance, "contention-instance");
assert!(
provider
.fetch_orchestration_item(Duration::from_secs(30), Duration::ZERO, None)
.await
.unwrap()
.is_none()
);
tracing::info!(
"✓ Test passed: multi-threaded lock contention verified (thread {} won)",
winner_thread
);
}
pub async fn test_multi_threaded_no_duplicate_processing<F: ProviderFactory>(factory: &F) {
tracing::info!("→ Testing instance locking: multi-threaded no duplicate processing");
let provider = Arc::new(factory.create_provider().await);
let num_instances: usize = 20;
for i in 0..num_instances {
provider
.enqueue_for_orchestrator(start_item(&format!("dup-test-{i}")), None)
.await
.unwrap();
}
let num_threads = num_instances * 2; let handles: Vec<_> = (0..num_threads)
.map(|i| {
let p = provider.clone();
tokio::spawn(async move {
let delay = (i * 3) % 50; tokio::time::sleep(Duration::from_millis(delay as u64)).await;
for attempt in 0..3 {
match p
.fetch_orchestration_item(Duration::from_secs(30), Duration::ZERO, None)
.await
{
Ok(item) => return Ok(item.map(|(i, _lock_token, _attempt_count)| i.instance.clone())),
Err(e) if e.retryable && attempt < 2 => {
tokio::time::sleep(Duration::from_millis(10 * (attempt + 1) as u64)).await;
continue;
}
Err(e) => return Err(e),
}
}
unreachable!()
})
})
.collect();
let results: Vec<_> = futures::future::join_all(handles)
.await
.into_iter()
.filter_map(|r| r.unwrap().ok().flatten())
.collect();
let fetched_instances: std::collections::HashSet<_> = results.iter().collect();
assert_eq!(
fetched_instances.len(),
results.len(),
"No duplicate instances should be fetched"
);
assert!(results.len() <= num_instances, "Cannot fetch more instances than exist");
assert_eq!(
fetched_instances.len(),
results.len(),
"All fetched instances should be unique"
);
tracing::info!(
"✓ Test passed: multi-threaded no duplicate processing verified ({} instances fetched by {} threads)",
fetched_instances.len(),
num_threads
);
}
pub async fn test_multi_threaded_lock_expiration_recovery<F: ProviderFactory>(factory: &F) {
tracing::info!("→ Testing instance locking: multi-threaded lock expiration recovery");
let provider = Arc::new(factory.create_provider().await);
provider
.enqueue_for_orchestrator(start_item("expiration-instance"), None)
.await
.unwrap();
let lock_timeout = factory.lock_timeout();
let barrier = Arc::new(tokio::sync::Barrier::new(3));
let provider1 = provider.clone();
let barrier1 = barrier.clone();
let handle1 = tokio::spawn(async move {
barrier1.wait().await;
let (item, lock_token, _attempt_count) = provider1
.fetch_orchestration_item(lock_timeout, Duration::ZERO, None)
.await
.unwrap()
.unwrap();
assert_eq!(item.instance, "expiration-instance");
tokio::time::sleep(lock_timeout + Duration::from_millis(200)).await;
lock_token
});
let provider2 = provider.clone();
let barrier2 = barrier.clone();
let handle2 = tokio::spawn(async move {
barrier2.wait().await;
tokio::time::sleep(Duration::from_millis(200)).await;
let result = provider2
.fetch_orchestration_item(lock_timeout, Duration::ZERO, None)
.await
.unwrap();
assert!(result.is_none(), "Instance should be locked");
result
});
let provider3 = provider.clone();
let barrier3 = barrier.clone();
let handle3 = tokio::spawn(async move {
barrier3.wait().await;
tokio::time::sleep(lock_timeout + Duration::from_millis(100)).await;
provider3
.fetch_orchestration_item(lock_timeout, Duration::ZERO, None)
.await
.unwrap()
});
let (lock_token1, result2, result3) = futures::future::join3(handle1, handle2, handle3).await;
let _lock_token1 = lock_token1.unwrap();
assert!(result2.unwrap().is_none());
let (item3, lock_token3, _attempt_count3) = result3.unwrap().unwrap();
assert_eq!(item3.instance, "expiration-instance");
assert_ne!(lock_token3, "expired-token");
tracing::info!("✓ Test passed: multi-threaded lock expiration recovery verified");
}