Skip to main content

antenna_client_shared/
dispatcher.rs

1use antenna_protocol::{PeerID, UserMsgPayload};
2use anyhow::Result;
3use std::collections::{HashMap, HashSet};
4
5#[derive(Clone)]
6pub enum EventType<Msg: UserMsgPayload> {
7    Connected,
8    UserMessage(PeerID, Msg),
9    Disconnected,
10    PeerConnected(PeerID),
11    PeerDisconnected(PeerID),
12    PeerDropped(PeerID),
13    Available,
14    Unavailable,
15}
16
17pub struct NoArgCallback(Box<dyn Fn() -> Result<()> + Send + Sync>);
18
19impl NoArgCallback {
20    pub fn from_fn<F>(f: F) -> Self
21    where
22        F: Fn() -> Result<()> + Send + Sync + 'static,
23    {
24        Self(Box::new(f))
25    }
26
27    pub fn call(&self) -> Result<()> {
28        (self.0)()
29    }
30}
31
32impl From<fn()> for NoArgCallback {
33    fn from(f: fn()) -> Self {
34        Self(Box::new(move || {
35            f();
36            Ok(())
37        }))
38    }
39}
40
41type PeerCallbackFn = dyn Fn(&PeerID) -> Result<()> + Send + Sync;
42pub struct PeerCallback(Box<PeerCallbackFn>);
43
44impl PeerCallback {
45    pub fn from_fn<F>(f: F) -> Self
46    where
47        F: Fn(&PeerID) -> Result<()> + Send + Sync + 'static,
48    {
49        Self(Box::new(f))
50    }
51
52    pub fn call(&self, peer: &PeerID) -> Result<()> {
53        (self.0)(peer)
54    }
55}
56
57impl From<fn(PeerID)> for PeerCallback {
58    fn from(f: fn(PeerID)) -> Self {
59        Self(Box::new(move |peer| {
60            f(peer.clone());
61            Ok(())
62        }))
63    }
64}
65
66type MessageCallbackFn<Msg> = dyn Fn(&PeerID, &Msg) -> Result<()> + Send + Sync;
67pub struct MessageCallback<Msg: UserMsgPayload>(Box<MessageCallbackFn<Msg>>);
68
69impl<Msg: UserMsgPayload> MessageCallback<Msg> {
70    pub fn from_fn<F>(f: F) -> Self
71    where
72        F: Fn(&PeerID, &Msg) -> Result<()> + Send + Sync + 'static,
73    {
74        Self(Box::new(f))
75    }
76
77    pub fn call(&self, peer: &PeerID, data: &Msg) -> Result<()> {
78        (self.0)(peer, data)
79    }
80}
81
82impl<Msg: UserMsgPayload + 'static> From<fn(PeerID, Msg)> for MessageCallback<Msg> {
83    fn from(f: fn(PeerID, Msg)) -> Self {
84        Self(Box::new(move |peer, data| {
85            f(peer.clone(), data.clone());
86            Ok(())
87        }))
88    }
89}
90
91pub enum Event<Msg: UserMsgPayload> {
92    Connected(NoArgCallback),
93    UserMessage(MessageCallback<Msg>),
94    Disconnected(NoArgCallback),
95    PeerConnected(PeerCallback),
96    PeerDisconnected(PeerCallback),
97    PeerLost(PeerCallback),
98    Available(NoArgCallback),
99    Unavailable(NoArgCallback),
100}
101
102#[derive(Clone, Copy, PartialEq, Eq, Hash)]
103enum SubscriptionKind {
104    Connected,
105    UserMessage,
106    Disconnected,
107    PeerConnected,
108    PeerDisconnected,
109    PeerLost,
110    Available,
111    Unavailable,
112}
113
114impl<Msg: UserMsgPayload> Event<Msg> {
115    fn kind(&self) -> SubscriptionKind {
116        match self {
117            Self::Connected(_) => SubscriptionKind::Connected,
118            Self::UserMessage(_) => SubscriptionKind::UserMessage,
119            Self::Disconnected(_) => SubscriptionKind::Disconnected,
120            Self::PeerConnected(_) => SubscriptionKind::PeerConnected,
121            Self::PeerDisconnected(_) => SubscriptionKind::PeerDisconnected,
122            Self::PeerLost(_) => SubscriptionKind::PeerLost,
123            Self::Available(_) => SubscriptionKind::Available,
124            Self::Unavailable(_) => SubscriptionKind::Unavailable,
125        }
126    }
127}
128
129fn event_kind<Msg: UserMsgPayload>(event: &EventType<Msg>) -> SubscriptionKind {
130    match event {
131        EventType::Connected => SubscriptionKind::Connected,
132        EventType::UserMessage(_, _) => SubscriptionKind::UserMessage,
133        EventType::Disconnected => SubscriptionKind::Disconnected,
134        EventType::PeerConnected(_) => SubscriptionKind::PeerConnected,
135        EventType::PeerDisconnected(_) => SubscriptionKind::PeerDisconnected,
136        EventType::PeerDropped(_) => SubscriptionKind::PeerLost,
137        EventType::Available => SubscriptionKind::Available,
138        EventType::Unavailable => SubscriptionKind::Unavailable,
139    }
140}
141
142pub struct RtcCallbacks<Msg: UserMsgPayload> {
143    next_callback_id: u64,
144    subscriptions: HashMap<u64, Event<Msg>>,
145    subscriptions_by_kind: HashMap<SubscriptionKind, HashSet<u64>>,
146}
147
148impl<Msg: UserMsgPayload> Default for RtcCallbacks<Msg> {
149    fn default() -> Self {
150        Self::new()
151    }
152}
153
154impl<Msg: UserMsgPayload> RtcCallbacks<Msg> {
155    pub fn new() -> Self {
156        Self {
157            next_callback_id: 1,
158            subscriptions: HashMap::new(),
159            subscriptions_by_kind: HashMap::new(),
160        }
161    }
162
163    fn next_id(&mut self) -> u64 {
164        let id = self.next_callback_id;
165        self.next_callback_id += 1;
166        id
167    }
168
169    pub fn subscribe(&mut self, subscription: Event<Msg>) -> u64 {
170        let id = self.next_id();
171        let kind = subscription.kind();
172
173        self.subscriptions.insert(id, subscription);
174        self.subscriptions_by_kind
175            .entry(kind)
176            .or_default()
177            .insert(id);
178
179        id
180    }
181
182    pub fn unsubscribe(&mut self, id: u64) -> bool {
183        let Some(subscription) = self.subscriptions.remove(&id) else {
184            return false;
185        };
186
187        let kind = subscription.kind();
188        if let Some(ids) = self.subscriptions_by_kind.get_mut(&kind) {
189            ids.remove(&id);
190            if ids.is_empty() {
191                self.subscriptions_by_kind.remove(&kind);
192            }
193        }
194        true
195    }
196}
197
198impl<Msg: UserMsgPayload> RtcCallbacks<Msg> {
199    pub fn emit(&self, event: EventType<Msg>) -> Result<()> {
200        let kind = event_kind(&event);
201        let Some(ids) = self.subscriptions_by_kind.get(&kind) else {
202            return Ok(());
203        };
204
205        let ids: Vec<u64> = ids.iter().copied().collect();
206
207        for id in ids {
208            let Some(subscription) = self.subscriptions.get(&id) else {
209                continue;
210            };
211            match (subscription, &event) {
212                (Event::Connected(cb), EventType::Connected) => cb.call()?,
213                (Event::UserMessage(cb), EventType::UserMessage(peer, data)) => {
214                    cb.call(peer, data)?
215                }
216                (Event::Disconnected(cb), EventType::Disconnected) => cb.call()?,
217                (Event::PeerConnected(cb), EventType::PeerConnected(peer)) => cb.call(peer)?,
218                (Event::PeerDisconnected(cb), EventType::PeerDisconnected(peer)) => {
219                    cb.call(peer)?
220                }
221                (Event::PeerLost(cb), EventType::PeerDropped(peer)) => cb.call(peer)?,
222                (Event::Available(cb), EventType::Available) => cb.call()?,
223                (Event::Unavailable(cb), EventType::Unavailable) => cb.call()?,
224                _ => {}
225            }
226        }
227
228        Ok(())
229    }
230}