1use std::sync::Arc;
2
3use tokio::task::JoinSet;
4use tokio_stream::StreamExt;
5use tonic::transport::Channel;
6
7use tokio::sync::Semaphore;
8
9use crate::api::DurableTaskError;
10use crate::internal::validate_identifier;
11use crate::proto;
12use crate::proto::history_event::EventType;
13use crate::proto::task_hub_sidecar_service_client::TaskHubSidecarServiceClient;
14use crate::proto::work_item::Request;
15
16use super::activity_executor::ActivityExecutor;
17use super::options::WorkerOptions;
18use super::orchestration_executor::OrchestrationExecutor;
19use super::reconnect_policy::BackoffIter;
20use super::registry::Registry;
21
22pub struct TaskHubGrpcWorker {
49 host_address: String,
50 registry: Arc<Registry>,
51 options: Arc<WorkerOptions>,
52}
53
54impl TaskHubGrpcWorker {
55 pub fn new(host_address: &str) -> Self {
57 Self {
58 host_address: host_address.to_string(),
59 registry: Arc::new(Registry::new()),
60 options: Arc::new(WorkerOptions::default()),
61 }
62 }
63
64 pub fn with_options(host_address: &str, options: WorkerOptions) -> Self {
66 Self {
67 host_address: host_address.to_string(),
68 registry: Arc::new(Registry::new()),
69 options: Arc::new(options),
70 }
71 }
72
73 pub fn registry_mut(&mut self) -> &mut Registry {
80 Arc::get_mut(&mut self.registry).expect("Cannot modify registry after worker has started")
81 }
82
83 pub async fn start(
101 &self,
102 shutdown: tokio_util::sync::CancellationToken,
103 ) -> crate::api::Result<()> {
104 let mut backoff = BackoffIter::new(&self.options.reconnect_policy);
105
106 loop {
107 if shutdown.is_cancelled() {
108 tracing::info!("Worker shutdown before connecting");
109 return Ok(());
110 }
111
112 tracing::info!(address = %self.host_address, "Worker connecting to sidecar");
113
114 match Self::connect(&self.host_address).await {
115 Ok(channel) => {
116 tracing::info!(address = %self.host_address, "Worker connected, starting work loop");
117 backoff.reset();
118
119 let mut client = TaskHubSidecarServiceClient::new(channel);
120
121 match Self::run_work_loop(&mut client, &self.registry, &self.options, &shutdown)
122 .await
123 {
124 Ok(()) => {
125 if shutdown.is_cancelled() {
126 tracing::info!(
127 "Worker shut down gracefully after draining in-flight tasks"
128 );
129 } else {
130 tracing::info!("Work item stream closed cleanly; shutting down");
131 }
132 return Ok(());
133 }
134 Err(e) => {
135 tracing::warn!(error = %e, "Work loop error");
136 }
137 }
138 }
139 Err(e) => {
140 tracing::warn!(error = %e, "Connection to sidecar failed");
141 }
142 }
143
144 match backoff.next_delay() {
146 None => {
147 let msg = format!(
148 "Worker exceeded maximum reconnect attempts ({}); giving up",
149 self.options.reconnect_policy.max_attempts.unwrap_or(0)
150 );
151 tracing::error!("{}", msg);
152 return Err(DurableTaskError::Other(msg));
153 }
154 Some(delay) => {
155 tracing::info!(
156 delay_ms = delay.as_millis(),
157 address = %self.host_address,
158 "Waiting before reconnect"
159 );
160 tokio::select! {
161 _ = shutdown.cancelled() => {
162 tracing::info!("Worker shutdown during reconnect wait");
163 return Ok(());
164 }
165 _ = tokio::time::sleep(delay) => {}
166 }
167 }
168 }
169 }
170 }
171
172 async fn connect(address: &str) -> crate::api::Result<Channel> {
173 Channel::from_shared(address.to_string())
174 .map_err(|e| DurableTaskError::Other(format!("Invalid address: {}", e)))?
175 .connect()
176 .await
177 .map_err(|e| DurableTaskError::Other(format!("Connection failed: {}", e)))
178 }
179
180 async fn run_work_loop(
181 client: &mut TaskHubSidecarServiceClient<Channel>,
182 registry: &Arc<Registry>,
183 options: &Arc<WorkerOptions>,
184 shutdown: &tokio_util::sync::CancellationToken,
185 ) -> crate::api::Result<()> {
186 let request = proto::GetWorkItemsRequest {};
187 let mut stream = client.get_work_items(request).await?.into_inner();
188 let semaphore = Arc::new(Semaphore::new(options.max_concurrent_work_items));
189 let mut tasks: JoinSet<()> = JoinSet::new();
190 tracing::info!("Work item stream established");
191
192 let shutdown_triggered = loop {
195 tokio::select! {
196 biased; _ = shutdown.cancelled() => {
198 tracing::info!(
199 in_flight = tasks.len(),
200 "Shutdown: stopping intake, draining in-flight work items"
201 );
202 break true;
203 }
204 item = stream.next() => {
205 match item {
206 None => {
207 tracing::info!("Work item stream closed by sidecar");
210 break false;
211 }
212 Some(Err(e)) => {
213 return Err(DurableTaskError::Other(format!("Stream error: {e}")));
214 }
215 Some(Ok(work_item)) => {
216 Self::dispatch_work_item(
217 work_item,
218 client,
219 registry,
220 options,
221 &semaphore,
222 &mut tasks,
223 ).await?;
224 }
225 }
226 }
227 }
228 };
229
230 if !tasks.is_empty() {
232 tracing::info!(count = tasks.len(), "Draining in-flight work items");
233 while let Some(outcome) = tasks.join_next().await {
234 if let Err(e) = outcome {
235 tracing::error!(error = ?e, "In-flight task panicked during drain");
236 }
237 }
238 tracing::info!("All in-flight work items drained");
239 }
240
241 if shutdown_triggered {
242 Ok(())
244 } else {
245 Err(DurableTaskError::Other(
247 "Work item stream closed by sidecar".into(),
248 ))
249 }
250 }
251
252 async fn dispatch_work_item(
254 work_item: proto::WorkItem,
255 client: &TaskHubSidecarServiceClient<Channel>,
256 registry: &Arc<Registry>,
257 options: &Arc<WorkerOptions>,
258 semaphore: &Arc<Semaphore>,
259 tasks: &mut JoinSet<()>,
260 ) -> crate::api::Result<()> {
261 match work_item.request {
262 Some(Request::WorkflowRequest(req)) => {
263 let instance_id = req.instance_id.clone();
264 if let Err(e) =
265 validate_identifier(&instance_id, "instance ID", options.max_identifier_length)
266 {
267 tracing::warn!(
268 instance_id = %instance_id,
269 error = %e,
270 "Rejected work item: invalid instance ID"
271 );
272 return Ok(());
273 }
274 tracing::debug!(
275 instance_id = %instance_id,
276 past_events = req.past_events.len(),
277 new_events = req.new_events.len(),
278 "Received orchestrator work item"
279 );
280
281 let registry = registry.clone();
282 let options = options.clone();
283 let mut stub = client.clone();
284 let completion_token = work_item.completion_token.clone();
285 let permit = semaphore
286 .clone()
287 .acquire_owned()
288 .await
289 .map_err(|_| DurableTaskError::Other("Semaphore closed".to_string()))?;
290
291 tasks.spawn(async move {
292 let _permit = permit;
293 let response = Self::handle_orchestrator_request(
294 ®istry,
295 req,
296 completion_token,
297 &options,
298 )
299 .await;
300 #[allow(deprecated)]
301 if let Err(e) = stub.complete_orchestrator_task(response).await {
302 tracing::error!(
303 instance_id = %instance_id,
304 error = %e,
305 "Failed to complete orchestrator task"
306 );
307 }
308 });
309 }
310 Some(Request::ActivityRequest(req)) => {
311 let instance_id = req
312 .workflow_instance
313 .as_ref()
314 .map(|i| i.instance_id.clone())
315 .unwrap_or_default();
316 tracing::debug!(
317 instance_id = %instance_id,
318 activity = %req.name,
319 task_id = req.task_id,
320 "Received activity work item"
321 );
322
323 let registry = registry.clone();
324 let options = options.clone();
325 let mut stub = client.clone();
326 let completion_token = work_item.completion_token.clone();
327 let activity_name = req.name.clone();
328 let permit = semaphore
329 .clone()
330 .acquire_owned()
331 .await
332 .map_err(|_| DurableTaskError::Other("Semaphore closed".to_string()))?;
333
334 tasks.spawn(async move {
335 let _permit = permit;
336 let response =
337 Self::handle_activity_request(®istry, req, completion_token, &options)
338 .await;
339 if let Err(e) = stub.complete_activity_task(response).await {
340 tracing::error!(
341 instance_id = %instance_id,
342 activity = %activity_name,
343 error = %e,
344 "Failed to complete activity task"
345 );
346 }
347 });
348 }
349 None => {
350 tracing::warn!("Received work item with no request payload");
351 }
352 }
353 Ok(())
354 }
355
356 async fn handle_orchestrator_request(
357 registry: &Registry,
358 request: proto::WorkflowRequest,
359 completion_token: String,
360 options: &WorkerOptions,
361 ) -> proto::WorkflowResponse {
362 let instance_id = request.instance_id.clone();
363
364 let (name, version) = request
366 .past_events
367 .iter()
368 .chain(request.new_events.iter())
369 .find_map(|e| {
370 if let Some(EventType::ExecutionStarted(es)) = &e.event_type {
371 Some((es.name.clone(), es.version.clone()))
372 } else {
373 None
374 }
375 })
376 .unwrap_or_default();
377
378 if let Err(e) =
379 validate_identifier(&name, "orchestrator name", options.max_identifier_length)
380 {
381 tracing::warn!(
382 instance_id = %instance_id,
383 orchestrator = %name,
384 error = %e,
385 "Rejected orchestrator request: invalid name"
386 );
387 return build_error_response(&instance_id, &e.to_string(), completion_token);
388 }
389
390 let orchestrator_fn = match registry.get_orchestrator_version(&name, version.as_deref()) {
391 Some(f) => f,
392 None => {
393 tracing::warn!(
394 instance_id = %instance_id,
395 orchestrator = %name,
396 "Unregistered orchestrator requested"
397 );
398 return build_error_response(
399 &instance_id,
400 &format!("Orchestrator '{}' not registered", name),
401 completion_token,
402 );
403 }
404 };
405
406 match OrchestrationExecutor::execute(
407 orchestrator_fn,
408 &instance_id,
409 request.past_events,
410 request.new_events,
411 completion_token.clone(),
412 options,
413 request
414 .propagated_history
415 .and_then(crate::api::PropagatedHistory::from_proto),
416 )
417 .await
418 {
419 Ok(response) => response,
420 Err(e) => {
421 tracing::error!(
422 instance_id = %instance_id,
423 orchestrator = %name,
424 error = %e,
425 "Orchestrator execution failed"
426 );
427 build_error_response(&instance_id, &e.to_string(), completion_token)
428 }
429 }
430 }
431
432 async fn handle_activity_request(
433 registry: &Registry,
434 request: proto::ActivityRequest,
435 completion_token: String,
436 options: &WorkerOptions,
437 ) -> proto::ActivityResponse {
438 let instance_id = request
439 .workflow_instance
440 .as_ref()
441 .map(|i| i.instance_id.as_str())
442 .unwrap_or("");
443
444 let build_activity_error =
445 |error_type: &str, error_message: String| proto::ActivityResponse {
446 instance_id: instance_id.to_string(),
447 task_id: request.task_id,
448 result: None,
449 failure_details: Some(proto::TaskFailureDetails {
450 error_type: error_type.to_string(),
451 error_message,
452 stack_trace: None,
453 inner_failure: None,
454 is_non_retriable: true,
455 }),
456 completion_token: completion_token.clone(),
457 };
458
459 if let Err(e) = validate_identifier(
460 &request.name,
461 "activity name",
462 options.max_identifier_length,
463 ) {
464 tracing::warn!(
465 instance_id = %instance_id,
466 activity = %request.name,
467 error = %e,
468 "Rejected activity request: invalid name"
469 );
470 return build_activity_error("InvalidActivityName", e.to_string());
471 }
472
473 let activity_fn = match registry.get_activity(&request.name) {
474 Some(f) => f,
475 None => {
476 tracing::warn!(
477 instance_id = %instance_id,
478 activity = %request.name,
479 "Unregistered activity requested"
480 );
481 return build_activity_error(
482 "ActivityNotRegistered",
483 format!("Activity '{}' not registered", request.name),
484 );
485 }
486 };
487
488 ActivityExecutor::execute(
489 activity_fn,
490 &request.name,
491 instance_id,
492 request.task_id,
493 request.task_execution_id,
494 request.input,
495 request.parent_trace_context.as_ref(),
496 completion_token,
497 request
498 .propagated_history
499 .and_then(crate::api::PropagatedHistory::from_proto),
500 )
501 .await
502 }
503}
504
505fn build_error_response(
506 instance_id: &str,
507 message: &str,
508 completion_token: String,
509) -> proto::WorkflowResponse {
510 proto::WorkflowResponse {
511 instance_id: instance_id.to_string(),
512 actions: vec![proto::WorkflowAction {
513 id: -1,
514 router: None,
515 workflow_action_type: Some(
516 proto::workflow_action::WorkflowActionType::CompleteWorkflow(
517 proto::CompleteWorkflowAction {
518 workflow_status: proto::OrchestrationStatus::Failed as i32,
519 result: None,
520 details: None,
521 new_version: None,
522 carryover_events: vec![],
523 failure_details: Some(proto::TaskFailureDetails {
524 error_type: "WorkerError".to_string(),
525 error_message: message.to_string(),
526 stack_trace: None,
527 inner_failure: None,
528 is_non_retriable: false,
529 }),
530 },
531 ),
532 ),
533 }],
534 custom_status: None,
535 completion_token,
536 num_events_processed: None,
537 version: None,
538 }
539}