1use std::sync::Arc;
2
3use tokio::task::JoinSet;
4use tokio_stream::StreamExt;
5use tonic::transport::Channel;
6
7use tokio::sync::Semaphore;
8
9use crate::api::DurableTaskError;
10use crate::internal::validate_identifier;
11use crate::proto;
12use crate::proto::history_event::EventType;
13use crate::proto::task_hub_sidecar_service_client::TaskHubSidecarServiceClient;
14use crate::proto::work_item::Request;
15
16use super::activity_executor::ActivityExecutor;
17use super::options::WorkerOptions;
18use super::orchestration_executor::OrchestrationExecutor;
19use super::reconnect_policy::BackoffIter;
20use super::registry::Registry;
21
22pub struct TaskHubGrpcWorker {
49 host_address: String,
50 registry: Arc<Registry>,
51 options: Arc<WorkerOptions>,
52}
53
54impl TaskHubGrpcWorker {
55 pub fn new(host_address: &str) -> Self {
57 Self {
58 host_address: host_address.to_string(),
59 registry: Arc::new(Registry::new()),
60 options: Arc::new(WorkerOptions::default()),
61 }
62 }
63
64 pub fn with_options(host_address: &str, options: WorkerOptions) -> Self {
66 Self {
67 host_address: host_address.to_string(),
68 registry: Arc::new(Registry::new()),
69 options: Arc::new(options),
70 }
71 }
72
73 pub fn registry_mut(&mut self) -> &mut Registry {
80 Arc::get_mut(&mut self.registry).expect("Cannot modify registry after worker has started")
81 }
82
83 pub async fn start(
101 &self,
102 shutdown: tokio_util::sync::CancellationToken,
103 ) -> crate::api::Result<()> {
104 let mut backoff = BackoffIter::new(&self.options.reconnect_policy);
105
106 loop {
107 if shutdown.is_cancelled() {
108 tracing::info!("Worker shutdown before connecting");
109 return Ok(());
110 }
111
112 tracing::info!(address = %self.host_address, "Worker connecting to sidecar");
113
114 match Self::connect(&self.host_address).await {
115 Ok(channel) => {
116 tracing::info!(address = %self.host_address, "Worker connected, starting work loop");
117 backoff.reset();
118
119 let mut client = TaskHubSidecarServiceClient::new(channel);
120
121 match Self::run_work_loop(&mut client, &self.registry, &self.options, &shutdown)
122 .await
123 {
124 Ok(()) => {
125 if shutdown.is_cancelled() {
126 tracing::info!(
127 "Worker shut down gracefully after draining in-flight tasks"
128 );
129 } else {
130 tracing::info!("Work item stream closed cleanly; shutting down");
131 }
132 return Ok(());
133 }
134 Err(e) => {
135 tracing::warn!(error = %e, "Work loop error");
136 }
137 }
138 }
139 Err(e) => {
140 tracing::warn!(error = %e, "Connection to sidecar failed");
141 }
142 }
143
144 match backoff.next_delay() {
146 None => {
147 let msg = format!(
148 "Worker exceeded maximum reconnect attempts ({}); giving up",
149 self.options.reconnect_policy.max_attempts.unwrap_or(0)
150 );
151 tracing::error!("{}", msg);
152 return Err(DurableTaskError::ConnectionFailed(msg));
153 }
154 Some(delay) => {
155 tracing::info!(
156 delay_ms = delay.as_millis(),
157 address = %self.host_address,
158 "Waiting before reconnect"
159 );
160 tokio::select! {
161 _ = shutdown.cancelled() => {
162 tracing::info!("Worker shutdown during reconnect wait");
163 return Ok(());
164 }
165 _ = tokio::time::sleep(delay) => {}
166 }
167 }
168 }
169 }
170 }
171
172 async fn connect(address: &str) -> crate::api::Result<Channel> {
173 const USER_AGENT: &str = concat!("dapr-durabletask/rust/", env!("CARGO_PKG_VERSION"));
174
175 Channel::from_shared(address.to_string())
176 .map_err(|e| DurableTaskError::InvalidAddress(format!("Invalid address: {e}")))?
177 .user_agent(USER_AGENT)
178 .map_err(|e| DurableTaskError::InvalidAddress(format!("Invalid user agent: {e}")))?
179 .connect()
180 .await
181 .map_err(|e| DurableTaskError::ConnectionFailed(format!("Connection failed: {e}")))
182 }
183
184 async fn run_work_loop(
185 client: &mut TaskHubSidecarServiceClient<Channel>,
186 registry: &Arc<Registry>,
187 options: &Arc<WorkerOptions>,
188 shutdown: &tokio_util::sync::CancellationToken,
189 ) -> crate::api::Result<()> {
190 let request = proto::GetWorkItemsRequest {};
191 let mut stream = client.get_work_items(request).await?.into_inner();
192 let semaphore = Arc::new(Semaphore::new(options.max_concurrent_work_items));
193 let mut tasks: JoinSet<()> = JoinSet::new();
194 tracing::info!("Work item stream established");
195
196 let shutdown_triggered = loop {
199 tokio::select! {
200 biased; _ = shutdown.cancelled() => {
202 tracing::info!(
203 in_flight = tasks.len(),
204 "Shutdown: stopping intake, draining in-flight work items"
205 );
206 break true;
207 }
208 item = stream.next() => {
209 match item {
210 None => {
211 tracing::info!("Work item stream closed by sidecar");
214 break false;
215 }
216 Some(Err(e)) => {
217 return Err(DurableTaskError::ConnectionFailed(format!("Stream error: {e}")));
218 }
219 Some(Ok(work_item)) => {
220 Self::dispatch_work_item(
221 work_item,
222 client.clone(),
223 registry,
224 options,
225 &semaphore,
226 &mut tasks,
227 ).await?;
228 }
229 }
230 }
231 }
232 };
233
234 if !tasks.is_empty() {
236 tracing::info!(count = tasks.len(), "Draining in-flight work items");
237 while let Some(outcome) = tasks.join_next().await {
238 if let Err(e) = outcome {
239 tracing::error!(error = ?e, "In-flight task panicked during drain");
240 }
241 }
242 tracing::info!("All in-flight work items drained");
243 }
244
245 if shutdown_triggered {
246 Ok(())
248 } else {
249 Err(DurableTaskError::ConnectionFailed(
251 "Work item stream closed by sidecar".into(),
252 ))
253 }
254 }
255
256 async fn dispatch_work_item(
258 work_item: proto::WorkItem,
259 client: TaskHubSidecarServiceClient<Channel>,
260 registry: &Arc<Registry>,
261 options: &Arc<WorkerOptions>,
262 semaphore: &Arc<Semaphore>,
263 tasks: &mut JoinSet<()>,
264 ) -> crate::api::Result<()> {
265 match work_item.request {
266 Some(Request::WorkflowRequest(req)) => {
267 let instance_id = req.instance_id.clone();
268 if let Err(e) =
269 validate_identifier(&instance_id, "instance ID", options.max_identifier_length)
270 {
271 tracing::warn!(
272 instance_id = %instance_id,
273 error = %e,
274 "Rejected work item: invalid instance ID"
275 );
276 return Ok(());
277 }
278 tracing::debug!(
279 instance_id = %instance_id,
280 past_events = req.past_events.len(),
281 new_events = req.new_events.len(),
282 "Received orchestrator work item"
283 );
284
285 let registry = registry.clone();
286 let options = options.clone();
287 let mut stub = client;
288 let completion_token = work_item.completion_token.clone();
289 let permit = semaphore
290 .clone()
291 .acquire_owned()
292 .await
293 .map_err(|_| DurableTaskError::Internal("Semaphore closed".to_string()))?;
294
295 tasks.spawn(async move {
296 let _permit = permit;
297 let response = Self::handle_orchestrator_request(
298 ®istry,
299 req,
300 completion_token,
301 &options,
302 )
303 .await;
304 #[allow(deprecated)]
306 if let Err(e) = stub.complete_orchestrator_task(response).await {
307 tracing::error!(
308 instance_id = %instance_id,
309 error = %e,
310 "Failed to complete orchestrator task"
311 );
312 }
313 });
314 }
315 Some(Request::ActivityRequest(req)) => {
316 let instance_id = req
317 .workflow_instance
318 .as_ref()
319 .map(|i| i.instance_id.clone())
320 .unwrap_or_default();
321 tracing::debug!(
322 instance_id = %instance_id,
323 activity = %req.name,
324 task_id = req.task_id,
325 "Received activity work item"
326 );
327
328 let registry = registry.clone();
329 let options = options.clone();
330 let mut stub = client;
331 let completion_token = work_item.completion_token.clone();
332 let activity_name = req.name.clone();
333 let permit = semaphore
334 .clone()
335 .acquire_owned()
336 .await
337 .map_err(|_| DurableTaskError::Internal("Semaphore closed".to_string()))?;
338
339 tasks.spawn(async move {
340 let _permit = permit;
341 let response =
342 Self::handle_activity_request(®istry, req, completion_token, &options)
343 .await;
344 if let Err(e) = stub.complete_activity_task(response).await {
345 tracing::error!(
346 instance_id = %instance_id,
347 activity = %activity_name,
348 error = %e,
349 "Failed to complete activity task"
350 );
351 }
352 });
353 }
354 None => {
355 tracing::warn!("Received work item with no request payload");
356 }
357 }
358 Ok(())
359 }
360
361 async fn handle_orchestrator_request(
362 registry: &Registry,
363 request: proto::WorkflowRequest,
364 completion_token: String,
365 options: &WorkerOptions,
366 ) -> proto::WorkflowResponse {
367 let instance_id = request.instance_id.clone();
368
369 let (name, version) = request
371 .past_events
372 .iter()
373 .chain(request.new_events.iter())
374 .find_map(|e| {
375 if let Some(EventType::ExecutionStarted(es)) = &e.event_type {
376 Some((es.name.clone(), es.version.clone()))
377 } else {
378 None
379 }
380 })
381 .unwrap_or_default();
382
383 if let Err(e) =
384 validate_identifier(&name, "orchestrator name", options.max_identifier_length)
385 {
386 tracing::warn!(
387 instance_id = %instance_id,
388 orchestrator = %name,
389 error = %e,
390 "Rejected orchestrator request: invalid name"
391 );
392 return build_error_response(&instance_id, &e.to_string(), completion_token);
393 }
394
395 let orchestrator_fn = match registry.get_orchestrator_version(&name, version.as_deref()) {
396 Some(f) => f,
397 None => {
398 tracing::warn!(
399 instance_id = %instance_id,
400 orchestrator = %name,
401 "Unregistered orchestrator requested"
402 );
403 return build_error_response(
404 &instance_id,
405 &format!("Orchestrator '{name}' not registered"),
406 completion_token,
407 );
408 }
409 };
410
411 match OrchestrationExecutor::execute(
412 orchestrator_fn,
413 &instance_id,
414 request.past_events,
415 request.new_events,
416 completion_token.clone(),
417 options,
418 request
419 .propagated_history
420 .and_then(crate::api::PropagatedHistory::from_proto),
421 )
422 .await
423 {
424 Ok(response) => response,
425 Err(e) => {
426 tracing::error!(
427 instance_id = %instance_id,
428 orchestrator = %name,
429 error = %e,
430 "Orchestrator execution failed"
431 );
432 build_error_response(&instance_id, &e.to_string(), completion_token)
433 }
434 }
435 }
436
437 async fn handle_activity_request(
438 registry: &Registry,
439 request: proto::ActivityRequest,
440 completion_token: String,
441 options: &WorkerOptions,
442 ) -> proto::ActivityResponse {
443 let instance_id = request
444 .workflow_instance
445 .as_ref()
446 .map(|i| i.instance_id.as_str())
447 .unwrap_or("");
448
449 let build_activity_error =
450 |error_type: &str, error_message: String| proto::ActivityResponse {
451 instance_id: instance_id.to_string(),
452 task_id: request.task_id,
453 result: None,
454 failure_details: Some(proto::TaskFailureDetails {
455 error_type: error_type.to_string(),
456 error_message,
457 stack_trace: None,
458 inner_failure: None,
459 is_non_retriable: true,
460 }),
461 completion_token: completion_token.clone(),
462 };
463
464 if let Err(e) = validate_identifier(
465 &request.name,
466 "activity name",
467 options.max_identifier_length,
468 ) {
469 tracing::warn!(
470 instance_id = %instance_id,
471 activity = %request.name,
472 error = %e,
473 "Rejected activity request: invalid name"
474 );
475 return build_activity_error("InvalidActivityName", e.to_string());
476 }
477
478 let activity_fn = match registry.get_activity(&request.name) {
479 Some(f) => f,
480 None => {
481 tracing::warn!(
482 instance_id = %instance_id,
483 activity = %request.name,
484 "Unregistered activity requested"
485 );
486 return build_activity_error(
487 "ActivityNotRegistered",
488 format!("Activity '{}' not registered", request.name),
489 );
490 }
491 };
492
493 ActivityExecutor::execute(
494 activity_fn,
495 &request.name,
496 instance_id,
497 request.task_id,
498 request.task_execution_id,
499 request.input,
500 request.parent_trace_context.as_ref(),
501 completion_token,
502 request
503 .propagated_history
504 .and_then(crate::api::PropagatedHistory::from_proto),
505 )
506 .await
507 }
508}
509
510fn build_error_response(
511 instance_id: &str,
512 message: &str,
513 completion_token: String,
514) -> proto::WorkflowResponse {
515 proto::WorkflowResponse {
516 instance_id: instance_id.to_string(),
517 actions: vec![proto::WorkflowAction {
518 id: -1,
519 router: None,
520 workflow_action_type: Some(
521 proto::workflow_action::WorkflowActionType::CompleteWorkflow(
522 proto::CompleteWorkflowAction {
523 workflow_status: proto::OrchestrationStatus::Failed as i32,
524 result: None,
525 details: None,
526 new_version: None,
527 carryover_events: vec![],
528 failure_details: Some(proto::TaskFailureDetails {
529 error_type: "WorkerError".to_string(),
530 error_message: message.to_string(),
531 stack_trace: None,
532 inner_failure: None,
533 is_non_retriable: false,
534 }),
535 },
536 ),
537 ),
538 }],
539 custom_status: None,
540 completion_token,
541 num_events_processed: None,
542 version: None,
543 }
544}