1use super::{ActiveList, Cons, Term};
2
3use odem_rs_core::{
4 ExitStatus, Puck,
5 config::Config,
6 continuation::Puck as TaskPuck,
7 fsm::Brand,
8 job::{Job, Settle},
9 simulator::Sim,
10};
11
12use core::{
13 cell::Cell,
14 future::{Future, IntoFuture},
15 pin::Pin,
16 ptr::NonNull,
17 task::{Context, Poll},
18};
19
20#[pin_project::pin_project]
32pub struct AllOf<'s, C: ?Sized + Config, J: Joinable<C>> {
33 sim: &'s Sim<C>,
35 #[pin]
37 fut: J,
38 #[pin]
40 state: JoinState<C>,
41}
42
43impl<'s, C: ?Sized + Config, J: Joinable<C>> AllOf<'s, C, J> {
44 pub(super) const fn new(sim: &'s Sim<C>, fut: J) -> Self {
46 AllOf {
47 sim,
48 fut,
49 state: JoinState::new(),
50 }
51 }
52
53 #[track_caller]
55 pub fn and<F: IntoFuture>(
56 self,
57 fut: F,
58 ) -> AllOf<'s, C, impl Joinable<C, Output = (J::Output, F::Output)> + use<F, C, J>> {
59 AllOf {
60 sim: self.sim,
61 fut: Cons(
62 self.fut,
63 Job::build()
64 .with_actions(fut)
65 .with_finalizer(NotifyQuorum::<C>::new())
66 .finish()
67 .into_inner(),
68 ),
69 state: self.state,
70 }
71 }
72}
73
74impl<C: ?Sized + Config, J: Joinable<C>> Future for AllOf<'_, C, J>
75where
76 J::Output: Flatten,
77{
78 type Output = <J::Output as Flatten>::Flattened;
79
80 fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Self::Output> {
81 let this = self.project();
82
83 let mut continuations = this.fut;
85
86 match this.state.counter.get() {
88 usize::MAX => {
90 let state = this.state.into_ref();
91 state.counter.set(J::LEN);
93 state.parent.set(Some(this.sim.active()));
95
96 unsafe {
102 continuations.as_mut().bind(state);
103 }
104
105 continuations.activate(this.sim);
107 }
108 0 => return Poll::Ready(continuations.collect().flatten()),
110 _ => unreachable!("unexpected spurious reactivation"),
112 }
113
114 Poll::Pending
115 }
116}
117
118pub struct JoinState<C: ?Sized + Config> {
120 counter: Cell<usize>,
123 parent: Cell<Option<TaskPuck<C>>>,
128}
129
130impl<C: ?Sized + Config> JoinState<C> {
131 const fn new() -> Self {
133 Self {
134 counter: Cell::new(usize::MAX),
135 parent: Cell::new(None),
136 }
137 }
138}
139
140pub(super) struct NotifyQuorum<C: ?Sized + Config> {
145 shared: Option<NonNull<JoinState<C>>>,
147}
148
149impl<C: ?Sized + Config> NotifyQuorum<C> {
150 pub(super) const fn new() -> Self {
155 NotifyQuorum { shared: None }
156 }
157
158 unsafe fn bind(&mut self, state: Pin<&JoinState<C>>) {
164 self.shared = Some(NonNull::from(&*state));
165 }
166}
167
168impl<C: ?Sized + Config, R> Settle<R> for NotifyQuorum<C> {
169 fn settle(self, _: &mut R) -> ExitStatus {
170 let state = unsafe { self.shared.unwrap().as_ref() };
173
174 let n = state.counter.get() - 1;
176
177 if n == 0 {
179 if let Some(mut puck) = state.parent.take() {
180 puck.wake().ok();
181 }
182 }
183
184 state.counter.set(n);
186
187 Ok(odem_rs_core::Success)
189 }
190}
191
192pub trait Joinable<C: ?Sized + Config>: ActiveList<C> {
196 type Output;
198
199 unsafe fn bind(self: Pin<&mut Self>, state: Pin<&JoinState<C>>);
206
207 fn collect(self: Pin<&mut Self>) -> Self::Output;
211}
212
213impl<C, F> Joinable<C> for Term<Job<C, F, NotifyQuorum<C>>>
214where
215 C: ?Sized + Config,
216 F: Future,
217{
218 type Output = (F::Output,);
219
220 unsafe fn bind(self: Pin<&mut Self>, state: Pin<&JoinState<C>>) {
221 let this = self.project();
222 this.0.brand(|job, once| {
223 let born = job.token(once).into_born().unwrap();
224 unsafe {
225 job.finalizer(&born).bind(state);
226 }
227 });
228 }
229
230 fn collect(self: Pin<&mut Self>) -> Self::Output {
231 let this = self.project();
232 (this.0.result().expect("activation without result"),)
233 }
234}
235
236impl<C, L, F> Joinable<C> for Cons<L, Job<C, F, NotifyQuorum<C>>>
237where
238 C: ?Sized + Config,
239 L: Joinable<C>,
240 F: Future,
241{
242 type Output = (L::Output, F::Output);
243
244 unsafe fn bind(self: Pin<&mut Self>, state: Pin<&JoinState<C>>) {
245 let this = self.project();
246
247 unsafe {
248 this.0.bind(state);
249 }
250 this.1.brand(|job, once| {
251 let born = job.token(once).into_born().unwrap();
252 unsafe {
253 job.finalizer(&born).bind(state);
254 }
255 });
256 }
257
258 fn collect(self: Pin<&mut Self>) -> Self::Output {
259 let this = self.project();
260
261 (
262 this.0.collect(),
263 this.1.result().expect("activation without result"),
264 )
265 }
266}
267pub trait Flatten {
275 type Flattened;
277
278 fn flatten(self) -> Self::Flattened;
280}
281
282macro_rules! impl_flatten {
284 (@name => $T:tt) => { $T };
285
286 (@name $HID:ident $($TID:ident)* =>) => {
287 impl_flatten!(@name $($TID)* => ($HID,))
288 };
289
290 (@name $HID:ident $($TID:ident)* => $T:tt) => {
291 impl_flatten!(@name $($TID)* => ($T, $HID))
292 };
293
294 (@leaf $HID:ident) => {
295 impl<$HID> Flatten for ($HID,) {
296 type Flattened = $HID;
297
298 fn flatten(self) -> Self::Flattened { self.0 }
299 }
300 };
301
302 (@leaf $($FID:ident)+) => {
303 impl<$($FID,)*> Flatten for impl_flatten!(@name $($FID)* =>) {
304 type Flattened = ($($FID,)*);
305
306 fn flatten(self) -> Self::Flattened {
307 #[allow(non_snake_case)]
308 let impl_flatten!(@name $($FID)* =>) = self;
309 ($($FID,)*)
310 }
311 }
312 };
313
314 (@node $($FID:ident)* .) => {
315 impl<H: Flatten, $($FID),*> Flatten for impl_flatten!(@name $($FID)* => H) {
316 type Flattened = (H::Flattened, $($FID),*);
317
318 fn flatten(self) -> Self::Flattened {
319 #[allow(non_snake_case)]
320 let impl_flatten!(@name $($FID)* => H) = self;
321 (H.flatten(), $($FID,)*)
322 }
323 }
324 };
325
326 (@node $($HID:ident)* . $MID:ident $($TID:ident)*) => {
327 impl_flatten!(@leaf $($HID)* $MID);
328 impl_flatten!(@node $($HID)* $MID . $($TID)*);
329 };
330
331 ($($FID:ident),* $(,)?) => {
332 impl_flatten!(@node . $($FID)*);
333 };
334}
335
336impl_flatten!(T1, T2, T3, T4, T5, T6, T7, T8, T9, TA, TB);