1use std::collections::{BTreeSet, HashMap};
4use std::sync::Arc;
5
6use aion_core::{ActivityError, ActivityId, Payload, WorkflowId};
7use async_trait::async_trait;
8use futures::StreamExt;
9use futures::future;
10use tokio::sync::{Semaphore, mpsc};
11use tracing::{debug, info};
12
13use crate::config::WorkerConfig;
14use crate::context::{ActivityContext, HeartbeatRequest};
15use crate::error::WorkerError;
16use crate::protocol::reconnect::UnackedResultTracker;
17use crate::protocol::{
18 ActivityExecutionKey, ActivityTask, HeartbeatBookkeeper, WorkerSession, WorkerSessionEvent,
19};
20use crate::runtime::report::{
21 DispatchFinished, InFlightActivity, RuntimeChannels, drain_remaining, record_first_error,
22 report_finished,
23};
24
25#[async_trait]
27pub trait ActivityDispatcher: Send + Sync + 'static {
28 async fn dispatch(
30 &self,
31 task: ActivityTask,
32 context: ActivityContext,
33 ) -> Result<DispatchOutcome, WorkerError>;
34
35 fn activity_types(&self) -> BTreeSet<String>;
37}
38
39#[derive(Clone, Debug, PartialEq, Eq)]
41pub enum DispatchOutcome {
42 Completed {
44 output: Payload,
46 },
47 Failed {
49 failure: ActivityError,
51 },
52}
53
54pub type NoShutdown = future::Pending<()>;
56
57#[derive(Clone, Copy, Debug, PartialEq, Eq)]
59pub enum ServeEnd {
60 Shutdown,
62 StreamClosed,
66 Drained,
70}
71
72#[derive(Debug, Default)]
75pub struct SessionHealth {
76 pub tasks_reported: usize,
78 pub stream_ended_at: Option<tokio::time::Instant>,
82 pub drain_received: bool,
88}
89
90pub async fn serve_activity_tasks<S, D>(
106 config: &WorkerConfig,
107 session: &mut S,
108 dispatcher: Arc<D>,
109 tracker: &mut UnackedResultTracker,
110) -> Result<ServeEnd, WorkerError>
111where
112 S: WorkerSession,
113 D: ActivityDispatcher,
114{
115 let mut health = SessionHealth::default();
116 serve_activity_tasks_until(
117 config,
118 session,
119 dispatcher,
120 tracker,
121 &mut health,
122 future::pending(),
123 )
124 .await
125}
126
127pub async fn serve_activity_tasks_until<S, D, Shutdown>(
158 config: &WorkerConfig,
159 session: &mut S,
160 dispatcher: Arc<D>,
161 tracker: &mut UnackedResultTracker,
162 health: &mut SessionHealth,
163 shutdown: Shutdown,
164) -> Result<ServeEnd, WorkerError>
165where
166 S: WorkerSession,
167 D: ActivityDispatcher,
168 Shutdown: Future<Output = ()> + Send,
169{
170 ensure_max_concurrency(config)?;
171 let semaphore = Arc::new(Semaphore::new(config.max_concurrency));
172 let (result_sender, heartbeat_sender, mut channels) = runtime_channels();
173 let heartbeat_bookkeeper = HeartbeatBookkeeper::default();
174 let mut stream = session.receive_tasks();
175 let mut in_flight = HashMap::<ActivityExecutionKey, InFlightActivity>::new();
176 let mut pending_error = None;
177 let mut end = ServeEnd::StreamClosed;
180 tokio::pin!(shutdown);
181
182 while pending_error.is_none() {
185 tokio::select! {
186 biased;
187 () = &mut shutdown => {
188 cancel_all_in_flight(&in_flight);
189 end = ServeEnd::Shutdown;
190 break;
191 }
192 finished = channels.results.recv() => {
198 if let Some(finished) = finished {
199 report_finished(
200 session,
201 &heartbeat_bookkeeper,
202 finished,
203 &mut in_flight,
204 tracker,
205 &mut health.tasks_reported,
206 &mut pending_error,
207 )
208 .await;
209 }
210 }
211 request = channels.heartbeats.recv() => {
215 if let Some(request) = request {
216 forward_heartbeat(session, &heartbeat_bookkeeper, request, &mut pending_error)
217 .await;
218 }
219 }
220 event = stream.next() => {
221 let Some(event) = event else { break; };
222 match event {
223 Ok(WorkerSessionEvent::Cancel { workflow_id, activity_id }) => {
224 deliver_cancellation(workflow_id, &activity_id, &in_flight);
225 }
226 Ok(WorkerSessionEvent::ResultAck { workflow_id, activity_id }) => {
229 acknowledge_result(&workflow_id, &activity_id, tracker);
230 }
231 Ok(WorkerSessionEvent::Drain) => {
232 info!("server drain received; finishing in-flight work before reconnect");
233 health.drain_received = true;
234 end = ServeEnd::Drained;
235 break;
236 }
237 Err(error) => {
238 pending_error = Some(error);
239 break;
240 }
241 Ok(WorkerSessionEvent::Task(proto_task)) => {
242 let Some(permit) =
243 acquire_permit_or_shutdown(shutdown.as_mut(), &semaphore).await?
244 else {
245 cancel_all_in_flight(&in_flight);
246 end = ServeEnd::Shutdown;
247 break;
248 };
249 if !handle_task(
250 proto_task,
251 SessionEventContext {
252 permit,
253 dispatcher: Arc::clone(&dispatcher),
254 result_sender: &result_sender,
255 heartbeat_sender: &heartbeat_sender,
256 heartbeat_bookkeeper: &heartbeat_bookkeeper,
257 in_flight: &mut in_flight,
258 pending_error: &mut pending_error,
259 },
260 )? {
261 break;
262 }
263 }
264 }
265 }
266 }
267 }
268
269 health.stream_ended_at = Some(tokio::time::Instant::now());
273
274 drop((result_sender, heartbeat_sender));
275 drain_remaining(
276 session,
277 &heartbeat_bookkeeper,
278 &mut channels,
279 &mut in_flight,
280 tracker,
281 &mut health.tasks_reported,
282 &mut pending_error,
283 )
284 .await;
285
286 pending_error.map_or(Ok(end), Err)
287}
288
289fn runtime_channels() -> (
291 mpsc::UnboundedSender<DispatchFinished>,
292 mpsc::UnboundedSender<HeartbeatRequest>,
293 RuntimeChannels,
294) {
295 let (result_sender, result_receiver) = mpsc::unbounded_channel();
296 let (heartbeat_sender, heartbeat_receiver) = mpsc::unbounded_channel();
297 let channels = RuntimeChannels {
298 heartbeats: heartbeat_receiver,
299 results: result_receiver,
300 };
301 (result_sender, heartbeat_sender, channels)
302}
303
304struct SessionEventContext<'a, D> {
305 permit: tokio::sync::OwnedSemaphorePermit,
306 dispatcher: Arc<D>,
307 result_sender: &'a mpsc::UnboundedSender<DispatchFinished>,
308 heartbeat_sender: &'a mpsc::UnboundedSender<HeartbeatRequest>,
309 heartbeat_bookkeeper: &'a HeartbeatBookkeeper,
310 in_flight: &'a mut HashMap<ActivityExecutionKey, InFlightActivity>,
311 pending_error: &'a mut Option<WorkerError>,
312}
313
314fn handle_task<D>(
315 proto_task: aion_proto::ProtoActivityTask,
316 ctx: SessionEventContext<'_, D>,
317) -> Result<bool, WorkerError>
318where
319 D: ActivityDispatcher,
320{
321 let task = match ActivityTask::try_from(proto_task) {
322 Ok(task) => task,
323 Err(error) => {
324 drop(ctx.permit);
325 *ctx.pending_error = Some(error);
326 return Ok(false);
327 }
328 };
329 spawn_activity(
330 task,
331 ctx.permit,
332 ctx.dispatcher,
333 ctx.result_sender.clone(),
334 ctx.heartbeat_sender.clone(),
335 ctx.heartbeat_bookkeeper,
336 ctx.in_flight,
337 )?;
338 Ok(true)
339}
340
341fn ensure_max_concurrency(config: &WorkerConfig) -> Result<(), WorkerError> {
343 if config.max_concurrency == 0 {
344 return Err(WorkerError::registration(InvalidMaxConcurrency));
345 }
346 Ok(())
347}
348
349async fn acquire_permit_or_shutdown<F>(
352 shutdown: std::pin::Pin<&mut F>,
353 semaphore: &Arc<Semaphore>,
354) -> Result<Option<tokio::sync::OwnedSemaphorePermit>, WorkerError>
355where
356 F: Future<Output = ()> + Send,
357{
358 tokio::select! {
359 biased;
360 () = shutdown => Ok(None),
361 permit = Arc::clone(semaphore).acquire_owned() => {
362 permit.map(Some).map_err(WorkerError::registration)
363 }
364 }
365}
366
367async fn forward_heartbeat<S>(
369 session: &mut S,
370 heartbeat_bookkeeper: &HeartbeatBookkeeper,
371 request: HeartbeatRequest,
372 pending_error: &mut Option<WorkerError>,
373) where
374 S: WorkerSession,
375{
376 record_first_error(
377 pending_error,
378 crate::protocol::send_heartbeat(session, heartbeat_bookkeeper, request).await,
379 );
380}
381
382fn acknowledge_result(
385 workflow_id: &WorkflowId,
386 activity_id: &ActivityId,
387 tracker: &mut UnackedResultTracker,
388) {
389 if tracker.acknowledge(workflow_id, activity_id).is_some() {
390 debug!(
391 workflow_id = %workflow_id,
392 activity_id = activity_id.sequence_position(),
393 "server acknowledged activity result; tracker entry cleared"
394 );
395 } else {
396 debug!(
397 workflow_id = %workflow_id,
398 activity_id = activity_id.sequence_position(),
399 "result ack for unknown tracker entry ignored"
400 );
401 }
402}
403
404fn spawn_activity<D>(
405 task: ActivityTask,
406 permit: tokio::sync::OwnedSemaphorePermit,
407 dispatcher: Arc<D>,
408 result_sender: mpsc::UnboundedSender<DispatchFinished>,
409 heartbeat_sender: mpsc::UnboundedSender<HeartbeatRequest>,
410 heartbeat_bookkeeper: &HeartbeatBookkeeper,
411 in_flight: &mut HashMap<ActivityExecutionKey, InFlightActivity>,
412) -> Result<(), WorkerError>
413where
414 D: ActivityDispatcher,
415{
416 info!(
417 activity_type = %task.activity_type,
418 activity_id = task.activity_id.sequence_position(),
419 workflow_id = %task.workflow_id,
420 attempt = task.attempt,
421 "received activity task"
422 );
423 let key = ActivityExecutionKey::new(task.workflow_id.clone(), task.activity_id.clone());
424 heartbeat_bookkeeper.register(key.clone())?;
425 let (context, cancellation_handle) = ActivityContext::for_workflow(
426 Some(task.workflow_id.clone()),
427 task.activity_id.clone(),
428 task.attempt,
429 Some(heartbeat_sender),
430 );
431 let finished_key = key.clone();
432 let join_handle = tokio::spawn(async move {
433 let outcome = dispatcher.dispatch(task, context).await;
434 if result_sender
435 .send(DispatchFinished {
436 key: finished_key,
437 outcome,
438 })
439 .is_err()
440 {
441 debug!("worker loop stopped before dispatch outcome could be delivered");
442 }
443 drop(permit);
444 });
445 in_flight.insert(
446 key,
447 InFlightActivity {
448 cancellation_handle,
449 join_handle,
450 },
451 );
452 Ok(())
453}
454
455fn deliver_cancellation(
456 workflow_id: WorkflowId,
457 activity_id: &ActivityId,
458 in_flight: &HashMap<ActivityExecutionKey, InFlightActivity>,
459) {
460 let key = ActivityExecutionKey::new(workflow_id, activity_id.clone());
461 if let Some(in_flight_activity) = in_flight.get(&key) {
462 in_flight_activity.cancellation_handle.cancel();
463 info!(
464 activity_id = activity_id.sequence_position(),
465 "delivered cooperative activity cancellation"
466 );
467 }
468}
469
470fn cancel_all_in_flight(in_flight: &HashMap<ActivityExecutionKey, InFlightActivity>) {
471 for (key, in_flight_activity) in in_flight {
472 in_flight_activity.cancellation_handle.cancel();
473 info!(
474 activity_id = key.activity_id.sequence_position(),
475 workflow_id = %key.workflow_id,
476 "delivered cooperative activity cancellation during worker shutdown"
477 );
478 }
479}
480
481#[derive(Debug, thiserror::Error)]
482#[error("worker max_concurrency must be greater than zero")]
483struct InvalidMaxConcurrency;
484
485#[cfg(test)]
486#[path = "loop_tests.rs"]
487mod tests;