1use antenna_protocol::{PeerID, UserMsgPayload};
2use anyhow::Result;
3use std::collections::{HashMap, HashSet};
4
5#[derive(Clone)]
6pub enum EventType<Msg: UserMsgPayload> {
7 Connected,
8 UserMessage(PeerID, Msg),
9 Disconnected,
10 PeerConnected(PeerID),
11 PeerDisconnected(PeerID),
12 PeerDropped(PeerID),
13 Available,
14 Unavailable,
15}
16
17pub struct NoArgCallback(Box<dyn Fn() -> Result<()> + Send + Sync>);
18
19impl NoArgCallback {
20 pub fn from_fn<F>(f: F) -> Self
21 where
22 F: Fn() -> Result<()> + Send + Sync + 'static,
23 {
24 Self(Box::new(f))
25 }
26
27 pub fn call(&self) -> Result<()> {
28 (self.0)()
29 }
30}
31
32impl From<fn()> for NoArgCallback {
33 fn from(f: fn()) -> Self {
34 Self(Box::new(move || {
35 f();
36 Ok(())
37 }))
38 }
39}
40
41type PeerCallbackFn = dyn Fn(&PeerID) -> Result<()> + Send + Sync;
42pub struct PeerCallback(Box<PeerCallbackFn>);
43
44impl PeerCallback {
45 pub fn from_fn<F>(f: F) -> Self
46 where
47 F: Fn(&PeerID) -> Result<()> + Send + Sync + 'static,
48 {
49 Self(Box::new(f))
50 }
51
52 pub fn call(&self, peer: &PeerID) -> Result<()> {
53 (self.0)(peer)
54 }
55}
56
57impl From<fn(PeerID)> for PeerCallback {
58 fn from(f: fn(PeerID)) -> Self {
59 Self(Box::new(move |peer| {
60 f(peer.clone());
61 Ok(())
62 }))
63 }
64}
65
66type MessageCallbackFn<Msg> = dyn Fn(&PeerID, &Msg) -> Result<()> + Send + Sync;
67pub struct MessageCallback<Msg: UserMsgPayload>(Box<MessageCallbackFn<Msg>>);
68
69impl<Msg: UserMsgPayload> MessageCallback<Msg> {
70 pub fn from_fn<F>(f: F) -> Self
71 where
72 F: Fn(&PeerID, &Msg) -> Result<()> + Send + Sync + 'static,
73 {
74 Self(Box::new(f))
75 }
76
77 pub fn call(&self, peer: &PeerID, data: &Msg) -> Result<()> {
78 (self.0)(peer, data)
79 }
80}
81
82impl<Msg: UserMsgPayload + 'static> From<fn(PeerID, Msg)> for MessageCallback<Msg> {
83 fn from(f: fn(PeerID, Msg)) -> Self {
84 Self(Box::new(move |peer, data| {
85 f(peer.clone(), data.clone());
86 Ok(())
87 }))
88 }
89}
90
91pub enum Event<Msg: UserMsgPayload> {
92 Connected(NoArgCallback),
93 UserMessage(MessageCallback<Msg>),
94 Disconnected(NoArgCallback),
95 PeerConnected(PeerCallback),
96 PeerDisconnected(PeerCallback),
97 PeerLost(PeerCallback),
98 Available(NoArgCallback),
99 Unavailable(NoArgCallback),
100}
101
102#[derive(Clone, Copy, PartialEq, Eq, Hash)]
103enum SubscriptionKind {
104 Connected,
105 UserMessage,
106 Disconnected,
107 PeerConnected,
108 PeerDisconnected,
109 PeerLost,
110 Available,
111 Unavailable,
112}
113
114impl<Msg: UserMsgPayload> Event<Msg> {
115 fn kind(&self) -> SubscriptionKind {
116 match self {
117 Self::Connected(_) => SubscriptionKind::Connected,
118 Self::UserMessage(_) => SubscriptionKind::UserMessage,
119 Self::Disconnected(_) => SubscriptionKind::Disconnected,
120 Self::PeerConnected(_) => SubscriptionKind::PeerConnected,
121 Self::PeerDisconnected(_) => SubscriptionKind::PeerDisconnected,
122 Self::PeerLost(_) => SubscriptionKind::PeerLost,
123 Self::Available(_) => SubscriptionKind::Available,
124 Self::Unavailable(_) => SubscriptionKind::Unavailable,
125 }
126 }
127}
128
129fn event_kind<Msg: UserMsgPayload>(event: &EventType<Msg>) -> SubscriptionKind {
130 match event {
131 EventType::Connected => SubscriptionKind::Connected,
132 EventType::UserMessage(_, _) => SubscriptionKind::UserMessage,
133 EventType::Disconnected => SubscriptionKind::Disconnected,
134 EventType::PeerConnected(_) => SubscriptionKind::PeerConnected,
135 EventType::PeerDisconnected(_) => SubscriptionKind::PeerDisconnected,
136 EventType::PeerDropped(_) => SubscriptionKind::PeerLost,
137 EventType::Available => SubscriptionKind::Available,
138 EventType::Unavailable => SubscriptionKind::Unavailable,
139 }
140}
141
142pub struct RtcCallbacks<Msg: UserMsgPayload> {
143 next_callback_id: u64,
144 subscriptions: HashMap<u64, Event<Msg>>,
145 subscriptions_by_kind: HashMap<SubscriptionKind, HashSet<u64>>,
146}
147
148impl<Msg: UserMsgPayload> Default for RtcCallbacks<Msg> {
149 fn default() -> Self {
150 Self::new()
151 }
152}
153
154impl<Msg: UserMsgPayload> RtcCallbacks<Msg> {
155 pub fn new() -> Self {
156 Self {
157 next_callback_id: 1,
158 subscriptions: HashMap::new(),
159 subscriptions_by_kind: HashMap::new(),
160 }
161 }
162
163 fn next_id(&mut self) -> u64 {
164 let id = self.next_callback_id;
165 self.next_callback_id += 1;
166 id
167 }
168
169 pub fn subscribe(&mut self, subscription: Event<Msg>) -> u64 {
170 let id = self.next_id();
171 let kind = subscription.kind();
172
173 self.subscriptions.insert(id, subscription);
174 self.subscriptions_by_kind
175 .entry(kind)
176 .or_default()
177 .insert(id);
178
179 id
180 }
181
182 pub fn unsubscribe(&mut self, id: u64) -> bool {
183 let Some(subscription) = self.subscriptions.remove(&id) else {
184 return false;
185 };
186
187 let kind = subscription.kind();
188 if let Some(ids) = self.subscriptions_by_kind.get_mut(&kind) {
189 ids.remove(&id);
190 if ids.is_empty() {
191 self.subscriptions_by_kind.remove(&kind);
192 }
193 }
194 true
195 }
196}
197
198impl<Msg: UserMsgPayload> RtcCallbacks<Msg> {
199 pub fn emit(&self, event: EventType<Msg>) -> Result<()> {
200 let kind = event_kind(&event);
201 let Some(ids) = self.subscriptions_by_kind.get(&kind) else {
202 return Ok(());
203 };
204
205 let ids: Vec<u64> = ids.iter().copied().collect();
206
207 for id in ids {
208 let Some(subscription) = self.subscriptions.get(&id) else {
209 continue;
210 };
211 match (subscription, &event) {
212 (Event::Connected(cb), EventType::Connected) => cb.call()?,
213 (Event::UserMessage(cb), EventType::UserMessage(peer, data)) => {
214 cb.call(peer, data)?
215 }
216 (Event::Disconnected(cb), EventType::Disconnected) => cb.call()?,
217 (Event::PeerConnected(cb), EventType::PeerConnected(peer)) => cb.call(peer)?,
218 (Event::PeerDisconnected(cb), EventType::PeerDisconnected(peer)) => {
219 cb.call(peer)?
220 }
221 (Event::PeerLost(cb), EventType::PeerDropped(peer)) => cb.call(peer)?,
222 (Event::Available(cb), EventType::Available) => cb.call()?,
223 (Event::Unavailable(cb), EventType::Unavailable) => cb.call()?,
224 _ => {}
225 }
226 }
227
228 Ok(())
229 }
230}