use anyhow::bail;
use parking_lot::RwLock;
use rand::seq::SliceRandom;
use std::{
collections::{
HashMap,
hash_map::Entry::{Occupied, Vacant},
},
sync::Arc,
};
use temporalio_common::{
protos::temporal::api::{
worker::v1::WorkerHeartbeat, workflowservice::v1::PollWorkflowTaskQueueResponse,
},
worker::{WorkerDeploymentOptions, WorkerTaskTypes},
};
use uuid::Uuid;
#[cfg_attr(test, mockall::automock)]
pub trait Slot {
fn schedule_wft(
self: Box<Self>,
task: PollWorkflowTaskQueueResponse,
) -> Result<(), anyhow::Error>;
}
pub(crate) struct SlotReservation {
pub slot: Box<dyn Slot + Send>,
pub deployment_options: Option<WorkerDeploymentOptions>,
}
#[derive(PartialEq, Eq, Hash, Debug, Clone)]
struct SlotKey {
namespace: String,
task_queue: String,
}
impl SlotKey {
fn new(namespace: String, task_queue: String) -> SlotKey {
SlotKey {
namespace,
task_queue,
}
}
}
#[derive(Debug, Clone)]
struct RegisteredWorkerInfo {
worker_id: Uuid,
build_id: Option<String>,
task_types: WorkerTaskTypes,
}
impl RegisteredWorkerInfo {
fn new(worker_id: Uuid, build_id: Option<String>, task_types: WorkerTaskTypes) -> Self {
Self {
worker_id,
build_id,
task_types,
}
}
}
struct ClientWorkerSetImpl {
slot_providers: HashMap<SlotKey, Vec<RegisteredWorkerInfo>>,
all_workers: HashMap<Uuid, Arc<dyn ClientWorker + Send + Sync>>,
shared_worker: HashMap<String, Box<dyn SharedNamespaceWorkerTrait + Send + Sync>>,
}
impl ClientWorkerSetImpl {
fn new() -> Self {
Self {
slot_providers: Default::default(),
all_workers: Default::default(),
shared_worker: Default::default(),
}
}
fn try_reserve_wft_slot(
&self,
namespace: String,
task_queue: String,
) -> Option<SlotReservation> {
let key = SlotKey::new(namespace, task_queue);
if let Some(worker_list) = self.slot_providers.get(&key) {
let workflow_workers: Vec<&RegisteredWorkerInfo> = worker_list
.iter()
.filter(|info| info.task_types.enable_workflows)
.collect();
for worker_id in Self::worker_ids_in_selection_order(&workflow_workers) {
if let Some(worker) = self.all_workers.get(&worker_id)
&& let Some(slot) = worker.try_reserve_wft_slot()
{
let deployment_options = worker.deployment_options();
return Some(SlotReservation {
slot,
deployment_options,
});
}
}
}
None
}
fn worker_ids_in_selection_order(worker_list: &[&RegisteredWorkerInfo]) -> Vec<Uuid> {
if cfg!(test) {
worker_list.iter().map(|info| info.worker_id).collect()
} else {
let mut rng = rand::rng();
let mut shuffled: Vec<_> = worker_list.to_vec();
shuffled.shuffle(&mut rng);
shuffled.iter().map(|info| info.worker_id).collect()
}
}
fn register(
&mut self,
worker: Arc<dyn ClientWorker + Send + Sync>,
skip_client_worker_set_check: bool,
) -> Result<(), anyhow::Error> {
let slot_key = SlotKey::new(
worker.namespace().to_string(),
worker.task_queue().to_string(),
);
let build_id = worker
.deployment_options()
.map(|opts| opts.version.build_id);
let task_types = worker.worker_task_types();
if !task_types.enable_workflows
&& !task_types.enable_local_activities
&& !task_types.enable_remote_activities
&& !task_types.enable_nexus
{
bail!(
"Worker must have at least one capability enabled (workflows, activities, or nexus)"
);
}
if !task_types.enable_workflows && task_types.enable_local_activities {
bail!("Local activities cannot be enabled without workflows")
}
if !skip_client_worker_set_check
&& let Some(existing_workers) = self.slot_providers.get(&slot_key)
{
for existing_worker_info in existing_workers {
if existing_worker_info.build_id.as_ref() == build_id.as_ref()
&& task_types.overlaps_with(&existing_worker_info.task_types)
{
bail!(
"Registration of multiple workers with overlapping worker task types \
on the same namespace, task queue, and deployment build ID not allowed: \
{slot_key:?}, worker_instance_key: {:?} \
build_id: {build_id:?}, \
new task types: {task_types:?}, \
existing task types: {:?}.",
existing_worker_info.task_types,
worker.worker_instance_key()
);
}
}
}
if worker.heartbeat_enabled()
&& let Some(heartbeat_callback) = worker.heartbeat_callback()
{
let worker_instance_key = worker.worker_instance_key();
let namespace = worker.namespace().to_string();
let shared_worker = match self.shared_worker.entry(namespace.clone()) {
Occupied(o) => o.into_mut(),
Vacant(v) => {
let shared_worker = worker.new_shared_namespace_worker()?;
v.insert(shared_worker)
}
};
shared_worker.register_callback(worker_instance_key, heartbeat_callback);
}
let worker_info =
RegisteredWorkerInfo::new(worker.worker_instance_key(), build_id, task_types);
match self.slot_providers.entry(slot_key.clone()) {
Occupied(o) => o.into_mut().push(worker_info),
Vacant(v) => {
v.insert(vec![worker_info]);
}
};
self.all_workers
.insert(worker.worker_instance_key(), worker);
Ok(())
}
fn unregister_slot_provider(&mut self, worker_instance_key: Uuid) -> Result<(), anyhow::Error> {
let worker = self.all_workers.get(&worker_instance_key).ok_or_else(|| {
anyhow::anyhow!("Worker not in all_workers during slot provider unregister")
})?;
let slot_key = SlotKey::new(
worker.namespace().to_string(),
worker.task_queue().to_string(),
);
if let Some(slot_vec) = self.slot_providers.get_mut(&slot_key) {
slot_vec.retain(|info| info.worker_id != worker_instance_key);
if slot_vec.is_empty() {
self.slot_providers.remove(&slot_key);
}
}
Ok(())
}
fn finalize_unregister(
&mut self,
worker_instance_key: Uuid,
) -> Result<Arc<dyn ClientWorker + Send + Sync>, anyhow::Error> {
if let Some(worker) = self.all_workers.get(&worker_instance_key)
&& let Some(slot_vec) = self.slot_providers.get(&SlotKey::new(
worker.namespace().to_string(),
worker.task_queue().to_string(),
))
&& slot_vec
.iter()
.any(|info| info.worker_id == worker_instance_key)
{
return Err(anyhow::anyhow!(
"Worker still in slot_providers during finalize"
));
}
let worker = self
.all_workers
.remove(&worker_instance_key)
.ok_or_else(|| anyhow::anyhow!("Worker not found in all_workers"))?;
if let Some(w) = self.shared_worker.get_mut(worker.namespace()) {
let (callback, is_empty) = w.unregister_callback(worker.worker_instance_key());
if callback.is_some() && is_empty {
self.shared_worker.remove(worker.namespace());
}
}
Ok(worker)
}
#[cfg(test)]
fn num_providers(&self) -> usize {
self.slot_providers.values().map(|v| v.len()).sum()
}
#[cfg(test)]
fn num_heartbeat_workers(&self) -> usize {
self.shared_worker.values().map(|v| v.num_workers()).sum()
}
}
pub trait SharedNamespaceWorkerTrait {
fn namespace(&self) -> String;
fn register_callback(&self, worker_instance_key: Uuid, heartbeat_callback: HeartbeatCallback);
fn unregister_callback(&self, worker_instance_key: Uuid) -> (Option<HeartbeatCallback>, bool);
fn num_workers(&self) -> usize;
}
pub struct ClientWorkerSet {
worker_grouping_key: Uuid,
worker_manager: RwLock<ClientWorkerSetImpl>,
}
impl Default for ClientWorkerSet {
fn default() -> Self {
Self::new()
}
}
impl ClientWorkerSet {
pub fn new() -> Self {
Self {
worker_grouping_key: Uuid::new_v4(),
worker_manager: RwLock::new(ClientWorkerSetImpl::new()),
}
}
pub(crate) fn try_reserve_wft_slot(
&self,
namespace: String,
task_queue: String,
) -> Option<SlotReservation> {
self.worker_manager
.read()
.try_reserve_wft_slot(namespace, task_queue)
}
pub fn register_worker(
&self,
worker: Arc<dyn ClientWorker + Send + Sync>,
skip_client_worker_set_check: bool,
) -> Result<(), anyhow::Error> {
self.worker_manager
.write()
.register(worker, skip_client_worker_set_check)
}
pub fn unregister_slot_provider(&self, worker_instance_key: Uuid) -> Result<(), anyhow::Error> {
self.worker_manager
.write()
.unregister_slot_provider(worker_instance_key)
}
pub fn finalize_unregister(
&self,
worker_instance_key: Uuid,
) -> Result<Arc<dyn ClientWorker + Send + Sync>, anyhow::Error> {
self.worker_manager
.write()
.finalize_unregister(worker_instance_key)
}
pub fn worker_grouping_key(&self) -> Uuid {
self.worker_grouping_key
}
#[cfg(test)]
pub fn num_providers(&self) -> usize {
self.worker_manager.read().num_providers()
}
#[cfg(test)]
pub fn num_heartbeat_workers(&self) -> usize {
self.worker_manager.read().num_heartbeat_workers()
}
}
impl std::fmt::Debug for ClientWorkerSet {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ClientWorkerSet")
.field("worker_grouping_key", &self.worker_grouping_key)
.finish()
}
}
pub type HeartbeatCallback = Arc<dyn Fn() -> WorkerHeartbeat + Send + Sync>;
#[cfg_attr(test, mockall::automock)]
pub trait ClientWorker: Send + Sync {
fn namespace(&self) -> &str;
fn task_queue(&self) -> &str;
fn try_reserve_wft_slot(&self) -> Option<Box<dyn Slot + Send>>;
fn deployment_options(&self) -> Option<WorkerDeploymentOptions>;
fn worker_instance_key(&self) -> Uuid;
fn heartbeat_enabled(&self) -> bool;
fn heartbeat_callback(&self) -> Option<HeartbeatCallback>;
fn new_shared_namespace_worker(
&self,
) -> Result<Box<dyn SharedNamespaceWorkerTrait + Send + Sync>, anyhow::Error>;
fn worker_task_types(&self) -> WorkerTaskTypes;
}
#[cfg(test)]
mod tests {
use super::*;
fn new_mock_slot(with_error: bool) -> Box<MockSlot> {
let mut mock_slot = MockSlot::new();
if with_error {
mock_slot
.expect_schedule_wft()
.returning(|_| Err(anyhow::anyhow!("Changed my mind")));
} else {
mock_slot.expect_schedule_wft().returning(|_| Ok(()));
}
Box::new(mock_slot)
}
fn new_mock_provider(
namespace: String,
task_queue: String,
with_error: bool,
no_slots: bool,
heartbeat_enabled: bool,
) -> MockClientWorker {
let mut mock_provider = MockClientWorker::new();
mock_provider
.expect_try_reserve_wft_slot()
.returning(move || {
if no_slots {
None
} else {
Some(new_mock_slot(with_error))
}
});
mock_provider.expect_namespace().return_const(namespace);
mock_provider.expect_task_queue().return_const(task_queue);
mock_provider.expect_deployment_options().return_const(None);
mock_provider
.expect_heartbeat_enabled()
.return_const(heartbeat_enabled);
mock_provider
.expect_worker_instance_key()
.return_const(Uuid::new_v4());
mock_provider
.expect_worker_task_types()
.return_const(WorkerTaskTypes {
enable_workflows: true,
enable_local_activities: true,
enable_remote_activities: true,
enable_nexus: true,
});
mock_provider
}
#[test]
fn reserve_wft_slot_retries_another_worker_when_first_has_no_slot() {
let mut manager = ClientWorkerSetImpl::new();
let namespace = "retry_namespace".to_string();
let task_queue = "retry_queue".to_string();
let failing_worker_id = Uuid::new_v4();
let mut failing_worker = MockClientWorker::new();
failing_worker
.expect_try_reserve_wft_slot()
.times(1)
.returning(|| None);
failing_worker
.expect_namespace()
.return_const(namespace.clone());
failing_worker
.expect_task_queue()
.return_const(task_queue.clone());
failing_worker
.expect_deployment_options()
.return_const(WorkerDeploymentOptions {
version: temporalio_common::worker::WorkerDeploymentVersion {
deployment_name: "test-deployment".to_string(),
build_id: "build-fail".to_string(),
},
use_worker_versioning: true,
default_versioning_behavior: None,
});
failing_worker
.expect_worker_instance_key()
.return_const(failing_worker_id);
failing_worker
.expect_heartbeat_enabled()
.return_const(false);
failing_worker
.expect_worker_task_types()
.return_const(WorkerTaskTypes {
enable_workflows: true,
enable_local_activities: true,
enable_remote_activities: true,
enable_nexus: true,
});
let succeeding_worker_id = Uuid::new_v4();
let mut succeeding_worker = MockClientWorker::new();
succeeding_worker
.expect_try_reserve_wft_slot()
.times(1)
.returning(|| Some(new_mock_slot(false)));
succeeding_worker
.expect_namespace()
.return_const(namespace.clone());
succeeding_worker
.expect_task_queue()
.return_const(task_queue.clone());
let success_deployment_options = WorkerDeploymentOptions {
version: temporalio_common::worker::WorkerDeploymentVersion {
deployment_name: "test-deployment".to_string(),
build_id: "build-success".to_string(),
},
use_worker_versioning: true,
default_versioning_behavior: None,
};
succeeding_worker
.expect_deployment_options()
.return_const(success_deployment_options.clone());
succeeding_worker
.expect_worker_instance_key()
.return_const(succeeding_worker_id);
succeeding_worker
.expect_heartbeat_enabled()
.return_const(false);
succeeding_worker
.expect_worker_task_types()
.return_const(WorkerTaskTypes {
enable_workflows: true,
enable_local_activities: true,
enable_remote_activities: true,
enable_nexus: true,
});
manager
.register(Arc::new(failing_worker), false)
.expect("failing worker registration succeeds");
manager
.register(Arc::new(succeeding_worker), false)
.expect("succeeding worker registration succeeds");
let reservation = manager.try_reserve_wft_slot(namespace.clone(), task_queue.clone());
let reservation_deployment_options = reservation
.expect("succeeding worker was used after failing worker failed")
.deployment_options
.unwrap();
assert_eq!(
reservation_deployment_options, success_deployment_options,
"deployment options bubble through from succeeding worker"
);
}
#[test]
fn reserve_wft_slot_retries_respects_slot_boundary() {
let mut manager = ClientWorkerSetImpl::new();
let namespace = "retry_namespace".to_string();
let task_queue = "retry_queue".to_string();
let failing_worker_id = Uuid::new_v4();
let mut failing_worker = MockClientWorker::new();
failing_worker
.expect_try_reserve_wft_slot()
.times(1)
.returning(|| None);
failing_worker
.expect_namespace()
.return_const(namespace.clone());
failing_worker
.expect_task_queue()
.return_const(task_queue.clone());
failing_worker
.expect_deployment_options()
.return_const(WorkerDeploymentOptions {
version: temporalio_common::worker::WorkerDeploymentVersion {
deployment_name: "test-deployment".to_string(),
build_id: "build-fail".to_string(),
},
use_worker_versioning: true,
default_versioning_behavior: None,
});
failing_worker
.expect_worker_instance_key()
.return_const(failing_worker_id);
failing_worker
.expect_heartbeat_enabled()
.return_const(false);
failing_worker
.expect_worker_task_types()
.return_const(WorkerTaskTypes {
enable_workflows: true,
enable_local_activities: true,
enable_remote_activities: true,
enable_nexus: true,
});
let succeeding_worker_id = Uuid::new_v4();
let mut succeeding_worker = MockClientWorker::new();
succeeding_worker.expect_try_reserve_wft_slot().times(0);
succeeding_worker
.expect_namespace()
.return_const(namespace.clone());
succeeding_worker
.expect_task_queue()
.return_const("other_task_queue".to_string());
succeeding_worker
.expect_deployment_options()
.return_const(None);
succeeding_worker
.expect_worker_instance_key()
.return_const(succeeding_worker_id);
succeeding_worker
.expect_heartbeat_enabled()
.return_const(false);
succeeding_worker
.expect_worker_task_types()
.return_const(WorkerTaskTypes {
enable_workflows: true,
enable_local_activities: true,
enable_remote_activities: true,
enable_nexus: true,
});
manager
.register(Arc::new(failing_worker), false)
.expect("failing worker registration succeeds");
manager
.register(Arc::new(succeeding_worker), false)
.expect("succeeding worker registration succeeds");
let reservation = manager.try_reserve_wft_slot(namespace.clone(), task_queue.clone());
assert!(
reservation.is_none(),
"succeeding_worker should not be picked due to it being on a separate task queue"
);
}
#[test]
fn registry_keeps_one_provider_per_namespace() {
let manager = ClientWorkerSet::new();
let mut worker_keys = vec![];
let mut successful_registrations = 0;
for i in 0..10 {
let namespace = format!("myId{}", i % 3);
let mock_provider =
new_mock_provider(namespace, "bar_q".to_string(), false, false, false);
let worker_instance_key = mock_provider.worker_instance_key();
let result = manager.register_worker(Arc::new(mock_provider), false);
if let Err(err) = result {
assert!(err.to_string().contains(
"Registration of multiple workers with overlapping worker task types"
));
} else {
successful_registrations += 1;
worker_keys.push(worker_instance_key);
}
}
assert_eq!(successful_registrations, 3);
assert_eq!(3, manager.num_providers());
let count = worker_keys.iter().fold(0, |count, key| {
manager.unregister_slot_provider(*key).unwrap();
manager.finalize_unregister(*key).unwrap();
let result = manager.unregister_slot_provider(*key);
assert!(result.is_err());
let result = manager.finalize_unregister(*key);
assert!(result.is_err());
count + 1
});
assert_eq!(3, count);
assert_eq!(0, manager.num_providers());
}
struct MockSharedNamespaceWorker {
namespace: String,
callbacks: Arc<RwLock<HashMap<Uuid, HeartbeatCallback>>>,
}
impl std::fmt::Debug for MockSharedNamespaceWorker {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MockSharedNamespaceWorker")
.field("namespace", &self.namespace)
.field("callbacks_count", &self.callbacks.read().len())
.finish()
}
}
impl MockSharedNamespaceWorker {
fn new(namespace: String) -> Self {
Self {
namespace,
callbacks: Arc::new(RwLock::new(HashMap::new())),
}
}
}
impl SharedNamespaceWorkerTrait for MockSharedNamespaceWorker {
fn namespace(&self) -> String {
self.namespace.clone()
}
fn register_callback(
&self,
worker_instance_key: Uuid,
heartbeat_callback: HeartbeatCallback,
) {
self.callbacks
.write()
.insert(worker_instance_key, heartbeat_callback);
}
fn unregister_callback(
&self,
worker_instance_key: Uuid,
) -> (Option<HeartbeatCallback>, bool) {
let mut callbacks = self.callbacks.write();
let callback = callbacks.remove(&worker_instance_key);
let is_empty = callbacks.is_empty();
(callback, is_empty)
}
fn num_workers(&self) -> usize {
self.callbacks.read().len()
}
}
fn new_mock_provider_with_heartbeat(
namespace: String,
task_queue: String,
heartbeat_enabled: bool,
build_id: Option<String>,
) -> MockClientWorker {
let mut mock_provider = MockClientWorker::new();
mock_provider
.expect_try_reserve_wft_slot()
.returning(|| Some(new_mock_slot(false)));
mock_provider
.expect_namespace()
.return_const(namespace.clone());
mock_provider.expect_task_queue().return_const(task_queue);
mock_provider
.expect_heartbeat_enabled()
.return_const(heartbeat_enabled);
mock_provider
.expect_worker_instance_key()
.return_const(Uuid::new_v4());
let deployment_name = "test-deployment".to_string();
let build_id_for_closure = build_id.clone();
mock_provider
.expect_deployment_options()
.returning(move || {
build_id_for_closure
.as_ref()
.map(|build_id| WorkerDeploymentOptions {
version: temporalio_common::worker::WorkerDeploymentVersion {
deployment_name: deployment_name.clone(),
build_id: build_id.clone(),
},
use_worker_versioning: true,
default_versioning_behavior: None,
})
});
if heartbeat_enabled {
mock_provider
.expect_heartbeat_callback()
.returning(|| Some(Arc::new(WorkerHeartbeat::default)));
let namespace_clone = namespace.clone();
mock_provider
.expect_new_shared_namespace_worker()
.returning(move || {
Ok(Box::new(MockSharedNamespaceWorker::new(
namespace_clone.clone(),
)))
});
}
mock_provider
.expect_worker_task_types()
.return_const(WorkerTaskTypes {
enable_workflows: true,
enable_local_activities: true,
enable_remote_activities: true,
enable_nexus: true,
});
mock_provider
}
#[test]
fn duplicate_namespace_task_queue_registration_fails() {
let manager = ClientWorkerSet::new();
let worker1 = new_mock_provider_with_heartbeat(
"test_namespace".to_string(),
"test_queue".to_string(),
true,
None,
);
let worker2 = new_mock_provider_with_heartbeat(
"test_namespace".to_string(),
"test_queue".to_string(),
true,
None,
);
manager.register_worker(Arc::new(worker1), false).unwrap();
let result = manager.register_worker(Arc::new(worker2), false);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Registration of multiple workers with overlapping worker task types")
);
assert_eq!(1, manager.num_providers());
assert_eq!(manager.num_heartbeat_workers(), 1);
let impl_ref = manager.worker_manager.read();
assert_eq!(impl_ref.shared_worker.len(), 1);
assert!(impl_ref.shared_worker.contains_key("test_namespace"));
}
#[test]
fn duplicate_namespace_with_different_build_ids_succeeds() {
let manager = ClientWorkerSet::new();
let namespace = "test_namespace".to_string();
let task_queue = "test_queue".to_string();
let worker1 =
new_mock_provider_with_heartbeat(namespace.clone(), task_queue.clone(), false, None);
let worker1_instance_key = worker1.worker_instance_key();
let worker2 = new_mock_provider_with_heartbeat(
namespace.clone(),
task_queue.clone(),
false,
Some("build-1".to_string()),
);
let worker2_instance_key = worker2.worker_instance_key();
let worker3 =
new_mock_provider_with_heartbeat(namespace.clone(), task_queue.clone(), false, None);
let worker4 = new_mock_provider_with_heartbeat(
namespace.clone(),
task_queue.clone(),
false,
Some("build-1".to_string()),
);
manager.register_worker(Arc::new(worker1), false).unwrap();
manager
.register_worker(Arc::new(worker2), false)
.expect("worker with new build ID should register");
assert_eq!(2, manager.num_providers());
assert!(
manager
.register_worker(Arc::new(worker3), false)
.unwrap_err()
.to_string()
.contains("Registration of multiple workers with overlapping worker task types")
);
assert!(
manager
.register_worker(Arc::new(worker4), false)
.unwrap_err()
.to_string()
.contains("Registration of multiple workers with overlapping worker task types")
);
assert_eq!(2, manager.num_providers());
{
let impl_ref = manager.worker_manager.read();
let slot_key = SlotKey::new(namespace.clone(), task_queue.clone());
let providers = impl_ref
.slot_providers
.get(&slot_key)
.expect("slot providers should exist for namespace/task queue");
assert_eq!(2, providers.len());
assert_eq!(providers[0].worker_id, worker1_instance_key);
assert_eq!(providers[0].build_id, None);
assert_eq!(providers[1].worker_id, worker2_instance_key);
assert_eq!(providers[1].build_id, Some("build-1".to_string()));
}
manager
.unregister_slot_provider(worker2_instance_key)
.unwrap();
manager.finalize_unregister(worker2_instance_key).unwrap();
{
let impl_ref = manager.worker_manager.read();
let slot_key = SlotKey::new(namespace.clone(), task_queue.clone());
let providers = impl_ref
.slot_providers
.get(&slot_key)
.expect("slot providers should exist for namespace/task queue");
assert_eq!(1, providers.len());
assert_eq!(providers[0].worker_id, worker1_instance_key);
assert_eq!(providers[0].build_id, None);
}
}
#[test]
fn multiple_workers_same_namespace_share_heartbeat_manager() {
let manager = ClientWorkerSet::new();
let worker1 = new_mock_provider_with_heartbeat(
"shared_namespace".to_string(),
"queue1".to_string(),
true,
None,
);
let worker2 = new_mock_provider_with_heartbeat(
"shared_namespace".to_string(),
"queue2".to_string(),
true,
None,
);
manager.register_worker(Arc::new(worker1), false).unwrap();
manager.register_worker(Arc::new(worker2), false).unwrap();
assert_eq!(2, manager.num_providers());
assert_eq!(manager.num_heartbeat_workers(), 2);
let impl_ref = manager.worker_manager.read();
assert_eq!(impl_ref.shared_worker.len(), 1);
assert!(impl_ref.shared_worker.contains_key("shared_namespace"));
let shared_worker = impl_ref.shared_worker.get("shared_namespace").unwrap();
assert_eq!(shared_worker.namespace(), "shared_namespace");
}
#[test]
fn different_namespaces_get_separate_heartbeat_managers() {
let manager = ClientWorkerSet::new();
let worker1 = new_mock_provider_with_heartbeat(
"namespace1".to_string(),
"queue1".to_string(),
true,
None,
);
let worker2 = new_mock_provider_with_heartbeat(
"namespace2".to_string(),
"queue1".to_string(),
true,
None,
);
manager.register_worker(Arc::new(worker1), false).unwrap();
manager.register_worker(Arc::new(worker2), false).unwrap();
assert_eq!(2, manager.num_providers());
assert_eq!(manager.num_heartbeat_workers(), 2);
let impl_ref = manager.worker_manager.read();
assert_eq!(impl_ref.num_heartbeat_workers(), 2);
assert!(impl_ref.shared_worker.contains_key("namespace1"));
assert!(impl_ref.shared_worker.contains_key("namespace2"));
}
#[test]
fn unregister_heartbeat_workers_cleans_up_shared_worker_when_last_removed() {
let manager = ClientWorkerSet::new();
let worker1 = new_mock_provider_with_heartbeat(
"test_namespace".to_string(),
"queue1".to_string(),
true,
None,
);
let worker2 = new_mock_provider_with_heartbeat(
"test_namespace".to_string(),
"queue2".to_string(),
true,
None,
);
let worker_instance_key1 = worker1.worker_instance_key();
let worker_instance_key2 = worker2.worker_instance_key();
assert_ne!(worker_instance_key1, worker_instance_key2);
manager.register_worker(Arc::new(worker1), false).unwrap();
manager.register_worker(Arc::new(worker2), false).unwrap();
assert_eq!(2, manager.num_providers());
assert_eq!(manager.num_heartbeat_workers(), 2);
let impl_ref = manager.worker_manager.read();
assert_eq!(impl_ref.shared_worker.len(), 1);
assert!(impl_ref.shared_worker.contains_key("test_namespace"));
assert_eq!(
impl_ref
.shared_worker
.get("test_namespace")
.unwrap()
.num_workers(),
2
);
drop(impl_ref);
manager
.unregister_slot_provider(worker_instance_key1)
.unwrap();
manager.finalize_unregister(worker_instance_key1).unwrap();
assert_eq!(1, manager.num_providers());
assert_eq!(manager.num_heartbeat_workers(), 1);
let impl_ref = manager.worker_manager.read();
assert_eq!(impl_ref.num_heartbeat_workers(), 1); assert!(impl_ref.shared_worker.contains_key("test_namespace"));
assert_eq!(
impl_ref
.shared_worker
.get("test_namespace")
.unwrap()
.num_workers(),
1
);
drop(impl_ref);
manager
.unregister_slot_provider(worker_instance_key2)
.unwrap();
manager.finalize_unregister(worker_instance_key2).unwrap();
assert_eq!(0, manager.num_providers());
assert_eq!(manager.num_heartbeat_workers(), 0);
let impl_ref = manager.worker_manager.read();
assert_eq!(impl_ref.shared_worker.len(), 0); assert!(!impl_ref.shared_worker.contains_key("test_namespace"));
}
#[test]
fn workflow_and_activity_only_workers_coexist() {
let manager = ClientWorkerSet::new();
let namespace = "test_namespace".to_string();
let task_queue = "test_queue".to_string();
let mut workflow_nexus_worker = MockClientWorker::new();
workflow_nexus_worker
.expect_namespace()
.return_const(namespace.clone());
workflow_nexus_worker
.expect_task_queue()
.return_const(task_queue.clone());
workflow_nexus_worker
.expect_deployment_options()
.return_const(None);
workflow_nexus_worker
.expect_worker_instance_key()
.return_const(Uuid::new_v4());
workflow_nexus_worker
.expect_heartbeat_enabled()
.return_const(false);
workflow_nexus_worker
.expect_worker_task_types()
.return_const(WorkerTaskTypes {
enable_workflows: true,
enable_local_activities: false,
enable_remote_activities: false,
enable_nexus: true,
});
let mut activity_worker = MockClientWorker::new();
activity_worker
.expect_namespace()
.return_const(namespace.clone());
activity_worker
.expect_task_queue()
.return_const(task_queue.clone());
activity_worker
.expect_deployment_options()
.return_const(None);
activity_worker
.expect_worker_instance_key()
.return_const(Uuid::new_v4());
activity_worker
.expect_heartbeat_enabled()
.return_const(false);
activity_worker
.expect_worker_task_types()
.return_const(WorkerTaskTypes {
enable_workflows: false,
enable_local_activities: false,
enable_remote_activities: true,
enable_nexus: false,
});
activity_worker.expect_try_reserve_wft_slot().times(0);
manager
.register_worker(Arc::new(workflow_nexus_worker), false)
.expect("workflow-nexus worker should register");
manager
.register_worker(Arc::new(activity_worker), false)
.expect("activity-only worker should register");
assert_eq!(2, manager.num_providers());
}
#[test]
fn overlapping_capabilities_rejected() {
let manager = ClientWorkerSet::new();
let namespace = "test_namespace".to_string();
let task_queue = "test_queue".to_string();
let mut worker1 = MockClientWorker::new();
worker1.expect_namespace().return_const(namespace.clone());
worker1.expect_task_queue().return_const(task_queue.clone());
worker1.expect_deployment_options().return_const(None);
worker1
.expect_worker_instance_key()
.return_const(Uuid::new_v4());
worker1.expect_heartbeat_enabled().return_const(false);
worker1
.expect_worker_task_types()
.return_const(WorkerTaskTypes {
enable_workflows: true,
enable_local_activities: true,
enable_remote_activities: true,
enable_nexus: false,
});
let mut worker2 = MockClientWorker::new();
worker2.expect_namespace().return_const(namespace.clone());
worker2.expect_task_queue().return_const(task_queue.clone());
worker2.expect_deployment_options().return_const(None);
worker2
.expect_worker_instance_key()
.return_const(Uuid::new_v4());
worker2.expect_heartbeat_enabled().return_const(false);
worker2
.expect_worker_task_types()
.return_const(WorkerTaskTypes {
enable_workflows: true,
enable_local_activities: true,
enable_remote_activities: true,
enable_nexus: false,
});
manager
.register_worker(Arc::new(worker1), false)
.expect("first worker should register");
let result = manager.register_worker(Arc::new(worker2), false);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("overlapping worker task types")
);
let mut worker3 = MockClientWorker::new();
worker3.expect_namespace().return_const(namespace.clone());
worker3.expect_task_queue().return_const(task_queue.clone());
worker3.expect_deployment_options().return_const(None);
worker3
.expect_worker_instance_key()
.return_const(Uuid::new_v4());
worker3.expect_heartbeat_enabled().return_const(false);
worker3
.expect_worker_task_types()
.return_const(WorkerTaskTypes {
enable_workflows: false,
enable_local_activities: false,
enable_remote_activities: true,
enable_nexus: false,
});
let result = manager.register_worker(Arc::new(worker3), false);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("overlapping worker task types")
);
}
#[test]
fn wft_slot_reservation_ignores_non_workflow_workers() {
let mut manager_impl = ClientWorkerSetImpl::new();
let namespace = "test_namespace".to_string();
let task_queue = "test_queue".to_string();
let mut activity_worker = MockClientWorker::new();
activity_worker
.expect_namespace()
.return_const(namespace.clone());
activity_worker
.expect_task_queue()
.return_const(task_queue.clone());
activity_worker
.expect_deployment_options()
.return_const(None);
activity_worker
.expect_worker_instance_key()
.return_const(Uuid::new_v4());
activity_worker
.expect_heartbeat_enabled()
.return_const(false);
activity_worker
.expect_worker_task_types()
.return_const(WorkerTaskTypes {
enable_workflows: false,
enable_local_activities: false,
enable_remote_activities: true,
enable_nexus: false,
});
let mut nexus_worker = MockClientWorker::new();
nexus_worker
.expect_namespace()
.return_const(namespace.clone());
nexus_worker
.expect_task_queue()
.return_const(task_queue.clone());
nexus_worker.expect_deployment_options().return_const(None);
nexus_worker
.expect_worker_instance_key()
.return_const(Uuid::new_v4());
nexus_worker.expect_heartbeat_enabled().return_const(false);
nexus_worker
.expect_worker_task_types()
.return_const(WorkerTaskTypes {
enable_workflows: false,
enable_local_activities: false,
enable_remote_activities: false,
enable_nexus: true,
});
manager_impl
.register(Arc::new(activity_worker), false)
.expect("activity worker should register");
manager_impl
.register(Arc::new(nexus_worker), false)
.expect("nexus worker should register");
let reservation = manager_impl.try_reserve_wft_slot(namespace.clone(), task_queue.clone());
assert!(
reservation.is_none(),
"should not find workflow workers when only activity/nexus workers registered"
);
let mut workflow_worker = MockClientWorker::new();
workflow_worker
.expect_namespace()
.return_const(namespace.clone());
workflow_worker
.expect_task_queue()
.return_const(task_queue.clone());
workflow_worker
.expect_deployment_options()
.return_const(None);
workflow_worker
.expect_worker_instance_key()
.return_const(Uuid::new_v4());
workflow_worker
.expect_heartbeat_enabled()
.return_const(false);
workflow_worker
.expect_worker_task_types()
.return_const(WorkerTaskTypes {
enable_workflows: true,
enable_local_activities: true,
enable_remote_activities: false,
enable_nexus: false,
});
workflow_worker
.expect_try_reserve_wft_slot()
.times(1)
.returning(|| Some(new_mock_slot(false)));
manager_impl
.register(Arc::new(workflow_worker), false)
.expect("workflow worker should register");
let reservation = manager_impl.try_reserve_wft_slot(namespace.clone(), task_queue.clone());
assert!(
reservation.is_some(),
"should find workflow worker after it's registered"
);
}
#[test]
fn worker_invalid_type_config_rejected() {
let manager = ClientWorkerSet::new();
let mut worker = MockClientWorker::new();
worker
.expect_namespace()
.return_const("test_namespace".to_string());
worker
.expect_task_queue()
.return_const("test_queue".to_string());
worker.expect_deployment_options().return_const(None);
worker
.expect_worker_instance_key()
.return_const(Uuid::new_v4());
worker.expect_heartbeat_enabled().return_const(false);
worker
.expect_worker_task_types()
.return_const(WorkerTaskTypes {
enable_workflows: false,
enable_local_activities: false,
enable_remote_activities: false,
enable_nexus: false,
});
let result = manager.register_worker(Arc::new(worker), false);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("must have at least one capability enabled")
);
let mut worker = MockClientWorker::new();
worker
.expect_namespace()
.return_const("test_namespace".to_string());
worker
.expect_task_queue()
.return_const("test_queue".to_string());
worker.expect_deployment_options().return_const(None);
worker
.expect_worker_instance_key()
.return_const(Uuid::new_v4());
worker.expect_heartbeat_enabled().return_const(false);
worker
.expect_worker_task_types()
.return_const(WorkerTaskTypes {
enable_workflows: false,
enable_local_activities: true,
enable_remote_activities: true,
enable_nexus: false,
});
let result = manager.register_worker(Arc::new(worker), false);
assert!(result.is_err());
assert_eq!(
result.unwrap_err().to_string(),
"Local activities cannot be enabled without workflows".to_string()
);
}
#[test]
fn unregister_with_multiple_workers() {
let manager = ClientWorkerSet::new();
let namespace = "test_namespace".to_string();
let task_queue = "test_queue".to_string();
let mut workflow_worker = MockClientWorker::new();
workflow_worker
.expect_namespace()
.return_const(namespace.clone());
workflow_worker
.expect_task_queue()
.return_const(task_queue.clone());
workflow_worker
.expect_deployment_options()
.return_const(None);
let wf_worker_key = Uuid::new_v4();
workflow_worker
.expect_worker_instance_key()
.return_const(wf_worker_key);
workflow_worker
.expect_heartbeat_enabled()
.return_const(false);
workflow_worker
.expect_worker_task_types()
.return_const(WorkerTaskTypes {
enable_workflows: true,
enable_local_activities: true,
enable_remote_activities: false,
enable_nexus: false,
});
workflow_worker
.expect_try_reserve_wft_slot()
.returning(|| Some(new_mock_slot(false)));
let mut activity_worker = MockClientWorker::new();
activity_worker
.expect_namespace()
.return_const(namespace.clone());
activity_worker
.expect_task_queue()
.return_const(task_queue.clone());
activity_worker
.expect_deployment_options()
.return_const(None);
let act_worker_key = Uuid::new_v4();
activity_worker
.expect_worker_instance_key()
.return_const(act_worker_key);
activity_worker
.expect_heartbeat_enabled()
.return_const(false);
activity_worker
.expect_worker_task_types()
.return_const(WorkerTaskTypes {
enable_workflows: false,
enable_local_activities: false,
enable_remote_activities: true,
enable_nexus: false,
});
manager
.register_worker(Arc::new(workflow_worker), false)
.expect("workflow worker should register");
manager
.register_worker(Arc::new(activity_worker), false)
.expect("activity worker should register");
assert_eq!(2, manager.num_providers());
let reservation = manager.try_reserve_wft_slot(namespace.clone(), task_queue.clone());
assert!(
reservation.is_some(),
"should be able to reserve slot from workflow worker"
);
manager
.unregister_slot_provider(wf_worker_key)
.expect("should unregister slot provider for workflow worker");
manager
.finalize_unregister(wf_worker_key)
.expect("should finalize unregister for workflow worker");
assert_eq!(1, manager.num_providers());
let reservation = manager.try_reserve_wft_slot(namespace.clone(), task_queue.clone());
assert!(
reservation.is_none(),
"should not find workflow worker after unregistration"
);
manager
.unregister_slot_provider(act_worker_key)
.expect("should unregister slot provider for activity worker");
manager
.finalize_unregister(act_worker_key)
.expect("should finalize unregister for activity worker");
assert_eq!(0, manager.num_providers());
}
#[test]
fn worker_unregister_order() {
let manager = ClientWorkerSet::new();
let worker = new_mock_provider_with_heartbeat(
"namespace1".to_string(),
"queue1".to_string(),
true,
None,
);
let worker_instance_key = worker.worker_instance_key();
manager.register_worker(Arc::new(worker), false).unwrap();
let res = manager.finalize_unregister(worker_instance_key);
assert!(res.is_err());
let err_string = res.err().map(|e| e.to_string()).unwrap();
assert!(err_string.contains("Worker still in slot_providers during finalize"));
manager
.unregister_slot_provider(worker_instance_key)
.unwrap();
manager.finalize_unregister(worker_instance_key).unwrap();
}
}