mecomp_core/
udp.rs

1//! Implementation for the UDP stack used by the server to broadcast events to clients
2
3use std::{
4    fmt::Debug,
5    marker::PhantomData,
6    net::{Ipv4Addr, SocketAddr},
7    time::Duration,
8};
9
10use mecomp_storage::db::schemas::Thing;
11use object_pool::Pool;
12use serde::{de::DeserializeOwned, Deserialize, Serialize};
13use tokio::net::UdpSocket;
14
15use crate::{
16    errors::UdpError,
17    state::{RepeatMode, Status},
18};
19
20pub type Result<T> = std::result::Result<T, UdpError>;
21
22#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
23pub enum Event {
24    LibraryRescanFinished,
25    LibraryAnalysisFinished,
26    LibraryReclusterFinished,
27    DaemonShutdown,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
31pub enum StateChange {
32    /// The player has been muted
33    Muted,
34    /// The player has been unmuted
35    Unmuted,
36    /// The player volume has changed
37    VolumeChanged(f32),
38    /// The current track has changed
39    TrackChanged(Option<Thing>),
40    /// The repeat mode has changed
41    RepeatModeChanged(RepeatMode),
42    /// Seeked to a new position in the track
43    Seeked(Duration),
44    /// Playback Status has changes
45    StatusChanged(Status),
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
49pub enum Message {
50    Event(Event),
51    StateChange(StateChange),
52}
53
54impl From<Event> for Message {
55    fn from(val: Event) -> Self {
56        Self::Event(val)
57    }
58}
59
60const MAX_MESSAGE_SIZE: usize = 1024;
61
62#[derive(Debug)]
63pub struct Listener<T, const BUF_SIZE: usize> {
64    socket: UdpSocket,
65    buffer: [u8; BUF_SIZE],
66    message_type: PhantomData<T>,
67}
68
69impl<T: DeserializeOwned + Send + Sync> Listener<T, MAX_MESSAGE_SIZE> {
70    /// Create a new UDP listener bound to the given socket address.
71    ///
72    /// # Errors
73    ///
74    /// Returns an error if the socket cannot be bound.
75    pub async fn new() -> Result<Self> {
76        Self::with_buffer_size().await
77    }
78}
79
80impl<T: DeserializeOwned + Send + Sync, const B: usize> Listener<T, B> {
81    /// Create a new UDP listener bound to the given socket address.
82    /// With a custom buffer size (set with const generics).
83    ///
84    /// # Errors
85    ///
86    /// Returns an error if the socket cannot be bound.
87    pub async fn with_buffer_size() -> Result<Self> {
88        let socket = UdpSocket::bind((Ipv4Addr::UNSPECIFIED, 0)).await?;
89
90        Ok(Self {
91            socket,
92            buffer: [0; B],
93            message_type: PhantomData,
94        })
95    }
96
97    /// Get the socket address of the listener
98    ///
99    /// # Errors
100    ///
101    /// Returns an error if the socket address cannot be retrieved.
102    pub fn local_addr(&self) -> Result<SocketAddr> {
103        Ok(self.socket.local_addr()?)
104    }
105
106    /// Receive a message from the UDP socket.
107    /// Cancel safe.
108    ///
109    /// # Errors
110    ///
111    /// Returns an error if the message cannot be deserialized or received.
112    pub async fn recv(&mut self) -> Result<T> {
113        let (size, _) = self.socket.recv_from(&mut self.buffer).await?;
114        let message = ciborium::from_reader(&self.buffer[..size])?;
115
116        Ok(message)
117    }
118}
119
120pub struct Sender<T> {
121    socket: UdpSocket,
122    buffer_pool: Pool<Vec<u8>>,
123    /// List of subscribers to send messages to
124    subscribers: Vec<SocketAddr>,
125    message_type: PhantomData<T>,
126}
127
128impl<T> std::fmt::Debug for Sender<T> {
129    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130        f.debug_struct("Sender")
131            .field("socket", &self.socket)
132            .field("subscribers", &self.subscribers)
133            .field("message_type", &self.message_type)
134            .field("buffer_pool.len", &self.buffer_pool.len())
135            .finish()
136    }
137}
138
139impl<T: Serialize + Send + Sync> Sender<T> {
140    /// Create a new UDP sender bound to an ephemeral port.
141    ///
142    /// # Errors
143    ///
144    /// Returns an error if the socket cannot be bound or connected.
145    pub async fn new() -> Result<Self> {
146        let socket = UdpSocket::bind((Ipv4Addr::LOCALHOST, 0)).await?;
147
148        Ok(Self {
149            socket,
150            buffer_pool: Pool::new(1, || Vec::with_capacity(MAX_MESSAGE_SIZE)),
151            subscribers: Vec::new(),
152            message_type: PhantomData,
153        })
154    }
155
156    /// Add a subscriber to the list of subscribers.
157    pub fn add_subscriber(&mut self, subscriber: SocketAddr) {
158        self.subscribers.push(subscriber);
159    }
160
161    /// Send a message to the UDP socket.
162    /// Cancel safe.
163    ///
164    /// # Errors
165    ///
166    /// Returns an error if the message cannot be serialized or sent.
167    pub async fn send(&self, message: impl Into<T> + Send + Sync + Debug) -> Result<()> {
168        log::info!(
169            "Forwarding state change: {message:?} to {} subscribers",
170            self.subscribers.len()
171        );
172
173        let (pool, mut buffer) = self.buffer_pool.pull(Vec::new).detach();
174        buffer.clear();
175
176        ciborium::into_writer(&message.into(), &mut buffer)?;
177
178        for subscriber in &self.subscribers {
179            self.socket.send_to(&buffer, subscriber).await?;
180        }
181
182        pool.attach(buffer);
183
184        Ok(())
185    }
186}
187
188#[cfg(test)]
189mod test {
190    use super::*;
191
192    #[rstest::rstest]
193    #[case(Message::Event(Event::LibraryRescanFinished))]
194    #[case(Message::Event(Event::LibraryAnalysisFinished))]
195    #[case(Message::Event(Event::LibraryReclusterFinished))]
196    #[tokio::test]
197    #[timeout(std::time::Duration::from_secs(1))]
198    async fn test_udp(#[case] message: Message, #[values(1, 2, 3)] num_listeners: usize) {
199        let mut sender = Sender::<Message>::new().await.unwrap();
200
201        let mut listeners = Vec::new();
202
203        for _ in 0..num_listeners {
204            let listener = Listener::new().await.unwrap();
205            sender.add_subscriber(listener.local_addr().unwrap());
206            listeners.push(listener);
207        }
208
209        sender.send(message.clone()).await.unwrap();
210
211        for (i, listener) in listeners.iter_mut().enumerate() {
212            let received_message: Message = listener.recv().await.unwrap();
213            assert_eq!(received_message, message, "Listener {i}");
214        }
215    }
216
217    #[rstest::rstest]
218    #[case(Message::Event(Event::LibraryRescanFinished))]
219    #[case(Message::Event(Event::LibraryAnalysisFinished))]
220    #[case(Message::Event(Event::LibraryReclusterFinished))]
221    #[case(Message::StateChange(StateChange::Muted))]
222    #[case(Message::StateChange(StateChange::Unmuted))]
223    #[case(Message::StateChange(StateChange::VolumeChanged(1. / 3.)))]
224    #[case(Message::StateChange(StateChange::TrackChanged(None)))]
225    #[case(Message::StateChange(StateChange::RepeatModeChanged(RepeatMode::None)))]
226    #[case(Message::StateChange(StateChange::Seeked(Duration::from_secs(3))))]
227    #[case(Message::StateChange(StateChange::StatusChanged(Status::Paused)))]
228    #[case(Message::StateChange(StateChange::StatusChanged(Status::Playing)))]
229    #[case(Message::StateChange(StateChange::StatusChanged(Status::Stopped)))]
230    #[case(Message::StateChange(StateChange::TrackChanged(Some(
231        mecomp_storage::db::schemas::song::Song::generate_id().into()
232    ))))]
233    fn test_message_encoding_length(#[case] message: Message) {
234        let mut buffer = Vec::new();
235        ciborium::into_writer(&message, &mut buffer).unwrap();
236
237        assert!(buffer.len() <= MAX_MESSAGE_SIZE);
238    }
239}