use std::{
cell::Cell,
fmt::{self, Debug, Formatter},
io::{Error, ErrorKind},
};
use std::sync::{Arc, RwLock};
use crate::{
channels::shared_channel::{self, *},
GlommioError,
Result,
};
#[derive(Debug)]
pub struct Senders<T: Send> {
peer_id: usize,
producer_id: Option<usize>,
senders: Vec<Option<ConnectedSender<T>>>,
}
impl<T: Send> Senders<T> {
pub fn peer_id(&self) -> usize {
self.peer_id
}
pub fn producer_id(&self) -> Option<usize> {
self.producer_id
}
pub fn nr_consumers(&self) -> usize {
self.senders.len()
}
pub async fn send_to(&self, idx: usize, msg: T) -> Result<(), T> {
match self.senders.get(idx) {
Some(Some(consumer)) => consumer.send(msg).await,
_ => {
let msg = if idx < self.nr_consumers() {
"Local message should not be sent via channel mesh".into()
} else {
format!("Shard {} is invalid in the channel mesh", idx)
};
Err(GlommioError::IoError(Error::new(
ErrorKind::InvalidInput,
msg,
)))
}
}
}
pub fn try_send_to(&self, idx: usize, msg: T) -> Result<(), T> {
match self.senders.get(idx) {
Some(Some(consumer)) => consumer.try_send(msg),
_ => Err(GlommioError::IoError(Error::new(
ErrorKind::InvalidInput,
"Local message should not be sent via channel mesh",
))),
}
}
pub fn close(&self) {
for sender in self.senders.iter().flatten() {
sender.close();
}
}
}
#[derive(Debug)]
pub struct Receivers<T: Send> {
peer_id: usize,
consumer_id: Option<usize>,
receivers: Vec<Option<ConnectedReceiver<T>>>,
}
impl<T: Send> Receivers<T> {
pub fn peer_id(&self) -> usize {
self.peer_id
}
pub fn consumer_id(&self) -> Option<usize> {
self.consumer_id
}
pub fn nr_producers(&self) -> usize {
self.receivers.len()
}
pub async fn recv_from(&self, idx: usize) -> Result<Option<T>, ()> {
match self.receivers.get(idx) {
Some(Some(producer)) => Ok(producer.recv().await),
_ => Err(GlommioError::IoError(Error::new(
ErrorKind::InvalidInput,
"Local message should not be received from channel mesh",
))),
}
}
pub fn streams(&mut self) -> Vec<(usize, ConnectedReceiver<T>)> {
self.receivers
.iter_mut()
.enumerate()
.flat_map(|(idx, recv)| recv.take().map(|recv| (idx, recv)))
.collect()
}
}
struct Peer {
executor_id: usize,
notifier: Option<SharedSender<bool>>,
role: Role,
}
impl Peer {
fn new(sender: Option<SharedSender<bool>>, role: Role) -> Self {
Self {
executor_id: crate::executor().id(),
notifier: sender,
role,
}
}
}
type SharedChannel<T> = (
Cell<Option<SharedSender<T>>>,
Cell<Option<SharedReceiver<T>>>,
);
type SharedChannels<T> = Vec<Vec<SharedChannel<T>>>;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Role {
Producer,
Consumer,
Both,
}
impl Role {
fn is_producer(&self) -> bool {
Self::Consumer.ne(self)
}
fn is_consumer(&self) -> bool {
Self::Producer.ne(self)
}
}
pub trait MeshAdapter: Clone {
fn connect(&self, from: &Role, to: &Role) -> bool;
fn is_full(&self) -> bool;
}
#[derive(Clone, Debug)]
pub struct Full;
impl MeshAdapter for Full {
fn connect(&self, _: &Role, _: &Role) -> bool {
true
}
fn is_full(&self) -> bool {
true
}
}
pub type FullMesh<T> = MeshBuilder<T, Full>;
#[derive(Clone, Debug)]
pub struct Partial;
impl MeshAdapter for Partial {
fn connect(&self, from: &Role, to: &Role) -> bool {
matches!((from, to), (Role::Producer, Role::Consumer))
}
fn is_full(&self) -> bool {
false
}
}
pub type PartialMesh<T> = MeshBuilder<T, Partial>;
pub struct MeshBuilder<T: Send, A: MeshAdapter> {
nr_peers: usize,
channel_size: usize,
peers: Arc<RwLock<Vec<Peer>>>,
channels: Arc<SharedChannels<T>>,
adapter: A,
}
unsafe impl<T: Send, A: MeshAdapter> Send for MeshBuilder<T, A> {}
impl<T: Send, A: MeshAdapter> Clone for MeshBuilder<T, A> {
fn clone(&self) -> Self {
Self {
nr_peers: self.nr_peers,
channel_size: self.channel_size,
peers: self.peers.clone(),
channels: self.channels.clone(),
adapter: self.adapter.clone(),
}
}
}
impl<T: Send, A: MeshAdapter> Debug for MeshBuilder<T, A> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.write_fmt(format_args!(
"MeshBuilder {{ nr_peers: {} }}",
self.nr_peers
))
}
}
impl<T: 'static + Send> MeshBuilder<T, Full> {
pub fn full(nr_peers: usize, channel_size: usize) -> Self {
Self::new(nr_peers, channel_size, Full)
}
pub async fn join(self) -> Result<(Senders<T>, Receivers<T>), ()> {
self.join_with(Role::Both).await
}
}
impl<T: 'static + Send> MeshBuilder<T, Partial> {
pub fn partial(nr_peers: usize, channel_size: usize) -> Self {
Self::new(nr_peers, channel_size, Partial)
}
pub async fn join(self, role: Role) -> Result<(Senders<T>, Receivers<T>), ()> {
self.join_with(role).await
}
}
impl<T: 'static + Send, A: MeshAdapter> MeshBuilder<T, A> {
fn new(nr_peers: usize, channel_size: usize, adapter: A) -> Self {
MeshBuilder {
nr_peers,
channel_size,
peers: Arc::new(RwLock::new(Vec::new())),
channels: Arc::new(Self::placeholder(nr_peers)),
adapter,
}
}
pub fn nr_peers(&self) -> usize {
self.nr_peers
}
fn placeholder(nr_peers: usize) -> SharedChannels<T> {
(0..nr_peers)
.map(|_| {
(0..nr_peers)
.map(|_| (Cell::new(None), Cell::new(None)))
.collect()
})
.collect()
}
fn register(&self, role: Role) -> Result<RegisterResult<bool>, ()> {
let mut peers = self.peers.write().unwrap();
if peers.len() == self.nr_peers {
return Err(GlommioError::IoError(Error::new(
ErrorKind::Other,
"The channel mesh is full.",
)));
}
let index = peers
.binary_search_by(|n| n.executor_id.cmp(&crate::executor().id()))
.expect_err("Should not join a mesh more than once.");
if peers.len() == self.nr_peers - 1 {
peers.insert(index, Peer::new(None, role));
for (idx_from, from) in peers.iter().enumerate() {
for (idx_to, to) in peers.iter().enumerate() {
let channel = &self.channels[idx_from][idx_to];
if idx_from != idx_to && self.adapter.connect(&from.role, &to.role) {
let (sender, receiver) = shared_channel::new_bounded(self.channel_size);
channel.0.set(Some(sender));
channel.1.set(Some(receiver));
}
}
}
let peers: Vec<_> = peers
.iter_mut()
.filter_map(|notifier| notifier.notifier.take())
.collect();
Ok(RegisterResult::NotificationSenders(peers))
} else {
let (sender, receiver) = shared_channel::new_bounded(1);
peers.insert(index, Peer::new(Some(sender), role));
Ok(RegisterResult::NotificationReceiver(receiver))
}
}
async fn join_with(self, role: Role) -> Result<(Senders<T>, Receivers<T>), ()> {
match Self::register(&self, role)? {
RegisterResult::NotificationReceiver(receiver) => {
receiver.connect().await.recv().await.unwrap();
}
RegisterResult::NotificationSenders(peers) => {
for peer in peers {
peer.connect().await.send(true).await.unwrap();
}
}
}
let (peer_id, role_id) = {
let peers = self.peers.read().unwrap();
let peer_id = peers
.binary_search_by(|n| n.executor_id.cmp(&crate::executor().id()))
.unwrap();
let role_id = peers
.iter()
.take(peer_id)
.filter(|r| r.role.eq(&role))
.count();
(peer_id, role_id)
};
let producer_id = if role.is_producer() {
Some(role_id)
} else {
None
};
let consumer_id = if role.is_consumer() {
Some(role_id)
} else {
None
};
let mut senders = Vec::with_capacity(self.nr_peers);
let mut receivers = Vec::with_capacity(self.nr_peers);
for i in 0..self.nr_peers {
let sender = self.channels[peer_id][i].0.take();
let receiver = self.channels[i][peer_id].1.take();
let sender = sender.map(|sender| crate::spawn_local(sender.connect()).detach());
let receiver = receiver.map(|receiver| crate::spawn_local(receiver.connect()).detach());
match sender {
None => {
if self.adapter.is_full() {
senders.push(None)
}
}
Some(sender) => senders.push(Some(sender.await.unwrap())),
}
match receiver {
None => {
if self.adapter.is_full() {
receivers.push(None)
}
}
Some(receiver) => receivers.push(Some(receiver.await.unwrap())),
}
}
Ok((
Senders {
peer_id,
producer_id,
senders,
},
Receivers {
peer_id,
consumer_id,
receivers,
},
))
}
}
enum RegisterResult<T: Send> {
NotificationReceiver(SharedReceiver<T>),
NotificationSenders(Vec<SharedSender<T>>),
}
#[cfg(test)]
mod tests {
use futures::future;
use crate::{enclose, prelude::*};
use super::*;
#[test]
fn test_channel_mesh() {
do_test_channel_mesh(0, 0, 0);
do_test_channel_mesh(1, 1, 1);
do_test_channel_mesh(10, 10, 10);
}
#[test]
fn test_peer_id_invariance() {
let nr_peers = 10;
let channel_size = 100;
let builder1 = MeshBuilder::<i32, _>::full(nr_peers, channel_size);
let builder2 = MeshBuilder::<usize, _>::full(nr_peers, channel_size);
let executors = (0..nr_peers).map(|_| {
LocalExecutorBuilder::default().spawn(
enclose!((builder1, builder2) move || async move {
let (sender1, receiver1) = builder1.join().await.unwrap();
assert_eq!(sender1.peer_id(), receiver1.peer_id());
let (sender2, receiver2) = builder2.join().await.unwrap();
assert_eq!(sender2.peer_id(), receiver2.peer_id());
assert_eq!(sender1.peer_id(), sender2.peer_id());
assert_eq!(receiver1.peer_id(), receiver2.peer_id());
}),
)
});
for ex in executors.collect::<Vec<_>>() {
ex.unwrap().join().unwrap();
}
}
#[test]
#[should_panic]
#[allow(unused_must_use)]
fn test_join_more_than_once() {
let mesh_builder = MeshBuilder::<u32, _>::full(2, 1);
LocalExecutor::default().run(enclose!((mesh_builder) async move {
future::join(mesh_builder.clone().join(), mesh_builder.join()).await;
}));
}
#[test]
#[should_panic]
#[allow(unused_must_use)]
fn test_join_more_than_capacity() {
do_test_channel_mesh(1, 1, 2);
}
fn do_test_channel_mesh(nr_peers: usize, channel_size: usize, nr_executors: usize) {
let mesh_builder = MeshBuilder::full(nr_peers, channel_size);
let executors = (0..nr_executors).map(|_| {
LocalExecutorBuilder::default().spawn(enclose!((mesh_builder) move || async move {
let (sender, receiver) = mesh_builder.join().await.unwrap();
assert_eq!(nr_peers, sender.nr_consumers());
assert_eq!(sender.peer_id, sender.producer_id.unwrap());
assert_eq!(nr_peers, receiver.nr_producers());
assert_eq!(receiver.peer_id, receiver.consumer_id.unwrap());
crate::spawn_local(async move {
for peer in 0..sender.nr_consumers() {
if peer != sender.peer_id() {
sender.send_to(peer, (sender.peer_id(), peer)).await.unwrap();
}
}
}).detach();
for peer in 0..receiver.nr_producers() {
if peer != receiver.peer_id() {
assert_eq!((peer, receiver.peer_id()), receiver.recv_from(peer).await.unwrap().unwrap());
}
}
}))
});
for ex in executors.collect::<Vec<_>>() {
ex.unwrap().join().unwrap();
}
}
#[test]
fn test_partial_mesh() {
let nr_producers = 2;
let nr_consumers = 3;
let mesh_builder = MeshBuilder::partial(5, 100);
let producers = (0..nr_producers).map(|i| {
LocalExecutorBuilder::default().spawn(enclose!((mesh_builder) move || async move {
let (sender, receiver) = mesh_builder.join(Role::Producer).await.unwrap();
assert_eq!(nr_consumers, sender.nr_consumers());
assert_eq!(Some(i), sender.producer_id());
assert_eq!(0, receiver.nr_producers());
assert_eq!(None, receiver.consumer_id());
for consumer_id in 0..sender.nr_consumers() {
sender.send_to(consumer_id, (sender.producer_id().unwrap(), consumer_id)).await.unwrap();
}
}))
});
let consumers = (0..nr_consumers).map(|i| {
LocalExecutorBuilder::default().spawn(enclose!((mesh_builder) move || async move {
let (sender, receiver) = mesh_builder.join(Role::Consumer).await.unwrap();
assert_eq!(0, sender.nr_consumers());
assert_eq!(None, sender.producer_id());
assert_eq!(nr_producers, receiver.nr_producers());
assert_eq!(Some(i), receiver.consumer_id());
for producer_id in 0..receiver.nr_producers() {
assert_eq!((producer_id, receiver.consumer_id().unwrap()), receiver.recv_from(producer_id).await.unwrap().unwrap());
}
}))
});
for ex in producers.chain(consumers).collect::<Vec<_>>() {
ex.unwrap().join().unwrap();
}
}
}