instance_chart/
chart.rs

1use std::collections::HashMap;
2use std::fmt::Debug;
3use std::net::IpAddr;
4use std::net::Ipv4Addr;
5use std::net::SocketAddr;
6use std::sync::Arc;
7use std::time::Duration;
8
9use serde::de::DeserializeOwned;
10use serde::Deserialize;
11use serde::Serialize;
12use serde_big_array::BigArray;
13use tokio::net::UdpSocket;
14use tokio::sync::broadcast;
15
16mod interval;
17use interval::Interval;
18use tracing::trace;
19
20mod notify;
21pub use notify::Notify;
22
23use crate::Id;
24mod builder;
25use builder::Port;
26
27pub use builder::ChartBuilder;
28
29pub mod get;
30pub mod to_vec;
31
32use self::interval::Until;
33
34#[derive(Serialize, Deserialize, Clone, Copy, Debug)]
35pub struct DiscoveryMsg<const N: usize, T>
36where
37    T: Serialize + DeserializeOwned,
38{
39    header: u64,
40    id: Id,
41    #[serde(with = "BigArray")]
42    msg: [T; N],
43}
44
45/// A chart entry representing a discovered node. The msg is an array of
46/// ports or a custom struct if you used [`custom_msg`](ChartBuilder::custom_msg()).
47///
48/// You probably do not want to use one of the [iterator methods](iter) instead
49#[derive(Debug, Clone)]
50pub struct Entry<Msg: Debug + Clone> {
51    pub ip: IpAddr,
52    pub msg: Msg,
53}
54
55/// The chart keeping track of the discoverd nodes. That a node appears in the
56/// chart is no guarentee that it is reachable at this moment.
57#[derive(Debug, Clone)]
58pub struct Chart<const N: usize, T: Debug + Clone + Serialize> {
59    header: u64,
60    service_id: Id,
61    msg: [T; N],
62    sock: Arc<UdpSocket>,
63    interval: Interval,
64    map: Arc<std::sync::Mutex<HashMap<Id, Entry<[T; N]>>>>,
65    broadcast: broadcast::Sender<(Id, Entry<[T; N]>)>,
66}
67
68impl<const N: usize, T: Serialize + Debug + Clone> Chart<N, T> {
69    fn insert(&self, id: Id, entry: Entry<[T; N]>) -> bool {
70        let old_key = {
71            let mut map = self.map.lock().unwrap();
72            map.insert(id, entry.clone())
73        };
74        if old_key.is_none() {
75            // errors if there are no active recievers which is
76            // the default and not a problem
77            let _ig_err = self.broadcast.send((id, entry));
78            true
79        } else {
80            false
81        }
82    }
83
84    #[tracing::instrument(skip(self, buf))]
85    fn process_buf<'de>(&self, buf: &'de [u8], addr: SocketAddr) -> bool
86    where
87        T: Serialize + DeserializeOwned + Debug,
88    {
89        let DiscoveryMsg::<N, T> { header, id, msg } = bincode::deserialize(buf).unwrap();
90        if header != self.header {
91            return false;
92        }
93        if id == self.service_id {
94            return false;
95        }
96        self.insert(id, Entry { ip: addr.ip(), msg })
97    }
98}
99
100/// The array of ports set for this chart instance, set in `ChartBuilder::with_service_ports`.
101impl<const N: usize> Chart<N, Port> {
102    #[must_use]
103    pub fn our_service_ports(&self) -> &[u16] {
104        &self.msg
105    }
106}
107
108/// The port set for this chart instance, set in `ChartBuilder::with_service_port`.
109impl Chart<1, Port> {
110    #[must_use]
111    pub fn our_service_port(&self) -> u16 {
112        self.msg[0]
113    }
114}
115
116/// The msg struct for this chart instance, set in `ChartBuilder::custom_msg`.
117impl<T: Debug + Clone + Serialize> Chart<1, T> {
118    #[must_use]
119    pub fn our_msg(&self) -> &T {
120        &self.msg[0]
121    }
122}
123
124impl<const N: usize, T: Debug + Clone + Serialize + DeserializeOwned> Chart<N, T> {
125    /// Wait for new discoveries. Use one of the methods on the [`notify object`](notify::Notify)
126    /// to _await_ a new discovery and get the data.
127    /// # Examples
128    /// ```rust
129    /// # use std::error::Error;
130    /// # use instance_chart::{discovery, ChartBuilder};
131    /// #
132    /// # #[tokio::main]
133    /// # async fn main() -> Result<(), Box<dyn Error>> {
134    /// # let full_size = 4u16;
135    /// # let handles: Vec<_> = (1..=full_size)
136    /// #     .into_iter()
137    /// #     .map(|id|
138    /// #         ChartBuilder::new()
139    /// #             .with_id(id.into())
140    /// #             .with_service_port(8042+id)
141    /// #             .with_discovery_port(8080)
142    /// #             .local_discovery(true)
143    /// #             .finish()
144    /// #             .unwrap()
145    /// #     )
146    /// #     .map(discovery::maintain)
147    /// #     .map(tokio::spawn)
148    /// #     .collect();
149    /// #
150    /// let chart = ChartBuilder::new()
151    ///     .with_id(1)
152    ///     .with_service_port(8042)
153    /// #   .with_discovery_port(8080)
154    ///     .local_discovery(true)
155    ///     .finish()?;
156    /// let mut node_discoverd = chart.notify();
157    /// let maintain = discovery::maintain(chart.clone());
158    /// let _ = tokio::spawn(maintain); // maintain task will run forever
159    ///
160    /// while chart.size() < full_size as usize {
161    ///     let new = node_discoverd.recv().await.unwrap();
162    ///     println!("discoverd new node: {:?}", new);
163    /// }
164    ///
165    /// #   Ok(())
166    /// # }
167    /// ```
168    #[must_use]
169    pub fn notify(&self) -> Notify<N, T> {
170        Notify(self.broadcast.subscribe())
171    }
172
173    /// forget a node removing it from the map. If it is discovered again notify 
174    /// subscribers will get a notification (again)
175    ///
176    /// # Note
177    /// This has no effect if the node has not yet been discoverd
178    #[allow(clippy::missing_panics_doc)] // ignore lock poisoning
179    pub fn forget(&self, id: Id) {
180        self.map.lock().unwrap().remove(&id);
181    }
182
183    /// number of instances discoverd including self
184    // lock poisoning happens only on crash in another thread, in which
185    // case panicing here is expected
186    #[allow(clippy::missing_panics_doc)] // ignore lock poisoning
187    #[must_use]
188    pub fn size(&self) -> usize {
189        self.map.lock().unwrap().len() + 1
190    }
191
192    /// The id set for this chart instance
193    #[must_use]
194    pub fn our_id(&self) -> Id {
195        self.service_id
196    }
197
198    /// The port this instance is using for discovery
199    #[allow(clippy::missing_panics_doc)] // socket is set during building
200    #[must_use]
201    pub fn discovery_port(&self) -> u16 {
202        self.sock.local_addr().unwrap().port()
203    }
204
205    #[must_use]
206    fn discovery_msg(&self) -> DiscoveryMsg<N, T> {
207        DiscoveryMsg {
208            header: self.header,
209            id: self.service_id,
210            msg: self.msg.clone(),
211        }
212    }
213
214    #[must_use]
215    fn discovery_buf(&self) -> Vec<u8> {
216        let msg = self.discovery_msg();
217        bincode::serialize(&msg).unwrap()
218    }
219
220    #[must_use]
221    fn broadcast_soon(&mut self) -> bool {
222        let next = self.interval.next();
223        next.until() < Duration::from_millis(100)
224    }
225}
226
227#[tracing::instrument]
228pub(crate) async fn handle_incoming<const N: usize, T>(mut chart: Chart<N, T>)
229where
230    T: Debug + Clone + Serialize + DeserializeOwned,
231{
232    loop {
233        let mut buf = [0; 1024];
234        let (_len, addr) = chart.sock.recv_from(&mut buf).await.unwrap();
235        trace!("got msg from: {addr:?}");
236        let was_uncharted = chart.process_buf(&buf, addr);
237        if was_uncharted && !chart.broadcast_soon() {
238            chart
239                .sock
240                .send_to(&chart.discovery_buf(), addr)
241                .await
242                .unwrap();
243        }
244    }
245}
246
247#[tracing::instrument]
248pub(crate) async fn broadcast_periodically<const N: usize, T>(
249    mut chart: Chart<N, T>,
250) where
251    T: Debug + Serialize + DeserializeOwned + Clone,
252{
253    loop {
254        trace!("sending discovery msg");
255        broadcast(&chart.sock, chart.discovery_port(), &chart.discovery_buf()).await;
256        chart.interval.sleep_till_next().await;
257    }
258}
259
260#[tracing::instrument]
261async fn broadcast(sock: &Arc<UdpSocket>, port: u16, msg: &[u8]) {
262    let multiaddr = Ipv4Addr::from([224, 0, 0, 251]);
263    let _len = sock
264        .send_to(msg, (multiaddr, port))
265        .await
266        .unwrap_or_else(|e| panic!("broadcast failed with port: {port}, error: {e:?}"));
267}