1use core::{
6 future::Future,
7 ops::Range,
8 pin::Pin,
9 task::{Context, Waker},
10};
11
12use alloc::{boxed::Box, collections::VecDeque, vec::Vec};
13use codas::types::TryAsFormat;
14
15use crate::{async_support, Error, Flow, FlowSubscriber, Flows};
16
17pub struct Stage<T: Flows> {
19 subscriber: FlowSubscriber<T>,
20
21 #[allow(clippy::type_complexity)]
23 processors: Vec<Box<dyn FnMut(&mut Proc, &T) + Send + 'static>>,
24
25 context: Proc,
28
29 max_procs_per_batch: usize,
32}
33
34impl<T: Flows> Stage<T> {
35 pub fn flow(&self) -> Flow<T> {
37 Flow {
38 state: self.subscriber.flow_state.clone(),
39 }
40 }
41
42 pub fn add_proc<D>(&mut self, mut proc: impl Procs<D>)
49 where
50 T: TryAsFormat<D>,
51 {
52 let proc = move |context: &mut Proc, data: &T| {
53 if let Ok(data) = data.try_as_format() {
54 proc.proc(context, data);
55 }
56
57 if context.remaining() == 0 {
58 proc.end_of_procs();
59 }
60 };
61
62 self.processors.push(Box::new(proc));
63 }
64
65 pub fn proc(&mut self) -> Result<u64, Error> {
69 let receivable_seqs = self.subscriber.receivable_seqs();
71 assert_eq!(receivable_seqs.start, self.context.receivable_seqs.start);
72 self.context.receivable_seqs = receivable_seqs;
73
74 let first_receivable = self.context.receivable_seqs.start;
76 let last_receivable = first_receivable + self.max_procs_per_batch as u64;
77 let mut last_received = None;
78 while let Some(next) = self.context.receivable_seqs.next() {
79 last_received = Some(next);
80
81 let data = unsafe { self.subscriber.flow_state.get(next) };
83
84 for proc in &mut self.processors {
86 (proc)(&mut self.context, data)
87 }
88
89 if next >= last_receivable {
92 break;
93 }
94 }
95
96 self.context.poll_tasks();
98
99 if let Some(last) = last_received {
101 self.subscriber.receive_up_to(last);
102 Ok(self.context.receivable_seqs.start - first_receivable)
103 } else {
104 Err(Error::Ahead)
105 }
106 }
107
108 pub async fn proc_loop(mut self) {
115 loop {
116 if self.proc().is_err() {
117 async_support::yield_now().await;
118 }
119 }
120 }
121
122 pub async fn proc_loop_with_waiter<W, Fut>(mut self, waiter: W)
132 where
133 W: Fn() -> Fut,
134 Fut: Future<Output = ()>,
135 {
136 loop {
137 if self.proc().is_err() {
138 waiter().await;
139 }
140 }
141 }
142}
143
144impl<T: Flows> From<FlowSubscriber<T>> for Stage<T> {
145 fn from(value: FlowSubscriber<T>) -> Self {
146 let max_procs_per_batch = value.flow_state.buffer.len() / 4;
147
148 Self {
149 subscriber: value,
150 context: Proc {
151 waker: async_support::noop_waker(),
152 pending_tasks: VecDeque::new(),
153 receivable_seqs: 0..0,
154 },
155 processors: alloc::vec![],
156 max_procs_per_batch,
157 }
158 }
159}
160
161pub trait Procs<D>: Send + 'static {
163 fn proc(&mut self, context: &mut Proc, data: &D);
165
166 #[inline(always)]
173 fn end_of_procs(&mut self) {}
174}
175
176impl<T, D> Procs<D> for T
177where
178 T: FnMut(&mut Proc, &D) + Send + 'static,
179{
180 fn proc(&mut self, context: &mut Proc, data: &D) {
181 (self)(context, data)
182 }
183}
184
185pub struct Proc {
187 waker: Waker,
189
190 pending_tasks: VecDeque<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
192
193 receivable_seqs: Range<u64>,
195}
196
197impl Proc {
198 pub fn remaining(&self) -> u64 {
201 self.receivable_seqs.end - self.receivable_seqs.start
202 }
203
204 pub fn spawn(&mut self, task: impl Future<Output = ()> + Send + 'static) {
206 let mut context = Context::from_waker(&self.waker);
207 let mut pinned = Box::pin(task);
208 if pinned.as_mut().poll(&mut context).is_pending() {
209 self.pending_tasks.push_back(Box::pin(pinned));
210 }
211 }
212
213 fn poll_tasks(&mut self) {
215 if !self.pending_tasks.is_empty() {
216 let mut context = Context::from_waker(&self.waker);
217 self.pending_tasks
218 .retain_mut(|future| future.as_mut().poll(&mut context).is_pending());
219 }
220 }
221}
222
223#[cfg(test)]
224mod tests {
225
226 use core::sync::atomic::Ordering;
227
228 use portable_atomic::AtomicU64;
229 use portable_atomic_util::Arc;
230
231 use crate::Flow;
232
233 use super::*;
234
235 #[test]
236 fn dynamic_subscribers() {
237 let (mut flow, [subscriber]) = Flow::<u32>::new(32);
239
240 let test_data = 1337;
242
243 let invocations = Arc::new(AtomicU64::new(0));
245 let mut stage = Stage::from(subscriber);
246 let invocations_a = invocations.clone();
247 stage.add_proc(move |proc: &mut Proc, data: &u32| {
248 let data = *data;
249 let invocations_a = invocations_a.clone();
250 proc.spawn(async move {
251 assert_eq!(test_data, data);
252 invocations_a.add(1, Ordering::SeqCst);
253 });
254 assert_eq!(0, proc.remaining());
255 });
256 let invocations_b = invocations.clone();
257 stage.add_proc(move |proc: &mut Proc, data: &u32| {
258 let data = *data;
259 let invocations_b = invocations_b.clone();
260 proc.spawn(async move {
261 assert_eq!(test_data, data);
262 invocations_b.add(1, Ordering::SeqCst);
263 });
264 assert_eq!(0, proc.remaining());
265 });
266
267 assert_eq!(Err(Error::Ahead), stage.proc());
269 flow.try_next().unwrap().publish(test_data);
270 assert_eq!(Ok(1), stage.proc());
271 assert_eq!(2, invocations.load(Ordering::SeqCst));
272 }
273}