use core::{
future::Future,
ops::Range,
pin::Pin,
task::{Context, Waker},
};
use alloc::{boxed::Box, collections::VecDeque, vec::Vec};
use codas::types::TryAsFormat;
use crate::{async_support, Error, Flow, FlowSubscriber, Flows};
pub struct Stage<T: Flows> {
subscriber: FlowSubscriber<T>,
#[allow(clippy::type_complexity)]
processors: Vec<Box<dyn FnMut(&mut Proc, &T) + Send + 'static>>,
context: Proc,
max_procs_per_batch: usize,
}
impl<T: Flows> Stage<T> {
pub fn flow(&self) -> Flow<T> {
Flow {
state: self.subscriber.flow_state.clone(),
}
}
pub fn add_proc<D>(&mut self, mut proc: impl Procs<D>)
where
T: TryAsFormat<D>,
{
let proc = move |context: &mut Proc, data: &T| {
if let Ok(data) = data.try_as_format() {
proc.proc(context, data);
}
if context.remaining() == 0 {
proc.end_of_procs();
}
};
self.processors.push(Box::new(proc));
}
pub fn proc(&mut self) -> Result<u64, Error> {
let receivable_seqs = self.subscriber.receivable_seqs();
assert_eq!(receivable_seqs.start, self.context.receivable_seqs.start);
self.context.receivable_seqs = receivable_seqs;
let first_receivable = self.context.receivable_seqs.start;
let last_receivable = first_receivable + self.max_procs_per_batch as u64;
let mut last_received = None;
while let Some(next) = self.context.receivable_seqs.next() {
last_received = Some(next);
let data = unsafe { self.subscriber.flow_state.get(next) };
for proc in &mut self.processors {
(proc)(&mut self.context, data)
}
if next >= last_receivable {
break;
}
}
self.context.poll_tasks();
if let Some(last) = last_received {
self.subscriber.receive_up_to(last);
Ok(self.context.receivable_seqs.start - first_receivable)
} else {
Err(Error::Ahead)
}
}
pub async fn proc_loop(mut self) {
loop {
if self.proc().is_err() {
async_support::yield_now().await;
}
}
}
pub async fn proc_loop_with_waiter<W, Fut>(mut self, waiter: W)
where
W: Fn() -> Fut,
Fut: Future<Output = ()>,
{
loop {
if self.proc().is_err() {
waiter().await;
}
}
}
}
impl<T: Flows> From<FlowSubscriber<T>> for Stage<T> {
fn from(value: FlowSubscriber<T>) -> Self {
let max_procs_per_batch = value.flow_state.buffer.len() / 4;
Self {
subscriber: value,
context: Proc::default(),
processors: Default::default(),
max_procs_per_batch,
}
}
}
pub trait Procs<D>: Send + 'static {
fn proc(&mut self, context: &mut Proc, data: &D);
#[inline(always)]
fn end_of_procs(&mut self) {}
}
impl<T, D> Procs<D> for T
where
T: FnMut(&mut Proc, &D) + Send + 'static,
{
fn proc(&mut self, context: &mut Proc, data: &D) {
(self)(context, data)
}
}
pub struct Proc {
waker: Waker,
pending_tasks: VecDeque<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
receivable_seqs: Range<u64>,
}
impl Proc {
pub fn remaining(&self) -> u64 {
self.receivable_seqs.end - self.receivable_seqs.start
}
pub fn spawn(&mut self, task: impl Future<Output = ()> + Send + 'static) {
let mut context = Context::from_waker(&self.waker);
let mut pinned = Box::pin(task);
if pinned.as_mut().poll(&mut context).is_pending() {
self.pending_tasks.push_back(Box::pin(pinned));
}
}
fn poll_tasks(&mut self) {
if !self.pending_tasks.is_empty() {
let mut context = Context::from_waker(&self.waker);
self.pending_tasks
.retain_mut(|future| future.as_mut().poll(&mut context).is_pending());
}
}
}
impl Default for Proc {
fn default() -> Self {
Self {
waker: async_support::noop_waker(),
pending_tasks: VecDeque::new(),
receivable_seqs: 0..0,
}
}
}
#[cfg(test)]
mod tests {
use core::sync::atomic::Ordering;
use portable_atomic::AtomicU64;
use portable_atomic_util::Arc;
use crate::Flow;
use super::*;
#[test]
fn dynamic_subscribers() {
let (mut flow, [subscriber]) = Flow::<u32>::new(32);
let test_data = 1337;
let invocations = Arc::new(AtomicU64::new(0));
let mut stage = Stage::from(subscriber);
let invocations_a = invocations.clone();
stage.add_proc(move |proc: &mut Proc, data: &u32| {
let data = *data;
let invocations_a = invocations_a.clone();
proc.spawn(async move {
assert_eq!(test_data, data);
invocations_a.add(1, Ordering::SeqCst);
});
assert_eq!(0, proc.remaining());
});
let invocations_b = invocations.clone();
stage.add_proc(move |proc: &mut Proc, data: &u32| {
let data = *data;
let invocations_b = invocations_b.clone();
proc.spawn(async move {
assert_eq!(test_data, data);
invocations_b.add(1, Ordering::SeqCst);
});
assert_eq!(0, proc.remaining());
});
assert_eq!(Err(Error::Ahead), stage.proc());
flow.try_next().unwrap().publish(test_data);
assert_eq!(Ok(1), stage.proc());
assert_eq!(2, invocations.load(Ordering::SeqCst));
}
}