mecomp_core/
udp.rs

1//! Implementation for the UDP stack used by the server to broadcast events to clients
2
3use std::{
4    fmt::Debug,
5    marker::PhantomData,
6    net::{Ipv4Addr, SocketAddr},
7    time::Duration,
8};
9
10use mecomp_storage::db::schemas::RecordId;
11use object_pool::Pool;
12use serde::{Deserialize, Serialize, de::DeserializeOwned};
13use tokio::net::UdpSocket;
14
15use crate::{
16    errors::UdpError,
17    state::{RepeatMode, Status},
18};
19
20pub type Result<T> = std::result::Result<T, UdpError>;
21
22#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
23pub enum Event {
24    LibraryRescanFinished,
25    LibraryAnalysisFinished,
26    LibraryReclusterFinished,
27    DaemonShutdown,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
31pub enum StateChange {
32    /// The player has been muted
33    Muted,
34    /// The player has been unmuted
35    Unmuted,
36    /// The player volume has changed
37    VolumeChanged(f32),
38    /// The current track has changed
39    TrackChanged(Option<RecordId>),
40    /// The repeat mode has changed
41    RepeatModeChanged(RepeatMode),
42    /// Seeked to a new position in the track
43    Seeked(Duration),
44    /// Playback Status has changes
45    StatusChanged(Status),
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
49pub enum Message {
50    Event(Event),
51    StateChange(StateChange),
52}
53
54impl From<Event> for Message {
55    #[inline]
56    fn from(val: Event) -> Self {
57        Self::Event(val)
58    }
59}
60
61const MAX_MESSAGE_SIZE: usize = 1024;
62
63#[derive(Debug)]
64pub struct Listener<T, const BUF_SIZE: usize> {
65    socket: UdpSocket,
66    buffer: [u8; BUF_SIZE],
67    message_type: PhantomData<T>,
68}
69
70impl<T: DeserializeOwned + Send + Sync> Listener<T, MAX_MESSAGE_SIZE> {
71    /// Create a new UDP listener bound to the given socket address.
72    ///
73    /// # Errors
74    ///
75    /// Returns an error if the socket cannot be bound.
76    #[inline]
77    pub async fn new() -> Result<Self> {
78        Self::with_buffer_size().await
79    }
80}
81
82impl<T: DeserializeOwned + Send + Sync, const B: usize> Listener<T, B> {
83    /// Create a new UDP listener bound to the given socket address.
84    /// With a custom buffer size (set with const generics).
85    ///
86    /// # Errors
87    ///
88    /// Returns an error if the socket cannot be bound.
89    #[inline]
90    pub async fn with_buffer_size() -> Result<Self> {
91        let socket = UdpSocket::bind((Ipv4Addr::UNSPECIFIED, 0)).await?;
92
93        Ok(Self {
94            socket,
95            buffer: [0; B],
96            message_type: PhantomData,
97        })
98    }
99
100    /// Get the socket address of the listener
101    ///
102    /// # Errors
103    ///
104    /// Returns an error if the socket address cannot be retrieved.
105    #[inline]
106    pub fn local_addr(&self) -> Result<SocketAddr> {
107        Ok(self.socket.local_addr()?)
108    }
109
110    /// Receive a message from the UDP socket.
111    /// Cancel safe.
112    ///
113    /// # Errors
114    ///
115    /// Returns an error if the message cannot be deserialized or received.
116    #[inline]
117    pub async fn recv(&mut self) -> Result<T> {
118        let (size, _) = self.socket.recv_from(&mut self.buffer).await?;
119        let message = ciborium::from_reader(&self.buffer[..size])?;
120
121        Ok(message)
122    }
123}
124
125pub struct Sender<T> {
126    socket: UdpSocket,
127    buffer_pool: Pool<Vec<u8>>,
128    /// List of subscribers to send messages to
129    subscribers: Vec<SocketAddr>,
130    message_type: PhantomData<T>,
131}
132
133impl<T> std::fmt::Debug for Sender<T> {
134    #[inline]
135    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136        f.debug_struct("Sender")
137            .field("socket", &self.socket)
138            .field("subscribers", &self.subscribers)
139            .field("message_type", &self.message_type)
140            .field("buffer_pool.len", &self.buffer_pool.len())
141            .finish()
142    }
143}
144
145impl<T: Serialize + Send + Sync> Sender<T> {
146    /// Create a new UDP sender bound to an ephemeral port.
147    ///
148    /// # Errors
149    ///
150    /// Returns an error if the socket cannot be bound or connected.
151    #[inline]
152    pub async fn new() -> Result<Self> {
153        let socket = UdpSocket::bind((Ipv4Addr::LOCALHOST, 0)).await?;
154
155        Ok(Self {
156            socket,
157            buffer_pool: Pool::new(1, || Vec::with_capacity(MAX_MESSAGE_SIZE)),
158            subscribers: Vec::new(),
159            message_type: PhantomData,
160        })
161    }
162
163    /// Add a subscriber to the list of subscribers.
164    #[inline]
165    pub fn add_subscriber(&mut self, subscriber: SocketAddr) {
166        self.subscribers.push(subscriber);
167    }
168
169    /// Send a message to the UDP socket.
170    /// Cancel safe.
171    ///
172    /// # Errors
173    ///
174    /// Returns an error if the message cannot be serialized or sent.
175    #[inline]
176    pub async fn send(&self, message: impl Into<T> + Send + Sync + Debug) -> Result<()> {
177        log::info!(
178            "Forwarding state change: {message:?} to {} subscribers",
179            self.subscribers.len()
180        );
181
182        let (pool, mut buffer) = self.buffer_pool.pull(Vec::new).detach();
183        buffer.clear();
184
185        ciborium::into_writer(&message.into(), &mut buffer)?;
186
187        for subscriber in &self.subscribers {
188            self.socket.send_to(&buffer, subscriber).await?;
189        }
190
191        pool.attach(buffer);
192
193        Ok(())
194    }
195}
196
197#[cfg(test)]
198mod test {
199    use super::*;
200
201    #[rstest::rstest]
202    #[case(Message::Event(Event::LibraryRescanFinished))]
203    #[case(Message::Event(Event::LibraryAnalysisFinished))]
204    #[case(Message::Event(Event::LibraryReclusterFinished))]
205    #[tokio::test]
206    #[timeout(std::time::Duration::from_secs(1))]
207    async fn test_udp(#[case] message: Message, #[values(1, 2, 3)] num_listeners: usize) {
208        let mut sender = Sender::<Message>::new().await.unwrap();
209
210        let mut listeners = Vec::new();
211
212        for _ in 0..num_listeners {
213            let listener = Listener::new().await.unwrap();
214            sender.add_subscriber(listener.local_addr().unwrap());
215            listeners.push(listener);
216        }
217
218        sender.send(message.clone()).await.unwrap();
219
220        for (i, listener) in listeners.iter_mut().enumerate() {
221            let received_message: Message = listener.recv().await.unwrap();
222            assert_eq!(received_message, message, "Listener {i}");
223        }
224    }
225
226    #[rstest::rstest]
227    #[case(Message::Event(Event::LibraryRescanFinished))]
228    #[case(Message::Event(Event::LibraryAnalysisFinished))]
229    #[case(Message::Event(Event::LibraryReclusterFinished))]
230    #[case(Message::StateChange(StateChange::Muted))]
231    #[case(Message::StateChange(StateChange::Unmuted))]
232    #[case(Message::StateChange(StateChange::VolumeChanged(1. / 3.)))]
233    #[case(Message::StateChange(StateChange::TrackChanged(None)))]
234    #[case(Message::StateChange(StateChange::RepeatModeChanged(RepeatMode::None)))]
235    #[case(Message::StateChange(StateChange::Seeked(Duration::from_secs(3))))]
236    #[case(Message::StateChange(StateChange::StatusChanged(Status::Paused)))]
237    #[case(Message::StateChange(StateChange::StatusChanged(Status::Playing)))]
238    #[case(Message::StateChange(StateChange::StatusChanged(Status::Stopped)))]
239    #[case(Message::StateChange(StateChange::TrackChanged(Some(
240        mecomp_storage::db::schemas::song::Song::generate_id().into()
241    ))))]
242    fn test_message_encoding_length(#[case] message: Message) {
243        let mut buffer = Vec::new();
244        ciborium::into_writer(&message, &mut buffer).unwrap();
245
246        assert!(buffer.len() <= MAX_MESSAGE_SIZE);
247    }
248}