1use std::sync::Arc;
2
3use tokio::task::JoinSet;
4use tokio_stream::StreamExt;
5use tonic::transport::Channel;
6
7use tokio::sync::Semaphore;
8
9use crate::api::DurableTaskError;
10use crate::internal::validate_identifier;
11use crate::proto;
12use crate::proto::history_event::EventType;
13use crate::proto::task_hub_sidecar_service_client::TaskHubSidecarServiceClient;
14use crate::proto::work_item::Request;
15
16use super::activity_executor::ActivityExecutor;
17use super::options::WorkerOptions;
18use super::orchestration_executor::OrchestrationExecutor;
19use super::reconnect_policy::BackoffIter;
20use super::registry::Registry;
21
22pub struct TaskHubGrpcWorker {
49 host_address: String,
50 registry: Arc<Registry>,
51 options: Arc<WorkerOptions>,
52}
53
54impl TaskHubGrpcWorker {
55 pub fn new(host_address: &str) -> Self {
57 Self {
58 host_address: host_address.to_string(),
59 registry: Arc::new(Registry::new()),
60 options: Arc::new(WorkerOptions::default()),
61 }
62 }
63
64 pub fn with_options(host_address: &str, options: WorkerOptions) -> Self {
66 Self {
67 host_address: host_address.to_string(),
68 registry: Arc::new(Registry::new()),
69 options: Arc::new(options),
70 }
71 }
72
73 pub fn registry_mut(&mut self) -> &mut Registry {
80 Arc::get_mut(&mut self.registry).expect("Cannot modify registry after worker has started")
81 }
82
83 pub async fn start(
101 &self,
102 shutdown: tokio_util::sync::CancellationToken,
103 ) -> crate::api::Result<()> {
104 let mut backoff = BackoffIter::new(&self.options.reconnect_policy);
105
106 loop {
107 if shutdown.is_cancelled() {
108 tracing::info!("Worker shutdown before connecting");
109 return Ok(());
110 }
111
112 tracing::info!(address = %self.host_address, "Worker connecting to sidecar");
113
114 match Self::connect(&self.host_address).await {
115 Ok(channel) => {
116 tracing::info!(address = %self.host_address, "Worker connected, starting work loop");
117 backoff.reset();
118
119 let mut client = TaskHubSidecarServiceClient::new(channel);
120
121 match Self::run_work_loop(&mut client, &self.registry, &self.options, &shutdown)
122 .await
123 {
124 Ok(()) => {
125 if shutdown.is_cancelled() {
126 tracing::info!(
127 "Worker shut down gracefully after draining in-flight tasks"
128 );
129 } else {
130 tracing::info!("Work item stream closed cleanly; shutting down");
131 }
132 return Ok(());
133 }
134 Err(e) => {
135 tracing::warn!(error = %e, "Work loop error");
136 }
137 }
138 }
139 Err(e) => {
140 tracing::warn!(error = %e, "Connection to sidecar failed");
141 }
142 }
143
144 match backoff.next_delay() {
146 None => {
147 let msg = format!(
148 "Worker exceeded maximum reconnect attempts ({}); giving up",
149 self.options.reconnect_policy.max_attempts.unwrap_or(0)
150 );
151 tracing::error!("{}", msg);
152 return Err(DurableTaskError::ConnectionFailed(msg));
153 }
154 Some(delay) => {
155 tracing::info!(
156 delay_ms = delay.as_millis(),
157 address = %self.host_address,
158 "Waiting before reconnect"
159 );
160 tokio::select! {
161 _ = shutdown.cancelled() => {
162 tracing::info!("Worker shutdown during reconnect wait");
163 return Ok(());
164 }
165 _ = tokio::time::sleep(delay) => {}
166 }
167 }
168 }
169 }
170 }
171
172 async fn connect(address: &str) -> crate::api::Result<Channel> {
173 const USER_AGENT: &str = concat!("dapr-durabletask/rust/", env!("CARGO_PKG_VERSION"));
174
175 Channel::from_shared(address.to_string())
176 .map_err(|e| DurableTaskError::InvalidAddress(format!("Invalid address: {e}")))?
177 .user_agent(USER_AGENT)
178 .map_err(|e| DurableTaskError::InvalidAddress(format!("Invalid user agent: {e}")))?
179 .connect()
180 .await
181 .map_err(|e| DurableTaskError::ConnectionFailed(format!("Connection failed: {e}")))
182 }
183
184 async fn run_work_loop(
185 client: &mut TaskHubSidecarServiceClient<Channel>,
186 registry: &Arc<Registry>,
187 options: &Arc<WorkerOptions>,
188 shutdown: &tokio_util::sync::CancellationToken,
189 ) -> crate::api::Result<()> {
190 let request = proto::GetWorkItemsRequest {};
191 let mut stream = client.get_work_items(request).await?.into_inner();
192 let semaphore = Arc::new(Semaphore::new(options.max_concurrent_work_items));
193 let mut tasks: JoinSet<()> = JoinSet::new();
194 tracing::info!("Work item stream established");
195
196 let shutdown_triggered = loop {
199 prune_finished_tasks(&mut tasks);
200 tokio::select! {
201 biased; _ = shutdown.cancelled() => {
203 tracing::info!(
204 in_flight = tasks.len(),
205 "Shutdown: stopping intake, draining in-flight work items"
206 );
207 break true;
208 }
209 Some(outcome) = tasks.join_next(), if !tasks.is_empty() => {
210 if let Err(e) = outcome {
211 tracing::error!(error = ?e, "Work item task panicked");
212 }
213 }
214 item = stream.next() => {
215 match item {
216 None => {
217 tracing::info!("Work item stream closed by sidecar");
220 break false;
221 }
222 Some(Err(e)) => {
223 return Err(DurableTaskError::ConnectionFailed(format!("Stream error: {e}")));
224 }
225 Some(Ok(work_item)) => {
226 Self::dispatch_work_item(
227 work_item,
228 client.clone(),
229 registry,
230 options,
231 &semaphore,
232 &mut tasks,
233 ).await?;
234 }
235 }
236 }
237 }
238 };
239
240 if !tasks.is_empty() {
242 tracing::info!(count = tasks.len(), "Draining in-flight work items");
243 while let Some(outcome) = tasks.join_next().await {
244 if let Err(e) = outcome {
245 tracing::error!(error = ?e, "In-flight task panicked during drain");
246 }
247 }
248 tracing::info!("All in-flight work items drained");
249 }
250
251 if shutdown_triggered {
252 Ok(())
254 } else {
255 Err(DurableTaskError::ConnectionFailed(
257 "Work item stream closed by sidecar".into(),
258 ))
259 }
260 }
261
262 async fn dispatch_work_item(
264 work_item: proto::WorkItem,
265 client: TaskHubSidecarServiceClient<Channel>,
266 registry: &Arc<Registry>,
267 options: &Arc<WorkerOptions>,
268 semaphore: &Arc<Semaphore>,
269 tasks: &mut JoinSet<()>,
270 ) -> crate::api::Result<()> {
271 match work_item.request {
272 Some(Request::WorkflowRequest(req)) => {
273 let instance_id = req.instance_id.clone();
274 if let Err(e) =
275 validate_identifier(&instance_id, "instance ID", options.max_identifier_length)
276 {
277 tracing::warn!(
278 instance_id = %instance_id,
279 error = %e,
280 "Rejected work item: invalid instance ID"
281 );
282 return Ok(());
283 }
284 tracing::debug!(
285 instance_id = %instance_id,
286 past_events = req.past_events.len(),
287 new_events = req.new_events.len(),
288 "Received orchestrator work item"
289 );
290
291 let registry = registry.clone();
292 let options = options.clone();
293 let mut stub = client;
294 let completion_token = work_item.completion_token.clone();
295 let permit = semaphore
296 .clone()
297 .acquire_owned()
298 .await
299 .map_err(|_| DurableTaskError::Internal("Semaphore closed".to_string()))?;
300
301 tasks.spawn(async move {
302 let _permit = permit;
303 let response = Self::handle_orchestrator_request(
304 ®istry,
305 req,
306 completion_token,
307 &options,
308 )
309 .await;
310 #[allow(deprecated)]
312 if let Err(e) = stub.complete_orchestrator_task(response).await {
313 tracing::error!(
314 instance_id = %instance_id,
315 error = %e,
316 "Failed to complete orchestrator task"
317 );
318 }
319 });
320 }
321 Some(Request::ActivityRequest(req)) => {
322 let instance_id = req
323 .workflow_instance
324 .as_ref()
325 .map(|i| i.instance_id.clone())
326 .unwrap_or_default();
327 tracing::debug!(
328 instance_id = %instance_id,
329 activity = %req.name,
330 task_id = req.task_id,
331 "Received activity work item"
332 );
333
334 let registry = registry.clone();
335 let options = options.clone();
336 let mut stub = client;
337 let completion_token = work_item.completion_token.clone();
338 let activity_name = req.name.clone();
339 let permit = semaphore
340 .clone()
341 .acquire_owned()
342 .await
343 .map_err(|_| DurableTaskError::Internal("Semaphore closed".to_string()))?;
344
345 tasks.spawn(async move {
346 let _permit = permit;
347 let response =
348 Self::handle_activity_request(®istry, req, completion_token, &options)
349 .await;
350 if let Err(e) = stub.complete_activity_task(response).await {
351 tracing::error!(
352 instance_id = %instance_id,
353 activity = %activity_name,
354 error = %e,
355 "Failed to complete activity task"
356 );
357 }
358 });
359 }
360 None => {
361 tracing::warn!("Received work item with no request payload");
362 }
363 }
364 Ok(())
365 }
366
367 async fn handle_orchestrator_request(
368 registry: &Registry,
369 request: proto::WorkflowRequest,
370 completion_token: String,
371 options: &WorkerOptions,
372 ) -> proto::WorkflowResponse {
373 let instance_id = request.instance_id.clone();
374
375 let (name, version) = request
377 .past_events
378 .iter()
379 .chain(request.new_events.iter())
380 .find_map(|e| {
381 if let Some(EventType::ExecutionStarted(es)) = &e.event_type {
382 Some((es.name.clone(), es.version.clone()))
383 } else {
384 None
385 }
386 })
387 .unwrap_or_default();
388
389 if let Err(e) =
390 validate_identifier(&name, "orchestrator name", options.max_identifier_length)
391 {
392 tracing::warn!(
393 instance_id = %instance_id,
394 orchestrator = %name,
395 error = %e,
396 "Rejected orchestrator request: invalid name"
397 );
398 return build_error_response(&instance_id, &e.to_string(), completion_token);
399 }
400
401 let orchestrator_fn = match registry.get_orchestrator_version(&name, version.as_deref()) {
402 Some(f) => f,
403 None => {
404 tracing::warn!(
405 instance_id = %instance_id,
406 orchestrator = %name,
407 "Unregistered orchestrator requested"
408 );
409 return build_error_response(
410 &instance_id,
411 &format!("Orchestrator '{name}' not registered"),
412 completion_token,
413 );
414 }
415 };
416
417 match OrchestrationExecutor::execute(
418 orchestrator_fn,
419 &instance_id,
420 request.past_events,
421 request.new_events,
422 completion_token.clone(),
423 options,
424 request
425 .propagated_history
426 .and_then(crate::api::PropagatedHistory::from_proto),
427 )
428 .await
429 {
430 Ok(response) => response,
431 Err(e) => {
432 tracing::error!(
433 instance_id = %instance_id,
434 orchestrator = %name,
435 error = %e,
436 "Orchestrator execution failed"
437 );
438 build_error_response(&instance_id, &e.to_string(), completion_token)
439 }
440 }
441 }
442
443 async fn handle_activity_request(
444 registry: &Registry,
445 request: proto::ActivityRequest,
446 completion_token: String,
447 options: &WorkerOptions,
448 ) -> proto::ActivityResponse {
449 let instance_id = request
450 .workflow_instance
451 .as_ref()
452 .map(|i| i.instance_id.as_str())
453 .unwrap_or("");
454
455 let build_activity_error =
456 |error_type: &str, error_message: String| proto::ActivityResponse {
457 instance_id: instance_id.to_string(),
458 task_id: request.task_id,
459 result: None,
460 failure_details: Some(proto::TaskFailureDetails {
461 error_type: error_type.to_string(),
462 error_message,
463 stack_trace: None,
464 inner_failure: None,
465 is_non_retriable: true,
466 }),
467 completion_token: completion_token.clone(),
468 };
469
470 if let Err(e) = validate_identifier(
471 &request.name,
472 "activity name",
473 options.max_identifier_length,
474 ) {
475 tracing::warn!(
476 instance_id = %instance_id,
477 activity = %request.name,
478 error = %e,
479 "Rejected activity request: invalid name"
480 );
481 return build_activity_error("InvalidActivityName", e.to_string());
482 }
483
484 let activity_fn = match registry.get_activity(&request.name) {
485 Some(f) => f,
486 None => {
487 tracing::warn!(
488 instance_id = %instance_id,
489 activity = %request.name,
490 "Unregistered activity requested"
491 );
492 return build_activity_error(
493 "ActivityNotRegistered",
494 format!("Activity '{}' not registered", request.name),
495 );
496 }
497 };
498
499 ActivityExecutor::execute(
500 activity_fn,
501 &request.name,
502 instance_id,
503 request.task_id,
504 request.task_execution_id,
505 request.input,
506 request.parent_trace_context.as_ref(),
507 completion_token,
508 request
509 .propagated_history
510 .and_then(crate::api::PropagatedHistory::from_proto),
511 )
512 .await
513 }
514}
515
516fn prune_finished_tasks(tasks: &mut JoinSet<()>) {
517 while let Some(outcome) = tasks.try_join_next() {
518 if let Err(e) = outcome {
519 tracing::error!(error = ?e, "Work item task panicked");
520 }
521 }
522}
523
524fn build_error_response(
525 instance_id: &str,
526 message: &str,
527 completion_token: String,
528) -> proto::WorkflowResponse {
529 proto::WorkflowResponse {
530 instance_id: instance_id.to_string(),
531 actions: vec![proto::WorkflowAction {
532 id: -1,
533 router: None,
534 workflow_action_type: Some(
535 proto::workflow_action::WorkflowActionType::CompleteWorkflow(
536 proto::CompleteWorkflowAction {
537 workflow_status: proto::OrchestrationStatus::Failed as i32,
538 result: None,
539 details: None,
540 new_version: None,
541 carryover_events: vec![],
542 failure_details: Some(proto::TaskFailureDetails {
543 error_type: "WorkerError".to_string(),
544 error_message: message.to_string(),
545 stack_trace: None,
546 inner_failure: None,
547 is_non_retriable: false,
548 }),
549 },
550 ),
551 ),
552 }],
553 custom_status: None,
554 completion_token,
555 num_events_processed: None,
556 version: None,
557 }
558}
559
560#[cfg(test)]
561mod tests {
562 use super::*;
563
564 use std::time::Duration;
565
566 use tokio::sync::oneshot;
567 use tokio::time::timeout;
568
569 const WAIT_TIMEOUT: Duration = Duration::from_secs(5);
570
571 async fn prune_until_empty(tasks: &mut JoinSet<()>) {
572 timeout(WAIT_TIMEOUT, async {
573 while !tasks.is_empty() {
574 prune_finished_tasks(tasks);
575 tokio::task::yield_now().await;
576 }
577 })
578 .await
579 .expect("timed out waiting for prune_finished_tasks to drain the JoinSet");
580 }
581
582 #[tokio::test]
583 async fn prune_finished_tasks_drains_all_completed_tasks() {
584 let mut tasks: JoinSet<()> = JoinSet::new();
585 for _ in 0..16 {
586 tasks.spawn(async {});
587 }
588
589 prune_until_empty(&mut tasks).await;
590
591 assert!(tasks.is_empty());
592 assert_eq!(tasks.len(), 0);
593 }
594
595 #[tokio::test]
596 async fn prune_finished_tasks_keeps_in_flight_tasks() {
597 let mut tasks: JoinSet<()> = JoinSet::new();
598
599 for _ in 0..8 {
600 tasks.spawn(async {});
601 }
602
603 let mut senders = Vec::new();
604 for _ in 0..4 {
605 let (tx, rx) = oneshot::channel::<()>();
606 senders.push(tx);
607 tasks.spawn(async move {
608 let _ = rx.await;
609 });
610 }
611
612 timeout(WAIT_TIMEOUT, async {
613 while tasks.len() > 4 {
614 prune_finished_tasks(&mut tasks);
615 tokio::task::yield_now().await;
616 }
617 })
618 .await
619 .expect("timed out waiting for completed tasks to be pruned");
620 assert_eq!(tasks.len(), 4);
621
622 for tx in senders {
623 let _ = tx.send(());
624 }
625 prune_until_empty(&mut tasks).await;
626 assert!(tasks.is_empty());
627 }
628
629 #[tokio::test]
630 async fn prune_finished_tasks_handles_panicked_tasks() {
631 let mut tasks: JoinSet<()> = JoinSet::new();
632 tasks.spawn(async {
633 panic!("intentional test panic");
634 });
635
636 prune_until_empty(&mut tasks).await;
637 assert!(tasks.is_empty());
638 }
639}