1use std::collections::{BTreeSet, HashMap};
4use std::sync::Arc;
5
6use aion_core::{ActivityError, ActivityId, Payload, WorkflowId};
7use async_trait::async_trait;
8use futures::StreamExt;
9use futures::future;
10use tokio::sync::{Semaphore, mpsc};
11use tracing::{debug, info};
12
13use crate::config::WorkerConfig;
14use crate::context::{ActivityContext, HeartbeatRequest};
15use crate::error::WorkerError;
16use crate::protocol::reconnect::UnackedResultTracker;
17use crate::protocol::{
18 ActivityExecutionKey, ActivityTask, HeartbeatBookkeeper, WorkerSession, WorkerSessionEvent,
19};
20use crate::runtime::report::{
21 DispatchFinished, InFlightActivity, RuntimeChannels, drain_remaining, drain_runtime_events,
22};
23
24#[async_trait]
26pub trait ActivityDispatcher: Send + Sync + 'static {
27 async fn dispatch(
29 &self,
30 task: ActivityTask,
31 context: ActivityContext,
32 ) -> Result<DispatchOutcome, WorkerError>;
33
34 fn activity_types(&self) -> BTreeSet<String>;
36}
37
38#[derive(Clone, Debug, PartialEq, Eq)]
40pub enum DispatchOutcome {
41 Completed {
43 output: Payload,
45 },
46 Failed {
48 failure: ActivityError,
50 },
51}
52
53pub type NoShutdown = future::Pending<()>;
55
56#[derive(Clone, Copy, Debug, PartialEq, Eq)]
58pub enum ServeEnd {
59 Shutdown,
61 StreamClosed,
66}
67
68#[derive(Debug, Default)]
71pub struct SessionHealth {
72 pub tasks_reported: usize,
74 pub stream_ended_at: Option<tokio::time::Instant>,
78}
79
80pub async fn serve_activity_tasks<S, D>(
95 config: &WorkerConfig,
96 session: &mut S,
97 dispatcher: Arc<D>,
98 tracker: &mut UnackedResultTracker,
99) -> Result<ServeEnd, WorkerError>
100where
101 S: WorkerSession,
102 D: ActivityDispatcher,
103{
104 let mut health = SessionHealth::default();
105 serve_activity_tasks_until(
106 config,
107 session,
108 dispatcher,
109 tracker,
110 &mut health,
111 future::pending(),
112 )
113 .await
114}
115
116pub async fn serve_activity_tasks_until<S, D, Shutdown>(
147 config: &WorkerConfig,
148 session: &mut S,
149 dispatcher: Arc<D>,
150 tracker: &mut UnackedResultTracker,
151 health: &mut SessionHealth,
152 shutdown: Shutdown,
153) -> Result<ServeEnd, WorkerError>
154where
155 S: WorkerSession,
156 D: ActivityDispatcher,
157 Shutdown: Future<Output = ()> + Send,
158{
159 if config.max_concurrency == 0 {
160 return Err(WorkerError::registration(InvalidMaxConcurrency));
161 }
162
163 let semaphore = Arc::new(Semaphore::new(config.max_concurrency));
164 let (result_sender, result_receiver) = mpsc::unbounded_channel();
165 let (heartbeat_sender, heartbeat_receiver) = mpsc::unbounded_channel();
166 let mut channels = RuntimeChannels {
167 heartbeats: heartbeat_receiver,
168 results: result_receiver,
169 };
170 let heartbeat_bookkeeper = HeartbeatBookkeeper::default();
171 let mut stream = session.receive_tasks();
172 let mut in_flight = HashMap::<ActivityExecutionKey, InFlightActivity>::new();
173 let mut pending_error = None;
174 let mut end = ServeEnd::StreamClosed;
177 tokio::pin!(shutdown);
178
179 while pending_error.is_none() {
180 drain_runtime_events(
181 session,
182 &heartbeat_bookkeeper,
183 &mut channels,
184 &mut in_flight,
185 tracker,
186 &mut health.tasks_reported,
187 &mut pending_error,
188 )
189 .await;
190 if pending_error.is_some() {
191 break;
192 }
193
194 tokio::select! {
195 biased;
196 () = &mut shutdown => {
197 cancel_all_in_flight(&in_flight);
198 end = ServeEnd::Shutdown;
199 break;
200 }
201 event = stream.next() => {
202 let Some(event) = event else { break; };
203 match event {
204 Ok(WorkerSessionEvent::Cancel { workflow_id, activity_id }) => {
205 deliver_cancellation(workflow_id, &activity_id, &in_flight);
206 }
207 Ok(WorkerSessionEvent::Drain) => {
208 break;
209 }
210 other => {
211 let permit = tokio::select! {
212 biased;
213 () = &mut shutdown => {
214 cancel_all_in_flight(&in_flight);
215 end = ServeEnd::Shutdown;
216 break;
217 }
218 permit = semaphore.clone().acquire_owned() => {
219 permit.map_err(WorkerError::registration)?
220 }
221 };
222 if !handle_session_event(
223 other,
224 SessionEventContext {
225 permit,
226 dispatcher: Arc::clone(&dispatcher),
227 result_sender: &result_sender,
228 heartbeat_sender: &heartbeat_sender,
229 heartbeat_bookkeeper: &heartbeat_bookkeeper,
230 in_flight: &mut in_flight,
231 pending_error: &mut pending_error,
232 },
233 )? {
234 break;
235 }
236 }
237 }
238 }
239 }
240 }
241
242 health.stream_ended_at = Some(tokio::time::Instant::now());
246
247 drop(result_sender);
248 drop(heartbeat_sender);
249 drain_remaining(
250 session,
251 &heartbeat_bookkeeper,
252 &mut channels,
253 &mut in_flight,
254 tracker,
255 &mut health.tasks_reported,
256 &mut pending_error,
257 )
258 .await;
259
260 if let Some(error) = pending_error {
261 return Err(error);
262 }
263 Ok(end)
264}
265
266struct SessionEventContext<'a, D> {
267 permit: tokio::sync::OwnedSemaphorePermit,
268 dispatcher: Arc<D>,
269 result_sender: &'a mpsc::UnboundedSender<DispatchFinished>,
270 heartbeat_sender: &'a mpsc::UnboundedSender<HeartbeatRequest>,
271 heartbeat_bookkeeper: &'a HeartbeatBookkeeper,
272 in_flight: &'a mut HashMap<ActivityExecutionKey, InFlightActivity>,
273 pending_error: &'a mut Option<WorkerError>,
274}
275
276fn handle_session_event<D>(
277 event: Result<WorkerSessionEvent, WorkerError>,
278 ctx: SessionEventContext<'_, D>,
279) -> Result<bool, WorkerError>
280where
281 D: ActivityDispatcher,
282{
283 match event {
284 Ok(WorkerSessionEvent::Task(proto_task)) => {
285 let task = match ActivityTask::try_from(proto_task) {
286 Ok(task) => task,
287 Err(error) => {
288 drop(ctx.permit);
289 *ctx.pending_error = Some(error);
290 return Ok(false);
291 }
292 };
293 spawn_activity(
294 task,
295 ctx.permit,
296 ctx.dispatcher,
297 ctx.result_sender.clone(),
298 ctx.heartbeat_sender.clone(),
299 ctx.heartbeat_bookkeeper,
300 ctx.in_flight,
301 )?;
302 Ok(true)
303 }
304 Ok(WorkerSessionEvent::Cancel { .. } | WorkerSessionEvent::Drain) => {
305 drop(ctx.permit);
306 Ok(true)
307 }
308 Err(error) => {
309 drop(ctx.permit);
310 *ctx.pending_error = Some(error);
311 Ok(false)
312 }
313 }
314}
315
316fn spawn_activity<D>(
317 task: ActivityTask,
318 permit: tokio::sync::OwnedSemaphorePermit,
319 dispatcher: Arc<D>,
320 result_sender: mpsc::UnboundedSender<DispatchFinished>,
321 heartbeat_sender: mpsc::UnboundedSender<HeartbeatRequest>,
322 heartbeat_bookkeeper: &HeartbeatBookkeeper,
323 in_flight: &mut HashMap<ActivityExecutionKey, InFlightActivity>,
324) -> Result<(), WorkerError>
325where
326 D: ActivityDispatcher,
327{
328 info!(
329 activity_type = %task.activity_type,
330 activity_id = task.activity_id.sequence_position(),
331 workflow_id = %task.workflow_id,
332 attempt = task.attempt,
333 "received activity task"
334 );
335 let key = ActivityExecutionKey::new(task.workflow_id.clone(), task.activity_id.clone());
336 heartbeat_bookkeeper.register(key.clone())?;
337 let (context, cancellation_handle) = ActivityContext::for_workflow(
338 Some(task.workflow_id.clone()),
339 task.activity_id.clone(),
340 task.attempt,
341 Some(heartbeat_sender),
342 );
343 let finished_key = key.clone();
344 let join_handle = tokio::spawn(async move {
345 let outcome = dispatcher.dispatch(task, context).await;
346 if result_sender
347 .send(DispatchFinished {
348 key: finished_key,
349 outcome,
350 })
351 .is_err()
352 {
353 debug!("worker loop stopped before dispatch outcome could be delivered");
354 }
355 drop(permit);
356 });
357 in_flight.insert(
358 key,
359 InFlightActivity {
360 cancellation_handle,
361 join_handle,
362 },
363 );
364 Ok(())
365}
366
367fn deliver_cancellation(
368 workflow_id: WorkflowId,
369 activity_id: &ActivityId,
370 in_flight: &HashMap<ActivityExecutionKey, InFlightActivity>,
371) {
372 let key = ActivityExecutionKey::new(workflow_id, activity_id.clone());
373 if let Some(in_flight_activity) = in_flight.get(&key) {
374 in_flight_activity.cancellation_handle.cancel();
375 info!(
376 activity_id = activity_id.sequence_position(),
377 "delivered cooperative activity cancellation"
378 );
379 }
380}
381
382fn cancel_all_in_flight(in_flight: &HashMap<ActivityExecutionKey, InFlightActivity>) {
383 for (key, in_flight_activity) in in_flight {
384 in_flight_activity.cancellation_handle.cancel();
385 info!(
386 activity_id = key.activity_id.sequence_position(),
387 workflow_id = %key.workflow_id,
388 "delivered cooperative activity cancellation during worker shutdown"
389 );
390 }
391}
392
393#[derive(Debug, thiserror::Error)]
394#[error("worker max_concurrency must be greater than zero")]
395struct InvalidMaxConcurrency;
396
397#[cfg(test)]
398#[path = "loop_tests.rs"]
399mod tests;