Skip to main content

dapr_durabletask/worker/
grpc_worker.rs

1use std::sync::Arc;
2
3use tokio::task::JoinSet;
4use tokio_stream::StreamExt;
5use tonic::transport::Channel;
6
7use tokio::sync::Semaphore;
8
9use crate::api::DurableTaskError;
10use crate::internal::validate_identifier;
11use crate::proto;
12use crate::proto::history_event::EventType;
13use crate::proto::task_hub_sidecar_service_client::TaskHubSidecarServiceClient;
14use crate::proto::work_item::Request;
15
16use super::activity_executor::ActivityExecutor;
17use super::options::WorkerOptions;
18use super::orchestration_executor::OrchestrationExecutor;
19use super::reconnect_policy::BackoffIter;
20use super::registry::Registry;
21
22/// Worker that connects to a Durable Task sidecar and processes work items.
23///
24/// The worker opens a streaming gRPC connection to receive orchestrator and
25/// activity work items, dispatches them to registered handler functions, and
26/// returns results to the sidecar.
27///
28/// # Example
29///
30/// ```rust,no_run
31/// use dapr_durabletask::worker::TaskHubGrpcWorker;
32/// use dapr_durabletask::task::OrchestrationContext;
33///
34/// # async fn example() {
35/// let mut worker = TaskHubGrpcWorker::new("http://localhost:4001");
36/// worker.registry_mut().add_named_orchestrator("my_orch", |ctx: OrchestrationContext| async move {
37///     let result = ctx.call_activity("greet", "world").await?;
38///     Ok(result)
39/// });
40/// worker.registry_mut().add_named_activity("greet", |_ctx, input| async move {
41///     Ok(input)
42/// });
43///
44/// let shutdown = tokio_util::sync::CancellationToken::new();
45/// // worker.start(shutdown).await.unwrap();
46/// # }
47/// ```
48pub struct TaskHubGrpcWorker {
49    host_address: String,
50    registry: Arc<Registry>,
51    options: Arc<WorkerOptions>,
52}
53
54impl TaskHubGrpcWorker {
55    /// Create a new worker that will connect to the given sidecar address.
56    pub fn new(host_address: &str) -> Self {
57        Self {
58            host_address: host_address.to_string(),
59            registry: Arc::new(Registry::new()),
60            options: Arc::new(WorkerOptions::default()),
61        }
62    }
63
64    /// Create a new worker with custom options.
65    pub fn with_options(host_address: &str, options: WorkerOptions) -> Self {
66        Self {
67            host_address: host_address.to_string(),
68            registry: Arc::new(Registry::new()),
69            options: Arc::new(options),
70        }
71    }
72
73    /// Get a mutable reference to the registry for adding orchestrators and activities.
74    ///
75    /// # Panics
76    ///
77    /// Panics if called after the registry has been shared (i.e., after `start()`
78    /// has begun processing).
79    pub fn registry_mut(&mut self) -> &mut Registry {
80        Arc::get_mut(&mut self.registry).expect("Cannot modify registry after worker has started")
81    }
82
83    /// Start the worker. Runs until the cancellation token is triggered or the
84    /// reconnect policy's `max_attempts` is exhausted.
85    ///
86    /// ## Shutdown behaviour
87    ///
88    /// When the cancellation token is fired:
89    /// 1. The worker **stops reading new work items** from the sidecar stream.
90    /// 2. It **waits for all in-flight tasks** (orchestrations and activities
91    ///    already dispatched) to complete and send their results to the sidecar
92    ///    before returning.
93    ///
94    /// This guarantees that no already-accepted work item is abandoned at
95    /// shutdown. Work items still queued inside the sidecar but not yet
96    /// dispatched to this worker are unaffected — the sidecar will re-dispatch
97    /// them to the next available worker.
98    ///
99    /// [`ReconnectPolicy`]: super::reconnect_policy::ReconnectPolicy
100    pub async fn start(
101        &self,
102        shutdown: tokio_util::sync::CancellationToken,
103    ) -> crate::api::Result<()> {
104        let mut backoff = BackoffIter::new(&self.options.reconnect_policy);
105
106        loop {
107            if shutdown.is_cancelled() {
108                tracing::info!("Worker shutdown before connecting");
109                return Ok(());
110            }
111
112            tracing::info!(address = %self.host_address, "Worker connecting to sidecar");
113
114            match Self::connect(&self.host_address).await {
115                Ok(channel) => {
116                    tracing::info!(address = %self.host_address, "Worker connected, starting work loop");
117                    backoff.reset();
118
119                    let mut client = TaskHubSidecarServiceClient::new(channel);
120
121                    match Self::run_work_loop(&mut client, &self.registry, &self.options, &shutdown)
122                        .await
123                    {
124                        Ok(()) => {
125                            if shutdown.is_cancelled() {
126                                tracing::info!(
127                                    "Worker shut down gracefully after draining in-flight tasks"
128                                );
129                            } else {
130                                tracing::info!("Work item stream closed cleanly; shutting down");
131                            }
132                            return Ok(());
133                        }
134                        Err(e) => {
135                            tracing::warn!(error = %e, "Work loop error");
136                        }
137                    }
138                }
139                Err(e) => {
140                    tracing::warn!(error = %e, "Connection to sidecar failed");
141                }
142            }
143
144            // Connection failed or stream dropped — apply backoff.
145            match backoff.next_delay() {
146                None => {
147                    let msg = format!(
148                        "Worker exceeded maximum reconnect attempts ({}); giving up",
149                        self.options.reconnect_policy.max_attempts.unwrap_or(0)
150                    );
151                    tracing::error!("{}", msg);
152                    return Err(DurableTaskError::ConnectionFailed(msg));
153                }
154                Some(delay) => {
155                    tracing::info!(
156                        delay_ms = delay.as_millis(),
157                        address = %self.host_address,
158                        "Waiting before reconnect"
159                    );
160                    tokio::select! {
161                        _ = shutdown.cancelled() => {
162                            tracing::info!("Worker shutdown during reconnect wait");
163                            return Ok(());
164                        }
165                        _ = tokio::time::sleep(delay) => {}
166                    }
167                }
168            }
169        }
170    }
171
172    async fn connect(address: &str) -> crate::api::Result<Channel> {
173        const USER_AGENT: &str = concat!("dapr-durabletask/rust/", env!("CARGO_PKG_VERSION"));
174
175        Channel::from_shared(address.to_string())
176            .map_err(|e| DurableTaskError::InvalidAddress(format!("Invalid address: {e}")))?
177            .user_agent(USER_AGENT)
178            .map_err(|e| DurableTaskError::InvalidAddress(format!("Invalid user agent: {e}")))?
179            .connect()
180            .await
181            .map_err(|e| DurableTaskError::ConnectionFailed(format!("Connection failed: {e}")))
182    }
183
184    async fn run_work_loop(
185        client: &mut TaskHubSidecarServiceClient<Channel>,
186        registry: &Arc<Registry>,
187        options: &Arc<WorkerOptions>,
188        shutdown: &tokio_util::sync::CancellationToken,
189    ) -> crate::api::Result<()> {
190        let request = proto::GetWorkItemsRequest {};
191        let mut stream = client.get_work_items(request).await?.into_inner();
192        let semaphore = Arc::new(Semaphore::new(options.max_concurrent_work_items));
193        let mut tasks: JoinSet<()> = JoinSet::new();
194        tracing::info!("Work item stream established");
195
196        // `shutdown_triggered` tracks whether we exited the intake loop because
197        // of a cancellation (true) or because the stream closed (false/error).
198        let shutdown_triggered = loop {
199            tokio::select! {
200                biased; // check shutdown first so we don't accept more items
201                _ = shutdown.cancelled() => {
202                    tracing::info!(
203                        in_flight = tasks.len(),
204                        "Shutdown: stopping intake, draining in-flight work items"
205                    );
206                    break true;
207                }
208                item = stream.next() => {
209                    match item {
210                        None => {
211                            // Sidecar closed the stream — treat as a transient
212                            // error so the caller will reconnect.
213                            tracing::info!("Work item stream closed by sidecar");
214                            break false;
215                        }
216                        Some(Err(e)) => {
217                            return Err(DurableTaskError::ConnectionFailed(format!("Stream error: {e}")));
218                        }
219                        Some(Ok(work_item)) => {
220                            Self::dispatch_work_item(
221                                work_item,
222                                client.clone(),
223                                registry,
224                                options,
225                                &semaphore,
226                                &mut tasks,
227                            ).await?;
228                        }
229                    }
230                }
231            }
232        };
233
234        // Drain all in-flight tasks before returning, regardless of why we stopped.
235        if !tasks.is_empty() {
236            tracing::info!(count = tasks.len(), "Draining in-flight work items");
237            while let Some(outcome) = tasks.join_next().await {
238                if let Err(e) = outcome {
239                    tracing::error!(error = ?e, "In-flight task panicked during drain");
240                }
241            }
242            tracing::info!("All in-flight work items drained");
243        }
244
245        if shutdown_triggered {
246            // Caller checks shutdown.is_cancelled() to know this was intentional.
247            Ok(())
248        } else {
249            // Stream closed by sidecar — signal the caller to reconnect.
250            Err(DurableTaskError::ConnectionFailed(
251                "Work item stream closed by sidecar".into(),
252            ))
253        }
254    }
255
256    /// Validate and dispatch a single work item into the `JoinSet`.
257    async fn dispatch_work_item(
258        work_item: proto::WorkItem,
259        client: TaskHubSidecarServiceClient<Channel>,
260        registry: &Arc<Registry>,
261        options: &Arc<WorkerOptions>,
262        semaphore: &Arc<Semaphore>,
263        tasks: &mut JoinSet<()>,
264    ) -> crate::api::Result<()> {
265        match work_item.request {
266            Some(Request::WorkflowRequest(req)) => {
267                let instance_id = req.instance_id.clone();
268                if let Err(e) =
269                    validate_identifier(&instance_id, "instance ID", options.max_identifier_length)
270                {
271                    tracing::warn!(
272                        instance_id = %instance_id,
273                        error = %e,
274                        "Rejected work item: invalid instance ID"
275                    );
276                    return Ok(());
277                }
278                tracing::debug!(
279                    instance_id = %instance_id,
280                    past_events = req.past_events.len(),
281                    new_events = req.new_events.len(),
282                    "Received orchestrator work item"
283                );
284
285                let registry = registry.clone();
286                let options = options.clone();
287                let mut stub = client;
288                let completion_token = work_item.completion_token.clone();
289                let permit = semaphore
290                    .clone()
291                    .acquire_owned()
292                    .await
293                    .map_err(|_| DurableTaskError::Internal("Semaphore closed".to_string()))?;
294
295                tasks.spawn(async move {
296                    let _permit = permit;
297                    let response = Self::handle_orchestrator_request(
298                        &registry,
299                        req,
300                        completion_token,
301                        &options,
302                    )
303                    .await;
304                    // TODO: migrate to complete_work_item once sidecar supports it.
305                    #[allow(deprecated)]
306                    if let Err(e) = stub.complete_orchestrator_task(response).await {
307                        tracing::error!(
308                            instance_id = %instance_id,
309                            error = %e,
310                            "Failed to complete orchestrator task"
311                        );
312                    }
313                });
314            }
315            Some(Request::ActivityRequest(req)) => {
316                let instance_id = req
317                    .workflow_instance
318                    .as_ref()
319                    .map(|i| i.instance_id.clone())
320                    .unwrap_or_default();
321                tracing::debug!(
322                    instance_id = %instance_id,
323                    activity = %req.name,
324                    task_id = req.task_id,
325                    "Received activity work item"
326                );
327
328                let registry = registry.clone();
329                let options = options.clone();
330                let mut stub = client;
331                let completion_token = work_item.completion_token.clone();
332                let activity_name = req.name.clone();
333                let permit = semaphore
334                    .clone()
335                    .acquire_owned()
336                    .await
337                    .map_err(|_| DurableTaskError::Internal("Semaphore closed".to_string()))?;
338
339                tasks.spawn(async move {
340                    let _permit = permit;
341                    let response =
342                        Self::handle_activity_request(&registry, req, completion_token, &options)
343                            .await;
344                    if let Err(e) = stub.complete_activity_task(response).await {
345                        tracing::error!(
346                            instance_id = %instance_id,
347                            activity = %activity_name,
348                            error = %e,
349                            "Failed to complete activity task"
350                        );
351                    }
352                });
353            }
354            None => {
355                tracing::warn!("Received work item with no request payload");
356            }
357        }
358        Ok(())
359    }
360
361    async fn handle_orchestrator_request(
362        registry: &Registry,
363        request: proto::WorkflowRequest,
364        completion_token: String,
365        options: &WorkerOptions,
366    ) -> proto::WorkflowResponse {
367        let instance_id = request.instance_id.clone();
368
369        // Single-pass extraction of orchestrator name and version from history.
370        let (name, version) = request
371            .past_events
372            .iter()
373            .chain(request.new_events.iter())
374            .find_map(|e| {
375                if let Some(EventType::ExecutionStarted(es)) = &e.event_type {
376                    Some((es.name.clone(), es.version.clone()))
377                } else {
378                    None
379                }
380            })
381            .unwrap_or_default();
382
383        if let Err(e) =
384            validate_identifier(&name, "orchestrator name", options.max_identifier_length)
385        {
386            tracing::warn!(
387                instance_id = %instance_id,
388                orchestrator = %name,
389                error = %e,
390                "Rejected orchestrator request: invalid name"
391            );
392            return build_error_response(&instance_id, &e.to_string(), completion_token);
393        }
394
395        let orchestrator_fn = match registry.get_orchestrator_version(&name, version.as_deref()) {
396            Some(f) => f,
397            None => {
398                tracing::warn!(
399                    instance_id = %instance_id,
400                    orchestrator = %name,
401                    "Unregistered orchestrator requested"
402                );
403                return build_error_response(
404                    &instance_id,
405                    &format!("Orchestrator '{name}' not registered"),
406                    completion_token,
407                );
408            }
409        };
410
411        match OrchestrationExecutor::execute(
412            orchestrator_fn,
413            &instance_id,
414            request.past_events,
415            request.new_events,
416            completion_token.clone(),
417            options,
418            request
419                .propagated_history
420                .and_then(crate::api::PropagatedHistory::from_proto),
421        )
422        .await
423        {
424            Ok(response) => response,
425            Err(e) => {
426                tracing::error!(
427                    instance_id = %instance_id,
428                    orchestrator = %name,
429                    error = %e,
430                    "Orchestrator execution failed"
431                );
432                build_error_response(&instance_id, &e.to_string(), completion_token)
433            }
434        }
435    }
436
437    async fn handle_activity_request(
438        registry: &Registry,
439        request: proto::ActivityRequest,
440        completion_token: String,
441        options: &WorkerOptions,
442    ) -> proto::ActivityResponse {
443        let instance_id = request
444            .workflow_instance
445            .as_ref()
446            .map(|i| i.instance_id.as_str())
447            .unwrap_or("");
448
449        let build_activity_error =
450            |error_type: &str, error_message: String| proto::ActivityResponse {
451                instance_id: instance_id.to_string(),
452                task_id: request.task_id,
453                result: None,
454                failure_details: Some(proto::TaskFailureDetails {
455                    error_type: error_type.to_string(),
456                    error_message,
457                    stack_trace: None,
458                    inner_failure: None,
459                    is_non_retriable: true,
460                }),
461                completion_token: completion_token.clone(),
462            };
463
464        if let Err(e) = validate_identifier(
465            &request.name,
466            "activity name",
467            options.max_identifier_length,
468        ) {
469            tracing::warn!(
470                instance_id = %instance_id,
471                activity = %request.name,
472                error = %e,
473                "Rejected activity request: invalid name"
474            );
475            return build_activity_error("InvalidActivityName", e.to_string());
476        }
477
478        let activity_fn = match registry.get_activity(&request.name) {
479            Some(f) => f,
480            None => {
481                tracing::warn!(
482                    instance_id = %instance_id,
483                    activity = %request.name,
484                    "Unregistered activity requested"
485                );
486                return build_activity_error(
487                    "ActivityNotRegistered",
488                    format!("Activity '{}' not registered", request.name),
489                );
490            }
491        };
492
493        ActivityExecutor::execute(
494            activity_fn,
495            &request.name,
496            instance_id,
497            request.task_id,
498            request.task_execution_id,
499            request.input,
500            request.parent_trace_context.as_ref(),
501            completion_token,
502            request
503                .propagated_history
504                .and_then(crate::api::PropagatedHistory::from_proto),
505        )
506        .await
507    }
508}
509
510fn build_error_response(
511    instance_id: &str,
512    message: &str,
513    completion_token: String,
514) -> proto::WorkflowResponse {
515    proto::WorkflowResponse {
516        instance_id: instance_id.to_string(),
517        actions: vec![proto::WorkflowAction {
518            id: -1,
519            router: None,
520            workflow_action_type: Some(
521                proto::workflow_action::WorkflowActionType::CompleteWorkflow(
522                    proto::CompleteWorkflowAction {
523                        workflow_status: proto::OrchestrationStatus::Failed as i32,
524                        result: None,
525                        details: None,
526                        new_version: None,
527                        carryover_events: vec![],
528                        failure_details: Some(proto::TaskFailureDetails {
529                            error_type: "WorkerError".to_string(),
530                            error_message: message.to_string(),
531                            stack_trace: None,
532                            inner_failure: None,
533                            is_non_retriable: false,
534                        }),
535                    },
536                ),
537            ),
538        }],
539        custom_status: None,
540        completion_token,
541        num_events_processed: None,
542        version: None,
543    }
544}