snapcast_client/connection/
mod.rs1#[cfg(feature = "websocket")]
9pub mod ws;
10#[cfg(feature = "tls")]
11pub mod wss;
12
13use std::collections::HashMap;
14use std::time::Duration;
15
16use anyhow::{Context, Result};
17use snapcast_proto::MessageType;
18use snapcast_proto::message::base::BaseMessage;
19use snapcast_proto::message::factory::{self, MessagePayload, TypedMessage};
20use snapcast_proto::types::Timeval;
21use tokio::io::{AsyncReadExt, AsyncWriteExt};
22use tokio::net::TcpStream;
23use tokio::sync::oneshot;
24
25async fn read_frame<R: AsyncReadExt + Unpin>(reader: &mut R) -> Result<TypedMessage> {
27 let mut header_buf = [0u8; BaseMessage::HEADER_SIZE];
29 reader
30 .read_exact(&mut header_buf)
31 .await
32 .context("reading base message header")?;
33
34 let mut base = BaseMessage::read_from(&mut &header_buf[..])
35 .map_err(|e| anyhow::anyhow!("parsing header: {e}"))?;
36
37 base.received = steady_time_of_day();
39 ensure_payload_size(base.size)?;
40
41 let mut payload_buf = vec![0u8; base.size as usize];
43 if !payload_buf.is_empty() {
44 reader
45 .read_exact(&mut payload_buf)
46 .await
47 .context("reading payload")?;
48 }
49
50 factory::deserialize(base, &payload_buf).map_err(|e| anyhow::anyhow!("deserializing: {e}"))
51}
52
53pub(crate) fn ensure_payload_size(size: u32) -> Result<()> {
54 anyhow::ensure!(
55 size <= snapcast_proto::DEFAULT_MAX_PAYLOAD_SIZE,
56 "payload too large: {size} bytes"
57 );
58 Ok(())
59}
60
61async fn write_frame<W: AsyncWriteExt + Unpin>(
63 writer: &mut W,
64 base: &mut BaseMessage,
65 payload: &MessagePayload,
66) -> Result<()> {
67 let frame =
68 factory::serialize(base, payload).map_err(|e| anyhow::anyhow!("serializing: {e}"))?;
69 writer.write_all(&frame).await.context("writing frame")?;
70 Ok(())
71}
72
73struct PendingRequest {
75 tx: oneshot::Sender<TypedMessage>,
76}
77
78pub struct TcpConnection {
80 stream: Option<TcpStream>,
81 host: String,
82 port: u16,
83 pending: HashMap<u16, PendingRequest>,
84 next_id: u16,
85}
86
87pub enum SnapConnection {
89 Tcp(TcpConnection),
91 #[cfg(feature = "websocket")]
92 Ws(ws::WsConnection),
94 #[cfg(feature = "tls")]
95 Wss(wss::WssConnection),
97}
98
99impl SnapConnection {
100 pub fn new(scheme: &str, host: &str, port: u16) -> Result<Self> {
102 match scheme {
103 snapcast_proto::SCHEME_TCP => Ok(Self::Tcp(TcpConnection::new(host, port))),
104 snapcast_proto::SCHEME_WS | snapcast_proto::SCHEME_WSS => anyhow::bail!(
105 "websocket audio transport is not supported yet; use tcp:// for Snapcast audio"
106 ),
107 _ => anyhow::bail!("unsupported scheme: {scheme}"),
108 }
109 }
110
111 pub async fn connect(&mut self) -> Result<()> {
113 match self {
114 Self::Tcp(c) => c.connect().await,
115 #[cfg(feature = "websocket")]
116 Self::Ws(c) => c.connect().await,
117 #[cfg(feature = "tls")]
118 Self::Wss(c) => c.connect().await,
119 }
120 }
121
122 pub fn disconnect(&mut self) {
124 match self {
125 Self::Tcp(c) => c.disconnect(),
126 #[cfg(feature = "websocket")]
127 Self::Ws(c) => c.disconnect(),
128 #[cfg(feature = "tls")]
129 Self::Wss(c) => c.disconnect(),
130 }
131 }
132
133 pub async fn send(&mut self, msg_type: MessageType, payload: &MessagePayload) -> Result<()> {
135 match self {
136 Self::Tcp(c) => c.send(msg_type, payload).await,
137 #[cfg(feature = "websocket")]
138 Self::Ws(c) => c.send(msg_type, payload).await,
139 #[cfg(feature = "tls")]
140 Self::Wss(c) => c.send(msg_type, payload).await,
141 }
142 }
143
144 pub async fn recv(&mut self) -> Result<TypedMessage> {
146 match self {
147 Self::Tcp(c) => c.recv().await,
148 #[cfg(feature = "websocket")]
149 Self::Ws(c) => c.recv().await,
150 #[cfg(feature = "tls")]
151 Self::Wss(c) => c.recv().await,
152 }
153 }
154}
155
156impl TcpConnection {
157 pub fn new(host: &str, port: u16) -> Self {
159 Self {
160 stream: None,
161 host: host.to_string(),
162 port,
163 pending: HashMap::new(),
164 next_id: 1,
165 }
166 }
167
168 pub async fn connect(&mut self) -> Result<()> {
170 let addr = format!("{}:{}", self.host, self.port);
171 let stream = TcpStream::connect(&addr)
172 .await
173 .with_context(|| format!("connecting to {addr}"))?;
174 self.stream = Some(stream);
175 self.pending.clear();
176 self.next_id = 1;
177 Ok(())
178 }
179
180 pub fn disconnect(&mut self) {
182 self.stream = None;
183 self.pending.clear();
184 }
185
186 fn stream_mut(&mut self) -> Result<&mut TcpStream> {
187 self.stream.as_mut().context("not connected")
188 }
189
190 pub async fn send(&mut self, msg_type: MessageType, payload: &MessagePayload) -> Result<()> {
192 let stream = self.stream_mut()?;
193 let mut base = BaseMessage {
194 msg_type,
195 id: 0,
196 refers_to: 0,
197 sent: Timeval::default(),
198 received: Timeval::default(),
199 size: 0,
200 };
201 stamp_sent(&mut base);
202 write_frame(stream, &mut base, payload).await
203 }
204
205 pub async fn send_request(
207 &mut self,
208 msg_type: MessageType,
209 payload: &MessagePayload,
210 timeout: Duration,
211 ) -> Result<TypedMessage> {
212 let id = self.next_id;
213 self.next_id = self.next_id.wrapping_add(1);
214
215 let (tx, rx) = oneshot::channel();
216 self.pending.insert(id, PendingRequest { tx });
217
218 let stream = self.stream_mut()?;
219 let mut base = BaseMessage {
220 msg_type,
221 id,
222 refers_to: 0,
223 sent: Timeval::default(),
224 received: Timeval::default(),
225 size: 0,
226 };
227 stamp_sent(&mut base);
228 write_frame(stream, &mut base, payload).await?;
229
230 tokio::time::timeout(timeout, rx)
231 .await
232 .context("request timed out")?
233 .context("response channel closed")
234 }
235
236 pub async fn recv(&mut self) -> Result<TypedMessage> {
239 loop {
240 let stream = self.stream_mut()?;
241 let msg = read_frame(stream).await?;
242
243 if msg.base.refers_to != 0
244 && let Some(pending) = self.pending.remove(&msg.base.refers_to)
245 {
246 let _ = pending.tx.send(msg);
247 continue;
248 }
249 return Ok(msg);
250 }
251 }
252}
253
254pub(super) fn stamp_sent(base: &mut BaseMessage) {
255 let tv = steady_time_of_day();
256 base.sent = tv;
257}
258
259pub(super) fn steady_time_of_day() -> Timeval {
263 let usec = monotonic_usec();
268 Timeval {
269 sec: (usec / 1_000_000) as i32,
270 usec: (usec % 1_000_000) as i32,
271 }
272}
273
274#[allow(unsafe_code)] fn monotonic_usec() -> i64 {
278 #[cfg(target_os = "macos")]
279 {
280 unsafe extern "C" {
283 fn mach_continuous_time() -> u64;
284 fn mach_timebase_info(info: *mut MachTimebaseInfo) -> i32;
285 }
286 #[repr(C)]
287 struct MachTimebaseInfo {
288 numer: u32,
289 denom: u32,
290 }
291 static TIMEBASE: std::sync::OnceLock<(u32, u32)> = std::sync::OnceLock::new();
292 let (numer, denom) = *TIMEBASE.get_or_init(|| {
293 let mut info = MachTimebaseInfo { numer: 0, denom: 0 };
294 unsafe {
295 mach_timebase_info(&mut info);
296 }
297 (info.numer, info.denom)
298 });
299 let ticks = unsafe { mach_continuous_time() };
300 let nanos = ticks as i128 * numer as i128 / denom as i128;
301 (nanos / 1_000) as i64
302 }
303 #[cfg(all(unix, not(target_os = "macos")))]
304 {
305 let mut ts = libc::timespec {
306 tv_sec: 0,
307 tv_nsec: 0,
308 };
309 unsafe {
311 libc::clock_gettime(libc::CLOCK_MONOTONIC, &mut ts);
312 }
313 ts.tv_sec * 1_000_000 + ts.tv_nsec / 1_000
314 }
315 #[cfg(not(unix))]
316 {
317 let now = std::time::SystemTime::now()
318 .duration_since(std::time::UNIX_EPOCH)
319 .unwrap_or_default();
320 now.as_micros() as i64
321 }
322}
323
324pub fn now_usec() -> i64 {
326 monotonic_usec()
327}
328
329#[cfg(test)]
330mod tests {
331 use super::*;
332 use snapcast_proto::message::time::Time;
333
334 #[tokio::test]
336 async fn write_and_read_frame() {
337 let payload = MessagePayload::Time(Time {
338 latency: Timeval { sec: 0, usec: 1234 },
339 });
340 let mut base = BaseMessage {
341 msg_type: MessageType::Time,
342 id: 42,
343 refers_to: 0,
344 sent: Timeval { sec: 1, usec: 0 },
345 received: Timeval::default(),
346 size: 0,
347 };
348
349 let mut buf = Vec::new();
351 write_frame(&mut buf, &mut base, &payload).await.unwrap();
352
353 assert_eq!(buf.len(), BaseMessage::HEADER_SIZE + Time::SIZE as usize);
355
356 let mut cursor = std::io::Cursor::new(&buf);
358 let msg = read_frame(&mut cursor).await.unwrap();
359 assert_eq!(msg.base.msg_type, MessageType::Time);
360 assert_eq!(msg.base.id, 42);
361 match msg.payload {
362 MessagePayload::Time(t) => assert_eq!(t.latency.usec, 1234),
363 _ => panic!("expected Time"),
364 }
365 }
366
367 #[tokio::test]
368 async fn write_and_read_error_frame() {
369 use snapcast_proto::message::error::Error;
370
371 let payload = MessagePayload::Error(Error {
372 code: 401,
373 error: "Unauthorized".into(),
374 message: "bad auth".into(),
375 });
376 let mut base = BaseMessage {
377 msg_type: MessageType::Error,
378 id: 0,
379 refers_to: 7,
380 sent: Timeval::default(),
381 received: Timeval::default(),
382 size: 0,
383 };
384
385 let mut buf = Vec::new();
386 write_frame(&mut buf, &mut base, &payload).await.unwrap();
387
388 let mut cursor = std::io::Cursor::new(&buf);
389 let msg = read_frame(&mut cursor).await.unwrap();
390 assert_eq!(msg.base.refers_to, 7);
391 match msg.payload {
392 MessagePayload::Error(e) => {
393 assert_eq!(e.code, 401);
394 assert_eq!(e.error, "Unauthorized");
395 }
396 _ => panic!("expected Error"),
397 }
398 }
399
400 #[tokio::test]
401 async fn write_and_read_multiple_frames() {
402 let frames: Vec<(MessageType, MessagePayload)> = vec![
403 (MessageType::Time, MessagePayload::Time(Time::default())),
404 (
405 MessageType::ClientInfo,
406 MessagePayload::ClientInfo(snapcast_proto::message::client_info::ClientInfo {
407 volume: 80,
408 muted: false,
409 }),
410 ),
411 ];
412
413 let mut buf = Vec::new();
414 for (mt, payload) in &frames {
415 let mut base = BaseMessage {
416 msg_type: *mt,
417 id: 0,
418 refers_to: 0,
419 sent: Timeval::default(),
420 received: Timeval::default(),
421 size: 0,
422 };
423 write_frame(&mut buf, &mut base, payload).await.unwrap();
424 }
425
426 let mut cursor = std::io::Cursor::new(&buf);
428 let msg1 = read_frame(&mut cursor).await.unwrap();
429 assert_eq!(msg1.base.msg_type, MessageType::Time);
430 let msg2 = read_frame(&mut cursor).await.unwrap();
431 assert_eq!(msg2.base.msg_type, MessageType::ClientInfo);
432 }
433
434 #[test]
435 fn tcp_connection_new() {
436 let conn = TcpConnection::new("localhost", 1704);
437 assert!(conn.stream.is_none());
438 assert_eq!(conn.host, "localhost");
439 assert_eq!(conn.port, 1704);
440 }
441
442 #[test]
443 fn rejects_oversized_payload() {
444 let too_large = snapcast_proto::DEFAULT_MAX_PAYLOAD_SIZE + 1;
445 assert!(ensure_payload_size(too_large).is_err());
446 }
447
448 #[test]
449 fn rejects_websocket_audio_scheme() {
450 assert!(SnapConnection::new("ws", "localhost", 1780).is_err());
451 assert!(SnapConnection::new("wss", "localhost", 1788).is_err());
452 }
453}