1use std::{
21 any::Any,
22 collections::BTreeMap,
23 future::{Future, poll_fn},
24 marker::PhantomData,
25 sync::Arc,
26 task::{Context, Poll, Waker},
27 time::Duration,
28};
29
30use amaru_observability::{amaru::stage, trace_span};
31use either::Either::{Left, Right};
32use futures_util::{FutureExt, StreamExt, stream::FuturesUnordered};
33use parking_lot::Mutex;
34use tokio::{
35 runtime::Handle,
36 sync::{
37 mpsc::{self, Receiver},
38 oneshot, watch,
39 },
40 task::JoinHandle,
41};
42
43use crate::{
44 BoxFuture, EPOCH, Effects, Instant, Name, ScheduleId, ScheduleIds, SendData, Sender, StageBuildRef, StageGraph,
45 StageRef,
46 adapter::{Adapter, StageOrAdapter, find_recipient},
47 drop_guard::DropGuard,
48 effect::{CallExtra, CallTimeout, CanSupervise, StageEffect, StageResponse, TransitionFactory},
49 effect_box::EffectBox,
50 resources::Resources,
51 serde::NoDebug,
52 simulation::Transition,
53 stage_name,
54 stage_ref::StageStateRef,
55 stagegraph::StageGraphRunning,
56 time::Clock,
57 trace_buffer::TraceBuffer,
58};
59
60#[derive(Debug, thiserror::Error)]
61#[error("message send failed to stage `{target}`")]
62pub struct SendError {
63 target: Name,
64}
65
66struct TokioInner {
67 senders: Mutex<BTreeMap<Name, StageOrAdapter<mpsc::Sender<Box<dyn SendData>>>>>,
68 handles: Mutex<Vec<JoinHandle<()>>>,
69 clock: Arc<dyn Clock + Send + Sync>,
70 resources: Resources,
71 schedule_ids: ScheduleIds,
72 mailbox_size: usize,
73 stage_counter: Mutex<usize>,
74 trace_buffer: Arc<Mutex<TraceBuffer>>,
75}
76
77impl TokioInner {
78 fn new() -> Self {
79 Self {
80 senders: Default::default(),
81 handles: Default::default(),
82 clock: Arc::new(TokioClock),
83 resources: Resources::default(),
84 schedule_ids: ScheduleIds::default(),
85 mailbox_size: 10,
86 stage_counter: Mutex::new(0usize),
87 trace_buffer: TraceBuffer::new_shared(0, 0),
88 }
89 }
90}
91
92struct TokioClock;
93impl Clock for TokioClock {
94 fn now(&self) -> Instant {
95 Instant::now()
96 }
97 fn advance_to(&self, _instant: Instant) {}
98}
99
100pub struct TokioBuilder {
106 tasks: Vec<Box<dyn FnOnce(Arc<TokioInner>) -> BoxFuture<'static, ()>>>,
107 inner: TokioInner,
108 termination: watch::Receiver<bool>,
109 termination_tx: watch::Sender<bool>,
110}
111
112impl Default for TokioBuilder {
113 fn default() -> Self {
114 let (termination_tx, termination_rx) = watch::channel(false);
115 Self { tasks: Default::default(), inner: TokioInner::new(), termination_tx, termination: termination_rx }
116 }
117}
118
119impl TokioBuilder {
120 pub fn run(self, rt: Handle) -> TokioRunning {
121 let Self {
122 tasks,
123 inner,
124 termination,
125 termination_tx: _, } = self;
127 let inner = Arc::new(inner);
128 let handles = tasks.into_iter().map(|t| rt.spawn(t(inner.clone()))).collect::<Vec<_>>();
129 inner.handles.lock().extend(handles);
130
131 let mut termination2 = termination.clone();
133 let inner2 = inner.clone();
134 rt.spawn(async move {
135 termination2.wait_for(|x| *x).await.ok();
136 let handles = inner2.handles.lock();
137 tracing::info!(stages = handles.len(), "termination signal received, shutting down stages");
138 for handle in handles.iter() {
139 handle.abort();
140 }
141 });
142
143 TokioRunning { inner, termination }
144 }
145
146 pub fn with_trace_buffer(mut self, trace_buffer: Arc<Mutex<TraceBuffer>>) -> Self {
147 self.inner.trace_buffer = trace_buffer;
148 self
149 }
150
151 pub fn with_schedule_ids(mut self, schedule_ids: ScheduleIds) -> Self {
152 self.inner.schedule_ids = schedule_ids;
153 self
154 }
155
156 pub fn with_epoch_clock(mut self) -> Self {
157 self.inner.clock = Arc::new(EpochClock::new());
158 self
159 }
160}
161
162struct EpochClock {
163 offset: Mutex<Option<Duration>>,
164}
165
166impl EpochClock {
167 fn new() -> Self {
168 Self { offset: Mutex::new(None) }
169 }
170}
171
172impl Clock for EpochClock {
173 fn now(&self) -> Instant {
174 let mut offset = self.offset.lock();
175 if let Some(offset) = *offset {
176 Instant::now() - offset
177 } else {
178 let now = Instant::now();
179 let since_epoch = now.saturating_since(*EPOCH);
180 *offset = Some(since_epoch);
181 now - since_epoch
182 }
183 }
184
185 fn advance_to(&self, _instant: Instant) {}
186}
187
188type RefAux = (Receiver<Box<dyn SendData>>, TransitionFactory);
189
190impl StageGraph for TokioBuilder {
191 #[expect(clippy::expect_used)]
192 fn stage<Msg: SendData, St: SendData, F, Fut>(
193 &mut self,
194 name: impl AsRef<str>,
195 mut f: F,
196 ) -> StageBuildRef<Msg, St, Box<dyn Any + Send>>
197 where
198 F: FnMut(St, Msg, Effects<Msg>) -> Fut + 'static + Send,
199 Fut: Future<Output = St> + 'static + Send,
200 {
201 let name = stage_name(&mut self.inner.stage_counter.lock(), name.as_ref());
203 let (tx, rx) = mpsc::channel(self.inner.mailbox_size);
204 self.inner.senders.lock().insert(name.clone(), StageOrAdapter::Stage(tx));
205
206 let me = StageRef::new(name.clone());
207 let clock = self.inner.clock.clone();
208 let resources = self.inner.resources.clone();
209 let schedule_ids = self.inner.schedule_ids.clone();
210 let trace_buffer = self.inner.trace_buffer.clone();
211 let ff = Box::new(move |effect| {
212 let eff = Effects::new(me, effect, clock, resources, schedule_ids, trace_buffer);
213 Box::new(move |state: Box<dyn SendData>, msg: Box<dyn SendData>| {
214 let state = state.cast::<St>().expect("internal state type error");
215 let msg = msg.cast::<Msg>().expect("internal message type error");
216 let state = f(*state, *msg, eff.clone());
217 Box::pin(async move { Box::new(state.await) as Box<dyn SendData> })
218 as BoxFuture<'static, Box<dyn SendData>>
219 }) as Transition
220 });
221 let network: RefAux = (rx, ff);
222
223 StageBuildRef { name, network: Box::new(network), _ph: PhantomData }
224 }
225
226 #[expect(clippy::expect_used)]
227 fn wire_up<Msg: SendData, St: SendData>(
228 &mut self,
229 stage: StageBuildRef<Msg, St, Box<dyn Any + Send>>,
230 state: St,
231 ) -> StageStateRef<Msg, St> {
232 let StageBuildRef { name, network, _ph } = stage;
233 let (rx, ff) = *network.downcast::<RefAux>().expect("internal network type error");
234 let stage_name = name.clone();
235 let state = Box::new(state);
236 let termination_tx = self.termination_tx.clone();
237 self.tasks.push(Box::new(move |inner| {
238 let stage = run_stage_boxed(state, rx, ff, stage_name, inner);
239 Box::pin(async move {
240 stage.await;
241 termination_tx.send_replace(true);
242 })
243 }));
244 StageStateRef::new(name)
245 }
246
247 fn contramap<Original: SendData, Mapped: SendData>(
248 &mut self,
249 stage_ref: impl AsRef<StageRef<Original>>,
250 new_name: impl AsRef<str>,
251 transform: impl Fn(Mapped) -> Original + 'static + Send,
252 ) -> StageRef<Mapped> {
253 let target = stage_ref.as_ref();
254 let new_name = stage_name(&mut self.inner.stage_counter.lock(), new_name.as_ref());
255 let adapter = Adapter::new(new_name.clone(), target.name().clone(), transform);
256 self.inner.senders.lock().insert(new_name.clone(), StageOrAdapter::Adapter(adapter));
257 StageRef::new(new_name)
258 }
259
260 fn preload<Msg: SendData>(
261 &mut self,
262 stage: impl AsRef<StageRef<Msg>>,
263 messages: impl IntoIterator<Item = Msg>,
264 ) -> Result<(), Box<dyn SendData>> {
265 let stage = stage.as_ref();
266 let mut senders = self.inner.senders.lock();
267 for msg in messages {
268 if let Some((tx, msg)) = find_recipient(&mut senders, stage.name().clone(), Some(Box::new(msg)))
269 && let Err(err) = tx.try_send(msg)
270 {
271 tracing::warn!("message preload failed to stage `{}`", stage.name());
272 return Err(err.into_inner());
273 }
274 }
275 Ok(())
276 }
277
278 fn input<Msg: SendData>(&mut self, stage: impl AsRef<StageRef<Msg>>) -> Sender<Msg> {
279 mk_sender(stage.as_ref().name(), &self.inner)
280 }
281
282 fn resources(&self) -> &Resources {
283 &self.inner.resources
284 }
285}
286
287enum PriorityMessage {
288 Scheduled(Box<dyn SendData>, ScheduleId, watch::Receiver<bool>),
289 TimerCancelled(ScheduleId),
290 Tombstone(Box<dyn SendData>),
291}
292
293async fn run_stage_boxed(
294 mut state: Box<dyn SendData>,
295 mut rx: Receiver<Box<dyn SendData + 'static>>,
296 transition: TransitionFactory,
297 stage_name: Name,
298 inner: Arc<TokioInner>,
299) {
300 tracing::debug!("running stage `{stage_name}`");
301
302 let effect = Arc::new(Mutex::new(None));
303 let mut transition = transition(effect.clone());
304
305 let mut timers = FuturesUnordered::<BoxFuture<'static, PriorityMessage>>::new();
308 let mut cancel_senders = BTreeMap::<ScheduleId, watch::Sender<bool>>::new();
309
310 let tb = DropGuard::new(inner.trace_buffer.clone(), |tb| {
311 tb.lock().push_terminated_aborted(&stage_name)
313 });
314
315 let mut msgs = Vec::new();
316
317 inner.trace_buffer.lock().push_state(&stage_name, &state);
318
319 'outer: loop {
320 let poll_timers = !timers.is_empty();
321 let mut timer_chunks = (&mut timers).ready_chunks(1000);
323
324 tokio::select! { biased;
325 Some(res) = timer_chunks.next(), if poll_timers => {
326 let mut scheduled = Vec::new();
327 for msg in res {
328 match msg {
329 PriorityMessage::Scheduled(msg, id, cancelation) => {
330 scheduled.push((id, msg, cancelation));
331 }
332 PriorityMessage::TimerCancelled(_id) => {}
333 PriorityMessage::Tombstone(msg) => msgs.push((msg, None)),
334 }
335 }
336 scheduled.sort_by_key(|(id, _, _)| *id);
338 for (id, msg, cancelation) in scheduled {
339 msgs.push((msg, Some((id, cancelation))));
340 }
341 }
342 Some(msg) = rx.recv() => msgs.push((msg, None)),
343 else => {
344 tracing::error!(%stage_name, "stage sender dropped");
345 break;
346 }
347 }
348
349 for (msg, cancelation) in msgs.drain(..) {
350 if let Some((id, canceled)) = cancelation {
351 cancel_senders.remove(&id);
352 if *canceled.borrow() {
353 continue;
355 }
356 }
357
358 if let Ok(CanSupervise(child)) = msg.cast_ref::<CanSupervise>() {
359 tracing::debug!("stage `{stage_name}` terminates because of an unsupervised child termination");
360 tb.lock().push_terminated_supervision(&stage_name, child);
361 break 'outer;
362 }
363
364 inner.trace_buffer.lock().push_input(&stage_name, &msg);
365
366 let f = (transition)(state, msg);
367 let result = interpreter(&inner, &effect, &stage_name, &mut timers, &mut cancel_senders, f).await;
368 match result {
369 Some(st) => state = st,
370 None => {
371 tracing::info!(%stage_name, "terminated");
372 tb.lock().push_terminated_voluntary(&stage_name);
373 break 'outer;
374 }
375 }
376
377 inner.trace_buffer.lock().push_state(&stage_name, &state);
378 }
379 }
380
381 DropGuard::into_inner(tb);
382}
383
384#[expect(clippy::expect_used, clippy::panic)]
385fn mk_sender<Msg: SendData>(stage_name: &Name, inner: &TokioInner) -> Sender<Msg> {
386 let senders = inner.senders.lock();
387 let StageOrAdapter::Stage(tx) = senders.get(stage_name).expect("stage ref contained unknown name") else {
388 panic!("cannot obtain input for adapter");
389 };
390 let tx = tx.clone();
391 Sender::new(Arc::new(move |msg: Msg| {
392 let tx = tx.clone();
393 Box::pin(async move {
394 tx.send(Box::new(msg)).await.map_err(|msg| *msg.0.cast::<Msg>().expect("message was just boxed"))
395 })
396 }))
397}
398
399type StageRefExtra = Mutex<Option<oneshot::Sender<Box<dyn SendData>>>>;
400
401#[expect(clippy::manual_async_fn)]
403fn interpreter(
404 inner: &Arc<TokioInner>,
405 effect: &EffectBox,
406 name: &Name,
407 timers: &mut FuturesUnordered<BoxFuture<'static, PriorityMessage>>,
408 cancel_senders: &mut BTreeMap<ScheduleId, watch::Sender<bool>>,
409 mut stage: BoxFuture<'static, Box<dyn SendData>>,
410) -> impl Future<Output = Option<Box<dyn SendData>>> + Send {
411 async move {
414 let tb = || inner.trace_buffer.lock();
415 tb().push_resume(name, &StageResponse::Unit);
416 loop {
417 let poll = {
418 let _span = trace_span!(stage::tokio::POLL, stage = %name).entered();
419 stage.as_mut().poll(&mut Context::from_waker(Waker::noop()))
420 };
421 if let Poll::Ready(state) = poll {
422 return Some(state);
423 }
424 drop(poll);
425
426 #[expect(clippy::panic)]
427 let Some(Left(eff)) = effect.lock().take() else {
428 panic!("stage `{name}` used .await on something that was not a stage effect");
429 };
430 tb().push_suspend_ref(name, &eff);
432
433 let resp = match eff {
434 StageEffect::Receive => {
435 #[expect(clippy::panic)]
436 {
437 panic!("effect Receive cannot be explicitly awaited (stage `{name}`)")
438 }
439 }
440 StageEffect::Send(target, ..) if target.is_empty() => {
441 tracing::warn!(stage = %name, "message send to blackhole stage dropped");
442 StageResponse::Unit
443 }
444 StageEffect::Send(_target, Some(call), msg) => {
445 #[expect(clippy::expect_used)]
446 let sender = call.downcast_ref::<StageRefExtra>().expect("expected CallExtra");
447 if let Some(sender) = sender.lock().take() {
448 sender.send(msg).ok();
449 }
450 StageResponse::Unit
451 }
452 StageEffect::Send(target, None, msg) => {
453 let (tx, msg) = {
454 let mut senders = inner.senders.lock();
455 #[expect(clippy::expect_used)]
456 let (tx, msg) = find_recipient(&mut senders, target.clone(), Some(msg))
457 .expect("stage ref contained unknown name");
458 (tx.clone(), msg)
459 };
460 tx.send(msg).await.ok();
461 StageResponse::Unit
462 }
463 StageEffect::Call(target, duration, msg) => {
464 #[expect(clippy::panic)]
465 let CallExtra::CallFn(NoDebug(msg)) = msg else {
466 panic!("expected CallFn, got {:?}", msg);
467 };
468 let (tx_response, rx) = oneshot::channel();
469 let sender = StageRefExtra::new(Some(tx_response));
472 let msg = (msg)(name.clone(), Arc::new(sender));
473
474 tb().push_suspend_call(name, &target, duration, &*msg);
475
476 let (tx_call, msg) = {
477 let mut senders = inner.senders.lock();
478 #[expect(clippy::expect_used)]
479 let (tx, msg) = find_recipient(&mut senders, target.clone(), Some(msg))
480 .expect("stage ref contained unknown name");
481 (tx.clone(), msg)
482 };
483 tx_call.send(msg).await.ok();
484 match tokio::time::timeout(duration, rx).await {
485 Ok(Ok(msg)) => StageResponse::CallResponse(msg),
486 _ => StageResponse::CallResponse(Box::new(CallTimeout)),
487 }
488 }
489 StageEffect::Clock => StageResponse::ClockResponse(now()),
490 StageEffect::Wait(duration) => {
491 tokio::time::sleep(duration).await;
492 StageResponse::WaitResponse(now())
493 }
494 StageEffect::External(effect) => {
495 tracing::debug!("stage `{name}` external effect: {:?}", effect);
496 StageResponse::ExternalResponse(effect.run(inner.resources.clone()).await)
497 }
498 StageEffect::Terminate => {
499 tracing::warn!("stage `{name}` terminated");
500 return None;
501 }
502 StageEffect::AddStage(name) => {
503 tracing::debug!("stage `{name}` added");
504 let name = stage_name(&mut inner.stage_counter.lock(), name.as_str());
505 StageResponse::AddStageResponse(name)
506 }
507 StageEffect::WireStage(name, transition, initial_state, tombstone) => {
508 tracing::debug!("stage `{name}` wired");
509 let (tx, rx) = mpsc::channel(inner.mailbox_size);
510 inner.senders.lock().insert(name.clone(), StageOrAdapter::Stage(tx));
511 let stage =
512 run_stage_boxed(initial_state, rx, transition.into_inner(), name.clone(), inner.clone());
513 let handle = tokio::spawn(stage);
514 let mut handle = DropGuard::new(handle, |handle| handle.abort());
517 timers.push(Box::pin(async move {
518 if let Err(err) = (&mut *handle).await {
519 tracing::error!("stage `{name}` failed: {}", err);
520 }
521 PriorityMessage::Tombstone(tombstone)
522 }));
523 StageResponse::Unit
524 }
525 StageEffect::Contramap { original, new_name, transform } => {
526 tracing::debug!("contramap {original} -> {new_name}");
527 let name = stage_name(&mut inner.stage_counter.lock(), new_name.as_str());
528 inner.senders.lock().insert(
529 name.clone(),
530 StageOrAdapter::Adapter(Adapter {
531 name: name.clone(),
532 target: original,
533 transform: transform.into_inner(),
534 }),
535 );
536 StageResponse::ContramapResponse(name)
537 }
538 StageEffect::Schedule(msg, id) => {
539 let when = id.time();
540 let sleep = tokio::time::sleep_until(when.to_tokio());
541 let (tx, mut rx) = watch::channel(false);
542 cancel_senders.insert(id, tx);
543 timers.push(Box::pin(async move {
544 let rx2 = rx.clone();
545 tokio::select! { biased;
546 _ = rx.wait_for(|x| *x) => PriorityMessage::TimerCancelled(id),
547 _ = sleep => PriorityMessage::Scheduled(msg, id, rx2),
548 }
549 }));
550 StageResponse::Unit
551 }
552 StageEffect::CancelSchedule(id) => {
553 if let Some(tx) = cancel_senders.remove(&id) {
554 tx.send_replace(true);
555 StageResponse::CancelScheduleResponse(true)
556 } else {
557 StageResponse::CancelScheduleResponse(false)
558 }
559 }
560 };
561 tb().push_resume(name, &resp);
562 *effect.lock() = Some(Right(resp));
563 }
564 }
565}
566
567fn now() -> Instant {
568 Instant::from_tokio(tokio::time::Instant::now())
569}
570
571#[derive(Clone)]
573#[must_use = "this handle needs to be either joined or aborted"]
574pub struct TokioRunning {
575 inner: Arc<TokioInner>,
576 termination: watch::Receiver<bool>,
577}
578
579impl TokioRunning {
580 pub fn abort(self) {
582 for handle in self.inner.handles.lock().iter() {
583 handle.abort();
584 }
585 }
586
587 pub async fn join(self) {
588 poll_fn(move |cx| {
589 let mut handles = self.inner.handles.lock();
590 handles.retain_mut(|h| {
591 if let Poll::Ready(res) = h.poll_unpin(cx) {
592 match res {
593 Ok(_) => tracing::info!("stage task completed"),
594 Err(err) if err.is_cancelled() => tracing::info!("stage task cancelled"),
595 Err(err) => tracing::error!("stage task failed: {:?}", err),
596 }
597 false
598 } else {
599 true
600 }
601 });
602 if handles.is_empty() { Poll::Ready(()) } else { Poll::Pending }
603 })
604 .await;
605 }
606
607 pub fn trace_buffer(&self) -> &Arc<Mutex<TraceBuffer>> {
608 &self.inner.trace_buffer
609 }
610
611 pub fn resources(&self) -> &Resources {
612 &self.inner.resources
613 }
614}
615
616impl StageGraphRunning for TokioRunning {
617 fn is_terminated(&self) -> bool {
618 *self.termination.borrow()
619 }
620
621 fn termination(&self) -> BoxFuture<'static, ()> {
622 let mut rx = self.termination.clone();
623 Box::pin(async move {
624 rx.wait_for(|x| *x).await.ok();
625 })
626 }
627}