1use std::{
4 fmt::Debug,
5 marker::PhantomData,
6 net::{Ipv4Addr, SocketAddr},
7 time::Duration,
8};
9
10use mecomp_storage::db::schemas::Thing;
11use object_pool::Pool;
12use serde::{de::DeserializeOwned, Deserialize, Serialize};
13use tokio::net::UdpSocket;
14
15use crate::{
16 errors::UdpError,
17 state::{RepeatMode, Status},
18};
19
20pub type Result<T> = std::result::Result<T, UdpError>;
21
22#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
23pub enum Event {
24 LibraryRescanFinished,
25 LibraryAnalysisFinished,
26 LibraryReclusterFinished,
27 DaemonShutdown,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
31pub enum StateChange {
32 Muted,
34 Unmuted,
36 VolumeChanged(f32),
38 TrackChanged(Option<Thing>),
40 RepeatModeChanged(RepeatMode),
42 Seeked(Duration),
44 StatusChanged(Status),
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
49pub enum Message {
50 Event(Event),
51 StateChange(StateChange),
52}
53
54impl From<Event> for Message {
55 fn from(val: Event) -> Self {
56 Self::Event(val)
57 }
58}
59
60const MAX_MESSAGE_SIZE: usize = 1024;
61
62#[derive(Debug)]
63pub struct Listener<T, const BUF_SIZE: usize> {
64 socket: UdpSocket,
65 buffer: [u8; BUF_SIZE],
66 message_type: PhantomData<T>,
67}
68
69impl<T: DeserializeOwned + Send + Sync> Listener<T, MAX_MESSAGE_SIZE> {
70 pub async fn new() -> Result<Self> {
76 Self::with_buffer_size().await
77 }
78}
79
80impl<T: DeserializeOwned + Send + Sync, const B: usize> Listener<T, B> {
81 pub async fn with_buffer_size() -> Result<Self> {
88 let socket = UdpSocket::bind((Ipv4Addr::UNSPECIFIED, 0)).await?;
89
90 Ok(Self {
91 socket,
92 buffer: [0; B],
93 message_type: PhantomData,
94 })
95 }
96
97 pub fn local_addr(&self) -> Result<SocketAddr> {
103 Ok(self.socket.local_addr()?)
104 }
105
106 pub async fn recv(&mut self) -> Result<T> {
113 let (size, _) = self.socket.recv_from(&mut self.buffer).await?;
114 let message = ciborium::from_reader(&self.buffer[..size])?;
115
116 Ok(message)
117 }
118}
119
120pub struct Sender<T> {
121 socket: UdpSocket,
122 buffer_pool: Pool<Vec<u8>>,
123 subscribers: Vec<SocketAddr>,
125 message_type: PhantomData<T>,
126}
127
128impl<T> std::fmt::Debug for Sender<T> {
129 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130 f.debug_struct("Sender")
131 .field("socket", &self.socket)
132 .field("subscribers", &self.subscribers)
133 .field("message_type", &self.message_type)
134 .field("buffer_pool.len", &self.buffer_pool.len())
135 .finish()
136 }
137}
138
139impl<T: Serialize + Send + Sync> Sender<T> {
140 pub async fn new() -> Result<Self> {
146 let socket = UdpSocket::bind((Ipv4Addr::LOCALHOST, 0)).await?;
147
148 Ok(Self {
149 socket,
150 buffer_pool: Pool::new(1, || Vec::with_capacity(MAX_MESSAGE_SIZE)),
151 subscribers: Vec::new(),
152 message_type: PhantomData,
153 })
154 }
155
156 pub fn add_subscriber(&mut self, subscriber: SocketAddr) {
158 self.subscribers.push(subscriber);
159 }
160
161 pub async fn send(&self, message: impl Into<T> + Send + Sync + Debug) -> Result<()> {
168 log::info!(
169 "Forwarding state change: {message:?} to {} subscribers",
170 self.subscribers.len()
171 );
172
173 let (pool, mut buffer) = self.buffer_pool.pull(Vec::new).detach();
174 buffer.clear();
175
176 ciborium::into_writer(&message.into(), &mut buffer)?;
177
178 for subscriber in &self.subscribers {
179 self.socket.send_to(&buffer, subscriber).await?;
180 }
181
182 pool.attach(buffer);
183
184 Ok(())
185 }
186}
187
188#[cfg(test)]
189mod test {
190 use super::*;
191
192 #[rstest::rstest]
193 #[case(Message::Event(Event::LibraryRescanFinished))]
194 #[case(Message::Event(Event::LibraryAnalysisFinished))]
195 #[case(Message::Event(Event::LibraryReclusterFinished))]
196 #[tokio::test]
197 #[timeout(std::time::Duration::from_secs(1))]
198 async fn test_udp(#[case] message: Message, #[values(1, 2, 3)] num_listeners: usize) {
199 let mut sender = Sender::<Message>::new().await.unwrap();
200
201 let mut listeners = Vec::new();
202
203 for _ in 0..num_listeners {
204 let listener = Listener::new().await.unwrap();
205 sender.add_subscriber(listener.local_addr().unwrap());
206 listeners.push(listener);
207 }
208
209 sender.send(message.clone()).await.unwrap();
210
211 for (i, listener) in listeners.iter_mut().enumerate() {
212 let received_message: Message = listener.recv().await.unwrap();
213 assert_eq!(received_message, message, "Listener {i}");
214 }
215 }
216
217 #[rstest::rstest]
218 #[case(Message::Event(Event::LibraryRescanFinished))]
219 #[case(Message::Event(Event::LibraryAnalysisFinished))]
220 #[case(Message::Event(Event::LibraryReclusterFinished))]
221 #[case(Message::StateChange(StateChange::Muted))]
222 #[case(Message::StateChange(StateChange::Unmuted))]
223 #[case(Message::StateChange(StateChange::VolumeChanged(1. / 3.)))]
224 #[case(Message::StateChange(StateChange::TrackChanged(None)))]
225 #[case(Message::StateChange(StateChange::RepeatModeChanged(RepeatMode::None)))]
226 #[case(Message::StateChange(StateChange::Seeked(Duration::from_secs(3))))]
227 #[case(Message::StateChange(StateChange::StatusChanged(Status::Paused)))]
228 #[case(Message::StateChange(StateChange::StatusChanged(Status::Playing)))]
229 #[case(Message::StateChange(StateChange::StatusChanged(Status::Stopped)))]
230 #[case(Message::StateChange(StateChange::TrackChanged(Some(
231 mecomp_storage::db::schemas::song::Song::generate_id().into()
232 ))))]
233 fn test_message_encoding_length(#[case] message: Message) {
234 let mut buffer = Vec::new();
235 ciborium::into_writer(&message, &mut buffer).unwrap();
236
237 assert!(buffer.len() <= MAX_MESSAGE_SIZE);
238 }
239}