use core::marker::PhantomData;
use nodo::{
channels::{Rx, RxBundle, SyncResult},
codelet::Context,
core::{Topic, WithTopic},
prelude::*,
};
pub struct TopicJoin<T> {
marker: PhantomData<T>,
}
#[derive(Config, Default)]
pub struct TopicJoinConfig;
impl<T> Default for TopicJoin<T> {
fn default() -> Self {
Self {
marker: PhantomData::default(),
}
}
}
impl<T> Codelet for TopicJoin<T>
where
T: Clone + Send + Sync,
{
type Status = DefaultStatus;
type Config = TopicJoinConfig;
type Rx = TopicJoinRx<Message<T>>;
type Tx = DoubleBufferTx<Message<WithTopic<T>>>;
type Signals = ();
fn build_bundles(_: &Self::Config) -> (Self::Rx, Self::Tx) {
(TopicJoinRx::default(), DoubleBufferTx::new_auto_size())
}
fn step(&mut self, _: Context<Self>, rx: &mut Self::Rx, tx: &mut Self::Tx) -> Outcome {
for (topic, channel, _) in rx.channels.iter_mut() {
tx.push_many(channel.drain(..).map(|msg| {
msg.map(|value| WithTopic {
topic: topic.clone(),
value,
})
}))?;
}
SUCCESS
}
}
pub struct TopicJoinRx<T> {
channels: Vec<(Topic, DoubleBufferRx<T>, String)>,
}
impl<T> Default for TopicJoinRx<T> {
fn default() -> Self {
Self {
channels: Vec::new(),
}
}
}
impl<T> TopicJoinRx<T> {
pub fn find_by_topic(&mut self, needle: &Topic) -> Option<&mut DoubleBufferRx<T>> {
self.channels
.iter_mut()
.find(|(key, _, _)| key == needle)
.map(|(_, value, _)| value)
}
pub fn add(&mut self, topic: Topic) -> &mut DoubleBufferRx<T> {
let channel_name: String = (&topic).into();
self.channels
.push((topic, DoubleBufferRx::new_auto_size(), channel_name));
&mut self.channels.last_mut().unwrap().1
}
}
impl<T: Send + Sync> RxBundle for TopicJoinRx<T> {
fn channel_count(&self) -> usize {
self.channels.len()
}
fn name(&self, index: usize) -> &str {
&self.channels[index].2
}
fn inbox_message_count(&self, index: usize) -> usize {
self.channels[index].1.len()
}
fn sync_all(&mut self, results: &mut [SyncResult]) {
for (i, channel) in self.channels.iter_mut().enumerate() {
results[i] = channel.1.sync()
}
}
fn check_connection(&self) -> nodo::channels::ConnectionCheck {
let mut cc = nodo::channels::ConnectionCheck::new(self.channels.len());
for (i, channel) in self.channels.iter().enumerate() {
cc.mark(i, channel.1.is_connected());
}
cc
}
}