Skip to main content

pure_stage/
tokio.rs

1// Copyright 2025 PRAGMA
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! This module contains the Tokio-based [`StageGraph`] implementation, to be used in production.
16//!
17//! It is good practice to perform the stage contruction and wiring in a function that takes an
18//! `&mut impl StageGraph` so that it can be reused between the Tokio and simulation implementations.
19
20use std::{
21    any::Any,
22    collections::BTreeMap,
23    future::{Future, poll_fn},
24    marker::PhantomData,
25    sync::Arc,
26    task::{Context, Poll, Waker},
27    time::Duration,
28};
29
30use amaru_observability::{amaru::stage, trace_span};
31use either::Either::{Left, Right};
32use futures_util::{FutureExt, StreamExt, stream::FuturesUnordered};
33use parking_lot::Mutex;
34use tokio::{
35    runtime::Handle,
36    sync::{
37        mpsc::{self, Receiver},
38        oneshot, watch,
39    },
40    task::JoinHandle,
41};
42
43use crate::{
44    BoxFuture, EPOCH, Effects, Instant, Name, ScheduleId, ScheduleIds, SendData, Sender, StageBuildRef, StageGraph,
45    StageRef,
46    adapter::{Adapter, StageOrAdapter, find_recipient},
47    drop_guard::DropGuard,
48    effect::{CallExtra, CallTimeout, CanSupervise, StageEffect, StageResponse, TransitionFactory},
49    effect_box::EffectBox,
50    resources::Resources,
51    serde::NoDebug,
52    simulation::Transition,
53    stage_name,
54    stage_ref::StageStateRef,
55    stagegraph::StageGraphRunning,
56    time::Clock,
57    trace_buffer::TraceBuffer,
58};
59
60#[derive(Debug, thiserror::Error)]
61#[error("message send failed to stage `{target}`")]
62pub struct SendError {
63    target: Name,
64}
65
66struct TokioInner {
67    senders: Mutex<BTreeMap<Name, StageOrAdapter<mpsc::Sender<Box<dyn SendData>>>>>,
68    handles: Mutex<Vec<JoinHandle<()>>>,
69    clock: Arc<dyn Clock + Send + Sync>,
70    resources: Resources,
71    schedule_ids: ScheduleIds,
72    mailbox_size: usize,
73    stage_counter: Mutex<usize>,
74    trace_buffer: Arc<Mutex<TraceBuffer>>,
75}
76
77impl TokioInner {
78    fn new() -> Self {
79        Self {
80            senders: Default::default(),
81            handles: Default::default(),
82            clock: Arc::new(TokioClock),
83            resources: Resources::default(),
84            schedule_ids: ScheduleIds::default(),
85            mailbox_size: 10,
86            stage_counter: Mutex::new(0usize),
87            trace_buffer: TraceBuffer::new_shared(0, 0),
88        }
89    }
90}
91
92struct TokioClock;
93impl Clock for TokioClock {
94    fn now(&self) -> Instant {
95        Instant::now()
96    }
97    fn advance_to(&self, _instant: Instant) {}
98}
99
100/// A [`StageGraph`] implementation that dispatches each stage as a task on the Tokio global pool.
101///
102/// *This is currently only a minimal sketch that will likely not fit the intended design.
103/// It is more likely that the effect handling will be done like in the [`SimulationBuilder`](crate::simulation::SimulationBuilder)
104/// implementation.*
105pub struct TokioBuilder {
106    tasks: Vec<Box<dyn FnOnce(Arc<TokioInner>) -> BoxFuture<'static, ()>>>,
107    inner: TokioInner,
108    termination: watch::Receiver<bool>,
109    termination_tx: watch::Sender<bool>,
110}
111
112impl Default for TokioBuilder {
113    fn default() -> Self {
114        let (termination_tx, termination_rx) = watch::channel(false);
115        Self { tasks: Default::default(), inner: TokioInner::new(), termination_tx, termination: termination_rx }
116    }
117}
118
119impl TokioBuilder {
120    pub fn run(self, rt: Handle) -> TokioRunning {
121        let Self {
122            tasks,
123            inner,
124            termination,
125            termination_tx: _, // only statically spawned stages can terminate the network
126        } = self;
127        let inner = Arc::new(inner);
128        let handles = tasks.into_iter().map(|t| rt.spawn(t(inner.clone()))).collect::<Vec<_>>();
129        inner.handles.lock().extend(handles);
130
131        // abort all tasks as soon as the termination signal is received
132        let mut termination2 = termination.clone();
133        let inner2 = inner.clone();
134        rt.spawn(async move {
135            termination2.wait_for(|x| *x).await.ok();
136            let handles = inner2.handles.lock();
137            tracing::info!(stages = handles.len(), "termination signal received, shutting down stages");
138            for handle in handles.iter() {
139                handle.abort();
140            }
141        });
142
143        TokioRunning { inner, termination }
144    }
145
146    pub fn with_trace_buffer(mut self, trace_buffer: Arc<Mutex<TraceBuffer>>) -> Self {
147        self.inner.trace_buffer = trace_buffer;
148        self
149    }
150
151    pub fn with_schedule_ids(mut self, schedule_ids: ScheduleIds) -> Self {
152        self.inner.schedule_ids = schedule_ids;
153        self
154    }
155
156    pub fn with_epoch_clock(mut self) -> Self {
157        self.inner.clock = Arc::new(EpochClock::new());
158        self
159    }
160}
161
162struct EpochClock {
163    offset: Mutex<Option<Duration>>,
164}
165
166impl EpochClock {
167    fn new() -> Self {
168        Self { offset: Mutex::new(None) }
169    }
170}
171
172impl Clock for EpochClock {
173    fn now(&self) -> Instant {
174        let mut offset = self.offset.lock();
175        if let Some(offset) = *offset {
176            Instant::now() - offset
177        } else {
178            let now = Instant::now();
179            let since_epoch = now.saturating_since(*EPOCH);
180            *offset = Some(since_epoch);
181            now - since_epoch
182        }
183    }
184
185    fn advance_to(&self, _instant: Instant) {}
186}
187
188type RefAux = (Receiver<Box<dyn SendData>>, TransitionFactory);
189
190impl StageGraph for TokioBuilder {
191    #[expect(clippy::expect_used)]
192    fn stage<Msg: SendData, St: SendData, F, Fut>(
193        &mut self,
194        name: impl AsRef<str>,
195        mut f: F,
196    ) -> StageBuildRef<Msg, St, Box<dyn Any + Send>>
197    where
198        F: FnMut(St, Msg, Effects<Msg>) -> Fut + 'static + Send,
199        Fut: Future<Output = St> + 'static + Send,
200    {
201        // THIS MUST MATCH THE SIMULATION BUILDER
202        let name = stage_name(&mut self.inner.stage_counter.lock(), name.as_ref());
203        let (tx, rx) = mpsc::channel(self.inner.mailbox_size);
204        self.inner.senders.lock().insert(name.clone(), StageOrAdapter::Stage(tx));
205
206        let me = StageRef::new(name.clone());
207        let clock = self.inner.clock.clone();
208        let resources = self.inner.resources.clone();
209        let schedule_ids = self.inner.schedule_ids.clone();
210        let trace_buffer = self.inner.trace_buffer.clone();
211        let ff = Box::new(move |effect| {
212            let eff = Effects::new(me, effect, clock, resources, schedule_ids, trace_buffer);
213            Box::new(move |state: Box<dyn SendData>, msg: Box<dyn SendData>| {
214                let state = state.cast::<St>().expect("internal state type error");
215                let msg = msg.cast::<Msg>().expect("internal message type error");
216                let state = f(*state, *msg, eff.clone());
217                Box::pin(async move { Box::new(state.await) as Box<dyn SendData> })
218                    as BoxFuture<'static, Box<dyn SendData>>
219            }) as Transition
220        });
221        let network: RefAux = (rx, ff);
222
223        StageBuildRef { name, network: Box::new(network), _ph: PhantomData }
224    }
225
226    #[expect(clippy::expect_used)]
227    fn wire_up<Msg: SendData, St: SendData>(
228        &mut self,
229        stage: StageBuildRef<Msg, St, Box<dyn Any + Send>>,
230        state: St,
231    ) -> StageStateRef<Msg, St> {
232        let StageBuildRef { name, network, _ph } = stage;
233        let (rx, ff) = *network.downcast::<RefAux>().expect("internal network type error");
234        let stage_name = name.clone();
235        let state = Box::new(state);
236        let termination_tx = self.termination_tx.clone();
237        self.tasks.push(Box::new(move |inner| {
238            let stage = run_stage_boxed(state, rx, ff, stage_name, inner);
239            Box::pin(async move {
240                stage.await;
241                termination_tx.send_replace(true);
242            })
243        }));
244        StageStateRef::new(name)
245    }
246
247    fn contramap<Original: SendData, Mapped: SendData>(
248        &mut self,
249        stage_ref: impl AsRef<StageRef<Original>>,
250        new_name: impl AsRef<str>,
251        transform: impl Fn(Mapped) -> Original + 'static + Send,
252    ) -> StageRef<Mapped> {
253        let target = stage_ref.as_ref();
254        let new_name = stage_name(&mut self.inner.stage_counter.lock(), new_name.as_ref());
255        let adapter = Adapter::new(new_name.clone(), target.name().clone(), transform);
256        self.inner.senders.lock().insert(new_name.clone(), StageOrAdapter::Adapter(adapter));
257        StageRef::new(new_name)
258    }
259
260    fn preload<Msg: SendData>(
261        &mut self,
262        stage: impl AsRef<StageRef<Msg>>,
263        messages: impl IntoIterator<Item = Msg>,
264    ) -> Result<(), Box<dyn SendData>> {
265        let stage = stage.as_ref();
266        let mut senders = self.inner.senders.lock();
267        for msg in messages {
268            if let Some((tx, msg)) = find_recipient(&mut senders, stage.name().clone(), Some(Box::new(msg)))
269                && let Err(err) = tx.try_send(msg)
270            {
271                tracing::warn!("message preload failed to stage `{}`", stage.name());
272                return Err(err.into_inner());
273            }
274        }
275        Ok(())
276    }
277
278    fn input<Msg: SendData>(&mut self, stage: impl AsRef<StageRef<Msg>>) -> Sender<Msg> {
279        mk_sender(stage.as_ref().name(), &self.inner)
280    }
281
282    fn resources(&self) -> &Resources {
283        &self.inner.resources
284    }
285}
286
287enum PriorityMessage {
288    Scheduled(Box<dyn SendData>, ScheduleId, watch::Receiver<bool>),
289    TimerCancelled(ScheduleId),
290    Tombstone(Box<dyn SendData>),
291}
292
293async fn run_stage_boxed(
294    mut state: Box<dyn SendData>,
295    mut rx: Receiver<Box<dyn SendData + 'static>>,
296    transition: TransitionFactory,
297    stage_name: Name,
298    inner: Arc<TokioInner>,
299) {
300    tracing::debug!("running stage `{stage_name}`");
301
302    let effect = Arc::new(Mutex::new(None));
303    let mut transition = transition(effect.clone());
304
305    // this also contains tasks tracking the termination of spawned stages, which when dropped
306    // will terminate those spawned stages
307    let mut timers = FuturesUnordered::<BoxFuture<'static, PriorityMessage>>::new();
308    let mut cancel_senders = BTreeMap::<ScheduleId, watch::Sender<bool>>::new();
309
310    let tb = DropGuard::new(inner.trace_buffer.clone(), |tb| {
311        // ensure that Aborted is traced when this Future is dropped
312        tb.lock().push_terminated_aborted(&stage_name)
313    });
314
315    let mut msgs = Vec::new();
316
317    inner.trace_buffer.lock().push_state(&stage_name, &state);
318
319    'outer: loop {
320        let poll_timers = !timers.is_empty();
321        // if multiple timers have fired since the last poll, we need them all so that we can deliver them in order
322        let mut timer_chunks = (&mut timers).ready_chunks(1000);
323
324        tokio::select! { biased;
325            Some(res) = timer_chunks.next(), if poll_timers => {
326                let mut scheduled = Vec::new();
327                for msg in res {
328                    match msg {
329                        PriorityMessage::Scheduled(msg, id, cancelation) => {
330                            scheduled.push((id, msg, cancelation));
331                        }
332                        PriorityMessage::TimerCancelled(_id) => {}
333                        PriorityMessage::Tombstone(msg) => msgs.push((msg, None)),
334                    }
335                }
336                // ensure that earliest timer is delivered first
337                scheduled.sort_by_key(|(id, _, _)| *id);
338                for (id, msg, cancelation) in scheduled {
339                    msgs.push((msg, Some((id, cancelation))));
340                }
341            }
342            Some(msg) = rx.recv() => msgs.push((msg, None)),
343            else => {
344                tracing::error!(%stage_name, "stage sender dropped");
345                break;
346            }
347        }
348
349        for (msg, cancelation) in msgs.drain(..) {
350            if let Some((id, canceled)) = cancelation {
351                cancel_senders.remove(&id);
352                if *canceled.borrow() {
353                    // cancellation happened after the timer fired but before the message was delivered
354                    continue;
355                }
356            }
357
358            if let Ok(CanSupervise(child)) = msg.cast_ref::<CanSupervise>() {
359                tracing::debug!("stage `{stage_name}` terminates because of an unsupervised child termination");
360                tb.lock().push_terminated_supervision(&stage_name, child);
361                break 'outer;
362            }
363
364            inner.trace_buffer.lock().push_input(&stage_name, &msg);
365
366            let f = (transition)(state, msg);
367            let result = interpreter(&inner, &effect, &stage_name, &mut timers, &mut cancel_senders, f).await;
368            match result {
369                Some(st) => state = st,
370                None => {
371                    tracing::info!(%stage_name, "terminated");
372                    tb.lock().push_terminated_voluntary(&stage_name);
373                    break 'outer;
374                }
375            }
376
377            inner.trace_buffer.lock().push_state(&stage_name, &state);
378        }
379    }
380
381    DropGuard::into_inner(tb);
382}
383
384#[expect(clippy::expect_used, clippy::panic)]
385fn mk_sender<Msg: SendData>(stage_name: &Name, inner: &TokioInner) -> Sender<Msg> {
386    let senders = inner.senders.lock();
387    let StageOrAdapter::Stage(tx) = senders.get(stage_name).expect("stage ref contained unknown name") else {
388        panic!("cannot obtain input for adapter");
389    };
390    let tx = tx.clone();
391    Sender::new(Arc::new(move |msg: Msg| {
392        let tx = tx.clone();
393        Box::pin(async move {
394            tx.send(Box::new(msg)).await.map_err(|msg| *msg.0.cast::<Msg>().expect("message was just boxed"))
395        })
396    }))
397}
398
399type StageRefExtra = Mutex<Option<oneshot::Sender<Box<dyn SendData>>>>;
400
401// clippy is lying, changing to async fn does not work.
402#[expect(clippy::manual_async_fn)]
403fn interpreter(
404    inner: &Arc<TokioInner>,
405    effect: &EffectBox,
406    name: &Name,
407    timers: &mut FuturesUnordered<BoxFuture<'static, PriorityMessage>>,
408    cancel_senders: &mut BTreeMap<ScheduleId, watch::Sender<bool>>,
409    mut stage: BoxFuture<'static, Box<dyn SendData>>,
410) -> impl Future<Output = Option<Box<dyn SendData>>> + Send {
411    // trying to write this as an async fn fails with inscrutable compile errors, it seems
412    // that rustc has some issue with this particular pattern
413    async move {
414        let tb = || inner.trace_buffer.lock();
415        tb().push_resume(name, &StageResponse::Unit);
416        loop {
417            let poll = {
418                let _span = trace_span!(stage::tokio::POLL, stage = %name).entered();
419                stage.as_mut().poll(&mut Context::from_waker(Waker::noop()))
420            };
421            if let Poll::Ready(state) = poll {
422                return Some(state);
423            }
424            drop(poll);
425
426            #[expect(clippy::panic)]
427            let Some(Left(eff)) = effect.lock().take() else {
428                panic!("stage `{name}` used .await on something that was not a stage effect");
429            };
430            // this does not push the Call effect because getting the message consumes it
431            tb().push_suspend_ref(name, &eff);
432
433            let resp = match eff {
434                StageEffect::Receive => {
435                    #[expect(clippy::panic)]
436                    {
437                        panic!("effect Receive cannot be explicitly awaited (stage `{name}`)")
438                    }
439                }
440                StageEffect::Send(target, ..) if target.is_empty() => {
441                    tracing::warn!(stage = %name, "message send to blackhole stage dropped");
442                    StageResponse::Unit
443                }
444                StageEffect::Send(_target, Some(call), msg) => {
445                    #[expect(clippy::expect_used)]
446                    let sender = call.downcast_ref::<StageRefExtra>().expect("expected CallExtra");
447                    if let Some(sender) = sender.lock().take() {
448                        sender.send(msg).ok();
449                    }
450                    StageResponse::Unit
451                }
452                StageEffect::Send(target, None, msg) => {
453                    let (tx, msg) = {
454                        let mut senders = inner.senders.lock();
455                        #[expect(clippy::expect_used)]
456                        let (tx, msg) = find_recipient(&mut senders, target.clone(), Some(msg))
457                            .expect("stage ref contained unknown name");
458                        (tx.clone(), msg)
459                    };
460                    tx.send(msg).await.ok();
461                    StageResponse::Unit
462                }
463                StageEffect::Call(target, duration, msg) => {
464                    #[expect(clippy::panic)]
465                    let CallExtra::CallFn(NoDebug(msg)) = msg else {
466                        panic!("expected CallFn, got {:?}", msg);
467                    };
468                    let (tx_response, rx) = oneshot::channel();
469                    // it is important to use the type alias StageRefExtra here, otherwise the
470                    // compiler would accept any type that implements Send + Sync + 'static
471                    let sender = StageRefExtra::new(Some(tx_response));
472                    let msg = (msg)(name.clone(), Arc::new(sender));
473
474                    tb().push_suspend_call(name, &target, duration, &*msg);
475
476                    let (tx_call, msg) = {
477                        let mut senders = inner.senders.lock();
478                        #[expect(clippy::expect_used)]
479                        let (tx, msg) = find_recipient(&mut senders, target.clone(), Some(msg))
480                            .expect("stage ref contained unknown name");
481                        (tx.clone(), msg)
482                    };
483                    tx_call.send(msg).await.ok();
484                    match tokio::time::timeout(duration, rx).await {
485                        Ok(Ok(msg)) => StageResponse::CallResponse(msg),
486                        _ => StageResponse::CallResponse(Box::new(CallTimeout)),
487                    }
488                }
489                StageEffect::Clock => StageResponse::ClockResponse(now()),
490                StageEffect::Wait(duration) => {
491                    tokio::time::sleep(duration).await;
492                    StageResponse::WaitResponse(now())
493                }
494                StageEffect::External(effect) => {
495                    tracing::debug!("stage `{name}` external effect: {:?}", effect);
496                    StageResponse::ExternalResponse(effect.run(inner.resources.clone()).await)
497                }
498                StageEffect::Terminate => {
499                    tracing::warn!("stage `{name}` terminated");
500                    return None;
501                }
502                StageEffect::AddStage(name) => {
503                    tracing::debug!("stage `{name}` added");
504                    let name = stage_name(&mut inner.stage_counter.lock(), name.as_str());
505                    StageResponse::AddStageResponse(name)
506                }
507                StageEffect::WireStage(name, transition, initial_state, tombstone) => {
508                    tracing::debug!("stage `{name}` wired");
509                    let (tx, rx) = mpsc::channel(inner.mailbox_size);
510                    inner.senders.lock().insert(name.clone(), StageOrAdapter::Stage(tx));
511                    let stage =
512                        run_stage_boxed(initial_state, rx, transition.into_inner(), name.clone(), inner.clone());
513                    let handle = tokio::spawn(stage);
514                    // need to construct DropGuard before pushing into the FuturesUnordered to avoid Future being dropped
515                    // before the guard is established
516                    let mut handle = DropGuard::new(handle, |handle| handle.abort());
517                    timers.push(Box::pin(async move {
518                        if let Err(err) = (&mut *handle).await {
519                            tracing::error!("stage `{name}` failed: {}", err);
520                        }
521                        PriorityMessage::Tombstone(tombstone)
522                    }));
523                    StageResponse::Unit
524                }
525                StageEffect::Contramap { original, new_name, transform } => {
526                    tracing::debug!("contramap {original} -> {new_name}");
527                    let name = stage_name(&mut inner.stage_counter.lock(), new_name.as_str());
528                    inner.senders.lock().insert(
529                        name.clone(),
530                        StageOrAdapter::Adapter(Adapter {
531                            name: name.clone(),
532                            target: original,
533                            transform: transform.into_inner(),
534                        }),
535                    );
536                    StageResponse::ContramapResponse(name)
537                }
538                StageEffect::Schedule(msg, id) => {
539                    let when = id.time();
540                    let sleep = tokio::time::sleep_until(when.to_tokio());
541                    let (tx, mut rx) = watch::channel(false);
542                    cancel_senders.insert(id, tx);
543                    timers.push(Box::pin(async move {
544                        let rx2 = rx.clone();
545                        tokio::select! { biased;
546                            _ = rx.wait_for(|x| *x) => PriorityMessage::TimerCancelled(id),
547                            _ = sleep => PriorityMessage::Scheduled(msg, id, rx2),
548                        }
549                    }));
550                    StageResponse::Unit
551                }
552                StageEffect::CancelSchedule(id) => {
553                    if let Some(tx) = cancel_senders.remove(&id) {
554                        tx.send_replace(true);
555                        StageResponse::CancelScheduleResponse(true)
556                    } else {
557                        StageResponse::CancelScheduleResponse(false)
558                    }
559                }
560            };
561            tb().push_resume(name, &resp);
562            *effect.lock() = Some(Right(resp));
563        }
564    }
565}
566
567fn now() -> Instant {
568    Instant::from_tokio(tokio::time::Instant::now())
569}
570
571/// Handle to the running stages.
572#[derive(Clone)]
573#[must_use = "this handle needs to be either joined or aborted"]
574pub struct TokioRunning {
575    inner: Arc<TokioInner>,
576    termination: watch::Receiver<bool>,
577}
578
579impl TokioRunning {
580    /// Abort all stage tasks of this network.
581    pub fn abort(self) {
582        for handle in self.inner.handles.lock().iter() {
583            handle.abort();
584        }
585    }
586
587    pub async fn join(self) {
588        poll_fn(move |cx| {
589            let mut handles = self.inner.handles.lock();
590            handles.retain_mut(|h| {
591                if let Poll::Ready(res) = h.poll_unpin(cx) {
592                    match res {
593                        Ok(_) => tracing::info!("stage task completed"),
594                        Err(err) if err.is_cancelled() => tracing::info!("stage task cancelled"),
595                        Err(err) => tracing::error!("stage task failed: {:?}", err),
596                    }
597                    false
598                } else {
599                    true
600                }
601            });
602            if handles.is_empty() { Poll::Ready(()) } else { Poll::Pending }
603        })
604        .await;
605    }
606
607    pub fn trace_buffer(&self) -> &Arc<Mutex<TraceBuffer>> {
608        &self.inner.trace_buffer
609    }
610
611    pub fn resources(&self) -> &Resources {
612        &self.inner.resources
613    }
614}
615
616impl StageGraphRunning for TokioRunning {
617    fn is_terminated(&self) -> bool {
618        *self.termination.borrow()
619    }
620
621    fn termination(&self) -> BoxFuture<'static, ()> {
622        let mut rx = self.termination.clone();
623        Box::pin(async move {
624            rx.wait_for(|x| *x).await.ok();
625        })
626    }
627}