mecomp_core/
udp.rs

1//! Implementation for the UDP stack used by the server to broadcast events to clients
2
3use std::{
4    fmt::Debug,
5    marker::PhantomData,
6    net::{Ipv4Addr, SocketAddr},
7    time::Duration,
8};
9
10use mecomp_storage::db::schemas::RecordId;
11use object_pool::Pool;
12use serde::{Deserialize, Serialize, de::DeserializeOwned};
13use tokio::{net::UdpSocket, sync::RwLock};
14
15use crate::{
16    errors::UdpError,
17    state::{RepeatMode, Status},
18};
19
20pub type Result<T> = std::result::Result<T, UdpError>;
21
22#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
23pub enum Event {
24    LibraryRescanFinished,
25    LibraryAnalysisFinished,
26    LibraryReclusterFinished,
27    DaemonShutdown,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
31pub enum StateChange {
32    /// The player has been muted
33    Muted,
34    /// The player has been unmuted
35    Unmuted,
36    /// The player volume has changed
37    VolumeChanged(f32),
38    /// The current track has changed
39    /// If something causes both the queue and current track to change, then this event is sent after `QueueChanged`
40    TrackChanged(Option<RecordId>),
41    /// The repeat mode has changed
42    RepeatModeChanged(RepeatMode),
43    /// Seeked to a new position in the track
44    Seeked(Duration),
45    /// Playback Status has changes
46    StatusChanged(Status),
47    /// The queue has changed, e.g. songs added/removed/reordered
48    QueueChanged,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
52pub enum Message {
53    Event(Event),
54    StateChange(StateChange),
55}
56
57impl From<Event> for Message {
58    #[inline]
59    fn from(val: Event) -> Self {
60        Self::Event(val)
61    }
62}
63
64const MAX_MESSAGE_SIZE: usize = 1024;
65
66#[derive(Debug)]
67pub struct Listener<T, const BUF_SIZE: usize> {
68    socket: UdpSocket,
69    buffer: [u8; BUF_SIZE],
70    message_type: PhantomData<T>,
71}
72
73impl<T: DeserializeOwned + Send + Sync> Listener<T, MAX_MESSAGE_SIZE> {
74    /// Create a new UDP listener bound to the given socket address.
75    ///
76    /// # Errors
77    ///
78    /// Returns an error if the socket cannot be bound.
79    #[inline]
80    pub async fn new() -> Result<Self> {
81        Self::with_buffer_size().await
82    }
83}
84
85impl<T: DeserializeOwned + Send + Sync, const B: usize> Listener<T, B> {
86    /// Create a new UDP listener bound to the given socket address.
87    /// With a custom buffer size (set with const generics).
88    ///
89    /// # Errors
90    ///
91    /// Returns an error if the socket cannot be bound.
92    #[inline]
93    pub async fn with_buffer_size() -> Result<Self> {
94        let socket = UdpSocket::bind((Ipv4Addr::UNSPECIFIED, 0)).await?;
95
96        Ok(Self {
97            socket,
98            buffer: [0; B],
99            message_type: PhantomData,
100        })
101    }
102
103    /// Get the socket address of the listener
104    ///
105    /// # Errors
106    ///
107    /// Returns an error if the socket address cannot be retrieved.
108    #[inline]
109    pub fn local_addr(&self) -> Result<SocketAddr> {
110        Ok(self.socket.local_addr()?)
111    }
112
113    /// Receive a message from the UDP socket.
114    /// Cancel safe.
115    ///
116    /// # Errors
117    ///
118    /// Returns an error if the message cannot be deserialized or received.
119    #[inline]
120    pub async fn recv(&mut self) -> Result<T> {
121        let (size, _) = self.socket.recv_from(&mut self.buffer).await?;
122        let message = ciborium::from_reader(&self.buffer[..size])?;
123
124        Ok(message)
125    }
126}
127
128pub struct Sender<T> {
129    socket: UdpSocket,
130    buffer_pool: Pool<Vec<u8>>,
131    /// List of subscribers to send messages to
132    subscribers: RwLock<Vec<SocketAddr>>,
133    message_type: PhantomData<T>,
134}
135
136impl<T> std::fmt::Debug for Sender<T> {
137    #[inline]
138    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
139        f.debug_struct("Sender")
140            .field("socket", &self.socket)
141            .field("subscribers", &self.subscribers)
142            .field("message_type", &self.message_type)
143            .field("buffer_pool.len", &self.buffer_pool.len())
144            .finish()
145    }
146}
147
148impl<T: Serialize + Send + Sync> Sender<T> {
149    /// Create a new UDP sender bound to an ephemeral port.
150    ///
151    /// # Errors
152    ///
153    /// Returns an error if the socket cannot be bound or connected.
154    #[inline]
155    pub async fn new() -> Result<Self> {
156        let socket = UdpSocket::bind((Ipv4Addr::LOCALHOST, 0)).await?;
157
158        Ok(Self {
159            socket,
160            buffer_pool: Pool::new(1, || Vec::with_capacity(MAX_MESSAGE_SIZE)),
161            subscribers: RwLock::new(Vec::new()),
162            message_type: PhantomData,
163        })
164    }
165
166    /// Add a subscriber to the list of subscribers.
167    ///
168    /// # Concurrency
169    ///
170    /// Acquires a write lock on the subscribers list to do this.
171    #[inline]
172    pub async fn add_subscriber(&self, subscriber: SocketAddr) {
173        self.subscribers.write().await.push(subscriber);
174    }
175
176    /// Send a message to the UDP socket.
177    /// Cancel safe.
178    ///
179    /// # Concurrency
180    ///
181    /// Acquires a read lock on the subscribers list
182    ///
183    /// # Errors
184    ///
185    /// Returns an error if the message cannot be serialized or sent.
186    #[inline]
187    pub async fn send(&self, message: impl Into<T> + Send + Sync + Debug) -> Result<()> {
188        let subscribers = self.subscribers.read().await;
189        log::info!(
190            "Forwarding state change: {message:?} to {} subscribers",
191            subscribers.len()
192        );
193
194        let (pool, mut buffer) = self.buffer_pool.pull(Vec::new).detach();
195        buffer.clear();
196
197        ciborium::into_writer(&message.into(), &mut buffer)?;
198
199        for subscriber in subscribers.iter() {
200            self.socket.send_to(&buffer, subscriber).await?;
201        }
202        drop(subscribers);
203
204        pool.attach(buffer);
205
206        Ok(())
207    }
208}
209
210#[cfg(test)]
211mod test {
212    use super::*;
213
214    #[rstest::rstest]
215    #[case(Message::Event(Event::LibraryRescanFinished))]
216    #[case(Message::Event(Event::LibraryAnalysisFinished))]
217    #[case(Message::Event(Event::LibraryReclusterFinished))]
218    #[tokio::test]
219    #[timeout(std::time::Duration::from_secs(1))]
220    async fn test_udp(#[case] message: Message, #[values(1, 2, 3)] num_listeners: usize) {
221        let sender = Sender::<Message>::new().await.unwrap();
222
223        let mut listeners = Vec::new();
224
225        for _ in 0..num_listeners {
226            let listener = Listener::new().await.unwrap();
227            sender.add_subscriber(listener.local_addr().unwrap()).await;
228            listeners.push(listener);
229        }
230
231        sender.send(message.clone()).await.unwrap();
232
233        for (i, listener) in listeners.iter_mut().enumerate() {
234            let received_message: Message = listener.recv().await.unwrap();
235            assert_eq!(received_message, message, "Listener {i}");
236        }
237    }
238
239    #[rstest::rstest]
240    #[case(Message::Event(Event::LibraryRescanFinished))]
241    #[case(Message::Event(Event::LibraryAnalysisFinished))]
242    #[case(Message::Event(Event::LibraryReclusterFinished))]
243    #[case(Message::StateChange(StateChange::Muted))]
244    #[case(Message::StateChange(StateChange::Unmuted))]
245    #[case(Message::StateChange(StateChange::VolumeChanged(1. / 3.)))]
246    #[case(Message::StateChange(StateChange::TrackChanged(None)))]
247    #[case(Message::StateChange(StateChange::RepeatModeChanged(RepeatMode::None)))]
248    #[case(Message::StateChange(StateChange::Seeked(Duration::from_secs(3))))]
249    #[case(Message::StateChange(StateChange::StatusChanged(Status::Paused)))]
250    #[case(Message::StateChange(StateChange::StatusChanged(Status::Playing)))]
251    #[case(Message::StateChange(StateChange::StatusChanged(Status::Stopped)))]
252    #[case(Message::StateChange(StateChange::TrackChanged(Some(
253        mecomp_storage::db::schemas::song::Song::generate_id().into()
254    ))))]
255    fn test_message_encoding_length(#[case] message: Message) {
256        let mut buffer = Vec::new();
257        ciborium::into_writer(&message, &mut buffer).unwrap();
258
259        assert!(buffer.len() <= MAX_MESSAGE_SIZE);
260    }
261}