1use core::{
8 future::Future,
9 ops::Range,
10 pin::Pin,
11 task::{Context, Waker},
12};
13
14use alloc::{boxed::Box, collections::VecDeque, vec::Vec};
15use codas::types::TryAsFormat;
16
17use crate::{async_support, Error, Flow, FlowSubscriber, Flows};
18
19pub struct Stage<T: Flows> {
21 subscriber: FlowSubscriber<T>,
22
23 #[allow(clippy::type_complexity)]
25 processors: Vec<Box<dyn FnMut(&mut Proc, &T) + Send + 'static>>,
26
27 context: Proc,
30
31 max_procs_per_batch: usize,
34}
35
36impl<T: Flows> Stage<T> {
37 pub fn flow(&self) -> Flow<T> {
39 Flow {
40 state: self.subscriber.flow_state.clone(),
41 }
42 }
43
44 pub fn add_proc<D>(&mut self, mut proc: impl Procs<D>)
51 where
52 T: TryAsFormat<D>,
53 {
54 let proc = move |context: &mut Proc, data: &T| {
55 if let Ok(data) = data.try_as_format() {
56 proc.proc(context, data);
57 }
58
59 if context.remaining() == 0 {
60 proc.end_of_procs();
61 }
62 };
63
64 self.processors.push(Box::new(proc));
65 }
66
67 pub fn proc(&mut self) -> Result<u64, Error> {
71 let receivable_seqs = self.subscriber.receivable_seqs();
73 assert_eq!(receivable_seqs.start, self.context.receivable_seqs.start);
74 self.context.receivable_seqs = receivable_seqs;
75
76 let first_receivable = self.context.receivable_seqs.start;
78 let last_receivable = first_receivable + self.max_procs_per_batch as u64;
79 let mut last_received = None;
80 while let Some(next) = self.context.receivable_seqs.next() {
81 last_received = Some(next);
82
83 let data = unsafe { self.subscriber.flow_state.get(next) };
85
86 for proc in &mut self.processors {
88 (proc)(&mut self.context, data)
89 }
90
91 if next >= last_receivable {
94 break;
95 }
96 }
97
98 self.context.poll_tasks();
100
101 if let Some(last) = last_received {
103 self.subscriber.receive_up_to(last);
104 Ok(self.context.receivable_seqs.start - first_receivable)
105 } else {
106 Err(Error::Ahead)
107 }
108 }
109
110 pub async fn proc_loop(mut self) {
117 loop {
118 if self.proc().is_err() {
119 async_support::yield_now().await;
120 }
121 }
122 }
123
124 pub async fn proc_loop_with_waiter<W, Fut>(mut self, waiter: W)
134 where
135 W: Fn() -> Fut,
136 Fut: Future<Output = ()>,
137 {
138 loop {
139 if self.proc().is_err() {
140 waiter().await;
141 }
142 }
143 }
144}
145
146impl<T: Flows> From<FlowSubscriber<T>> for Stage<T> {
147 fn from(value: FlowSubscriber<T>) -> Self {
148 let max_procs_per_batch = value.flow_state.buffer.len() / 4;
149
150 Self {
151 subscriber: value,
152 context: Proc::default(),
153 processors: Default::default(),
154 max_procs_per_batch,
155 }
156 }
157}
158
159pub trait Procs<D>: Send + 'static {
161 fn proc(&mut self, context: &mut Proc, data: &D);
163
164 #[inline(always)]
171 fn end_of_procs(&mut self) {}
172}
173
174impl<T, D> Procs<D> for T
175where
176 T: FnMut(&mut Proc, &D) + Send + 'static,
177{
178 fn proc(&mut self, context: &mut Proc, data: &D) {
179 (self)(context, data)
180 }
181}
182
183pub struct Proc {
185 waker: Waker,
187
188 pending_tasks: VecDeque<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
190
191 receivable_seqs: Range<u64>,
193}
194
195impl Proc {
196 pub fn remaining(&self) -> u64 {
199 self.receivable_seqs.end - self.receivable_seqs.start
200 }
201
202 pub fn spawn(&mut self, task: impl Future<Output = ()> + Send + 'static) {
204 let mut context = Context::from_waker(&self.waker);
205 let mut pinned = Box::pin(task);
206 if pinned.as_mut().poll(&mut context).is_pending() {
207 self.pending_tasks.push_back(Box::pin(pinned));
208 }
209 }
210
211 fn poll_tasks(&mut self) {
213 if !self.pending_tasks.is_empty() {
214 let mut context = Context::from_waker(&self.waker);
215 self.pending_tasks
216 .retain_mut(|future| future.as_mut().poll(&mut context).is_pending());
217 }
218 }
219}
220
221impl Default for Proc {
222 fn default() -> Self {
223 Self {
224 waker: async_support::noop_waker(),
225 pending_tasks: VecDeque::new(),
226 receivable_seqs: 0..0,
227 }
228 }
229}
230
231#[cfg(test)]
232mod tests {
233
234 use core::sync::atomic::Ordering;
235
236 use portable_atomic::AtomicU64;
237 use portable_atomic_util::Arc;
238
239 use crate::Flow;
240
241 use super::*;
242
243 #[test]
244 fn dynamic_subscribers() {
245 let (mut flow, [subscriber]) = Flow::<u32>::new(32);
247
248 let test_data = 1337;
250
251 let invocations = Arc::new(AtomicU64::new(0));
253 let mut stage = Stage::from(subscriber);
254 let invocations_a = invocations.clone();
255 stage.add_proc(move |proc: &mut Proc, data: &u32| {
256 let data = *data;
257 let invocations_a = invocations_a.clone();
258 proc.spawn(async move {
259 assert_eq!(test_data, data);
260 invocations_a.add(1, Ordering::SeqCst);
261 });
262 assert_eq!(0, proc.remaining());
263 });
264 let invocations_b = invocations.clone();
265 stage.add_proc(move |proc: &mut Proc, data: &u32| {
266 let data = *data;
267 let invocations_b = invocations_b.clone();
268 proc.spawn(async move {
269 assert_eq!(test_data, data);
270 invocations_b.add(1, Ordering::SeqCst);
271 });
272 assert_eq!(0, proc.remaining());
273 });
274
275 assert_eq!(Err(Error::Ahead), stage.proc());
277 flow.try_next().unwrap().publish(test_data);
278 assert_eq!(Ok(1), stage.proc());
279 assert_eq!(2, invocations.load(Ordering::SeqCst));
280 }
281}