mod async_iter;
mod builder;
mod chan;
mod error;
mod handler;
mod message;
mod rand;
mod reorder_queue;
mod task;
use std::{
any::{Any, TypeId},
collections::HashMap,
marker::PhantomData,
sync::{
atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering},
Arc,
},
};
use builder::Config;
use chan::{channel, AbstractSender, BusSender, BusSenderClose, Receiver, Sender};
use dashmap::DashMap;
use futures::Future;
use message::Msg;
use rand::RndGen;
use reorder_queue::{QueueItem, ReorderQueueInner};
use task::{TaskCounter, TaskSpawnerWrapper};
use tokio::sync::{Notify, RwLock};
pub use async_iter::*;
pub use builder::{Builder, DefaultBuilder, SharedBuilder};
pub use error::{Error, VoidError};
pub use handler::{Context, Handler};
pub use message::{async_iter, ErrorMessage, IntoMessages, Message};
pub const DEFAUL_STREAM_ID: u32 = u32::MAX;
pub const DEFAUL_TASK_ID: u32 = 0;
#[derive(Default)]
struct BusInner {
senders: DashMap<(u32, u32, TypeId), Arc<dyn BusSenderClose>>,
spawners: RwLock<HashMap<TypeId, Box<dyn Any + Send + Sync>>>,
reordering: DashMap<(u32, TypeId), Arc<dyn AbstractSender>>,
counters: DashMap<(u32, TypeId), Arc<AtomicU64>>,
stream_id_seq: AtomicU32,
abort_notify: Arc<Notify>,
task_counter: Arc<TaskCounter>,
spawn_counter: Arc<TaskCounter>,
stopping: AtomicBool,
stopping_notify: Arc<Notify>,
rng: RndGen,
}
impl BusInner {
#[inline]
pub fn next_stream_id(&self) -> u32 {
self.stream_id_seq.fetch_add(1, Ordering::Relaxed)
}
fn get_task_id<M: Message>(&self, stream_id: u32, config: &Config) -> u32 {
if !config.queue_per_task || config.task_count == 1 {
return DEFAUL_TASK_ID;
}
let type_id = TypeId::of::<M>();
let (id1, id2) = self.rng.next_u32_pair(config.task_count);
let Some(l1) = self
.senders
.get(&(stream_id, id1, type_id))
.as_deref()
.map(|x| x.load())
else {
return id1;
};
let Some(l2) = self
.senders
.get(&(stream_id, id2, type_id))
.as_deref()
.map(|x| x.load())
else {
return id2;
};
if l1.0 < l2.0 {
id1
} else {
id2
}
}
async fn spawn_task<M: Message>(
self: &Arc<Self>,
spawner: &TaskSpawnerWrapper<M>,
chan: (Sender<Msg<M>>, Receiver<Msg<M>>),
stream_id: u32,
task_id: u32,
) -> Result<(), Error> {
let type_id = TypeId::of::<M>();
let sender = spawner
.spawn_task(
chan,
stream_id,
task_id,
self.abort_notify.clone(),
self.task_counter.clone(),
self.spawn_counter.clone(),
self.counters
.entry((stream_id, type_id))
.or_default()
.value()
.clone(),
self.clone(),
)
.await?;
self.senders
.insert((stream_id, task_id, type_id), Arc::new(sender) as _);
Ok(())
}
pub async fn send<M: Message>(
self: &Arc<Self>,
msg: Option<Result<M, Error>>,
index: u64,
stream_id: u32,
reorder_buff: u32,
) -> Result<(), Error> {
let type_id = TypeId::of::<M>();
if type_id == TypeId::of::<()>() {
return Ok(());
}
let config = if let Some(spawner) = self
.spawners
.read()
.await
.get(&type_id)
.and_then(|x| x.downcast_ref::<TaskSpawnerWrapper<M>>())
{
spawner.config(stream_id)
} else {
Config::default()
};
if reorder_buff > 1 {
self.send_ordered(msg, index, stream_id, config, reorder_buff)
.await
} else {
self.send_inner(msg, index, stream_id, config).await
}
}
pub async fn send_inner<M: Message>(
self: &Arc<Self>,
msg: Option<Result<M, Error>>,
index: u64,
stream_id: u32,
config: Config,
) -> Result<(), Error> {
let type_id = TypeId::of::<M>();
let task_id = self.get_task_id::<M>(stream_id, &config);
if !self.senders.contains_key(&(stream_id, task_id, type_id)) {
let spawner = if let Some(spawner) = self.spawners.read().await.get(&type_id) {
spawner
.downcast_ref::<TaskSpawnerWrapper<M>>()
.unwrap()
.clone()
} else {
return Err(Error::HandlerIsNotRegistered);
};
if config.queue_per_task {
if config.lazy_task_creation {
self.spawn_task(
&spawner,
channel::<Msg<M>>(config.queue_size),
stream_id,
task_id,
)
.await?;
} else {
for tid in 0..config.task_count {
self.spawn_task(
&spawner,
channel::<Msg<M>>(config.queue_size),
stream_id,
tid,
)
.await?;
}
}
} else {
let (tx, rx) = channel::<Msg<M>>(config.queue_size);
for tid in 0..config.task_count - 1 {
self.spawn_task(&spawner, (tx.clone(), rx.clone()), stream_id, tid)
.await?;
}
self.spawn_task(&spawner, (tx, rx), stream_id, config.task_count - 1)
.await?;
};
}
let senders = self
.senders
.get(&(stream_id, task_id, type_id))
.unwrap()
.clone();
senders
.upcast()
.downcast_ref::<BusSender<M>>()
.unwrap()
.send(Msg {
inner: msg,
index,
stream_id,
})
.await
.unwrap();
Ok(())
}
pub async fn send_ordered<M: Message>(
self: &Arc<Self>,
msg: Option<Result<M, Error>>,
index: u64,
stream_id: u32,
config: Config,
reorder_buff: u32,
) -> Result<(), Error> {
let type_id = TypeId::of::<M>();
let sender = self
.reordering
.entry((stream_id, type_id))
.or_insert_with(|| {
let (tx, rx) = kanal::bounded_async::<Msg<M>>(4);
let bus = self.clone();
tokio::spawn(async move {
let mut queue = ReorderQueueInner::new(reorder_buff as _, 0);
while let Ok(msg) = rx.recv().await {
if let Some(inner) = msg.inner {
if let Some(item) = queue.push(msg.index, inner) {
match item {
QueueItem::Item(index, msg) => bus
.send_inner(Some(msg), index, stream_id, config)
.await
.unwrap(),
QueueItem::Drop(index, _, expected) => bus
.send_inner(
Some(Err::<M, _>(Error::ReorderingDropMessage(
index, expected,
))),
index,
stream_id,
config,
)
.await
.unwrap(),
}
}
}
while let Some((index, msg)) = queue.try_pop() {
if let Err(err) =
bus.send_inner(Some(msg), index, stream_id, config).await
{
println!("Err: {}", err);
}
}
}
while let Some((index, msg)) = queue.force_pop() {
if let Err(err) = bus.send_inner(Some(msg), index, stream_id, config).await
{
println!("Err: {}", err);
}
}
});
Arc::new(tx) as _
})
.downgrade()
.clone();
sender
.upcast()
.downcast_ref::<kanal::AsyncSender<Msg<M>>>()
.unwrap()
.send(Msg {
inner: msg,
index,
stream_id,
})
.await
.map_err(Error::SendError)
}
pub async fn register<M: Message, B: Builder<M>>(self: Arc<Self>, builder: B)
where
B::Context: Handler<M>,
{
let type_id = TypeId::of::<M>();
self.spawners.write().await.insert(
type_id,
Box::new(TaskSpawnerWrapper::from_handler(builder)) as _,
);
}
#[inline]
pub async fn close(&self, force: bool) {
if force {
self.abort_notify.notify_waiters();
}
self.stopping.store(true, Ordering::Relaxed);
self.stopping_notify.notify_waiters();
}
#[inline]
pub async fn wait(&self) {
while !self.stopping.load(Ordering::Relaxed) {
self.stopping_notify.notified().await;
}
loop {
self.task_counter.wait().await;
for queue in self.reordering.iter() {
queue.value().close();
}
if self.check_stopped() {
break;
}
}
for sender in self.senders.iter() {
let _ = sender.value().stop().await;
}
self.spawn_counter.wait().await;
}
fn check_stopped(&self) -> bool {
println!("Checking Stopped:");
for sender in self.senders.iter() {
let load = sender.value().load();
println!("{:?}: {}/{}", sender.key(), load.0, load.1);
if sender.value().load().0 > 0 {
return false;
}
}
println!("Checking Reordering Queues:");
true
}
}
#[derive(Default, Clone)]
pub struct Bus {
inner: Arc<BusInner>,
}
impl Bus {
pub fn new() -> Self {
Self {
inner: Arc::new(BusInner::default()),
}
}
#[inline]
pub async fn register<M: Message, B: Builder<M>>(&self, builder: B) -> &Self
where
B::Context: Handler<M>,
{
self.inner.clone().register(builder).await;
self
}
#[inline]
pub async fn register_mapper<
M: Message,
R: Message,
E: ErrorMessage,
C: Send + Clone + Sync + FnMut(u32, u32, M) -> Result<R, E> + 'static,
>(
&self,
cb: C,
) -> &Self {
let mapper = DefaultBuilder::new(0, move |_, _| {
let cb = cb.clone();
async move { Ok(Mapper { cb, m: PhantomData }) }
});
self.inner.clone().register(mapper).await;
self
}
#[inline]
pub async fn send<M: Message>(&self, inner: M) -> Result<(), Error> {
self.send_with_stream(DEFAUL_STREAM_ID, inner).await
}
#[inline]
pub async fn send_with_stream<M: Message>(
&self,
stream_id: u32,
inner: M,
) -> Result<(), Error> {
let type_id = TypeId::of::<M>();
let index = self
.inner
.counters
.entry((stream_id, type_id))
.or_default()
.value()
.fetch_add(1, Ordering::Relaxed);
self.inner.send(Some(Ok(inner)), index, stream_id, 0).await
}
#[inline]
pub async fn shutdown(&self) {
self.inner.close(true).await;
}
#[inline]
pub async fn close(&self) {
self.inner.close(false).await;
}
#[inline]
pub fn wait(&self) -> impl Future<Output = ()> + '_ {
self.inner.wait()
}
}
struct Mapper<M, R, E, C> {
cb: C,
m: PhantomData<(M, R, E)>,
}
impl<M: Message, R: Message, E: ErrorMessage, C> Handler<M> for Mapper<M, R, E, C>
where
M: Message,
R: Message,
E: ErrorMessage,
C: Send + Sync + FnMut(u32, u32, M) -> Result<R, E> + 'static,
{
type Result = R;
type Error = E;
async fn handle(
&mut self,
_msg: M,
ctx: Context,
_bus: Bus,
) -> Result<impl IntoMessages<Self::Result, Self::Error>, Self::Error> {
(self.cb)(ctx.stream_id, ctx.task_id, _msg).map(|x| [x])
}
async fn handle_error(
&mut self,
_err: Error,
_ctx: Context,
_bus: Bus,
) -> Result<impl IntoMessages<Self::Result, Self::Error>, Self::Error> {
Ok(None)
}
async fn finalize(self, _bus: Bus) -> Result<(), Self::Error> {
Ok(())
}
}
#[cfg(test)]
mod tests {
use std::{sync::Arc, time::Duration};
use async_stream::stream;
use rand::RngCore;
use crate::{
handler::Context, stream, Bus, DefaultBuilder, Error, Handler, IntoMessages, Message,
SharedBuilder,
};
impl Message for u64 {}
impl Message for u32 {}
impl Message for i16 {}
impl Message for u16 {}
#[derive(Default)]
struct TestProducer;
impl Handler<u32> for TestProducer {
type Result = u64;
type Error = anyhow::Error;
async fn handle(
&mut self,
_msg: u32,
_ctx: Context,
_bus: Bus,
) -> Result<impl IntoMessages<Self::Result, Self::Error>, Self::Error> {
Ok(stream(stream! {
for i in 0u64..10 {
yield Ok(i)
}
}))
}
async fn handle_error(
&mut self,
_err: Error,
_ctx: Context,
_bus: Bus,
) -> Result<impl IntoMessages<Self::Result, Self::Error>, Self::Error> {
Ok(None)
}
async fn finalize(self, _bus: Bus) -> Result<(), Self::Error> {
println!("producer finalized");
Ok(())
}
}
struct TestConsumer(u16);
impl Default for TestConsumer {
fn default() -> Self {
Self(rand::thread_rng().next_u32() as _)
}
}
impl Handler<u64> for Arc<TestConsumer> {
type Result = ();
type Error = anyhow::Error;
async fn handle(
&mut self,
msg: u64,
ctx: Context,
_bus: Bus,
) -> Result<impl IntoMessages<Self::Result, Self::Error>, Self::Error> {
tokio::time::sleep(Duration::from_millis(1000)).await;
println!(
"[{}] shared consumer handle {}u64 ({}:{})",
self.0, msg, ctx.stream_id, ctx.task_id
);
Ok(())
}
async fn handle_error(
&mut self,
_err: Error,
_ctx: Context,
_bus: Bus,
) -> Result<impl IntoMessages<Self::Result, Self::Error>, Self::Error> {
Ok(None)
}
async fn finalize(self, _bus: Bus) -> Result<(), Self::Error> {
println!("[{}] shared consumer finalized", self.0);
Ok(())
}
}
impl Handler<u64> for TestConsumer {
type Result = ();
type Error = anyhow::Error;
async fn handle(
&mut self,
msg: u64,
ctx: Context,
_bus: Bus,
) -> Result<impl IntoMessages<Self::Result, Self::Error>, Self::Error> {
tokio::time::sleep(Duration::from_millis(100)).await;
println!(
"[{}] consumer handle {}u64 ({}:{})",
self.0, msg, ctx.stream_id, ctx.task_id
);
Ok(())
}
async fn handle_error(
&mut self,
_err: Error,
_ctx: Context,
_bus: Bus,
) -> Result<impl IntoMessages<Self::Result, Self::Error>, Self::Error> {
Ok(None)
}
async fn finalize(self, _bus: Bus) -> Result<(), Self::Error> {
println!("[{}] consumer finalized", self.0);
Ok(())
}
}
struct TestHandler {}
impl Handler<i16> for Arc<TestHandler> {
type Result = u16;
type Error = anyhow::Error;
async fn handle(
&mut self,
msg: i16,
ctx: Context,
_bus: Bus,
) -> Result<impl IntoMessages<Self::Result, Self::Error>, Self::Error> {
if ctx.task_id % 2 == 0 {
tokio::time::sleep(Duration::from_millis(13)).await;
} else {
tokio::time::sleep(Duration::from_millis(22)).await;
}
println!("handle {}", msg);
Ok([msg as u16])
}
async fn handle_error(
&mut self,
_err: Error,
_ctx: Context,
_bus: Bus,
) -> Result<impl IntoMessages<Self::Result, Self::Error>, Self::Error> {
Ok(None)
}
async fn finalize(self, _bus: Bus) -> Result<(), Self::Error> {
Ok(())
}
}
#[derive(Default)]
struct TestCollector {
inner: Vec<u16>,
}
impl Handler<u16> for TestCollector {
type Result = ();
type Error = anyhow::Error;
async fn handle(
&mut self,
msg: u16,
_ctx: Context,
_bus: Bus,
) -> Result<impl IntoMessages<Self::Result, Self::Error>, Self::Error> {
println!("{}", msg);
self.inner.push(msg);
Ok(None)
}
async fn finalize(self, _bus: Bus) -> Result<(), Self::Error> {
println!("Checking");
assert_eq!(self.inner, (0u16..1024).collect::<Vec<_>>());
Ok(())
}
async fn handle_error(
&mut self,
err: Error,
_ctx: Context,
_bus: Bus,
) -> Result<impl IntoMessages<Self::Result, Self::Error>, Self::Error> {
println!("{:?}", err);
Ok(None)
}
}
#[tokio::test]
#[ignore]
async fn test_streams() {
let bus = Bus::default();
bus.register(DefaultBuilder::<u64, _, _, _>::new(2, |_, _| async move {
Ok(TestConsumer::default())
}))
.await;
bus.register(DefaultBuilder::<u32, _, _, _>::new(2, |_, _| async move {
Ok(TestProducer)
}))
.await;
for start in 0u32..10 {
bus.send_with_stream(start, start).await.unwrap();
}
bus.close().await;
bus.wait().await;
}
#[tokio::test]
#[ignore]
async fn test_tasks_shared() {
let bus = Bus::default();
bus.register(SharedBuilder::new(2, 5, |_sid, _tid| async move {
Ok(TestConsumer::default())
}))
.await;
bus.register(DefaultBuilder::<u32, _, _, _>::new(2, |_, _| async move {
Ok(TestProducer)
}))
.await;
for start in 0u32..10 {
bus.send_with_stream(start, start).await.unwrap();
}
bus.close().await;
bus.wait().await;
}
#[tokio::test]
async fn test_reordering() {
let bus = Bus::default();
bus.register(
SharedBuilder::new(4, 128, |sid, tid| async move {
println!("NEW HANDLER {}/{}", sid, tid);
Ok(TestHandler {})
})
.ordered(None),
)
.await;
bus.register(DefaultBuilder::<_, _, _, _>::new(4, |_, _| async move {
Ok(TestCollector::default())
}))
.await;
for i in 0i16..1024 {
bus.send(i).await.unwrap();
}
bus.close().await;
bus.wait().await;
}
}