Skip to main content

snapcast_client/connection/
mod.rs

1//! Connection layer — TCP, WebSocket, and WSS implementations.
2
3#[cfg(feature = "websocket")]
4pub mod ws;
5#[cfg(feature = "tls")]
6pub mod wss;
7
8use std::collections::HashMap;
9use std::time::Duration;
10
11use anyhow::{Context, Result};
12use snapcast_proto::MessageType;
13use snapcast_proto::message::base::BaseMessage;
14use snapcast_proto::message::factory::{self, MessagePayload, TypedMessage};
15use snapcast_proto::types::Timeval;
16use tokio::io::{AsyncReadExt, AsyncWriteExt};
17use tokio::net::TcpStream;
18use tokio::sync::oneshot;
19
20/// Read a complete frame (header + payload) from an async reader.
21async fn read_frame<R: AsyncReadExt + Unpin>(reader: &mut R) -> Result<TypedMessage> {
22    // Read 26-byte header
23    let mut header_buf = [0u8; BaseMessage::HEADER_SIZE];
24    reader
25        .read_exact(&mut header_buf)
26        .await
27        .context("reading base message header")?;
28
29    let mut base = BaseMessage::read_from(&mut &header_buf[..])
30        .map_err(|e| anyhow::anyhow!("parsing header: {e}"))?;
31
32    // Stamp received time using steady clock (matching C++ steadytimeofday)
33    base.received = steady_time_of_day();
34
35    // Read payload
36    let mut payload_buf = vec![0u8; base.size as usize];
37    if !payload_buf.is_empty() {
38        reader
39            .read_exact(&mut payload_buf)
40            .await
41            .context("reading payload")?;
42    }
43
44    factory::deserialize(base, &payload_buf).map_err(|e| anyhow::anyhow!("deserializing: {e}"))
45}
46
47/// Write a complete frame (header + payload) to an async writer.
48async fn write_frame<W: AsyncWriteExt + Unpin>(
49    writer: &mut W,
50    base: &mut BaseMessage,
51    payload: &MessagePayload,
52) -> Result<()> {
53    let frame =
54        factory::serialize(base, payload).map_err(|e| anyhow::anyhow!("serializing: {e}"))?;
55    writer.write_all(&frame).await.context("writing frame")?;
56    Ok(())
57}
58
59/// Pending request waiting for a response.
60struct PendingRequest {
61    tx: oneshot::Sender<TypedMessage>,
62}
63
64/// TCP connection to a snapserver.
65pub struct TcpConnection {
66    stream: Option<TcpStream>,
67    host: String,
68    port: u16,
69    pending: HashMap<u16, PendingRequest>,
70    next_id: u16,
71}
72
73impl TcpConnection {
74    /// Create a new connection to the given host and port.
75    pub fn new(host: &str, port: u16) -> Self {
76        Self {
77            stream: None,
78            host: host.to_string(),
79            port,
80            pending: HashMap::new(),
81            next_id: 1,
82        }
83    }
84
85    /// Establish the TCP connection.
86    pub async fn connect(&mut self) -> Result<()> {
87        let addr = format!("{}:{}", self.host, self.port);
88        let stream = TcpStream::connect(&addr)
89            .await
90            .with_context(|| format!("connecting to {addr}"))?;
91        self.stream = Some(stream);
92        self.pending.clear();
93        self.next_id = 1;
94        Ok(())
95    }
96
97    /// Close the connection.
98    pub fn disconnect(&mut self) {
99        self.stream = None;
100        self.pending.clear();
101    }
102
103    fn stream_mut(&mut self) -> Result<&mut TcpStream> {
104        self.stream.as_mut().context("not connected")
105    }
106
107    /// Send a message without waiting for a response.
108    pub async fn send(&mut self, msg_type: MessageType, payload: &MessagePayload) -> Result<()> {
109        let stream = self.stream_mut()?;
110        let mut base = BaseMessage {
111            msg_type,
112            id: 0,
113            refers_to: 0,
114            sent: Timeval::default(),
115            received: Timeval::default(),
116            size: 0,
117        };
118        stamp_sent(&mut base);
119        write_frame(stream, &mut base, payload).await
120    }
121
122    /// Send a request and wait for the response (matched by `refersTo`).
123    pub async fn send_request(
124        &mut self,
125        msg_type: MessageType,
126        payload: &MessagePayload,
127        timeout: Duration,
128    ) -> Result<TypedMessage> {
129        let id = self.next_id;
130        self.next_id = self.next_id.wrapping_add(1);
131
132        let (tx, rx) = oneshot::channel();
133        self.pending.insert(id, PendingRequest { tx });
134
135        let stream = self.stream_mut()?;
136        let mut base = BaseMessage {
137            msg_type,
138            id,
139            refers_to: 0,
140            sent: Timeval::default(),
141            received: Timeval::default(),
142            size: 0,
143        };
144        stamp_sent(&mut base);
145        write_frame(stream, &mut base, payload).await?;
146
147        tokio::time::timeout(timeout, rx)
148            .await
149            .context("request timed out")?
150            .context("response channel closed")
151    }
152
153    /// Receive the next message. If it's a response to a pending request,
154    /// deliver it to the waiting caller and receive again.
155    pub async fn recv(&mut self) -> Result<TypedMessage> {
156        loop {
157            let stream = self.stream_mut()?;
158            let msg = read_frame(stream).await?;
159
160            if msg.base.refers_to != 0
161                && let Some(pending) = self.pending.remove(&msg.base.refers_to)
162            {
163                let _ = pending.tx.send(msg);
164                continue;
165            }
166            return Ok(msg);
167        }
168    }
169}
170
171fn stamp_sent(base: &mut BaseMessage) {
172    let tv = steady_time_of_day();
173    base.sent = tv;
174}
175
176/// Matches the C++ `chronos::steadytimeofday` — monotonic clock time.
177/// On macOS/Linux, `Instant` is based on `CLOCK_MONOTONIC` which counts
178/// seconds since boot, matching the C++ snapserver's clock domain.
179fn steady_time_of_day() -> Timeval {
180    // Instant::now().duration_since(EPOCH) gives time since first call.
181    // We need time since boot. On Unix, Instant uses CLOCK_MONOTONIC
182    // which starts at boot. We can get this via the elapsed time from
183    // a known-early Instant.
184    let usec = monotonic_usec();
185    Timeval {
186        sec: (usec / 1_000_000) as i32,
187        usec: (usec % 1_000_000) as i32,
188    }
189}
190
191/// Microseconds since boot (monotonic clock).
192/// Uses the same clock source as C++ std::chrono::steady_clock.
193#[allow(unsafe_code)] // FFI: mach_continuous_time (macOS), clock_gettime (Linux)
194fn monotonic_usec() -> i64 {
195    #[cfg(target_os = "macos")]
196    {
197        // macOS: C++ steady_clock uses mach_continuous_time, not CLOCK_MONOTONIC.
198        // These differ by ~2s on macOS. We must match the server's clock exactly.
199        unsafe extern "C" {
200            fn mach_continuous_time() -> u64;
201            fn mach_timebase_info(info: *mut MachTimebaseInfo) -> i32;
202        }
203        #[repr(C)]
204        struct MachTimebaseInfo {
205            numer: u32,
206            denom: u32,
207        }
208        static TIMEBASE: std::sync::OnceLock<(u32, u32)> = std::sync::OnceLock::new();
209        let (numer, denom) = *TIMEBASE.get_or_init(|| {
210            let mut info = MachTimebaseInfo { numer: 0, denom: 0 };
211            unsafe {
212                mach_timebase_info(&mut info);
213            }
214            (info.numer, info.denom)
215        });
216        let ticks = unsafe { mach_continuous_time() };
217        let nanos = ticks as i128 * numer as i128 / denom as i128;
218        (nanos / 1_000) as i64
219    }
220    #[cfg(all(unix, not(target_os = "macos")))]
221    {
222        let mut ts = libc::timespec {
223            tv_sec: 0,
224            tv_nsec: 0,
225        };
226        // SAFETY: clock_gettime with CLOCK_MONOTONIC is always safe
227        unsafe {
228            libc::clock_gettime(libc::CLOCK_MONOTONIC, &mut ts);
229        }
230        ts.tv_sec * 1_000_000 + ts.tv_nsec / 1_000
231    }
232    #[cfg(not(unix))]
233    {
234        let now = std::time::SystemTime::now()
235            .duration_since(std::time::UNIX_EPOCH)
236            .unwrap_or_default();
237        now.as_micros() as i64
238    }
239}
240
241/// Current time in microseconds using the steady clock.
242pub fn now_usec() -> i64 {
243    monotonic_usec()
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249    use snapcast_proto::message::time::Time;
250
251    /// Test frame read/write with in-memory buffers (no network needed).
252    #[tokio::test]
253    async fn write_and_read_frame() {
254        let payload = MessagePayload::Time(Time {
255            latency: Timeval { sec: 0, usec: 1234 },
256        });
257        let mut base = BaseMessage {
258            msg_type: MessageType::Time,
259            id: 42,
260            refers_to: 0,
261            sent: Timeval { sec: 1, usec: 0 },
262            received: Timeval::default(),
263            size: 0,
264        };
265
266        // Write to buffer
267        let mut buf = Vec::new();
268        write_frame(&mut buf, &mut base, &payload).await.unwrap();
269
270        // Size should be header + payload
271        assert_eq!(buf.len(), BaseMessage::HEADER_SIZE + Time::SIZE as usize);
272
273        // Read back
274        let mut cursor = std::io::Cursor::new(&buf);
275        let msg = read_frame(&mut cursor).await.unwrap();
276        assert_eq!(msg.base.msg_type, MessageType::Time);
277        assert_eq!(msg.base.id, 42);
278        match msg.payload {
279            MessagePayload::Time(t) => assert_eq!(t.latency.usec, 1234),
280            _ => panic!("expected Time"),
281        }
282    }
283
284    #[tokio::test]
285    async fn write_and_read_error_frame() {
286        use snapcast_proto::message::error::Error;
287
288        let payload = MessagePayload::Error(Error {
289            code: 401,
290            error: "Unauthorized".into(),
291            message: "bad auth".into(),
292        });
293        let mut base = BaseMessage {
294            msg_type: MessageType::Error,
295            id: 0,
296            refers_to: 7,
297            sent: Timeval::default(),
298            received: Timeval::default(),
299            size: 0,
300        };
301
302        let mut buf = Vec::new();
303        write_frame(&mut buf, &mut base, &payload).await.unwrap();
304
305        let mut cursor = std::io::Cursor::new(&buf);
306        let msg = read_frame(&mut cursor).await.unwrap();
307        assert_eq!(msg.base.refers_to, 7);
308        match msg.payload {
309            MessagePayload::Error(e) => {
310                assert_eq!(e.code, 401);
311                assert_eq!(e.error, "Unauthorized");
312            }
313            _ => panic!("expected Error"),
314        }
315    }
316
317    #[tokio::test]
318    async fn write_and_read_multiple_frames() {
319        let frames: Vec<(MessageType, MessagePayload)> = vec![
320            (MessageType::Time, MessagePayload::Time(Time::default())),
321            (
322                MessageType::ClientInfo,
323                MessagePayload::ClientInfo(snapcast_proto::message::client_info::ClientInfo {
324                    volume: 80,
325                    muted: false,
326                }),
327            ),
328        ];
329
330        let mut buf = Vec::new();
331        for (mt, payload) in &frames {
332            let mut base = BaseMessage {
333                msg_type: *mt,
334                id: 0,
335                refers_to: 0,
336                sent: Timeval::default(),
337                received: Timeval::default(),
338                size: 0,
339            };
340            write_frame(&mut buf, &mut base, payload).await.unwrap();
341        }
342
343        // Read both back
344        let mut cursor = std::io::Cursor::new(&buf);
345        let msg1 = read_frame(&mut cursor).await.unwrap();
346        assert_eq!(msg1.base.msg_type, MessageType::Time);
347        let msg2 = read_frame(&mut cursor).await.unwrap();
348        assert_eq!(msg2.base.msg_type, MessageType::ClientInfo);
349    }
350
351    #[test]
352    fn tcp_connection_new() {
353        let conn = TcpConnection::new("localhost", 1704);
354        assert!(conn.stream.is_none());
355        assert_eq!(conn.host, "localhost");
356        assert_eq!(conn.port, 1704);
357    }
358}