Skip to main content

dapr_durabletask/worker/
grpc_worker.rs

1use std::sync::Arc;
2
3use tokio::task::JoinSet;
4use tokio_stream::StreamExt;
5use tonic::transport::Channel;
6
7use tokio::sync::Semaphore;
8
9use crate::api::DurableTaskError;
10use crate::internal::validate_identifier;
11use crate::proto;
12use crate::proto::history_event::EventType;
13use crate::proto::task_hub_sidecar_service_client::TaskHubSidecarServiceClient;
14use crate::proto::work_item::Request;
15
16use super::activity_executor::ActivityExecutor;
17use super::options::WorkerOptions;
18use super::orchestration_executor::OrchestrationExecutor;
19use super::reconnect_policy::BackoffIter;
20use super::registry::Registry;
21
22/// Worker that connects to a Durable Task sidecar and processes work items.
23///
24/// The worker opens a streaming gRPC connection to receive orchestrator and
25/// activity work items, dispatches them to registered handler functions, and
26/// returns results to the sidecar.
27///
28/// # Example
29///
30/// ```rust,no_run
31/// use dapr_durabletask::worker::TaskHubGrpcWorker;
32/// use dapr_durabletask::task::OrchestrationContext;
33///
34/// # async fn example() {
35/// let mut worker = TaskHubGrpcWorker::new("http://localhost:4001");
36/// worker.registry_mut().add_named_orchestrator("my_orch", |ctx: OrchestrationContext| async move {
37///     let result = ctx.call_activity("greet", "world").await?;
38///     Ok(result)
39/// });
40/// worker.registry_mut().add_named_activity("greet", |_ctx, input| async move {
41///     Ok(input)
42/// });
43///
44/// let shutdown = tokio_util::sync::CancellationToken::new();
45/// // worker.start(shutdown).await.unwrap();
46/// # }
47/// ```
48pub struct TaskHubGrpcWorker {
49    host_address: String,
50    registry: Arc<Registry>,
51    options: Arc<WorkerOptions>,
52}
53
54impl TaskHubGrpcWorker {
55    /// Create a new worker that will connect to the given sidecar address.
56    pub fn new(host_address: &str) -> Self {
57        Self {
58            host_address: host_address.to_string(),
59            registry: Arc::new(Registry::new()),
60            options: Arc::new(WorkerOptions::default()),
61        }
62    }
63
64    /// Create a new worker with custom options.
65    pub fn with_options(host_address: &str, options: WorkerOptions) -> Self {
66        Self {
67            host_address: host_address.to_string(),
68            registry: Arc::new(Registry::new()),
69            options: Arc::new(options),
70        }
71    }
72
73    /// Get a mutable reference to the registry for adding orchestrators and activities.
74    ///
75    /// # Panics
76    ///
77    /// Panics if called after the registry has been shared (i.e., after `start()`
78    /// has begun processing).
79    pub fn registry_mut(&mut self) -> &mut Registry {
80        Arc::get_mut(&mut self.registry).expect("Cannot modify registry after worker has started")
81    }
82
83    /// Start the worker. Runs until the cancellation token is triggered or the
84    /// reconnect policy's `max_attempts` is exhausted.
85    ///
86    /// ## Shutdown behaviour
87    ///
88    /// When the cancellation token is fired:
89    /// 1. The worker **stops reading new work items** from the sidecar stream.
90    /// 2. It **waits for all in-flight tasks** (orchestrations and activities
91    ///    already dispatched) to complete and send their results to the sidecar
92    ///    before returning.
93    ///
94    /// This guarantees that no already-accepted work item is abandoned at
95    /// shutdown. Work items still queued inside the sidecar but not yet
96    /// dispatched to this worker are unaffected — the sidecar will re-dispatch
97    /// them to the next available worker.
98    ///
99    /// [`ReconnectPolicy`]: super::reconnect_policy::ReconnectPolicy
100    pub async fn start(
101        &self,
102        shutdown: tokio_util::sync::CancellationToken,
103    ) -> crate::api::Result<()> {
104        let mut backoff = BackoffIter::new(&self.options.reconnect_policy);
105
106        loop {
107            if shutdown.is_cancelled() {
108                tracing::info!("Worker shutdown before connecting");
109                return Ok(());
110            }
111
112            tracing::info!(address = %self.host_address, "Worker connecting to sidecar");
113
114            match Self::connect(&self.host_address).await {
115                Ok(channel) => {
116                    tracing::info!(address = %self.host_address, "Worker connected, starting work loop");
117                    backoff.reset();
118
119                    let mut client = TaskHubSidecarServiceClient::new(channel);
120
121                    match Self::run_work_loop(&mut client, &self.registry, &self.options, &shutdown)
122                        .await
123                    {
124                        Ok(()) => {
125                            if shutdown.is_cancelled() {
126                                tracing::info!(
127                                    "Worker shut down gracefully after draining in-flight tasks"
128                                );
129                            } else {
130                                tracing::info!("Work item stream closed cleanly; shutting down");
131                            }
132                            return Ok(());
133                        }
134                        Err(e) => {
135                            tracing::warn!(error = %e, "Work loop error");
136                        }
137                    }
138                }
139                Err(e) => {
140                    tracing::warn!(error = %e, "Connection to sidecar failed");
141                }
142            }
143
144            // Connection failed or stream dropped — apply backoff.
145            match backoff.next_delay() {
146                None => {
147                    let msg = format!(
148                        "Worker exceeded maximum reconnect attempts ({}); giving up",
149                        self.options.reconnect_policy.max_attempts.unwrap_or(0)
150                    );
151                    tracing::error!("{}", msg);
152                    return Err(DurableTaskError::Other(msg));
153                }
154                Some(delay) => {
155                    tracing::info!(
156                        delay_ms = delay.as_millis(),
157                        address = %self.host_address,
158                        "Waiting before reconnect"
159                    );
160                    tokio::select! {
161                        _ = shutdown.cancelled() => {
162                            tracing::info!("Worker shutdown during reconnect wait");
163                            return Ok(());
164                        }
165                        _ = tokio::time::sleep(delay) => {}
166                    }
167                }
168            }
169        }
170    }
171
172    async fn connect(address: &str) -> crate::api::Result<Channel> {
173        Channel::from_shared(address.to_string())
174            .map_err(|e| DurableTaskError::Other(format!("Invalid address: {}", e)))?
175            .connect()
176            .await
177            .map_err(|e| DurableTaskError::Other(format!("Connection failed: {}", e)))
178    }
179
180    async fn run_work_loop(
181        client: &mut TaskHubSidecarServiceClient<Channel>,
182        registry: &Arc<Registry>,
183        options: &Arc<WorkerOptions>,
184        shutdown: &tokio_util::sync::CancellationToken,
185    ) -> crate::api::Result<()> {
186        let request = proto::GetWorkItemsRequest {};
187        let mut stream = client.get_work_items(request).await?.into_inner();
188        let semaphore = Arc::new(Semaphore::new(options.max_concurrent_work_items));
189        let mut tasks: JoinSet<()> = JoinSet::new();
190        tracing::info!("Work item stream established");
191
192        // `shutdown_triggered` tracks whether we exited the intake loop because
193        // of a cancellation (true) or because the stream closed (false/error).
194        let shutdown_triggered = loop {
195            tokio::select! {
196                biased; // check shutdown first so we don't accept more items
197                _ = shutdown.cancelled() => {
198                    tracing::info!(
199                        in_flight = tasks.len(),
200                        "Shutdown: stopping intake, draining in-flight work items"
201                    );
202                    break true;
203                }
204                item = stream.next() => {
205                    match item {
206                        None => {
207                            // Sidecar closed the stream — treat as a transient
208                            // error so the caller will reconnect.
209                            tracing::info!("Work item stream closed by sidecar");
210                            break false;
211                        }
212                        Some(Err(e)) => {
213                            return Err(DurableTaskError::Other(format!("Stream error: {e}")));
214                        }
215                        Some(Ok(work_item)) => {
216                            Self::dispatch_work_item(
217                                work_item,
218                                client,
219                                registry,
220                                options,
221                                &semaphore,
222                                &mut tasks,
223                            ).await?;
224                        }
225                    }
226                }
227            }
228        };
229
230        // Drain all in-flight tasks before returning, regardless of why we stopped.
231        if !tasks.is_empty() {
232            tracing::info!(count = tasks.len(), "Draining in-flight work items");
233            while let Some(outcome) = tasks.join_next().await {
234                if let Err(e) = outcome {
235                    tracing::error!(error = ?e, "In-flight task panicked during drain");
236                }
237            }
238            tracing::info!("All in-flight work items drained");
239        }
240
241        if shutdown_triggered {
242            // Caller checks shutdown.is_cancelled() to know this was intentional.
243            Ok(())
244        } else {
245            // Stream closed by sidecar — signal the caller to reconnect.
246            Err(DurableTaskError::Other(
247                "Work item stream closed by sidecar".into(),
248            ))
249        }
250    }
251
252    /// Validate and dispatch a single work item into the `JoinSet`.
253    async fn dispatch_work_item(
254        work_item: proto::WorkItem,
255        client: &TaskHubSidecarServiceClient<Channel>,
256        registry: &Arc<Registry>,
257        options: &Arc<WorkerOptions>,
258        semaphore: &Arc<Semaphore>,
259        tasks: &mut JoinSet<()>,
260    ) -> crate::api::Result<()> {
261        match work_item.request {
262            Some(Request::WorkflowRequest(req)) => {
263                let instance_id = req.instance_id.clone();
264                if let Err(e) =
265                    validate_identifier(&instance_id, "instance ID", options.max_identifier_length)
266                {
267                    tracing::warn!(
268                        instance_id = %instance_id,
269                        error = %e,
270                        "Rejected work item: invalid instance ID"
271                    );
272                    return Ok(());
273                }
274                tracing::debug!(
275                    instance_id = %instance_id,
276                    past_events = req.past_events.len(),
277                    new_events = req.new_events.len(),
278                    "Received orchestrator work item"
279                );
280
281                let registry = registry.clone();
282                let options = options.clone();
283                let mut stub = client.clone();
284                let completion_token = work_item.completion_token.clone();
285                let permit = semaphore
286                    .clone()
287                    .acquire_owned()
288                    .await
289                    .map_err(|_| DurableTaskError::Other("Semaphore closed".to_string()))?;
290
291                tasks.spawn(async move {
292                    let _permit = permit;
293                    let response = Self::handle_orchestrator_request(
294                        &registry,
295                        req,
296                        completion_token,
297                        &options,
298                    )
299                    .await;
300                    #[allow(deprecated)]
301                    if let Err(e) = stub.complete_orchestrator_task(response).await {
302                        tracing::error!(
303                            instance_id = %instance_id,
304                            error = %e,
305                            "Failed to complete orchestrator task"
306                        );
307                    }
308                });
309            }
310            Some(Request::ActivityRequest(req)) => {
311                let instance_id = req
312                    .workflow_instance
313                    .as_ref()
314                    .map(|i| i.instance_id.clone())
315                    .unwrap_or_default();
316                tracing::debug!(
317                    instance_id = %instance_id,
318                    activity = %req.name,
319                    task_id = req.task_id,
320                    "Received activity work item"
321                );
322
323                let registry = registry.clone();
324                let options = options.clone();
325                let mut stub = client.clone();
326                let completion_token = work_item.completion_token.clone();
327                let activity_name = req.name.clone();
328                let permit = semaphore
329                    .clone()
330                    .acquire_owned()
331                    .await
332                    .map_err(|_| DurableTaskError::Other("Semaphore closed".to_string()))?;
333
334                tasks.spawn(async move {
335                    let _permit = permit;
336                    let response =
337                        Self::handle_activity_request(&registry, req, completion_token, &options)
338                            .await;
339                    if let Err(e) = stub.complete_activity_task(response).await {
340                        tracing::error!(
341                            instance_id = %instance_id,
342                            activity = %activity_name,
343                            error = %e,
344                            "Failed to complete activity task"
345                        );
346                    }
347                });
348            }
349            None => {
350                tracing::warn!("Received work item with no request payload");
351            }
352        }
353        Ok(())
354    }
355
356    async fn handle_orchestrator_request(
357        registry: &Registry,
358        request: proto::WorkflowRequest,
359        completion_token: String,
360        options: &WorkerOptions,
361    ) -> proto::WorkflowResponse {
362        let instance_id = request.instance_id.clone();
363
364        // Single-pass extraction of orchestrator name and version from history.
365        let (name, version) = request
366            .past_events
367            .iter()
368            .chain(request.new_events.iter())
369            .find_map(|e| {
370                if let Some(EventType::ExecutionStarted(es)) = &e.event_type {
371                    Some((es.name.clone(), es.version.clone()))
372                } else {
373                    None
374                }
375            })
376            .unwrap_or_default();
377
378        if let Err(e) =
379            validate_identifier(&name, "orchestrator name", options.max_identifier_length)
380        {
381            tracing::warn!(
382                instance_id = %instance_id,
383                orchestrator = %name,
384                error = %e,
385                "Rejected orchestrator request: invalid name"
386            );
387            return build_error_response(&instance_id, &e.to_string(), completion_token);
388        }
389
390        let orchestrator_fn = match registry.get_orchestrator_version(&name, version.as_deref()) {
391            Some(f) => f,
392            None => {
393                tracing::warn!(
394                    instance_id = %instance_id,
395                    orchestrator = %name,
396                    "Unregistered orchestrator requested"
397                );
398                return build_error_response(
399                    &instance_id,
400                    &format!("Orchestrator '{}' not registered", name),
401                    completion_token,
402                );
403            }
404        };
405
406        match OrchestrationExecutor::execute(
407            orchestrator_fn,
408            &instance_id,
409            request.past_events,
410            request.new_events,
411            completion_token.clone(),
412            options,
413            request
414                .propagated_history
415                .and_then(crate::api::PropagatedHistory::from_proto),
416        )
417        .await
418        {
419            Ok(response) => response,
420            Err(e) => {
421                tracing::error!(
422                    instance_id = %instance_id,
423                    orchestrator = %name,
424                    error = %e,
425                    "Orchestrator execution failed"
426                );
427                build_error_response(&instance_id, &e.to_string(), completion_token)
428            }
429        }
430    }
431
432    async fn handle_activity_request(
433        registry: &Registry,
434        request: proto::ActivityRequest,
435        completion_token: String,
436        options: &WorkerOptions,
437    ) -> proto::ActivityResponse {
438        let instance_id = request
439            .workflow_instance
440            .as_ref()
441            .map(|i| i.instance_id.as_str())
442            .unwrap_or("");
443
444        let build_activity_error =
445            |error_type: &str, error_message: String| proto::ActivityResponse {
446                instance_id: instance_id.to_string(),
447                task_id: request.task_id,
448                result: None,
449                failure_details: Some(proto::TaskFailureDetails {
450                    error_type: error_type.to_string(),
451                    error_message,
452                    stack_trace: None,
453                    inner_failure: None,
454                    is_non_retriable: true,
455                }),
456                completion_token: completion_token.clone(),
457            };
458
459        if let Err(e) = validate_identifier(
460            &request.name,
461            "activity name",
462            options.max_identifier_length,
463        ) {
464            tracing::warn!(
465                instance_id = %instance_id,
466                activity = %request.name,
467                error = %e,
468                "Rejected activity request: invalid name"
469            );
470            return build_activity_error("InvalidActivityName", e.to_string());
471        }
472
473        let activity_fn = match registry.get_activity(&request.name) {
474            Some(f) => f,
475            None => {
476                tracing::warn!(
477                    instance_id = %instance_id,
478                    activity = %request.name,
479                    "Unregistered activity requested"
480                );
481                return build_activity_error(
482                    "ActivityNotRegistered",
483                    format!("Activity '{}' not registered", request.name),
484                );
485            }
486        };
487
488        ActivityExecutor::execute(
489            activity_fn,
490            &request.name,
491            instance_id,
492            request.task_id,
493            request.task_execution_id,
494            request.input,
495            request.parent_trace_context.as_ref(),
496            completion_token,
497            request
498                .propagated_history
499                .and_then(crate::api::PropagatedHistory::from_proto),
500        )
501        .await
502    }
503}
504
505fn build_error_response(
506    instance_id: &str,
507    message: &str,
508    completion_token: String,
509) -> proto::WorkflowResponse {
510    proto::WorkflowResponse {
511        instance_id: instance_id.to_string(),
512        actions: vec![proto::WorkflowAction {
513            id: -1,
514            router: None,
515            workflow_action_type: Some(
516                proto::workflow_action::WorkflowActionType::CompleteWorkflow(
517                    proto::CompleteWorkflowAction {
518                        workflow_status: proto::OrchestrationStatus::Failed as i32,
519                        result: None,
520                        details: None,
521                        new_version: None,
522                        carryover_events: vec![],
523                        failure_details: Some(proto::TaskFailureDetails {
524                            error_type: "WorkerError".to_string(),
525                            error_message: message.to_string(),
526                            stack_trace: None,
527                            inner_failure: None,
528                            is_non_retriable: false,
529                        }),
530                    },
531                ),
532            ),
533        }],
534        custom_status: None,
535        completion_token,
536        num_events_processed: None,
537        version: None,
538    }
539}