1use std::{
4 fmt::Debug,
5 marker::PhantomData,
6 net::{Ipv4Addr, SocketAddr},
7 time::Duration,
8};
9
10use mecomp_storage::db::schemas::RecordId;
11use object_pool::Pool;
12use serde::{Deserialize, Serialize, de::DeserializeOwned};
13use tokio::{net::UdpSocket, sync::RwLock};
14
15use crate::{
16 errors::UdpError,
17 state::{RepeatMode, Status},
18};
19
20pub type Result<T> = std::result::Result<T, UdpError>;
21
22#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
23pub enum Event {
24 LibraryRescanFinished,
25 LibraryAnalysisFinished,
26 LibraryReclusterFinished,
27 DaemonShutdown,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
31pub enum StateChange {
32 Muted,
34 Unmuted,
36 VolumeChanged(f32),
38 TrackChanged(Option<RecordId>),
40 RepeatModeChanged(RepeatMode),
42 Seeked(Duration),
44 StatusChanged(Status),
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
49pub enum Message {
50 Event(Event),
51 StateChange(StateChange),
52}
53
54impl From<Event> for Message {
55 #[inline]
56 fn from(val: Event) -> Self {
57 Self::Event(val)
58 }
59}
60
61const MAX_MESSAGE_SIZE: usize = 1024;
62
63#[derive(Debug)]
64pub struct Listener<T, const BUF_SIZE: usize> {
65 socket: UdpSocket,
66 buffer: [u8; BUF_SIZE],
67 message_type: PhantomData<T>,
68}
69
70impl<T: DeserializeOwned + Send + Sync> Listener<T, MAX_MESSAGE_SIZE> {
71 #[inline]
77 pub async fn new() -> Result<Self> {
78 Self::with_buffer_size().await
79 }
80}
81
82impl<T: DeserializeOwned + Send + Sync, const B: usize> Listener<T, B> {
83 #[inline]
90 pub async fn with_buffer_size() -> Result<Self> {
91 let socket = UdpSocket::bind((Ipv4Addr::UNSPECIFIED, 0)).await?;
92
93 Ok(Self {
94 socket,
95 buffer: [0; B],
96 message_type: PhantomData,
97 })
98 }
99
100 #[inline]
106 pub fn local_addr(&self) -> Result<SocketAddr> {
107 Ok(self.socket.local_addr()?)
108 }
109
110 #[inline]
117 pub async fn recv(&mut self) -> Result<T> {
118 let (size, _) = self.socket.recv_from(&mut self.buffer).await?;
119 let message = ciborium::from_reader(&self.buffer[..size])?;
120
121 Ok(message)
122 }
123}
124
125pub struct Sender<T> {
126 socket: UdpSocket,
127 buffer_pool: Pool<Vec<u8>>,
128 subscribers: RwLock<Vec<SocketAddr>>,
130 message_type: PhantomData<T>,
131}
132
133impl<T> std::fmt::Debug for Sender<T> {
134 #[inline]
135 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136 f.debug_struct("Sender")
137 .field("socket", &self.socket)
138 .field("subscribers", &self.subscribers)
139 .field("message_type", &self.message_type)
140 .field("buffer_pool.len", &self.buffer_pool.len())
141 .finish()
142 }
143}
144
145impl<T: Serialize + Send + Sync> Sender<T> {
146 #[inline]
152 pub async fn new() -> Result<Self> {
153 let socket = UdpSocket::bind((Ipv4Addr::LOCALHOST, 0)).await?;
154
155 Ok(Self {
156 socket,
157 buffer_pool: Pool::new(1, || Vec::with_capacity(MAX_MESSAGE_SIZE)),
158 subscribers: RwLock::new(Vec::new()),
159 message_type: PhantomData,
160 })
161 }
162
163 #[inline]
169 pub async fn add_subscriber(&self, subscriber: SocketAddr) {
170 self.subscribers.write().await.push(subscriber);
171 }
172
173 #[inline]
184 pub async fn send(&self, message: impl Into<T> + Send + Sync + Debug) -> Result<()> {
185 let subscribers = self.subscribers.read().await;
186 log::info!(
187 "Forwarding state change: {message:?} to {} subscribers",
188 subscribers.len()
189 );
190
191 let (pool, mut buffer) = self.buffer_pool.pull(Vec::new).detach();
192 buffer.clear();
193
194 ciborium::into_writer(&message.into(), &mut buffer)?;
195
196 for subscriber in subscribers.iter() {
197 self.socket.send_to(&buffer, subscriber).await?;
198 }
199 drop(subscribers);
200
201 pool.attach(buffer);
202
203 Ok(())
204 }
205}
206
207#[cfg(test)]
208mod test {
209 use super::*;
210
211 #[rstest::rstest]
212 #[case(Message::Event(Event::LibraryRescanFinished))]
213 #[case(Message::Event(Event::LibraryAnalysisFinished))]
214 #[case(Message::Event(Event::LibraryReclusterFinished))]
215 #[tokio::test]
216 #[timeout(std::time::Duration::from_secs(1))]
217 async fn test_udp(#[case] message: Message, #[values(1, 2, 3)] num_listeners: usize) {
218 let sender = Sender::<Message>::new().await.unwrap();
219
220 let mut listeners = Vec::new();
221
222 for _ in 0..num_listeners {
223 let listener = Listener::new().await.unwrap();
224 sender.add_subscriber(listener.local_addr().unwrap()).await;
225 listeners.push(listener);
226 }
227
228 sender.send(message.clone()).await.unwrap();
229
230 for (i, listener) in listeners.iter_mut().enumerate() {
231 let received_message: Message = listener.recv().await.unwrap();
232 assert_eq!(received_message, message, "Listener {i}");
233 }
234 }
235
236 #[rstest::rstest]
237 #[case(Message::Event(Event::LibraryRescanFinished))]
238 #[case(Message::Event(Event::LibraryAnalysisFinished))]
239 #[case(Message::Event(Event::LibraryReclusterFinished))]
240 #[case(Message::StateChange(StateChange::Muted))]
241 #[case(Message::StateChange(StateChange::Unmuted))]
242 #[case(Message::StateChange(StateChange::VolumeChanged(1. / 3.)))]
243 #[case(Message::StateChange(StateChange::TrackChanged(None)))]
244 #[case(Message::StateChange(StateChange::RepeatModeChanged(RepeatMode::None)))]
245 #[case(Message::StateChange(StateChange::Seeked(Duration::from_secs(3))))]
246 #[case(Message::StateChange(StateChange::StatusChanged(Status::Paused)))]
247 #[case(Message::StateChange(StateChange::StatusChanged(Status::Playing)))]
248 #[case(Message::StateChange(StateChange::StatusChanged(Status::Stopped)))]
249 #[case(Message::StateChange(StateChange::TrackChanged(Some(
250 mecomp_storage::db::schemas::song::Song::generate_id().into()
251 ))))]
252 fn test_message_encoding_length(#[case] message: Message) {
253 let mut buffer = Vec::new();
254 ciborium::into_writer(&message, &mut buffer).unwrap();
255
256 assert!(buffer.len() <= MAX_MESSAGE_SIZE);
257 }
258}