odem_rs_sync/fork/
all_of.rs

1use super::{ActiveList, Cons, Term};
2
3use odem_rs_core::{
4	ExitStatus, Puck,
5	config::Config,
6	continuation::Puck as TaskPuck,
7	fsm::Brand,
8	job::{Job, Settle},
9	simulator::Sim,
10};
11
12use core::{
13	cell::Cell,
14	future::{Future, IntoFuture},
15	pin::Pin,
16	ptr::NonNull,
17	task::{Context, Poll},
18};
19
20/* *************************************************************** AllOf Type */
21
22/// A combinator for awaiting the completion of multiple futures.
23///
24/// This utility allows you to combine multiple [`IntoFuture`] futures into a
25/// single future that resolves only after all input futures have completed.
26///
27/// # Safety Warning
28/// - The ordering of fields in [`AllOf`] must not be changed. Specifically, the
29///   list of [`Joinable`] futures must precede the `JoinState`. Incorrect
30///   ordering can lead to undefined behavior due to dangling pointers.
31#[pin_project::pin_project]
32pub struct AllOf<'s, C: ?Sized + Config, J: Joinable<C>> {
33	/// Reference to the stored simulation context.
34	sim: &'s Sim<C>,
35	/// List of futures to join.
36	#[pin]
37	fut: J,
38	/// State used to manage completed futures and signal completion.
39	#[pin]
40	state: JoinState<C>,
41}
42
43impl<'s, C: ?Sized + Config, J: Joinable<C>> AllOf<'s, C, J> {
44	/// Creates a new future combinator containing a set of [Joinable] futures.
45	pub(super) const fn new(sim: &'s Sim<C>, fut: J) -> Self {
46		AllOf {
47			sim,
48			fut,
49			state: JoinState::new(),
50		}
51	}
52
53	/// Adds another future to await to the list.
54	#[track_caller]
55	pub fn and<F: IntoFuture>(
56		self,
57		fut: F,
58	) -> AllOf<'s, C, impl Joinable<C, Output = (J::Output, F::Output)> + use<F, C, J>> {
59		AllOf {
60			sim: self.sim,
61			fut: Cons(
62				self.fut,
63				Job::build()
64					.with_actions(fut)
65					.with_finalizer(NotifyQuorum::<C>::new())
66					.finish()
67					.into_inner(),
68			),
69			state: self.state,
70		}
71	}
72}
73
74impl<C: ?Sized + Config, J: Joinable<C>> Future for AllOf<'_, C, J>
75where
76	J::Output: Flatten,
77{
78	type Output = <J::Output as Flatten>::Flattened;
79
80	fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Self::Output> {
81		let this = self.project();
82
83		// project the pin onto the continuations
84		let mut continuations = this.fut;
85
86		// match over the state counter
87		match this.state.counter.get() {
88			// has it been initialized already?
89			usize::MAX => {
90				let state = this.state.into_ref();
91				// initialize with the final number of continuations to complete
92				state.counter.set(J::LEN);
93				// initialize the calling continuation pin
94				state.parent.set(Some(this.sim.active()));
95
96				// SAFETY: the `JoinState` outlives the list of jobs activated
97				// here by the ordering in the `AllOf` structure in combination
98				// with drop order guarantees.
99				// Since the type is pinned, we also know that it will be
100				// dropped eventually or the program ends, whichever is first.
101				unsafe {
102					continuations.as_mut().bind(state);
103				}
104
105				// activate all continuations
106				continuations.activate(this.sim);
107			}
108			// have all the sub-continuations completed successfully?
109			0 => return Poll::Ready(continuations.collect().flatten()),
110			// have we been reawakened spuriously?
111			_ => unreachable!("unexpected spurious reactivation"),
112		}
113
114		Poll::Pending
115	}
116}
117
118/// The state shared with the continuations generated by [AllOf].
119pub struct JoinState<C: ?Sized + Config> {
120	/// Number of continuations that have to be completed before the [AllOf]-future
121	/// can return.
122	counter: Cell<usize>,
123	/// Puck of the calling [AllOf]-future. Is reactivated only by the continuation
124	/// reducing the [counter] to 0.
125	///
126	/// [counter]: #structfield.counter
127	parent: Cell<Option<TaskPuck<C>>>,
128}
129
130impl<C: ?Sized + Config> JoinState<C> {
131	/// Creates a new, default `JoinState`.
132	const fn new() -> Self {
133		Self {
134			counter: Cell::new(usize::MAX),
135			parent: Cell::new(None),
136		}
137	}
138}
139
140/* ************************************************************* NotifyQuorum */
141
142/// Future adapter that reactivates a given continuation when a quorum in regard to
143/// the number of completed sub-continuations is reached.
144pub(super) struct NotifyQuorum<C: ?Sized + Config> {
145	/// Reference to the shared [`JoinState`] of the caller.
146	shared: Option<NonNull<JoinState<C>>>,
147}
148
149impl<C: ?Sized + Config> NotifyQuorum<C> {
150	/// Initializes the notifier from a future. Parent continuation and quorum are
151	/// initialized later during binding to ensure the reactivation of the
152	/// correct continuation, even if the job has been moved to a different continuation
153	/// between creation and activation.
154	pub(super) const fn new() -> Self {
155		NotifyQuorum { shared: None }
156	}
157
158	/// Method initializing the parent continuation and the shared quorum countdown.
159	///
160	/// # Safety
161	/// The caller is responsible that the shared [`JoinState`] outlives the
162	/// [`Job`] using this finisher.
163	unsafe fn bind(&mut self, state: Pin<&JoinState<C>>) {
164		self.shared = Some(NonNull::from(&*state));
165	}
166}
167
168impl<C: ?Sized + Config, R> Settle<R> for NotifyQuorum<C> {
169	fn settle(self, _: &mut R) -> ExitStatus {
170		// SAFETY: the bind-method guaranteed that the `JoinState` outlives the
171		// job that just completed.
172		let state = unsafe { self.shared.unwrap().as_ref() };
173
174		// decrease the number of expected terminations
175		let n = state.counter.get() - 1;
176
177		// if it is reduced to 0, activate the parent
178		if n == 0 {
179			if let Some(mut puck) = state.parent.take() {
180				puck.wake().ok();
181			}
182		}
183
184		// update the counter
185		state.counter.set(n);
186
187		// return success
188		Ok(odem_rs_core::Success)
189	}
190}
191
192/* *************************************************************** Join Trait */
193
194/// Represents a heterogeneous list of [`Job`] to be awaited.
195pub trait Joinable<C: ?Sized + Config>: ActiveList<C> {
196	/// The result tuple type once all jobs complete.
197	type Output;
198
199	/// Completes initializing the `NotifyQuorum` finisher with the shared
200	/// `JoinState`.
201	///
202	/// # Safety
203	/// The caller has to guarantee that the shared `JoinState` outlives all
204	/// `Job`s contained in this list.
205	unsafe fn bind(self: Pin<&mut Self>, state: Pin<&JoinState<C>>);
206
207	/// Collects the result list containing the outputs of the futures.
208	///
209	/// This method panics if not all jobs have terminated.
210	fn collect(self: Pin<&mut Self>) -> Self::Output;
211}
212
213impl<C, F> Joinable<C> for Term<Job<C, F, NotifyQuorum<C>>>
214where
215	C: ?Sized + Config,
216	F: Future,
217{
218	type Output = (F::Output,);
219
220	unsafe fn bind(self: Pin<&mut Self>, state: Pin<&JoinState<C>>) {
221		let this = self.project();
222		this.0.brand(|job, once| {
223			let born = job.token(once).into_born().unwrap();
224			unsafe {
225				job.finalizer(&born).bind(state);
226			}
227		});
228	}
229
230	fn collect(self: Pin<&mut Self>) -> Self::Output {
231		let this = self.project();
232		(this.0.result().expect("activation without result"),)
233	}
234}
235
236impl<C, L, F> Joinable<C> for Cons<L, Job<C, F, NotifyQuorum<C>>>
237where
238	C: ?Sized + Config,
239	L: Joinable<C>,
240	F: Future,
241{
242	type Output = (L::Output, F::Output);
243
244	unsafe fn bind(self: Pin<&mut Self>, state: Pin<&JoinState<C>>) {
245		let this = self.project();
246
247		unsafe {
248			this.0.bind(state);
249		}
250		this.1.brand(|job, once| {
251			let born = job.token(once).into_born().unwrap();
252			unsafe {
253				job.finalizer(&born).bind(state);
254			}
255		});
256	}
257
258	fn collect(self: Pin<&mut Self>) -> Self::Output {
259		let this = self.project();
260
261		(
262			this.0.collect(),
263			this.1.result().expect("activation without result"),
264		)
265	}
266}
267/* ****************************************************************** Flatten */
268
269/// Helper trait that allows flattening of the recursively nested tuples
270/// produced in [Joinable::Output].
271///
272/// The trait is only implemented for the cases required by [`AllOf`]
273/// implemented here, not in general for all kinds of recursive nesting.
274pub trait Flatten {
275	/// Result type after the flattening.
276	type Flattened;
277
278	/// Method to perform the flattening.
279	fn flatten(self) -> Self::Flattened;
280}
281
282/// Helper macro to implement the Flatten-trait generically.
283macro_rules! impl_flatten {
284	(@name => $T:tt) => { $T };
285
286	(@name $HID:ident $($TID:ident)* =>) => {
287		impl_flatten!(@name $($TID)* => ($HID,))
288	};
289
290	(@name $HID:ident $($TID:ident)* => $T:tt) => {
291		impl_flatten!(@name $($TID)* => ($T, $HID))
292	};
293
294	(@leaf $HID:ident) => {
295		impl<$HID> Flatten for ($HID,) {
296			type Flattened = $HID;
297
298			fn flatten(self) -> Self::Flattened { self.0 }
299		}
300	};
301
302	(@leaf $($FID:ident)+) => {
303		impl<$($FID,)*> Flatten for impl_flatten!(@name $($FID)* =>) {
304			type Flattened = ($($FID,)*);
305
306			fn flatten(self) -> Self::Flattened {
307				#[allow(non_snake_case)]
308				let impl_flatten!(@name $($FID)* =>) = self;
309				($($FID,)*)
310			}
311		}
312	};
313
314	(@node $($FID:ident)* .) => {
315		impl<H: Flatten, $($FID),*> Flatten for impl_flatten!(@name $($FID)* => H) {
316			type Flattened = (H::Flattened, $($FID),*);
317
318			fn flatten(self) -> Self::Flattened {
319				#[allow(non_snake_case)]
320				let impl_flatten!(@name $($FID)* => H) = self;
321				(H.flatten(), $($FID,)*)
322			}
323		}
324	};
325
326	(@node $($HID:ident)* . $MID:ident $($TID:ident)*) => {
327		impl_flatten!(@leaf $($HID)* $MID);
328		impl_flatten!(@node $($HID)* $MID . $($TID)*);
329	};
330
331	($($FID:ident),* $(,)?) => {
332		impl_flatten!(@node . $($FID)*);
333	};
334}
335
336// specialize tuple-flattening for up to 12-tuples before recursively returning
337// the remaining tuples partially flattened; this should be enough for most
338// cases and also allows debug printing the tuple's contents, which is only
339// implemented up to tuple-size 12
340impl_flatten!(T1, T2, T3, T4, T5, T6, T7, T8, T9, TA, TB);