1use std::{
4 fmt::Debug,
5 marker::PhantomData,
6 net::{Ipv4Addr, SocketAddr},
7 time::Duration,
8};
9
10use mecomp_storage::db::schemas::RecordId;
11use object_pool::Pool;
12use serde::{Deserialize, Serialize, de::DeserializeOwned};
13use tokio::net::UdpSocket;
14
15use crate::{
16 errors::UdpError,
17 state::{RepeatMode, Status},
18};
19
20pub type Result<T> = std::result::Result<T, UdpError>;
21
22#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
23pub enum Event {
24 LibraryRescanFinished,
25 LibraryAnalysisFinished,
26 LibraryReclusterFinished,
27 DaemonShutdown,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
31pub enum StateChange {
32 Muted,
34 Unmuted,
36 VolumeChanged(f32),
38 TrackChanged(Option<RecordId>),
40 RepeatModeChanged(RepeatMode),
42 Seeked(Duration),
44 StatusChanged(Status),
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
49pub enum Message {
50 Event(Event),
51 StateChange(StateChange),
52}
53
54impl From<Event> for Message {
55 #[inline]
56 fn from(val: Event) -> Self {
57 Self::Event(val)
58 }
59}
60
61const MAX_MESSAGE_SIZE: usize = 1024;
62
63#[derive(Debug)]
64pub struct Listener<T, const BUF_SIZE: usize> {
65 socket: UdpSocket,
66 buffer: [u8; BUF_SIZE],
67 message_type: PhantomData<T>,
68}
69
70impl<T: DeserializeOwned + Send + Sync> Listener<T, MAX_MESSAGE_SIZE> {
71 #[inline]
77 pub async fn new() -> Result<Self> {
78 Self::with_buffer_size().await
79 }
80}
81
82impl<T: DeserializeOwned + Send + Sync, const B: usize> Listener<T, B> {
83 #[inline]
90 pub async fn with_buffer_size() -> Result<Self> {
91 let socket = UdpSocket::bind((Ipv4Addr::UNSPECIFIED, 0)).await?;
92
93 Ok(Self {
94 socket,
95 buffer: [0; B],
96 message_type: PhantomData,
97 })
98 }
99
100 #[inline]
106 pub fn local_addr(&self) -> Result<SocketAddr> {
107 Ok(self.socket.local_addr()?)
108 }
109
110 #[inline]
117 pub async fn recv(&mut self) -> Result<T> {
118 let (size, _) = self.socket.recv_from(&mut self.buffer).await?;
119 let message = ciborium::from_reader(&self.buffer[..size])?;
120
121 Ok(message)
122 }
123}
124
125pub struct Sender<T> {
126 socket: UdpSocket,
127 buffer_pool: Pool<Vec<u8>>,
128 subscribers: Vec<SocketAddr>,
130 message_type: PhantomData<T>,
131}
132
133impl<T> std::fmt::Debug for Sender<T> {
134 #[inline]
135 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136 f.debug_struct("Sender")
137 .field("socket", &self.socket)
138 .field("subscribers", &self.subscribers)
139 .field("message_type", &self.message_type)
140 .field("buffer_pool.len", &self.buffer_pool.len())
141 .finish()
142 }
143}
144
145impl<T: Serialize + Send + Sync> Sender<T> {
146 #[inline]
152 pub async fn new() -> Result<Self> {
153 let socket = UdpSocket::bind((Ipv4Addr::LOCALHOST, 0)).await?;
154
155 Ok(Self {
156 socket,
157 buffer_pool: Pool::new(1, || Vec::with_capacity(MAX_MESSAGE_SIZE)),
158 subscribers: Vec::new(),
159 message_type: PhantomData,
160 })
161 }
162
163 #[inline]
165 pub fn add_subscriber(&mut self, subscriber: SocketAddr) {
166 self.subscribers.push(subscriber);
167 }
168
169 #[inline]
176 pub async fn send(&self, message: impl Into<T> + Send + Sync + Debug) -> Result<()> {
177 log::info!(
178 "Forwarding state change: {message:?} to {} subscribers",
179 self.subscribers.len()
180 );
181
182 let (pool, mut buffer) = self.buffer_pool.pull(Vec::new).detach();
183 buffer.clear();
184
185 ciborium::into_writer(&message.into(), &mut buffer)?;
186
187 for subscriber in &self.subscribers {
188 self.socket.send_to(&buffer, subscriber).await?;
189 }
190
191 pool.attach(buffer);
192
193 Ok(())
194 }
195}
196
197#[cfg(test)]
198mod test {
199 use super::*;
200
201 #[rstest::rstest]
202 #[case(Message::Event(Event::LibraryRescanFinished))]
203 #[case(Message::Event(Event::LibraryAnalysisFinished))]
204 #[case(Message::Event(Event::LibraryReclusterFinished))]
205 #[tokio::test]
206 #[timeout(std::time::Duration::from_secs(1))]
207 async fn test_udp(#[case] message: Message, #[values(1, 2, 3)] num_listeners: usize) {
208 let mut sender = Sender::<Message>::new().await.unwrap();
209
210 let mut listeners = Vec::new();
211
212 for _ in 0..num_listeners {
213 let listener = Listener::new().await.unwrap();
214 sender.add_subscriber(listener.local_addr().unwrap());
215 listeners.push(listener);
216 }
217
218 sender.send(message.clone()).await.unwrap();
219
220 for (i, listener) in listeners.iter_mut().enumerate() {
221 let received_message: Message = listener.recv().await.unwrap();
222 assert_eq!(received_message, message, "Listener {i}");
223 }
224 }
225
226 #[rstest::rstest]
227 #[case(Message::Event(Event::LibraryRescanFinished))]
228 #[case(Message::Event(Event::LibraryAnalysisFinished))]
229 #[case(Message::Event(Event::LibraryReclusterFinished))]
230 #[case(Message::StateChange(StateChange::Muted))]
231 #[case(Message::StateChange(StateChange::Unmuted))]
232 #[case(Message::StateChange(StateChange::VolumeChanged(1. / 3.)))]
233 #[case(Message::StateChange(StateChange::TrackChanged(None)))]
234 #[case(Message::StateChange(StateChange::RepeatModeChanged(RepeatMode::None)))]
235 #[case(Message::StateChange(StateChange::Seeked(Duration::from_secs(3))))]
236 #[case(Message::StateChange(StateChange::StatusChanged(Status::Paused)))]
237 #[case(Message::StateChange(StateChange::StatusChanged(Status::Playing)))]
238 #[case(Message::StateChange(StateChange::StatusChanged(Status::Stopped)))]
239 #[case(Message::StateChange(StateChange::TrackChanged(Some(
240 mecomp_storage::db::schemas::song::Song::generate_id().into()
241 ))))]
242 fn test_message_encoding_length(#[case] message: Message) {
243 let mut buffer = Vec::new();
244 ciborium::into_writer(&message, &mut buffer).unwrap();
245
246 assert!(buffer.len() <= MAX_MESSAGE_SIZE);
247 }
248}