Skip to main content

aion_worker/runtime/
loop_.rs

1//! receive->dispatch->report worker loop + bounded concurrency
2
3use std::collections::{BTreeSet, HashMap};
4use std::sync::Arc;
5
6use aion_core::{ActivityError, ActivityId, Payload, WorkflowId};
7use async_trait::async_trait;
8use futures::StreamExt;
9use futures::future;
10use tokio::sync::{Semaphore, mpsc};
11use tracing::{debug, info};
12
13use crate::config::WorkerConfig;
14use crate::context::{ActivityContext, HeartbeatRequest};
15use crate::error::WorkerError;
16use crate::protocol::reconnect::UnackedResultTracker;
17use crate::protocol::{
18    ActivityExecutionKey, ActivityTask, HeartbeatBookkeeper, WorkerSession, WorkerSessionEvent,
19};
20use crate::runtime::report::{
21    DispatchFinished, InFlightActivity, RuntimeChannels, drain_remaining, record_first_error,
22    report_finished,
23};
24
25/// Dispatch seam used by the receive loop to execute decoded activity tasks.
26#[async_trait]
27pub trait ActivityDispatcher: Send + Sync + 'static {
28    /// Executes one decoded activity task with the provided handler context.
29    async fn dispatch(
30        &self,
31        task: ActivityTask,
32        context: ActivityContext,
33    ) -> Result<DispatchOutcome, WorkerError>;
34
35    /// Activity type names this dispatcher can serve.
36    fn activity_types(&self) -> BTreeSet<String>;
37}
38
39/// Activity execution outcome returned by the dispatch seam.
40#[derive(Clone, Debug, PartialEq, Eq)]
41pub enum DispatchOutcome {
42    /// Activity completed with an output payload.
43    Completed {
44        /// Opaque output payload.
45        output: Payload,
46    },
47    /// Activity failed with explicit classification.
48    Failed {
49        /// Classified activity failure.
50        failure: ActivityError,
51    },
52}
53
54/// Future that never resolves, used by the default serve entrypoint.
55pub type NoShutdown = future::Pending<()>;
56
57/// Why the serve loop ended without an error.
58#[derive(Clone, Copy, Debug, PartialEq, Eq)]
59pub enum ServeEnd {
60    /// The caller's shutdown future fired; in-flight work was drained.
61    Shutdown,
62    /// The server ended the task stream cleanly without announcing a drain.
63    /// The reconnect-aware run loop treats this unannounced close as a
64    /// budgeted retryable session drop — never as a run end.
65    StreamClosed,
66    /// The server announced a drain: in-flight work was finished and
67    /// reported, and the run loop reconnects after the schedule's initial
68    /// backoff without consuming any drop budget.
69    Drained,
70}
71
72/// Per-session health accounting written by the serve loop for the
73/// reconnect-aware caller's drop-budget reset decision.
74#[derive(Debug, Default)]
75pub struct SessionHealth {
76    /// Activity tasks whose outcome report was sent on this session.
77    pub tasks_reported: usize,
78    /// When the receive stream ended or dropped, captured before in-flight
79    /// handlers are drained — so post-drop draining never extends the
80    /// session's measured connected lifetime.
81    pub stream_ended_at: Option<tokio::time::Instant>,
82    /// Latched when a drain frame is observed on this session: the eventual
83    /// stream end — clean OR abrupt — is then drain-class (the server
84    /// announced it was going away), so the drop consumes no budget even if
85    /// the post-drain reporting fails. Survives an error return because this
86    /// is an out-parameter.
87    pub drain_received: bool,
88}
89
90/// Runs the worker receive loop until the session's task stream completes.
91///
92/// The loop only forwards explicit handler heartbeats and cancellation flags. It
93/// never emits automatic heartbeats, never enforces heartbeat timeouts, and never
94/// aborts running handler tasks on cancellation.
95///
96/// Every computed dispatch outcome is recorded in `tracker` before its report
97/// is sent, so a caller that reconnects after a transport drop can re-report
98/// the backlog; the server acks each consumed report (`ResultAck`), and only
99/// that explicit acknowledgement clears a tracker entry.
100///
101/// # Errors
102///
103/// Returns [`WorkerError`] when task decode, dispatch, heartbeat send, or result
104/// reporting fails.
105pub async fn serve_activity_tasks<S, D>(
106    config: &WorkerConfig,
107    session: &mut S,
108    dispatcher: Arc<D>,
109    tracker: &mut UnackedResultTracker,
110) -> Result<ServeEnd, WorkerError>
111where
112    S: WorkerSession,
113    D: ActivityDispatcher,
114{
115    let mut health = SessionHealth::default();
116    serve_activity_tasks_until(
117        config,
118        session,
119        dispatcher,
120        tracker,
121        &mut health,
122        future::pending(),
123    )
124    .await
125}
126
127/// Runs the worker receive loop until the session's task stream completes.
128///
129/// The loop only forwards explicit handler heartbeats and cancellation flags. It
130/// never emits automatic heartbeats, never enforces heartbeat timeouts, and never
131/// aborts running handler tasks on cancellation.
132///
133/// Every computed dispatch outcome is recorded in `tracker` before its report
134/// is sent, so a caller that reconnects after a transport drop can re-report
135/// the backlog; the server ingests reports idempotently and acks each one
136/// with a `ResultAck` frame. Only that explicit acknowledgement clears a
137/// tracker entry — a successful send proves nothing on its own.
138///
139/// `health` accumulates session-health accounting: the activity tasks whose
140/// outcome report was sent on this session, and the instant the receive
141/// stream ended (captured before in-flight handlers are drained). It is an
142/// out-parameter (rather than part of the return value) so the accounting
143/// survives an error return: the reconnect-aware caller uses it for the
144/// drop-budget reset decision — a session that served at least one task, or
145/// that stayed connected longer than the maximum backoff delay measured to
146/// the recorded stream end (never to the end of the post-drop drain), resets
147/// the cumulative drop budget even when it later drops.
148///
149/// On a clean end this returns [`ServeEnd`] distinguishing a caller-driven
150/// shutdown from a server-side stream close, so the caller can treat the
151/// latter as a retryable drop.
152///
153/// # Errors
154///
155/// Returns [`WorkerError`] when task decode, dispatch, heartbeat send, or result
156/// reporting fails.
157pub async fn serve_activity_tasks_until<S, D, Shutdown>(
158    config: &WorkerConfig,
159    session: &mut S,
160    dispatcher: Arc<D>,
161    tracker: &mut UnackedResultTracker,
162    health: &mut SessionHealth,
163    shutdown: Shutdown,
164) -> Result<ServeEnd, WorkerError>
165where
166    S: WorkerSession,
167    D: ActivityDispatcher,
168    Shutdown: Future<Output = ()> + Send,
169{
170    ensure_max_concurrency(config)?;
171    let semaphore = Arc::new(Semaphore::new(config.max_concurrency));
172    let (result_sender, heartbeat_sender, mut channels) = runtime_channels();
173    let heartbeat_bookkeeper = HeartbeatBookkeeper::default();
174    let mut stream = session.receive_tasks();
175    let mut in_flight = HashMap::<ActivityExecutionKey, InFlightActivity>::new();
176    let mut pending_error = None;
177    // Overridden at the shutdown break sites; every other clean exit is the
178    // server ending the stream.
179    let mut end = ServeEnd::StreamClosed;
180    tokio::pin!(shutdown);
181
182    // No batching preamble: the select arms below consume queued dispatch
183    // outcomes and heartbeats directly, so nothing waits for a stream event.
184    while pending_error.is_none() {
185        tokio::select! {
186            biased;
187            () = &mut shutdown => {
188                cancel_all_in_flight(&in_flight);
189                end = ServeEnd::Shutdown;
190                break;
191            }
192            // Dispatch outcomes are reported the moment they complete — the
193            // loop must not sit in `stream.next()` while a finished result
194            // waits, or a single dispatched task on an otherwise idle stream
195            // is only reported when the stream ends (the server-side dispatch
196            // would time out against a healthy worker).
197            finished = channels.results.recv() => {
198                if let Some(finished) = finished {
199                    report_finished(
200                        session,
201                        &heartbeat_bookkeeper,
202                        finished,
203                        &mut in_flight,
204                        tracker,
205                        &mut health.tasks_reported,
206                        &mut pending_error,
207                    )
208                    .await;
209                }
210            }
211            // Handler heartbeats are forwarded as they arrive for the same
212            // reason: the server's liveness window must be beatable while the
213            // stream is idle.
214            request = channels.heartbeats.recv() => {
215                if let Some(request) = request {
216                    forward_heartbeat(session, &heartbeat_bookkeeper, request, &mut pending_error)
217                        .await;
218                }
219            }
220            event = stream.next() => {
221                let Some(event) = event else { break; };
222                match event {
223                    Ok(WorkerSessionEvent::Cancel { workflow_id, activity_id }) => {
224                        deliver_cancellation(workflow_id, &activity_id, &in_flight);
225                    }
226                    // Acks are bookkeeping, not work: consumed without a
227                    // concurrency permit, like cancellation delivery.
228                    Ok(WorkerSessionEvent::ResultAck { workflow_id, activity_id }) => {
229                        acknowledge_result(&workflow_id, &activity_id, tracker);
230                    }
231                    Ok(WorkerSessionEvent::Drain) => {
232                        info!("server drain received; finishing in-flight work before reconnect");
233                        health.drain_received = true;
234                        end = ServeEnd::Drained;
235                        break;
236                    }
237                    Err(error) => {
238                        pending_error = Some(error);
239                        break;
240                    }
241                    Ok(WorkerSessionEvent::Task(proto_task)) => {
242                        let Some(permit) =
243                            acquire_permit_or_shutdown(shutdown.as_mut(), &semaphore).await?
244                        else {
245                            cancel_all_in_flight(&in_flight);
246                            end = ServeEnd::Shutdown;
247                            break;
248                        };
249                        if !handle_task(
250                            proto_task,
251                            SessionEventContext {
252                                permit,
253                                dispatcher: Arc::clone(&dispatcher),
254                                result_sender: &result_sender,
255                                heartbeat_sender: &heartbeat_sender,
256                                heartbeat_bookkeeper: &heartbeat_bookkeeper,
257                                in_flight: &mut in_flight,
258                                pending_error: &mut pending_error,
259                            },
260                        )? {
261                            break;
262                        }
263                    }
264                }
265            }
266        }
267    }
268
269    // The stream just ended — cleanly, by error, or by shutdown. Capture the
270    // moment before draining in-flight handlers so the caller's drop-budget
271    // reset decision measures connected time, never drain time.
272    health.stream_ended_at = Some(tokio::time::Instant::now());
273
274    drop((result_sender, heartbeat_sender));
275    drain_remaining(
276        session,
277        &heartbeat_bookkeeper,
278        &mut channels,
279        &mut in_flight,
280        tracker,
281        &mut health.tasks_reported,
282        &mut pending_error,
283    )
284    .await;
285
286    pending_error.map_or(Ok(end), Err)
287}
288
289/// Builds the runtime's dispatch-outcome and heartbeat channels.
290fn runtime_channels() -> (
291    mpsc::UnboundedSender<DispatchFinished>,
292    mpsc::UnboundedSender<HeartbeatRequest>,
293    RuntimeChannels,
294) {
295    let (result_sender, result_receiver) = mpsc::unbounded_channel();
296    let (heartbeat_sender, heartbeat_receiver) = mpsc::unbounded_channel();
297    let channels = RuntimeChannels {
298        heartbeats: heartbeat_receiver,
299        results: result_receiver,
300    };
301    (result_sender, heartbeat_sender, channels)
302}
303
304struct SessionEventContext<'a, D> {
305    permit: tokio::sync::OwnedSemaphorePermit,
306    dispatcher: Arc<D>,
307    result_sender: &'a mpsc::UnboundedSender<DispatchFinished>,
308    heartbeat_sender: &'a mpsc::UnboundedSender<HeartbeatRequest>,
309    heartbeat_bookkeeper: &'a HeartbeatBookkeeper,
310    in_flight: &'a mut HashMap<ActivityExecutionKey, InFlightActivity>,
311    pending_error: &'a mut Option<WorkerError>,
312}
313
314fn handle_task<D>(
315    proto_task: aion_proto::ProtoActivityTask,
316    ctx: SessionEventContext<'_, D>,
317) -> Result<bool, WorkerError>
318where
319    D: ActivityDispatcher,
320{
321    let task = match ActivityTask::try_from(proto_task) {
322        Ok(task) => task,
323        Err(error) => {
324            drop(ctx.permit);
325            *ctx.pending_error = Some(error);
326            return Ok(false);
327        }
328    };
329    spawn_activity(
330        task,
331        ctx.permit,
332        ctx.dispatcher,
333        ctx.result_sender.clone(),
334        ctx.heartbeat_sender.clone(),
335        ctx.heartbeat_bookkeeper,
336        ctx.in_flight,
337    )?;
338    Ok(true)
339}
340
341/// Rejects a zero `max_concurrency` before the serve loop starts.
342fn ensure_max_concurrency(config: &WorkerConfig) -> Result<(), WorkerError> {
343    if config.max_concurrency == 0 {
344        return Err(WorkerError::registration(InvalidMaxConcurrency));
345    }
346    Ok(())
347}
348
349/// Waits for a dispatch permit, racing the caller's shutdown future; returns
350/// `None` when shutdown won.
351async fn acquire_permit_or_shutdown<F>(
352    shutdown: std::pin::Pin<&mut F>,
353    semaphore: &Arc<Semaphore>,
354) -> Result<Option<tokio::sync::OwnedSemaphorePermit>, WorkerError>
355where
356    F: Future<Output = ()> + Send,
357{
358    tokio::select! {
359        biased;
360        () = shutdown => Ok(None),
361        permit = Arc::clone(semaphore).acquire_owned() => {
362            permit.map(Some).map_err(WorkerError::registration)
363        }
364    }
365}
366
367/// Forwards one handler heartbeat to the session, recording the first error.
368async fn forward_heartbeat<S>(
369    session: &mut S,
370    heartbeat_bookkeeper: &HeartbeatBookkeeper,
371    request: HeartbeatRequest,
372    pending_error: &mut Option<WorkerError>,
373) where
374    S: WorkerSession,
375{
376    record_first_error(
377        pending_error,
378        crate::protocol::send_heartbeat(session, heartbeat_bookkeeper, request).await,
379    );
380}
381
382/// Clears the acknowledged tracker entry; an unknown ack (already cleared on
383/// a previous session, or replaced by a re-record) is a logged no-op.
384fn acknowledge_result(
385    workflow_id: &WorkflowId,
386    activity_id: &ActivityId,
387    tracker: &mut UnackedResultTracker,
388) {
389    if tracker.acknowledge(workflow_id, activity_id).is_some() {
390        debug!(
391            workflow_id = %workflow_id,
392            activity_id = activity_id.sequence_position(),
393            "server acknowledged activity result; tracker entry cleared"
394        );
395    } else {
396        debug!(
397            workflow_id = %workflow_id,
398            activity_id = activity_id.sequence_position(),
399            "result ack for unknown tracker entry ignored"
400        );
401    }
402}
403
404fn spawn_activity<D>(
405    task: ActivityTask,
406    permit: tokio::sync::OwnedSemaphorePermit,
407    dispatcher: Arc<D>,
408    result_sender: mpsc::UnboundedSender<DispatchFinished>,
409    heartbeat_sender: mpsc::UnboundedSender<HeartbeatRequest>,
410    heartbeat_bookkeeper: &HeartbeatBookkeeper,
411    in_flight: &mut HashMap<ActivityExecutionKey, InFlightActivity>,
412) -> Result<(), WorkerError>
413where
414    D: ActivityDispatcher,
415{
416    info!(
417        activity_type = %task.activity_type,
418        activity_id = task.activity_id.sequence_position(),
419        workflow_id = %task.workflow_id,
420        attempt = task.attempt,
421        "received activity task"
422    );
423    let key = ActivityExecutionKey::new(task.workflow_id.clone(), task.activity_id.clone());
424    heartbeat_bookkeeper.register(key.clone())?;
425    let (context, cancellation_handle) = ActivityContext::for_workflow(
426        Some(task.workflow_id.clone()),
427        task.activity_id.clone(),
428        task.attempt,
429        Some(heartbeat_sender),
430    );
431    let finished_key = key.clone();
432    let join_handle = tokio::spawn(async move {
433        let outcome = dispatcher.dispatch(task, context).await;
434        if result_sender
435            .send(DispatchFinished {
436                key: finished_key,
437                outcome,
438            })
439            .is_err()
440        {
441            debug!("worker loop stopped before dispatch outcome could be delivered");
442        }
443        drop(permit);
444    });
445    in_flight.insert(
446        key,
447        InFlightActivity {
448            cancellation_handle,
449            join_handle,
450        },
451    );
452    Ok(())
453}
454
455fn deliver_cancellation(
456    workflow_id: WorkflowId,
457    activity_id: &ActivityId,
458    in_flight: &HashMap<ActivityExecutionKey, InFlightActivity>,
459) {
460    let key = ActivityExecutionKey::new(workflow_id, activity_id.clone());
461    if let Some(in_flight_activity) = in_flight.get(&key) {
462        in_flight_activity.cancellation_handle.cancel();
463        info!(
464            activity_id = activity_id.sequence_position(),
465            "delivered cooperative activity cancellation"
466        );
467    }
468}
469
470fn cancel_all_in_flight(in_flight: &HashMap<ActivityExecutionKey, InFlightActivity>) {
471    for (key, in_flight_activity) in in_flight {
472        in_flight_activity.cancellation_handle.cancel();
473        info!(
474            activity_id = key.activity_id.sequence_position(),
475            workflow_id = %key.workflow_id,
476            "delivered cooperative activity cancellation during worker shutdown"
477        );
478    }
479}
480
481#[derive(Debug, thiserror::Error)]
482#[error("worker max_concurrency must be greater than zero")]
483struct InvalidMaxConcurrency;
484
485#[cfg(test)]
486#[path = "loop_tests.rs"]
487mod tests;