1use std::{
4 fmt::Debug,
5 marker::PhantomData,
6 net::{Ipv4Addr, SocketAddr},
7 time::Duration,
8};
9
10use mecomp_storage::db::schemas::RecordId;
11use object_pool::Pool;
12use serde::{Deserialize, Serialize, de::DeserializeOwned};
13use tokio::{net::UdpSocket, sync::RwLock};
14
15use crate::{
16 errors::UdpError,
17 state::{RepeatMode, Status},
18};
19
20pub type Result<T> = std::result::Result<T, UdpError>;
21
22#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
23pub enum Event {
24 LibraryRescanFinished,
25 LibraryAnalysisFinished,
26 LibraryReclusterFinished,
27 DaemonShutdown,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
31pub enum StateChange {
32 Muted,
34 Unmuted,
36 VolumeChanged(f32),
38 TrackChanged(Option<RecordId>),
41 RepeatModeChanged(RepeatMode),
43 Seeked(Duration),
45 StatusChanged(Status),
47 QueueChanged,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
52pub enum Message {
53 Event(Event),
54 StateChange(StateChange),
55}
56
57impl From<Event> for Message {
58 #[inline]
59 fn from(val: Event) -> Self {
60 Self::Event(val)
61 }
62}
63
64const MAX_MESSAGE_SIZE: usize = 1024;
65
66#[derive(Debug)]
67pub struct Listener<T, const BUF_SIZE: usize> {
68 socket: UdpSocket,
69 buffer: [u8; BUF_SIZE],
70 message_type: PhantomData<T>,
71}
72
73impl<T: DeserializeOwned + Send + Sync> Listener<T, MAX_MESSAGE_SIZE> {
74 #[inline]
80 pub async fn new() -> Result<Self> {
81 Self::with_buffer_size().await
82 }
83}
84
85impl<T: DeserializeOwned + Send + Sync, const B: usize> Listener<T, B> {
86 #[inline]
93 pub async fn with_buffer_size() -> Result<Self> {
94 let socket = UdpSocket::bind((Ipv4Addr::UNSPECIFIED, 0)).await?;
95
96 Ok(Self {
97 socket,
98 buffer: [0; B],
99 message_type: PhantomData,
100 })
101 }
102
103 #[inline]
109 pub fn local_addr(&self) -> Result<SocketAddr> {
110 Ok(self.socket.local_addr()?)
111 }
112
113 #[inline]
120 pub async fn recv(&mut self) -> Result<T> {
121 let (size, _) = self.socket.recv_from(&mut self.buffer).await?;
122 let message = ciborium::from_reader(&self.buffer[..size])?;
123
124 Ok(message)
125 }
126}
127
128pub struct Sender<T> {
129 socket: UdpSocket,
130 buffer_pool: Pool<Vec<u8>>,
131 subscribers: RwLock<Vec<SocketAddr>>,
133 message_type: PhantomData<T>,
134}
135
136impl<T> std::fmt::Debug for Sender<T> {
137 #[inline]
138 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
139 f.debug_struct("Sender")
140 .field("socket", &self.socket)
141 .field("subscribers", &self.subscribers)
142 .field("message_type", &self.message_type)
143 .field("buffer_pool.len", &self.buffer_pool.len())
144 .finish()
145 }
146}
147
148impl<T: Serialize + Send + Sync> Sender<T> {
149 #[inline]
155 pub async fn new() -> Result<Self> {
156 let socket = UdpSocket::bind((Ipv4Addr::LOCALHOST, 0)).await?;
157
158 Ok(Self {
159 socket,
160 buffer_pool: Pool::new(1, || Vec::with_capacity(MAX_MESSAGE_SIZE)),
161 subscribers: RwLock::new(Vec::new()),
162 message_type: PhantomData,
163 })
164 }
165
166 #[inline]
172 pub async fn add_subscriber(&self, subscriber: SocketAddr) {
173 self.subscribers.write().await.push(subscriber);
174 }
175
176 #[inline]
187 pub async fn send(&self, message: impl Into<T> + Send + Sync + Debug) -> Result<()> {
188 let subscribers = self.subscribers.read().await;
189 log::info!(
190 "Forwarding state change: {message:?} to {} subscribers",
191 subscribers.len()
192 );
193
194 let (pool, mut buffer) = self.buffer_pool.pull(Vec::new).detach();
195 buffer.clear();
196
197 ciborium::into_writer(&message.into(), &mut buffer)?;
198
199 for subscriber in subscribers.iter() {
200 self.socket.send_to(&buffer, subscriber).await?;
201 }
202 drop(subscribers);
203
204 pool.attach(buffer);
205
206 Ok(())
207 }
208}
209
210#[cfg(test)]
211mod test {
212 use super::*;
213
214 #[rstest::rstest]
215 #[case(Message::Event(Event::LibraryRescanFinished))]
216 #[case(Message::Event(Event::LibraryAnalysisFinished))]
217 #[case(Message::Event(Event::LibraryReclusterFinished))]
218 #[tokio::test]
219 #[timeout(std::time::Duration::from_secs(1))]
220 async fn test_udp(#[case] message: Message, #[values(1, 2, 3)] num_listeners: usize) {
221 let sender = Sender::<Message>::new().await.unwrap();
222
223 let mut listeners = Vec::new();
224
225 for _ in 0..num_listeners {
226 let listener = Listener::new().await.unwrap();
227 sender.add_subscriber(listener.local_addr().unwrap()).await;
228 listeners.push(listener);
229 }
230
231 sender.send(message.clone()).await.unwrap();
232
233 for (i, listener) in listeners.iter_mut().enumerate() {
234 let received_message: Message = listener.recv().await.unwrap();
235 assert_eq!(received_message, message, "Listener {i}");
236 }
237 }
238
239 #[rstest::rstest]
240 #[case(Message::Event(Event::LibraryRescanFinished))]
241 #[case(Message::Event(Event::LibraryAnalysisFinished))]
242 #[case(Message::Event(Event::LibraryReclusterFinished))]
243 #[case(Message::StateChange(StateChange::Muted))]
244 #[case(Message::StateChange(StateChange::Unmuted))]
245 #[case(Message::StateChange(StateChange::VolumeChanged(1. / 3.)))]
246 #[case(Message::StateChange(StateChange::TrackChanged(None)))]
247 #[case(Message::StateChange(StateChange::RepeatModeChanged(RepeatMode::None)))]
248 #[case(Message::StateChange(StateChange::Seeked(Duration::from_secs(3))))]
249 #[case(Message::StateChange(StateChange::StatusChanged(Status::Paused)))]
250 #[case(Message::StateChange(StateChange::StatusChanged(Status::Playing)))]
251 #[case(Message::StateChange(StateChange::StatusChanged(Status::Stopped)))]
252 #[case(Message::StateChange(StateChange::TrackChanged(Some(
253 mecomp_storage::db::schemas::song::Song::generate_id().into()
254 ))))]
255 fn test_message_encoding_length(#[case] message: Message) {
256 let mut buffer = Vec::new();
257 ciborium::into_writer(&message, &mut buffer).unwrap();
258
259 assert!(buffer.len() <= MAX_MESSAGE_SIZE);
260 }
261}