Skip to main content

fs_share_utils/broadcast/
receiver.rs

1//! # UDP Broadcast Receiver
2//!
3//! This module provides a lightweight UDP-based broadcast receiver
4//! used for service discovery in local networks.
5//!
6use std::{
7    collections::HashMap,
8    net::{IpAddr, Ipv4Addr},
9    net::{SocketAddr, UdpSocket},
10    num::NonZero,
11    sync::mpsc::{self, Receiver},
12    thread::{self, JoinHandle},
13    time::Duration,
14};
15
16use anyhow::Context;
17
18/// UDP broadcast receiver.
19///
20/// Listens for UDP packets, filters them using a prefix,
21/// and emits parsed data via a channel.
22pub struct BroadcastReceiver {
23    /// Packet prefix used to identify valid messages
24    prefix: Vec<u8>,
25
26    /// Internal buffer used for receiving packets
27    buffer: Box<[u8]>,
28
29    /// UDP socket bound to a local address
30    socket: UdpSocket,
31}
32
33impl BroadcastReceiver {
34    /// Create a new builder for configuring [`BroadcastReceiver`]
35    pub fn builder() -> BroadcastReceiverBuilder {
36        BroadcastReceiverBuilder::default()
37    }
38}
39
40/// Iterator over structured payload data.
41///
42/// Payload format:
43/// ```text
44/// :<len:u16><bytes>
45/// :<len:u16><bytes>
46/// ...
47/// ```
48///
49/// Each field is prefixed with `:` and a 2-byte big-endian length.
50pub struct PayloadReader<'a> {
51    buf: &'a [u8],
52    pos: usize,
53}
54
55impl<'a> PayloadReader<'a> {
56    /// Create a new payload reader from raw bytes
57    pub fn new(buf: &'a [u8]) -> Self {
58        Self { buf, pos: 0 }
59    }
60}
61
62impl<'a> Iterator for PayloadReader<'a> {
63    type Item = &'a [u8];
64
65    fn next(&mut self) -> Option<Self::Item> {
66        // End of buffer
67        if self.pos >= self.buf.len() {
68            return None;
69        }
70
71        // Expect field marker ':'
72        unsafe {
73            if *self.buf.get_unchecked(self.pos) != b':' {
74                return None;
75            }
76        }
77        self.pos += 1;
78
79        // Ensure enough bytes for length
80        if self.pos + 2 > self.buf.len() {
81            return None;
82        }
83
84        let len = u16::from_be_bytes([self.buf[self.pos], self.buf[self.pos + 1]]) as usize;
85
86        self.pos += 2;
87
88        // Ensure enough bytes for data
89        if self.pos + len > self.buf.len() {
90            return None;
91        }
92
93        let slice = &self.buf[self.pos..self.pos + len];
94        self.pos += len;
95
96        Some(slice)
97    }
98}
99
100impl BroadcastReceiver {
101    /// Start receiving broadcast packets in a background thread.
102    ///
103    /// Returns:
104    /// - Stop function
105    /// - Data receiver channel
106    /// - Thread handle
107    ///
108    /// ## Type Parameter
109    ///
110    /// `U` is a user-defined type that can be constructed from:
111    /// ```text
112    /// (SocketAddr, PayloadReader)
113    /// ```
114    ///
115    /// This allows flexible decoding of incoming packets.
116    ///
117    /// ## Behavior
118    ///
119    /// - Deduplicates data per sender (`SocketAddr`)
120    /// - Sends only new or changed data
121    /// - Ignores invalid payloads silently
122    pub fn start<U>(
123        self,
124    ) -> (
125        Box<dyn FnOnce() + Send>,
126        Receiver<(SocketAddr, U)>,
127        JoinHandle<()>,
128    )
129    where
130        U: for<'a> TryFrom<(SocketAddr, PayloadReader<'a>)>,
131        U: Clone + PartialEq + Send + 'static,
132    {
133        let (data_tx, data_rx) = mpsc::channel();
134        let (stop_tx, stop_rx) = mpsc::channel();
135
136        let handle = thread::spawn(move || {
137            let mut this = self;
138
139            // Track last seen data per sender
140            let mut seen: HashMap<SocketAddr, U> = HashMap::new();
141
142            loop {
143                // Check stop signal
144                if stop_rx.try_recv().is_ok() {
145                    break;
146                }
147
148                match this.socket.recv_from(&mut this.buffer) {
149                    Ok((size, addr)) => {
150                        // Validate prefix
151                        if this.buffer.starts_with(&this.prefix) {
152                            let payload = &this.buffer[this.prefix.len()..size];
153                            let reader = PayloadReader::new(payload);
154
155                            match U::try_from((addr, reader)) {
156                                Ok(data) => {
157                                    // Deduplicate
158                                    let is_new_or_changed = match seen.get(&addr) {
159                                        Some(old) => old != &data,
160                                        None => true,
161                                    };
162
163                                    if is_new_or_changed {
164                                        seen.insert(addr, data.clone());
165                                        let _ = data_tx.send((addr, data));
166                                    }
167                                }
168                                Err(_) => continue, // Ignore invalid payload
169                            }
170                        }
171                    }
172                    Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => continue,
173                    Err(ref e) if e.kind() == std::io::ErrorKind::TimedOut => continue,
174                    Err(e) => {
175                        eprintln!("Receive error: {}", e);
176                        break;
177                    }
178                }
179            }
180        });
181
182        let stop = Box::new(move || {
183            let _ = stop_tx.send(());
184        });
185
186        (stop, data_rx, handle)
187    }
188}
189
190/// Builder for [`BroadcastReceiver`]
191///
192pub struct BroadcastReceiverBuilder {
193    prefix: Vec<u8>,
194    timeout: Option<Duration>,
195    buffer_size: Option<NonZero<usize>>,
196    bind_addr: SocketAddr,
197}
198
199impl Default for BroadcastReceiverBuilder {
200    fn default() -> Self {
201        Self {
202            prefix: Vec::new(),
203            timeout: Some(Duration::from_millis(300)),
204            buffer_size: NonZero::new(8 * 1024), // 8 KB
205            bind_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 7755),
206        }
207    }
208}
209
210impl BroadcastReceiverBuilder {
211    /// Set packet prefix for filtering
212    pub fn prefix<T: Into<Vec<u8>>>(mut self, value: T) -> Self {
213        self.prefix = value.into();
214        self
215    }
216
217    /// Set internal buffer size
218    pub fn buffer_size(mut self, size: usize) -> Self {
219        self.buffer_size = NonZero::new(size);
220        self
221    }
222
223    /// Set socket bind address
224    pub fn bind_addr(mut self, addr: SocketAddr) -> Self {
225        self.bind_addr = addr;
226        self
227    }
228
229    /// Build [`BroadcastReceiver`]
230    pub fn build(self) -> anyhow::Result<BroadcastReceiver> {
231        let buffer_size = self.buffer_size.context("Buffer size is not set")?.get();
232
233        let buffer = vec![0u8; buffer_size + self.prefix.len()].into_boxed_slice();
234
235        let socket = UdpSocket::bind(self.bind_addr)
236            .with_context(|| format!("Failed to bind UDP socket on {}", self.bind_addr))?;
237
238        socket.set_read_timeout(self.timeout).with_context(|| {
239            format!(
240                "Failed to set read timeout {:?} on {}",
241                self.timeout, self.bind_addr
242            )
243        })?;
244
245        Ok(BroadcastReceiver {
246            prefix: self.prefix,
247            buffer,
248            socket,
249        })
250    }
251}