1use core::marker::PhantomData;
4use nodo::{
5 channels::{Rx, RxBundle, SyncResult},
6 codelet::Context,
7 core::{Topic, WithTopic},
8 prelude::*,
9};
10
11pub struct TopicJoin<T> {
15 marker: PhantomData<T>,
16}
17
18#[derive(Config, Default)]
19pub struct TopicJoinConfig;
20
21impl<T> Default for TopicJoin<T> {
22 fn default() -> Self {
23 Self {
24 marker: PhantomData::default(),
25 }
26 }
27}
28
29impl<T> Codelet for TopicJoin<T>
30where
31 T: Clone + Send + Sync,
32{
33 type Status = DefaultStatus;
34 type Config = TopicJoinConfig;
35 type Rx = TopicJoinRx<Message<T>>;
36 type Tx = DoubleBufferTx<Message<WithTopic<T>>>;
37
38 fn build_bundles(_: &Self::Config) -> (Self::Rx, Self::Tx) {
39 (TopicJoinRx::default(), DoubleBufferTx::new_auto_size())
40 }
41
42 fn step(&mut self, _: &Context<Self>, rx: &mut Self::Rx, tx: &mut Self::Tx) -> Outcome {
43 for (topic, channel) in rx.channels.iter_mut() {
44 tx.push_many(channel.drain(..).map(|msg| {
45 msg.map(|value| WithTopic {
47 topic: topic.clone(),
48 value,
49 })
50 }))?;
51 }
52 SUCCESS
53 }
54}
55
56pub struct TopicJoinRx<T> {
57 channels: Vec<(Topic, DoubleBufferRx<T>)>,
58}
59
60impl<T> Default for TopicJoinRx<T> {
61 fn default() -> Self {
62 Self {
63 channels: Vec::new(),
64 }
65 }
66}
67
68impl<T> TopicJoinRx<T> {
69 pub fn find_by_topic(&mut self, needle: &Topic) -> Option<&mut DoubleBufferRx<T>> {
71 self.channels
72 .iter_mut()
73 .find(|(key, _)| key == needle)
74 .map(|(_, value)| value)
75 }
76
77 pub fn add(&mut self, topic: Topic) -> &mut DoubleBufferRx<T> {
79 self.channels.push((topic, DoubleBufferRx::new_auto_size()));
80 &mut self.channels.last_mut().unwrap().1
81 }
82}
83
84impl<T: Send + Sync> RxBundle for TopicJoinRx<T> {
85 fn len(&self) -> usize {
86 self.channels.len()
87 }
88
89 fn name(&self, index: usize) -> String {
90 if index < self.channels.len() {
91 format!("input_{index}")
92 } else {
93 panic!(
94 "invalid index '{index}': number of channels is {}",
95 self.channels.len()
96 )
97 }
98 }
99
100 fn sync_all(&mut self, results: &mut [SyncResult]) {
101 for (i, channel) in self.channels.iter_mut().enumerate() {
102 results[i] = channel.1.sync()
103 }
104 }
105
106 fn check_connection(&self) -> nodo::channels::ConnectionCheck {
107 let mut cc = nodo::channels::ConnectionCheck::new(self.channels.len());
108 for (i, channel) in self.channels.iter().enumerate() {
109 cc.mark(i, channel.1.is_connected());
110 }
111 cc
112 }
113}