1use std::collections::HashMap;
2use std::fmt::Debug;
3use std::net::IpAddr;
4use std::net::Ipv4Addr;
5use std::net::SocketAddr;
6use std::sync::Arc;
7use std::time::Duration;
8
9use serde::de::DeserializeOwned;
10use serde::Deserialize;
11use serde::Serialize;
12use serde_big_array::BigArray;
13use tokio::net::UdpSocket;
14use tokio::sync::broadcast;
15
16mod interval;
17use interval::Interval;
18use tracing::trace;
19
20mod notify;
21pub use notify::Notify;
22
23use crate::Id;
24mod builder;
25use builder::Port;
26
27pub use builder::ChartBuilder;
28
29pub mod get;
30pub mod to_vec;
31
32use self::interval::Until;
33
34#[derive(Serialize, Deserialize, Clone, Copy, Debug)]
35pub struct DiscoveryMsg<const N: usize, T>
36where
37 T: Serialize + DeserializeOwned,
38{
39 header: u64,
40 id: Id,
41 #[serde(with = "BigArray")]
42 msg: [T; N],
43}
44
45#[derive(Debug, Clone)]
50pub struct Entry<Msg: Debug + Clone> {
51 pub ip: IpAddr,
52 pub msg: Msg,
53}
54
55#[derive(Debug, Clone)]
58pub struct Chart<const N: usize, T: Debug + Clone + Serialize> {
59 header: u64,
60 service_id: Id,
61 msg: [T; N],
62 sock: Arc<UdpSocket>,
63 interval: Interval,
64 map: Arc<std::sync::Mutex<HashMap<Id, Entry<[T; N]>>>>,
65 broadcast: broadcast::Sender<(Id, Entry<[T; N]>)>,
66}
67
68impl<const N: usize, T: Serialize + Debug + Clone> Chart<N, T> {
69 fn insert(&self, id: Id, entry: Entry<[T; N]>) -> bool {
70 let old_key = {
71 let mut map = self.map.lock().unwrap();
72 map.insert(id, entry.clone())
73 };
74 if old_key.is_none() {
75 let _ig_err = self.broadcast.send((id, entry));
78 true
79 } else {
80 false
81 }
82 }
83
84 #[tracing::instrument(skip(self, buf))]
85 fn process_buf<'de>(&self, buf: &'de [u8], addr: SocketAddr) -> bool
86 where
87 T: Serialize + DeserializeOwned + Debug,
88 {
89 let DiscoveryMsg::<N, T> { header, id, msg } = bincode::deserialize(buf).unwrap();
90 if header != self.header {
91 return false;
92 }
93 if id == self.service_id {
94 return false;
95 }
96 self.insert(id, Entry { ip: addr.ip(), msg })
97 }
98}
99
100impl<const N: usize> Chart<N, Port> {
102 #[must_use]
103 pub fn our_service_ports(&self) -> &[u16] {
104 &self.msg
105 }
106}
107
108impl Chart<1, Port> {
110 #[must_use]
111 pub fn our_service_port(&self) -> u16 {
112 self.msg[0]
113 }
114}
115
116impl<T: Debug + Clone + Serialize> Chart<1, T> {
118 #[must_use]
119 pub fn our_msg(&self) -> &T {
120 &self.msg[0]
121 }
122}
123
124impl<const N: usize, T: Debug + Clone + Serialize + DeserializeOwned> Chart<N, T> {
125 #[must_use]
169 pub fn notify(&self) -> Notify<N, T> {
170 Notify(self.broadcast.subscribe())
171 }
172
173 #[allow(clippy::missing_panics_doc)] pub fn forget(&self, id: Id) {
180 self.map.lock().unwrap().remove(&id);
181 }
182
183 #[allow(clippy::missing_panics_doc)] #[must_use]
188 pub fn size(&self) -> usize {
189 self.map.lock().unwrap().len() + 1
190 }
191
192 #[must_use]
194 pub fn our_id(&self) -> Id {
195 self.service_id
196 }
197
198 #[allow(clippy::missing_panics_doc)] #[must_use]
201 pub fn discovery_port(&self) -> u16 {
202 self.sock.local_addr().unwrap().port()
203 }
204
205 #[must_use]
206 fn discovery_msg(&self) -> DiscoveryMsg<N, T> {
207 DiscoveryMsg {
208 header: self.header,
209 id: self.service_id,
210 msg: self.msg.clone(),
211 }
212 }
213
214 #[must_use]
215 fn discovery_buf(&self) -> Vec<u8> {
216 let msg = self.discovery_msg();
217 bincode::serialize(&msg).unwrap()
218 }
219
220 #[must_use]
221 fn broadcast_soon(&mut self) -> bool {
222 let next = self.interval.next();
223 next.until() < Duration::from_millis(100)
224 }
225}
226
227#[tracing::instrument]
228pub(crate) async fn handle_incoming<const N: usize, T>(mut chart: Chart<N, T>)
229where
230 T: Debug + Clone + Serialize + DeserializeOwned,
231{
232 loop {
233 let mut buf = [0; 1024];
234 let (_len, addr) = chart.sock.recv_from(&mut buf).await.unwrap();
235 trace!("got msg from: {addr:?}");
236 let was_uncharted = chart.process_buf(&buf, addr);
237 if was_uncharted && !chart.broadcast_soon() {
238 chart
239 .sock
240 .send_to(&chart.discovery_buf(), addr)
241 .await
242 .unwrap();
243 }
244 }
245}
246
247#[tracing::instrument]
248pub(crate) async fn broadcast_periodically<const N: usize, T>(
249 mut chart: Chart<N, T>,
250) where
251 T: Debug + Serialize + DeserializeOwned + Clone,
252{
253 loop {
254 trace!("sending discovery msg");
255 broadcast(&chart.sock, chart.discovery_port(), &chart.discovery_buf()).await;
256 chart.interval.sleep_till_next().await;
257 }
258}
259
260#[tracing::instrument]
261async fn broadcast(sock: &Arc<UdpSocket>, port: u16, msg: &[u8]) {
262 let multiaddr = Ipv4Addr::from([224, 0, 0, 251]);
263 let _len = sock
264 .send_to(msg, (multiaddr, port))
265 .await
266 .unwrap_or_else(|e| panic!("broadcast failed with port: {port}, error: {e:?}"));
267}