Skip to main content

beetry_plugin/
channel.rs

1use anyhow::{Result, anyhow};
2use beetry_channel::{AnyBoxReceiver, AnyBoxSender, BoxReceiver, BoxSender};
3use beetry_editor_types::{
4    output::channel::{ChannelConfig, ChannelKind, TokioChannelKind},
5    spec::channel::ChannelSpec,
6};
7use bon::Builder;
8
9use crate::{BoxPlugin, ConstructPlugin, Named, PluginConstructor, PluginError, unique_plugins};
10
11#[derive(Builder)]
12pub struct TypeErasedChannel {
13    pub senders: Vec<AnyBoxSender>,
14    pub receivers: Vec<AnyBoxReceiver>,
15}
16
17impl TypeErasedChannel {
18    pub fn try_take_sender(&mut self) -> Result<AnyBoxSender> {
19        self.senders
20            .pop()
21            .ok_or_else(|| anyhow!("no free sender in the channel"))
22    }
23
24    pub fn try_take_receiver(&mut self) -> Result<AnyBoxReceiver> {
25        self.receivers
26            .pop()
27            .ok_or_else(|| anyhow!("no free receiver in the channel"))
28    }
29}
30
31impl Named for ChannelSpec {
32    fn name(&self) -> &str {
33        self.msg_type_name().as_str()
34    }
35}
36
37pub struct Factory {
38    func: Box<dyn Fn(ChannelConfig) -> TypeErasedChannel>,
39}
40
41impl std::fmt::Debug for Factory {
42    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43        f.debug_struct("Factory")
44            .field("func", &"<function>")
45            .finish()
46    }
47}
48
49impl Factory {
50    #[must_use]
51    pub fn from_msg<M: Clone + Default + 'static>() -> Self {
52        Self {
53            func: (Box::new(|config| {
54                let capacity = config.capacity();
55                let count = config.count();
56                // client is responsible for providing correct and valid number of senders and
57                // receivers
58                let n_senders = count.sender();
59                let n_receivers = count.receiver();
60
61                let (senders, receivers) = match config.kind() {
62                    ChannelKind::Tokio(TokioChannelKind::Broadcast) => {
63                        let (sender, receiver) =
64                            beetry_channel::tokio::broadcast::channel::<M>(capacity);
65
66                        let receivers: Vec<_> =
67                            std::iter::once(Box::new(receiver) as BoxReceiver<M>)
68                                .chain(
69                                    std::iter::repeat_with(|| {
70                                        Box::new(sender.subscribe()) as BoxReceiver<M>
71                                    })
72                                    .take(n_receivers - 1),
73                                )
74                                .collect();
75                        let senders: Vec<_> =
76                            std::iter::repeat_with(|| Box::new(sender.clone()) as BoxSender<M>)
77                                .take(n_senders)
78                                .collect();
79
80                        (senders, receivers)
81                    }
82                    ChannelKind::Tokio(TokioChannelKind::Mpsc) => {
83                        let (sender, receiver) =
84                            beetry_channel::tokio::mpsc::channel::<M>(capacity);
85
86                        let senders: Vec<_> =
87                            std::iter::repeat_with(|| Box::new(sender.clone()) as BoxSender<M>)
88                                .take(n_senders)
89                                .collect();
90                        let receivers = vec![Box::new(receiver) as BoxReceiver<M>];
91
92                        (senders, receivers)
93                    }
94                    ChannelKind::Tokio(TokioChannelKind::Watch) => {
95                        let (sender, receiver) = beetry_channel::tokio::watch::channel::<M>();
96
97                        let receivers: Vec<_> =
98                            std::iter::once(Box::new(receiver) as BoxReceiver<M>)
99                                .chain(
100                                    std::iter::repeat_with(|| {
101                                        Box::new(sender.subscribe()) as BoxReceiver<M>
102                                    })
103                                    .take(n_receivers - 1),
104                                )
105                                .collect();
106                        let senders: Vec<_> =
107                            std::iter::repeat_with(|| Box::new(sender.clone()) as BoxSender<M>)
108                                .take(n_senders)
109                                .collect();
110
111                        (senders, receivers)
112                    }
113                };
114
115                TypeErasedChannel::builder()
116                    .senders(senders.into_iter().map(Into::into).collect())
117                    .receivers(receivers.into_iter().map(Into::into).collect())
118                    .build()
119            })),
120        }
121    }
122
123    #[must_use]
124    pub fn create(&self, config: ChannelConfig) -> TypeErasedChannel {
125        (self.func)(config)
126    }
127}
128
129pub type BoxChannelPlugin = BoxPlugin<ChannelSpec, Factory>;
130pub type ChannelPluginConstructor = PluginConstructor<ChannelSpec, Factory>;
131
132impl ChannelPluginConstructor {
133    pub fn plugins() -> Result<Vec<BoxChannelPlugin>, PluginError> {
134        unique_plugins::<Self, <Self as ConstructPlugin>::Spec, <Self as ConstructPlugin>::Factory>(
135        )
136    }
137}
138
139inventory::collect!(ChannelPluginConstructor);