Skip to main content

dapr_durabletask/worker/
grpc_worker.rs

1use std::sync::Arc;
2
3use tokio::task::JoinSet;
4use tokio_stream::StreamExt;
5use tonic::transport::Channel;
6
7use tokio::sync::Semaphore;
8
9use crate::api::DurableTaskError;
10use crate::internal::validate_identifier;
11use crate::proto;
12use crate::proto::history_event::EventType;
13use crate::proto::task_hub_sidecar_service_client::TaskHubSidecarServiceClient;
14use crate::proto::work_item::Request;
15
16use super::activity_executor::ActivityExecutor;
17use super::options::WorkerOptions;
18use super::orchestration_executor::OrchestrationExecutor;
19use super::reconnect_policy::BackoffIter;
20use super::registry::Registry;
21
22/// Worker that connects to a Durable Task sidecar and processes work items.
23///
24/// The worker opens a streaming gRPC connection to receive orchestrator and
25/// activity work items, dispatches them to registered handler functions, and
26/// returns results to the sidecar.
27///
28/// # Example
29///
30/// ```rust,no_run
31/// use dapr_durabletask::worker::TaskHubGrpcWorker;
32/// use dapr_durabletask::task::OrchestrationContext;
33///
34/// # async fn example() {
35/// let mut worker = TaskHubGrpcWorker::new("http://localhost:4001");
36/// worker.registry_mut().add_named_orchestrator("my_orch", |ctx: OrchestrationContext| async move {
37///     let result = ctx.call_activity("greet", "world").await?;
38///     Ok(result)
39/// });
40/// worker.registry_mut().add_named_activity("greet", |_ctx, input| async move {
41///     Ok(input)
42/// });
43///
44/// let shutdown = tokio_util::sync::CancellationToken::new();
45/// // worker.start(shutdown).await.unwrap();
46/// # }
47/// ```
48pub struct TaskHubGrpcWorker {
49    host_address: String,
50    registry: Arc<Registry>,
51    options: Arc<WorkerOptions>,
52}
53
54impl TaskHubGrpcWorker {
55    /// Create a new worker that will connect to the given sidecar address.
56    pub fn new(host_address: &str) -> Self {
57        Self {
58            host_address: host_address.to_string(),
59            registry: Arc::new(Registry::new()),
60            options: Arc::new(WorkerOptions::default()),
61        }
62    }
63
64    /// Create a new worker with custom options.
65    pub fn with_options(host_address: &str, options: WorkerOptions) -> Self {
66        Self {
67            host_address: host_address.to_string(),
68            registry: Arc::new(Registry::new()),
69            options: Arc::new(options),
70        }
71    }
72
73    /// Get a mutable reference to the registry for adding orchestrators and activities.
74    ///
75    /// # Panics
76    ///
77    /// Panics if called after the registry has been shared (i.e., after `start()`
78    /// has begun processing).
79    pub fn registry_mut(&mut self) -> &mut Registry {
80        Arc::get_mut(&mut self.registry).expect("Cannot modify registry after worker has started")
81    }
82
83    /// Start the worker. Runs until the cancellation token is triggered or the
84    /// reconnect policy's `max_attempts` is exhausted.
85    ///
86    /// ## Shutdown behaviour
87    ///
88    /// When the cancellation token is fired:
89    /// 1. The worker **stops reading new work items** from the sidecar stream.
90    /// 2. It **waits for all in-flight tasks** (orchestrations and activities
91    ///    already dispatched) to complete and send their results to the sidecar
92    ///    before returning.
93    ///
94    /// This guarantees that no already-accepted work item is abandoned at
95    /// shutdown. Work items still queued inside the sidecar but not yet
96    /// dispatched to this worker are unaffected — the sidecar will re-dispatch
97    /// them to the next available worker.
98    ///
99    /// [`ReconnectPolicy`]: super::reconnect_policy::ReconnectPolicy
100    pub async fn start(
101        &self,
102        shutdown: tokio_util::sync::CancellationToken,
103    ) -> crate::api::Result<()> {
104        let mut backoff = BackoffIter::new(&self.options.reconnect_policy);
105
106        loop {
107            if shutdown.is_cancelled() {
108                tracing::info!("Worker shutdown before connecting");
109                return Ok(());
110            }
111
112            tracing::info!(address = %self.host_address, "Worker connecting to sidecar");
113
114            match Self::connect(&self.host_address).await {
115                Ok(channel) => {
116                    tracing::info!(address = %self.host_address, "Worker connected, starting work loop");
117                    backoff.reset();
118
119                    let mut client = TaskHubSidecarServiceClient::new(channel);
120
121                    match Self::run_work_loop(&mut client, &self.registry, &self.options, &shutdown)
122                        .await
123                    {
124                        Ok(()) => {
125                            if shutdown.is_cancelled() {
126                                tracing::info!(
127                                    "Worker shut down gracefully after draining in-flight tasks"
128                                );
129                            } else {
130                                tracing::info!("Work item stream closed cleanly; shutting down");
131                            }
132                            return Ok(());
133                        }
134                        Err(e) => {
135                            tracing::warn!(error = %e, "Work loop error");
136                        }
137                    }
138                }
139                Err(e) => {
140                    tracing::warn!(error = %e, "Connection to sidecar failed");
141                }
142            }
143
144            // Connection failed or stream dropped — apply backoff.
145            match backoff.next_delay() {
146                None => {
147                    let msg = format!(
148                        "Worker exceeded maximum reconnect attempts ({}); giving up",
149                        self.options.reconnect_policy.max_attempts.unwrap_or(0)
150                    );
151                    tracing::error!("{}", msg);
152                    return Err(DurableTaskError::ConnectionFailed(msg));
153                }
154                Some(delay) => {
155                    tracing::info!(
156                        delay_ms = delay.as_millis(),
157                        address = %self.host_address,
158                        "Waiting before reconnect"
159                    );
160                    tokio::select! {
161                        _ = shutdown.cancelled() => {
162                            tracing::info!("Worker shutdown during reconnect wait");
163                            return Ok(());
164                        }
165                        _ = tokio::time::sleep(delay) => {}
166                    }
167                }
168            }
169        }
170    }
171
172    async fn connect(address: &str) -> crate::api::Result<Channel> {
173        const USER_AGENT: &str = concat!("dapr-durabletask/rust/", env!("CARGO_PKG_VERSION"));
174
175        Channel::from_shared(address.to_string())
176            .map_err(|e| DurableTaskError::InvalidAddress(format!("Invalid address: {e}")))?
177            .user_agent(USER_AGENT)
178            .map_err(|e| DurableTaskError::InvalidAddress(format!("Invalid user agent: {e}")))?
179            .connect()
180            .await
181            .map_err(|e| DurableTaskError::ConnectionFailed(format!("Connection failed: {e}")))
182    }
183
184    async fn run_work_loop(
185        client: &mut TaskHubSidecarServiceClient<Channel>,
186        registry: &Arc<Registry>,
187        options: &Arc<WorkerOptions>,
188        shutdown: &tokio_util::sync::CancellationToken,
189    ) -> crate::api::Result<()> {
190        let request = proto::GetWorkItemsRequest {};
191        let mut stream = client.get_work_items(request).await?.into_inner();
192        let semaphore = Arc::new(Semaphore::new(options.max_concurrent_work_items));
193        let mut tasks: JoinSet<()> = JoinSet::new();
194        tracing::info!("Work item stream established");
195
196        // `shutdown_triggered` tracks whether we exited the intake loop because
197        // of a cancellation (true) or because the stream closed (false/error).
198        let shutdown_triggered = loop {
199            prune_finished_tasks(&mut tasks);
200            tokio::select! {
201                biased; // check shutdown first so we don't accept more items
202                _ = shutdown.cancelled() => {
203                    tracing::info!(
204                        in_flight = tasks.len(),
205                        "Shutdown: stopping intake, draining in-flight work items"
206                    );
207                    break true;
208                }
209                Some(outcome) = tasks.join_next(), if !tasks.is_empty() => {
210                    if let Err(e) = outcome {
211                        tracing::error!(error = ?e, "Work item task panicked");
212                    }
213                }
214                item = stream.next() => {
215                    match item {
216                        None => {
217                            // Sidecar closed the stream — treat as a transient
218                            // error so the caller will reconnect.
219                            tracing::info!("Work item stream closed by sidecar");
220                            break false;
221                        }
222                        Some(Err(e)) => {
223                            return Err(DurableTaskError::ConnectionFailed(format!("Stream error: {e}")));
224                        }
225                        Some(Ok(work_item)) => {
226                            Self::dispatch_work_item(
227                                work_item,
228                                client.clone(),
229                                registry,
230                                options,
231                                &semaphore,
232                                &mut tasks,
233                            ).await?;
234                        }
235                    }
236                }
237            }
238        };
239
240        // Drain all in-flight tasks before returning, regardless of why we stopped.
241        if !tasks.is_empty() {
242            tracing::info!(count = tasks.len(), "Draining in-flight work items");
243            while let Some(outcome) = tasks.join_next().await {
244                if let Err(e) = outcome {
245                    tracing::error!(error = ?e, "In-flight task panicked during drain");
246                }
247            }
248            tracing::info!("All in-flight work items drained");
249        }
250
251        if shutdown_triggered {
252            // Caller checks shutdown.is_cancelled() to know this was intentional.
253            Ok(())
254        } else {
255            // Stream closed by sidecar — signal the caller to reconnect.
256            Err(DurableTaskError::ConnectionFailed(
257                "Work item stream closed by sidecar".into(),
258            ))
259        }
260    }
261
262    /// Validate and dispatch a single work item into the `JoinSet`.
263    async fn dispatch_work_item(
264        work_item: proto::WorkItem,
265        client: TaskHubSidecarServiceClient<Channel>,
266        registry: &Arc<Registry>,
267        options: &Arc<WorkerOptions>,
268        semaphore: &Arc<Semaphore>,
269        tasks: &mut JoinSet<()>,
270    ) -> crate::api::Result<()> {
271        match work_item.request {
272            Some(Request::WorkflowRequest(req)) => {
273                let instance_id = req.instance_id.clone();
274                if let Err(e) =
275                    validate_identifier(&instance_id, "instance ID", options.max_identifier_length)
276                {
277                    tracing::warn!(
278                        instance_id = %instance_id,
279                        error = %e,
280                        "Rejected work item: invalid instance ID"
281                    );
282                    return Ok(());
283                }
284                tracing::debug!(
285                    instance_id = %instance_id,
286                    past_events = req.past_events.len(),
287                    new_events = req.new_events.len(),
288                    "Received orchestrator work item"
289                );
290
291                let registry = registry.clone();
292                let options = options.clone();
293                let mut stub = client;
294                let completion_token = work_item.completion_token.clone();
295                let permit = semaphore
296                    .clone()
297                    .acquire_owned()
298                    .await
299                    .map_err(|_| DurableTaskError::Internal("Semaphore closed".to_string()))?;
300
301                tasks.spawn(async move {
302                    let _permit = permit;
303                    let response = Self::handle_orchestrator_request(
304                        &registry,
305                        req,
306                        completion_token,
307                        &options,
308                    )
309                    .await;
310                    // TODO: migrate to complete_work_item once sidecar supports it.
311                    #[allow(deprecated)]
312                    if let Err(e) = stub.complete_orchestrator_task(response).await {
313                        tracing::error!(
314                            instance_id = %instance_id,
315                            error = %e,
316                            "Failed to complete orchestrator task"
317                        );
318                    }
319                });
320            }
321            Some(Request::ActivityRequest(req)) => {
322                let instance_id = req
323                    .workflow_instance
324                    .as_ref()
325                    .map(|i| i.instance_id.clone())
326                    .unwrap_or_default();
327                tracing::debug!(
328                    instance_id = %instance_id,
329                    activity = %req.name,
330                    task_id = req.task_id,
331                    "Received activity work item"
332                );
333
334                let registry = registry.clone();
335                let options = options.clone();
336                let mut stub = client;
337                let completion_token = work_item.completion_token.clone();
338                let activity_name = req.name.clone();
339                let permit = semaphore
340                    .clone()
341                    .acquire_owned()
342                    .await
343                    .map_err(|_| DurableTaskError::Internal("Semaphore closed".to_string()))?;
344
345                tasks.spawn(async move {
346                    let _permit = permit;
347                    let response =
348                        Self::handle_activity_request(&registry, req, completion_token, &options)
349                            .await;
350                    if let Err(e) = stub.complete_activity_task(response).await {
351                        tracing::error!(
352                            instance_id = %instance_id,
353                            activity = %activity_name,
354                            error = %e,
355                            "Failed to complete activity task"
356                        );
357                    }
358                });
359            }
360            None => {
361                tracing::warn!("Received work item with no request payload");
362            }
363        }
364        Ok(())
365    }
366
367    async fn handle_orchestrator_request(
368        registry: &Registry,
369        request: proto::WorkflowRequest,
370        completion_token: String,
371        options: &WorkerOptions,
372    ) -> proto::WorkflowResponse {
373        let instance_id = request.instance_id.clone();
374
375        // Single-pass extraction of orchestrator name and version from history.
376        let (name, version) = request
377            .past_events
378            .iter()
379            .chain(request.new_events.iter())
380            .find_map(|e| {
381                if let Some(EventType::ExecutionStarted(es)) = &e.event_type {
382                    Some((es.name.clone(), es.version.clone()))
383                } else {
384                    None
385                }
386            })
387            .unwrap_or_default();
388
389        if let Err(e) =
390            validate_identifier(&name, "orchestrator name", options.max_identifier_length)
391        {
392            tracing::warn!(
393                instance_id = %instance_id,
394                orchestrator = %name,
395                error = %e,
396                "Rejected orchestrator request: invalid name"
397            );
398            return build_error_response(&instance_id, &e.to_string(), completion_token);
399        }
400
401        let orchestrator_fn = match registry.get_orchestrator_version(&name, version.as_deref()) {
402            Some(f) => f,
403            None => {
404                tracing::warn!(
405                    instance_id = %instance_id,
406                    orchestrator = %name,
407                    "Unregistered orchestrator requested"
408                );
409                return build_error_response(
410                    &instance_id,
411                    &format!("Orchestrator '{name}' not registered"),
412                    completion_token,
413                );
414            }
415        };
416
417        match OrchestrationExecutor::execute(
418            orchestrator_fn,
419            &instance_id,
420            request.past_events,
421            request.new_events,
422            completion_token.clone(),
423            options,
424            request
425                .propagated_history
426                .and_then(crate::api::PropagatedHistory::from_proto),
427        )
428        .await
429        {
430            Ok(response) => response,
431            Err(e) => {
432                tracing::error!(
433                    instance_id = %instance_id,
434                    orchestrator = %name,
435                    error = %e,
436                    "Orchestrator execution failed"
437                );
438                build_error_response(&instance_id, &e.to_string(), completion_token)
439            }
440        }
441    }
442
443    async fn handle_activity_request(
444        registry: &Registry,
445        request: proto::ActivityRequest,
446        completion_token: String,
447        options: &WorkerOptions,
448    ) -> proto::ActivityResponse {
449        let instance_id = request
450            .workflow_instance
451            .as_ref()
452            .map(|i| i.instance_id.as_str())
453            .unwrap_or("");
454
455        let build_activity_error =
456            |error_type: &str, error_message: String| proto::ActivityResponse {
457                instance_id: instance_id.to_string(),
458                task_id: request.task_id,
459                result: None,
460                failure_details: Some(proto::TaskFailureDetails {
461                    error_type: error_type.to_string(),
462                    error_message,
463                    stack_trace: None,
464                    inner_failure: None,
465                    is_non_retriable: true,
466                }),
467                completion_token: completion_token.clone(),
468            };
469
470        if let Err(e) = validate_identifier(
471            &request.name,
472            "activity name",
473            options.max_identifier_length,
474        ) {
475            tracing::warn!(
476                instance_id = %instance_id,
477                activity = %request.name,
478                error = %e,
479                "Rejected activity request: invalid name"
480            );
481            return build_activity_error("InvalidActivityName", e.to_string());
482        }
483
484        let activity_fn = match registry.get_activity(&request.name) {
485            Some(f) => f,
486            None => {
487                tracing::warn!(
488                    instance_id = %instance_id,
489                    activity = %request.name,
490                    "Unregistered activity requested"
491                );
492                return build_activity_error(
493                    "ActivityNotRegistered",
494                    format!("Activity '{}' not registered", request.name),
495                );
496            }
497        };
498
499        ActivityExecutor::execute(
500            activity_fn,
501            &request.name,
502            instance_id,
503            request.task_id,
504            request.task_execution_id,
505            request.input,
506            request.parent_trace_context.as_ref(),
507            completion_token,
508            request
509                .propagated_history
510                .and_then(crate::api::PropagatedHistory::from_proto),
511        )
512        .await
513    }
514}
515
516fn prune_finished_tasks(tasks: &mut JoinSet<()>) {
517    while let Some(outcome) = tasks.try_join_next() {
518        if let Err(e) = outcome {
519            tracing::error!(error = ?e, "Work item task panicked");
520        }
521    }
522}
523
524fn build_error_response(
525    instance_id: &str,
526    message: &str,
527    completion_token: String,
528) -> proto::WorkflowResponse {
529    proto::WorkflowResponse {
530        instance_id: instance_id.to_string(),
531        actions: vec![proto::WorkflowAction {
532            id: -1,
533            router: None,
534            workflow_action_type: Some(
535                proto::workflow_action::WorkflowActionType::CompleteWorkflow(
536                    proto::CompleteWorkflowAction {
537                        workflow_status: proto::OrchestrationStatus::Failed as i32,
538                        result: None,
539                        details: None,
540                        new_version: None,
541                        carryover_events: vec![],
542                        failure_details: Some(proto::TaskFailureDetails {
543                            error_type: "WorkerError".to_string(),
544                            error_message: message.to_string(),
545                            stack_trace: None,
546                            inner_failure: None,
547                            is_non_retriable: false,
548                        }),
549                    },
550                ),
551            ),
552        }],
553        custom_status: None,
554        completion_token,
555        num_events_processed: None,
556        version: None,
557    }
558}
559
560#[cfg(test)]
561mod tests {
562    use super::*;
563
564    use std::time::Duration;
565
566    use tokio::sync::oneshot;
567    use tokio::time::timeout;
568
569    const WAIT_TIMEOUT: Duration = Duration::from_secs(5);
570
571    async fn prune_until_empty(tasks: &mut JoinSet<()>) {
572        timeout(WAIT_TIMEOUT, async {
573            while !tasks.is_empty() {
574                prune_finished_tasks(tasks);
575                tokio::task::yield_now().await;
576            }
577        })
578        .await
579        .expect("timed out waiting for prune_finished_tasks to drain the JoinSet");
580    }
581
582    #[tokio::test]
583    async fn prune_finished_tasks_drains_all_completed_tasks() {
584        let mut tasks: JoinSet<()> = JoinSet::new();
585        for _ in 0..16 {
586            tasks.spawn(async {});
587        }
588
589        prune_until_empty(&mut tasks).await;
590
591        assert!(tasks.is_empty());
592        assert_eq!(tasks.len(), 0);
593    }
594
595    #[tokio::test]
596    async fn prune_finished_tasks_keeps_in_flight_tasks() {
597        let mut tasks: JoinSet<()> = JoinSet::new();
598
599        for _ in 0..8 {
600            tasks.spawn(async {});
601        }
602
603        let mut senders = Vec::new();
604        for _ in 0..4 {
605            let (tx, rx) = oneshot::channel::<()>();
606            senders.push(tx);
607            tasks.spawn(async move {
608                let _ = rx.await;
609            });
610        }
611
612        timeout(WAIT_TIMEOUT, async {
613            while tasks.len() > 4 {
614                prune_finished_tasks(&mut tasks);
615                tokio::task::yield_now().await;
616            }
617        })
618        .await
619        .expect("timed out waiting for completed tasks to be pruned");
620        assert_eq!(tasks.len(), 4);
621
622        for tx in senders {
623            let _ = tx.send(());
624        }
625        prune_until_empty(&mut tasks).await;
626        assert!(tasks.is_empty());
627    }
628
629    #[tokio::test]
630    async fn prune_finished_tasks_handles_panicked_tasks() {
631        let mut tasks: JoinSet<()> = JoinSet::new();
632        tasks.spawn(async {
633            panic!("intentional test panic");
634        });
635
636        prune_until_empty(&mut tasks).await;
637        assert!(tasks.is_empty());
638    }
639}