use std::{
mem,
pin::Pin,
task::{Context, Poll},
};
use crate::state::State;
use channels::{
io::{AsyncRead, AsyncWrite, IntoRead, IntoWrite},
serdes::Cbor,
};
use crossbeam_queue::{ArrayQueue, SegQueue};
use futures::future::{self, join_all, pending};
use tokio::{
select,
sync::broadcast,
task::{self, JoinHandle},
};
use tokio_util::sync::CancellationToken;
use triomphe::Arc;
use crate::{
error::{BusError, BusSendError, CallSubscribeError},
event::*,
util::CenterErrorIter,
};
pub struct CenterTicker<R> {
pub rx: Vec<channels::Receiver<EventData, R, Cbor>>,
pub tx: Vec<tokio::sync::mpsc::UnboundedSender<EventData>>,
}
impl<R> CenterTicker<R>
where
R: AsyncRead + Unpin,
{
pub fn new(rx: Vec<channels::Receiver<EventData, R, Cbor>>) -> Self {
Self { rx, tx: Vec::new() }
}
pub async fn tick(
&mut self,
) -> CenterErrorIter<impl Iterator<Item = tokio::sync::mpsc::error::SendError<EventData>>, R>
{
if self.rx.is_empty() {
pending::<()>().await;
}
let iter = self.rx.iter_mut().map(|a| Box::pin(a.recv()));
let result = future::select_ok(iter).await;
match result {
Ok((e, v)) => {
let results = task::unconstrained(async {
drop(v);
let results = self.tx.iter_mut().map(move |a| a.send(e.clone()));
results.into_iter().filter_map(Result::err)
})
.await;
CenterErrorIter::Left(results)
}
Err(e) => CenterErrorIter::Right(Some(e.into())),
}
}
pub fn new_receiver(&mut self) -> tokio::sync::mpsc::UnboundedReceiver<EventData> {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
self.tx.push(tx);
rx
}
}
pub struct SubscribeTicker<T: 'static> {
pub subs: Subscribers<T>,
pub rx: tokio::sync::mpsc::UnboundedReceiver<EventData>,
pub err_queue: Arc<SegQueue<CallSubscribeError>>,
}
impl<T> SubscribeTicker<T>
where
T: Clone + Send + Sync + 'static,
{
pub async fn tick(
&mut self,
state: &State<T>,
) -> impl Iterator<Item = CallSubscribeError> + Send + 'static {
let event = self.rx.recv().await;
if let Some(event) = event {
let results = self.subs.emit(state, &event).await;
Some(results)
} else {
None
}
.into_iter()
.flatten()
}
pub async fn try_tick(
&mut self,
state: &State<T>,
) -> impl Iterator<Item = CallSubscribeError> + Send + 'static {
let event = self.rx.try_recv();
if let Ok(event) = event {
let subs = self.subs.clone();
let state = state.clone();
let err_queue = self.err_queue.clone();
tokio::spawn(UnsafeSendFuture(async move {
let results = subs.emit(&state, &event).await;
for result in results {
err_queue.push(result);
}
}));
}
let queue = ArrayQueue::new(self.err_queue.len());
while let Some(err) = self.err_queue.pop() {
let e = queue.push(err);
if queue.is_full() {
break;
}
if let Err(e) = e {
self.err_queue.push(e);
break;
}
}
queue.into_iter()
}
}
pub struct ShooterTicker {
pub rx: tokio::sync::mpsc::UnboundedReceiver<EventData>,
}
impl ShooterTicker {
pub async fn tick<T>(&mut self, state: &State<T>) {
let event = self.rx.recv().await;
if let Some(event) = event {
task::unconstrained(async {
let mut beginning = state.event_shooters.lock().await;
let mut then = Vec::with_capacity(beginning.len());
mem::swap(&mut *beginning, &mut then);
for shooter in then.into_iter() {
if let Some(shooter) = shooter.try_dispatch(&event) {
beginning.push(shooter);
}
}
})
.await;
}
}
}
pub struct EffectTicker<W> {
pub tx: Vec<channels::Sender<EventData, W, Cbor>>,
pub state_rx: tokio::sync::mpsc::UnboundedReceiver<EventData>,
}
impl<W> EffectTicker<W>
where
W: AsyncWrite + Unpin,
{
pub async fn tick(&mut self) -> impl Iterator<Item = BusSendError<W::Error>> {
let event = self.state_rx.recv().await;
if let Some(event) = event {
let results = task::unconstrained(async {
let results = join_all(self.tx.iter_mut().map(|t| t.send(event.clone())));
let results = results.await;
let results = results.into_iter().filter_map(Result::err).map(Into::into);
results
})
.await;
Some(results)
} else {
None
}
.into_iter()
.flatten()
}
}
#[derive(Clone)]
pub struct EffectWright {
pub state_tx: tokio::sync::mpsc::UnboundedSender<EventData>,
}
impl EffectWright {
pub fn emit<E>(&self, event: &E) -> Result<(), CallSubscribeError>
where
E: Event,
{
let event = event.upcast()?;
self.state_tx.send(event)?;
Ok(())
}
}
pub struct Bus<T, W, R>
where
T: 'static + Send + Sync,
W: AsyncWrite + Unpin,
R: AsyncRead + Unpin,
{
pub center_ticker: CenterTicker<R>,
pub subscribe_ticker: SubscribeTicker<T>,
pub effect_ticker: EffectTicker<W>,
pub shooter_ticker: ShooterTicker,
}
impl<T, W, R> Bus<T, W, R>
where
T: Clone + Send + Sync + 'static,
W: AsyncWrite + Unpin + 'static,
R: AsyncRead + Unpin + 'static,
{
pub async fn run<F>(
self,
state: State<T>,
handle_error: &'static F,
) -> CloseHandle<impl Future<Output = ()>>
where
F: Fn(BusError<W::Error, R::Error>) + Send + Sync + 'static,
{
let token = CancellationToken::new();
let (close_signal, mut close_signal_receiver) = broadcast::channel::<()>(1);
let Bus {
mut center_ticker,
mut subscribe_ticker,
mut effect_ticker,
mut shooter_ticker,
..
} = self;
let state_clone = state.clone();
let token_clone = token.clone();
let handle_subscribe_ticker = tokio::spawn(UnsafeSendFuture(async move {
loop {
if token_clone.is_cancelled() {
break;
}
let error = subscribe_ticker.tick(&state_clone).await;
error.map(|e| e.into()).for_each(handle_error);
}
}));
let state_clone = state.clone();
let token_clone = token.clone();
let handle_shooter_ticker = tokio::spawn(async move {
loop {
if token_clone.is_cancelled() {
break;
}
shooter_ticker.tick(&state_clone).await;
}
});
let future = async move {
loop {
select! {
errors = effect_ticker.tick() => {
errors.map(|e| e.into()).for_each(handle_error);
}
errors = center_ticker.tick() => {
errors.map(|e| e.into()).for_each(handle_error);
}
_ = close_signal_receiver.recv() => {
token.cancel();
handle_subscribe_ticker.abort();
handle_shooter_ticker.abort();
break;
}
}
}
};
CloseHandle {
close_signal: CloseSignal(close_signal),
future,
}
}
}
pub struct CloseSignal(broadcast::Sender<()>);
impl CloseSignal {
pub fn close(self) {
self.0.send(()).unwrap();
}
}
pub struct CloseHandle<F>
where
F: Future<Output = ()>,
{
pub close_signal: CloseSignal,
pub future: F,
}
impl<F> CloseHandle<F>
where
F: Future<Output = ()> + Send + 'static,
{
pub async fn close(self) {
self.close_signal.close();
self.future.await;
}
pub async fn join(self) {
self.future.await;
}
pub fn spawn(self) -> (JoinHandle<()>, CloseSignal) {
(tokio::spawn(self.future), self.close_signal)
}
}
pub struct IoPair<IR, IW> {
pub reader: IR,
pub writer: IW,
}
impl IoPair<tokio::io::Stdin, tokio::io::Stdout> {
pub fn stdio() -> Self {
IoPair {
reader: tokio::io::stdin(),
writer: tokio::io::stdout(),
}
}
}
impl TryFrom<tokio::process::Child>
for IoPair<tokio::process::ChildStdout, tokio::process::ChildStdin>
{
type Error = ();
fn try_from(mut value: tokio::process::Child) -> Result<Self, Self::Error> {
let (child_stdin, child_stdout) = (value.stdin.take(), value.stdout.take());
if let (Some(child_stdin), Some(child_stdout)) = (child_stdin, child_stdout) {
Ok(IoPair {
reader: child_stdout,
writer: child_stdin,
})
} else {
Err(())
}
}
}
impl TryFrom<tokio::process::Command>
for IoPair<tokio::process::ChildStdout, tokio::process::ChildStdin>
{
type Error = std::io::Error;
fn try_from(mut value: tokio::process::Command) -> Result<Self, Self::Error> {
value
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped());
let child = value.spawn()?;
child
.try_into()
.map_err(|_| std::io::Error::last_os_error())
}
}
pub struct BusBuilder<T, W, R>
where
T: 'static + Send + Sync,
W: AsyncWrite + Unpin,
R: AsyncRead + Unpin,
{
subs: Subscribers<T>,
rx: Vec<channels::Receiver<EventData, R, Cbor>>,
tx: Vec<channels::Sender<EventData, W, Cbor>>,
state_rx: tokio::sync::mpsc::UnboundedReceiver<EventData>,
state_tx: tokio::sync::mpsc::UnboundedSender<EventData>,
}
impl<T, W, R> BusBuilder<T, W, R>
where
T: 'static + Send + Sync,
W: AsyncWrite + Unpin,
R: AsyncRead + Unpin,
{
pub fn new(subscribes: Subscribers<T>) -> Self {
let (state_tx, state_rx) = tokio::sync::mpsc::unbounded_channel();
Self {
subs: subscribes,
rx: Vec::new(),
tx: Vec::new(),
state_rx,
state_tx,
}
}
pub fn add_reader<IR>(&mut self, reader: IR) -> &mut Self
where
IR: IntoRead<R>,
{
let rx = channels::Receiver::<EventData, _, _>::builder()
.reader(reader)
.deserializer(Cbor::new())
.build();
self.rx.push(rx);
self
}
pub fn add_sender<IW>(&mut self, writer: IW) -> &mut Self
where
IW: IntoWrite<W>,
{
let rx = channels::Sender::<EventData, _, _>::builder()
.writer(writer)
.serializer(Cbor::new())
.build();
self.tx.push(rx);
self
}
pub fn add_pair<IR, IW>(&mut self, pair: IoPair<IR, IW>) -> &mut Self
where
IR: IntoRead<R>,
IW: IntoWrite<W>,
{
let IoPair { reader, writer } = pair.into();
self.add_reader(reader);
self.add_sender(writer);
self
}
pub fn build(self) -> (Bus<T, W, R>, EffectWright) {
let mut center_ticker = CenterTicker::new(self.rx);
let rx1 = center_ticker.new_receiver();
let rx2 = center_ticker.new_receiver();
(
Bus {
center_ticker,
shooter_ticker: ShooterTicker { rx: rx2 },
subscribe_ticker: SubscribeTicker {
subs: self.subs,
rx: rx1,
err_queue: Arc::new(SegQueue::new()),
},
effect_ticker: EffectTicker {
tx: self.tx,
state_rx: self.state_rx,
},
},
EffectWright {
state_tx: self.state_tx,
},
)
}
}
struct UnsafeSendFuture<F: Future>(F);
impl<F: Future> Future for UnsafeSendFuture<F> {
type Output = F::Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let inner_future = unsafe { self.map_unchecked_mut(|s| &mut s.0) };
inner_future.poll(cx)
}
}
unsafe impl<F: Future> Send for UnsafeSendFuture<F> {}