Skip to main content

aion_worker/runtime/
loop_.rs

1//! receive->dispatch->report worker loop + bounded concurrency
2
3use std::collections::{BTreeSet, HashMap};
4use std::sync::Arc;
5
6use aion_core::{ActivityError, ActivityId, Payload, WorkflowId};
7use async_trait::async_trait;
8use futures::StreamExt;
9use futures::future;
10use tokio::sync::{Semaphore, mpsc};
11use tracing::{debug, info};
12
13use crate::config::WorkerConfig;
14use crate::context::{ActivityContext, HeartbeatRequest};
15use crate::error::WorkerError;
16use crate::protocol::reconnect::UnackedResultTracker;
17use crate::protocol::{
18    ActivityExecutionKey, ActivityTask, HeartbeatBookkeeper, WorkerSession, WorkerSessionEvent,
19};
20use crate::runtime::report::{
21    DispatchFinished, InFlightActivity, RuntimeChannels, drain_remaining, drain_runtime_events,
22};
23
24/// Dispatch seam used by the receive loop to execute decoded activity tasks.
25#[async_trait]
26pub trait ActivityDispatcher: Send + Sync + 'static {
27    /// Executes one decoded activity task with the provided handler context.
28    async fn dispatch(
29        &self,
30        task: ActivityTask,
31        context: ActivityContext,
32    ) -> Result<DispatchOutcome, WorkerError>;
33
34    /// Activity type names this dispatcher can serve.
35    fn activity_types(&self) -> BTreeSet<String>;
36}
37
38/// Activity execution outcome returned by the dispatch seam.
39#[derive(Clone, Debug, PartialEq, Eq)]
40pub enum DispatchOutcome {
41    /// Activity completed with an output payload.
42    Completed {
43        /// Opaque output payload.
44        output: Payload,
45    },
46    /// Activity failed with explicit classification.
47    Failed {
48        /// Classified activity failure.
49        failure: ActivityError,
50    },
51}
52
53/// Future that never resolves, used by the default serve entrypoint.
54pub type NoShutdown = future::Pending<()>;
55
56/// Why the serve loop ended without an error.
57#[derive(Clone, Copy, Debug, PartialEq, Eq)]
58pub enum ServeEnd {
59    /// The caller's shutdown future fired; in-flight work was drained.
60    Shutdown,
61    /// The server ended the task stream cleanly (end-of-stream or a drain
62    /// frame). The reconnect-aware run loop treats this as a retryable
63    /// session drop — never as a run end — so workers ride through graceful
64    /// server closes such as deploys.
65    StreamClosed,
66}
67
68/// Per-session health accounting written by the serve loop for the
69/// reconnect-aware caller's drop-budget reset decision.
70#[derive(Debug, Default)]
71pub struct SessionHealth {
72    /// Activity tasks whose outcome report was sent on this session.
73    pub tasks_reported: usize,
74    /// When the receive stream ended or dropped, captured before in-flight
75    /// handlers are drained — so post-drop draining never extends the
76    /// session's measured connected lifetime.
77    pub stream_ended_at: Option<tokio::time::Instant>,
78}
79
80/// Runs the worker receive loop until the session's task stream completes.
81///
82/// The loop only forwards explicit handler heartbeats and cancellation flags. It
83/// never emits automatic heartbeats, never enforces heartbeat timeouts, and never
84/// aborts running handler tasks on cancellation.
85///
86/// Every computed dispatch outcome is recorded in `tracker` before its report
87/// is sent, so a caller that reconnects after a transport drop can re-report
88/// the backlog; the engine ingests reports idempotently by `ActivityId`.
89///
90/// # Errors
91///
92/// Returns [`WorkerError`] when task decode, dispatch, heartbeat send, or result
93/// reporting fails.
94pub async fn serve_activity_tasks<S, D>(
95    config: &WorkerConfig,
96    session: &mut S,
97    dispatcher: Arc<D>,
98    tracker: &mut UnackedResultTracker,
99) -> Result<ServeEnd, WorkerError>
100where
101    S: WorkerSession,
102    D: ActivityDispatcher,
103{
104    let mut health = SessionHealth::default();
105    serve_activity_tasks_until(
106        config,
107        session,
108        dispatcher,
109        tracker,
110        &mut health,
111        future::pending(),
112    )
113    .await
114}
115
116/// Runs the worker receive loop until the session's task stream completes.
117///
118/// The loop only forwards explicit handler heartbeats and cancellation flags. It
119/// never emits automatic heartbeats, never enforces heartbeat timeouts, and never
120/// aborts running handler tasks on cancellation.
121///
122/// Every computed dispatch outcome is recorded in `tracker` before its report
123/// is sent, so a caller that reconnects after a transport drop can re-report
124/// the backlog; the engine ingests reports idempotently by `ActivityId`. Only
125/// an explicit engine acknowledgement clears tracker entries, so successful
126/// sends leave their entries in place.
127///
128/// `health` accumulates session-health accounting: the activity tasks whose
129/// outcome report was sent on this session, and the instant the receive
130/// stream ended (captured before in-flight handlers are drained). It is an
131/// out-parameter (rather than part of the return value) so the accounting
132/// survives an error return: the reconnect-aware caller uses it for the
133/// drop-budget reset decision — a session that served at least one task, or
134/// that stayed connected longer than the maximum backoff delay measured to
135/// the recorded stream end (never to the end of the post-drop drain), resets
136/// the cumulative drop budget even when it later drops.
137///
138/// On a clean end this returns [`ServeEnd`] distinguishing a caller-driven
139/// shutdown from a server-side stream close, so the caller can treat the
140/// latter as a retryable drop.
141///
142/// # Errors
143///
144/// Returns [`WorkerError`] when task decode, dispatch, heartbeat send, or result
145/// reporting fails.
146pub async fn serve_activity_tasks_until<S, D, Shutdown>(
147    config: &WorkerConfig,
148    session: &mut S,
149    dispatcher: Arc<D>,
150    tracker: &mut UnackedResultTracker,
151    health: &mut SessionHealth,
152    shutdown: Shutdown,
153) -> Result<ServeEnd, WorkerError>
154where
155    S: WorkerSession,
156    D: ActivityDispatcher,
157    Shutdown: Future<Output = ()> + Send,
158{
159    if config.max_concurrency == 0 {
160        return Err(WorkerError::registration(InvalidMaxConcurrency));
161    }
162
163    let semaphore = Arc::new(Semaphore::new(config.max_concurrency));
164    let (result_sender, result_receiver) = mpsc::unbounded_channel();
165    let (heartbeat_sender, heartbeat_receiver) = mpsc::unbounded_channel();
166    let mut channels = RuntimeChannels {
167        heartbeats: heartbeat_receiver,
168        results: result_receiver,
169    };
170    let heartbeat_bookkeeper = HeartbeatBookkeeper::default();
171    let mut stream = session.receive_tasks();
172    let mut in_flight = HashMap::<ActivityExecutionKey, InFlightActivity>::new();
173    let mut pending_error = None;
174    // Overridden at the shutdown break sites; every other clean exit is the
175    // server ending the stream.
176    let mut end = ServeEnd::StreamClosed;
177    tokio::pin!(shutdown);
178
179    while pending_error.is_none() {
180        drain_runtime_events(
181            session,
182            &heartbeat_bookkeeper,
183            &mut channels,
184            &mut in_flight,
185            tracker,
186            &mut health.tasks_reported,
187            &mut pending_error,
188        )
189        .await;
190        if pending_error.is_some() {
191            break;
192        }
193
194        tokio::select! {
195            biased;
196            () = &mut shutdown => {
197                cancel_all_in_flight(&in_flight);
198                end = ServeEnd::Shutdown;
199                break;
200            }
201            event = stream.next() => {
202                let Some(event) = event else { break; };
203                match event {
204                    Ok(WorkerSessionEvent::Cancel { workflow_id, activity_id }) => {
205                        deliver_cancellation(workflow_id, &activity_id, &in_flight);
206                    }
207                    Ok(WorkerSessionEvent::Drain) => {
208                        break;
209                    }
210                    other => {
211                        let permit = tokio::select! {
212                            biased;
213                            () = &mut shutdown => {
214                                cancel_all_in_flight(&in_flight);
215                                end = ServeEnd::Shutdown;
216                                break;
217                            }
218                            permit = semaphore.clone().acquire_owned() => {
219                                permit.map_err(WorkerError::registration)?
220                            }
221                        };
222                        if !handle_session_event(
223                            other,
224                            SessionEventContext {
225                                permit,
226                                dispatcher: Arc::clone(&dispatcher),
227                                result_sender: &result_sender,
228                                heartbeat_sender: &heartbeat_sender,
229                                heartbeat_bookkeeper: &heartbeat_bookkeeper,
230                                in_flight: &mut in_flight,
231                                pending_error: &mut pending_error,
232                            },
233                        )? {
234                            break;
235                        }
236                    }
237                }
238            }
239        }
240    }
241
242    // The stream just ended — cleanly, by error, or by shutdown. Capture the
243    // moment before draining in-flight handlers so the caller's drop-budget
244    // reset decision measures connected time, never drain time.
245    health.stream_ended_at = Some(tokio::time::Instant::now());
246
247    drop(result_sender);
248    drop(heartbeat_sender);
249    drain_remaining(
250        session,
251        &heartbeat_bookkeeper,
252        &mut channels,
253        &mut in_flight,
254        tracker,
255        &mut health.tasks_reported,
256        &mut pending_error,
257    )
258    .await;
259
260    if let Some(error) = pending_error {
261        return Err(error);
262    }
263    Ok(end)
264}
265
266struct SessionEventContext<'a, D> {
267    permit: tokio::sync::OwnedSemaphorePermit,
268    dispatcher: Arc<D>,
269    result_sender: &'a mpsc::UnboundedSender<DispatchFinished>,
270    heartbeat_sender: &'a mpsc::UnboundedSender<HeartbeatRequest>,
271    heartbeat_bookkeeper: &'a HeartbeatBookkeeper,
272    in_flight: &'a mut HashMap<ActivityExecutionKey, InFlightActivity>,
273    pending_error: &'a mut Option<WorkerError>,
274}
275
276fn handle_session_event<D>(
277    event: Result<WorkerSessionEvent, WorkerError>,
278    ctx: SessionEventContext<'_, D>,
279) -> Result<bool, WorkerError>
280where
281    D: ActivityDispatcher,
282{
283    match event {
284        Ok(WorkerSessionEvent::Task(proto_task)) => {
285            let task = match ActivityTask::try_from(proto_task) {
286                Ok(task) => task,
287                Err(error) => {
288                    drop(ctx.permit);
289                    *ctx.pending_error = Some(error);
290                    return Ok(false);
291                }
292            };
293            spawn_activity(
294                task,
295                ctx.permit,
296                ctx.dispatcher,
297                ctx.result_sender.clone(),
298                ctx.heartbeat_sender.clone(),
299                ctx.heartbeat_bookkeeper,
300                ctx.in_flight,
301            )?;
302            Ok(true)
303        }
304        Ok(WorkerSessionEvent::Cancel { .. } | WorkerSessionEvent::Drain) => {
305            drop(ctx.permit);
306            Ok(true)
307        }
308        Err(error) => {
309            drop(ctx.permit);
310            *ctx.pending_error = Some(error);
311            Ok(false)
312        }
313    }
314}
315
316fn spawn_activity<D>(
317    task: ActivityTask,
318    permit: tokio::sync::OwnedSemaphorePermit,
319    dispatcher: Arc<D>,
320    result_sender: mpsc::UnboundedSender<DispatchFinished>,
321    heartbeat_sender: mpsc::UnboundedSender<HeartbeatRequest>,
322    heartbeat_bookkeeper: &HeartbeatBookkeeper,
323    in_flight: &mut HashMap<ActivityExecutionKey, InFlightActivity>,
324) -> Result<(), WorkerError>
325where
326    D: ActivityDispatcher,
327{
328    info!(
329        activity_type = %task.activity_type,
330        activity_id = task.activity_id.sequence_position(),
331        workflow_id = %task.workflow_id,
332        attempt = task.attempt,
333        "received activity task"
334    );
335    let key = ActivityExecutionKey::new(task.workflow_id.clone(), task.activity_id.clone());
336    heartbeat_bookkeeper.register(key.clone())?;
337    let (context, cancellation_handle) = ActivityContext::for_workflow(
338        Some(task.workflow_id.clone()),
339        task.activity_id.clone(),
340        task.attempt,
341        Some(heartbeat_sender),
342    );
343    let finished_key = key.clone();
344    let join_handle = tokio::spawn(async move {
345        let outcome = dispatcher.dispatch(task, context).await;
346        if result_sender
347            .send(DispatchFinished {
348                key: finished_key,
349                outcome,
350            })
351            .is_err()
352        {
353            debug!("worker loop stopped before dispatch outcome could be delivered");
354        }
355        drop(permit);
356    });
357    in_flight.insert(
358        key,
359        InFlightActivity {
360            cancellation_handle,
361            join_handle,
362        },
363    );
364    Ok(())
365}
366
367fn deliver_cancellation(
368    workflow_id: WorkflowId,
369    activity_id: &ActivityId,
370    in_flight: &HashMap<ActivityExecutionKey, InFlightActivity>,
371) {
372    let key = ActivityExecutionKey::new(workflow_id, activity_id.clone());
373    if let Some(in_flight_activity) = in_flight.get(&key) {
374        in_flight_activity.cancellation_handle.cancel();
375        info!(
376            activity_id = activity_id.sequence_position(),
377            "delivered cooperative activity cancellation"
378        );
379    }
380}
381
382fn cancel_all_in_flight(in_flight: &HashMap<ActivityExecutionKey, InFlightActivity>) {
383    for (key, in_flight_activity) in in_flight {
384        in_flight_activity.cancellation_handle.cancel();
385        info!(
386            activity_id = key.activity_id.sequence_position(),
387            workflow_id = %key.workflow_id,
388            "delivered cooperative activity cancellation during worker shutdown"
389        );
390    }
391}
392
393#[derive(Debug, thiserror::Error)]
394#[error("worker max_concurrency must be greater than zero")]
395struct InvalidMaxConcurrency;
396
397#[cfg(test)]
398#[path = "loop_tests.rs"]
399mod tests;