1use std::collections::HashMap;
4use std::sync::Arc;
5
6use anyhow::{Context, Result};
7use snapcast_proto::MessageType;
8use snapcast_proto::message::base::BaseMessage;
9use snapcast_proto::message::codec_header::CodecHeader;
10use snapcast_proto::message::factory::{self, MessagePayload, TypedMessage};
11use snapcast_proto::message::server_settings::ServerSettings;
12use snapcast_proto::message::time::Time;
13use snapcast_proto::message::wire_chunk::WireChunk;
14use snapcast_proto::types::Timeval;
15use tokio::io::{AsyncReadExt, AsyncWriteExt};
16use tokio::net::{TcpListener, TcpStream};
17use tokio::sync::{Mutex, broadcast, mpsc, watch};
18
19use crate::ClientSettingsUpdate;
20use crate::ServerEvent;
21use crate::WireChunkData;
22use crate::time::now_usec;
23
24#[derive(Debug, Clone, PartialEq)]
28pub struct SessionRouting {
29 pub stream_id: String,
31 pub client_muted: bool,
33 pub group_muted: bool,
35}
36
37#[derive(Debug, Clone)]
39pub struct StreamCodecInfo {
40 pub codec: String,
42 pub header: Vec<u8>,
44}
45
46struct SessionContext {
66 buffer_ms: i32,
67 auth: Option<Arc<dyn crate::auth::AuthValidator>>,
68 send_audio_to_muted: bool,
69 settings_senders: Mutex<HashMap<String, mpsc::Sender<ClientSettingsUpdate>>>,
70 #[cfg(feature = "custom-protocol")]
71 custom_senders: Mutex<HashMap<String, mpsc::Sender<CustomOutbound>>>,
72 routing_senders: Mutex<HashMap<String, watch::Sender<SessionRouting>>>,
73 codec_headers: Mutex<HashMap<String, StreamCodecInfo>>,
74 shared_state: Arc<tokio::sync::Mutex<crate::state::ServerState>>,
75 default_stream: String,
76}
77
78impl SessionContext {
79 fn build_routing(state: &crate::state::ServerState, client_id: &str) -> Option<SessionRouting> {
82 let group = state
83 .groups
84 .iter()
85 .find(|g| g.clients.contains(&client_id.to_string()))?;
86 let client_muted = state
87 .clients
88 .get(client_id)
89 .map(|c| c.config.volume.muted)
90 .unwrap_or(false);
91 Some(SessionRouting {
92 stream_id: group.stream_id.clone(),
93 client_muted,
94 group_muted: group.muted,
95 })
96 }
97
98 async fn push_routing(&self, client_id: &str) {
100 let s = self.shared_state.lock().await;
101 if let Some(routing) = Self::build_routing(&s, client_id) {
102 let senders = self.routing_senders.lock().await;
103 if let Some(tx) = senders.get(client_id) {
104 let _ = tx.send(routing);
105 }
106 }
107 }
108
109 async fn push_routing_for_group(&self, group_id: &str) {
111 let s = self.shared_state.lock().await;
112 let senders = self.routing_senders.lock().await;
113 let Some(group) = s.groups.iter().find(|g| g.id == group_id) else {
114 return;
115 };
116 for client_id in &group.clients {
117 if let Some(routing) = Self::build_routing(&s, client_id)
118 && let Some(tx) = senders.get(client_id)
119 {
120 let _ = tx.send(routing);
121 }
122 }
123 }
124
125 async fn push_routing_all(&self) {
127 let s = self.shared_state.lock().await;
128 let senders = self.routing_senders.lock().await;
129 for group in &s.groups {
130 for client_id in &group.clients {
131 if let Some(routing) = Self::build_routing(&s, client_id)
132 && let Some(tx) = senders.get(client_id)
133 {
134 let _ = tx.send(routing);
135 }
136 }
137 }
138 }
139
140 async fn codec_header_for(&self, stream_id: &str) -> Option<StreamCodecInfo> {
142 self.codec_headers.lock().await.get(stream_id).cloned()
143 }
144}
145
146pub struct SessionServer {
150 port: u16,
151 ctx: Arc<SessionContext>,
152}
153
154#[cfg(feature = "custom-protocol")]
156#[derive(Debug, Clone)]
157pub struct CustomOutbound {
158 pub type_id: u16,
160 pub payload: Vec<u8>,
162}
163
164impl SessionServer {
165 pub fn new(
167 port: u16,
168 buffer_ms: i32,
169 auth: Option<Arc<dyn crate::auth::AuthValidator>>,
170 shared_state: Arc<tokio::sync::Mutex<crate::state::ServerState>>,
171 default_stream: String,
172 send_audio_to_muted: bool,
173 ) -> Self {
174 Self {
175 port,
176 ctx: Arc::new(SessionContext {
177 buffer_ms,
178 auth,
179 send_audio_to_muted,
180 settings_senders: Mutex::new(HashMap::new()),
181 #[cfg(feature = "custom-protocol")]
182 custom_senders: Mutex::new(HashMap::new()),
183 routing_senders: Mutex::new(HashMap::new()),
184 codec_headers: Mutex::new(HashMap::new()),
185 shared_state,
186 default_stream,
187 }),
188 }
189 }
190
191 pub async fn register_stream_codec(&self, stream_id: &str, codec: &str, header: &[u8]) {
193 self.ctx.codec_headers.lock().await.insert(
194 stream_id.to_string(),
195 StreamCodecInfo {
196 codec: codec.to_string(),
197 header: header.to_vec(),
198 },
199 );
200 }
201
202 pub async fn push_settings(&self, update: ClientSettingsUpdate) {
204 let senders = self.ctx.settings_senders.lock().await;
205 if let Some(tx) = senders.get(&update.client_id) {
206 let _ = tx.send(update).await;
207 }
208 }
209
210 pub async fn update_routing_for_client(&self, client_id: &str) {
212 self.ctx.push_routing(client_id).await;
213 }
214
215 pub async fn update_routing_for_group(&self, group_id: &str) {
217 self.ctx.push_routing_for_group(group_id).await;
218 }
219
220 pub async fn update_routing_all(&self) {
222 self.ctx.push_routing_all().await;
223 }
224
225 pub async fn run(
227 &self,
228 chunk_rx: broadcast::Sender<WireChunkData>,
229 event_tx: mpsc::Sender<ServerEvent>,
230 ) -> Result<()> {
231 let listener = TcpListener::bind(format!("0.0.0.0:{}", self.port)).await?;
232 tracing::info!(port = self.port, "Stream server listening");
233
234 loop {
235 let (stream, peer) = listener.accept().await?;
236 stream.set_nodelay(true).ok();
237 let ka = socket2::TcpKeepalive::new().with_time(std::time::Duration::from_secs(10));
238 let sock = socket2::SockRef::from(&stream);
239 sock.set_tcp_keepalive(&ka).ok();
240 tracing::info!(%peer, "Client connecting");
241
242 let chunk_sub = chunk_rx.subscribe();
243 let ctx = Arc::clone(&self.ctx);
244 let event_tx = event_tx.clone();
245
246 tokio::spawn(async move {
247 let result = handle_client(stream, chunk_sub, &ctx, event_tx).await;
248 if let Err(e) = result {
249 tracing::debug!(%peer, error = %e, "Client session ended");
250 }
251 });
252 }
253 }
254
255 #[cfg(feature = "custom-protocol")]
257 pub async fn send_custom(&self, client_id: &str, type_id: u16, payload: Vec<u8>) {
258 let senders = self.ctx.custom_senders.lock().await;
259 if let Some(tx) = senders.get(client_id) {
260 let _ = tx.send(CustomOutbound { type_id, payload }).await;
261 }
262 }
263}
264
265async fn handle_client(
268 mut stream: TcpStream,
269 chunk_rx: broadcast::Receiver<WireChunkData>,
270 ctx: &SessionContext,
271 event_tx: mpsc::Sender<ServerEvent>,
272) -> Result<()> {
273 let hello_msg = read_frame_from(&mut stream).await?;
274 let hello_id = hello_msg.base.id;
275 let hello = match hello_msg.payload {
276 MessagePayload::Hello(h) => h,
277 _ => anyhow::bail!("expected Hello, got {:?}", hello_msg.base.msg_type),
278 };
279
280 let client_id = hello.id.clone();
281 tracing::info!(id = %client_id, name = %hello.host_name, mac = %hello.mac, "Client hello");
282
283 if let Some(validator) = &ctx.auth {
284 validate_auth(validator.as_ref(), &hello, &mut stream, &client_id).await?;
285 }
286
287 let (settings_tx, settings_rx) = mpsc::channel(16);
289 #[cfg(feature = "custom-protocol")]
290 let (custom_tx, custom_rx) = mpsc::channel(64);
291
292 ctx.settings_senders
293 .lock()
294 .await
295 .insert(client_id.clone(), settings_tx);
296 #[cfg(feature = "custom-protocol")]
297 ctx.custom_senders
298 .lock()
299 .await
300 .insert(client_id.clone(), custom_tx);
301
302 let initial_stream_id;
304 let initial_routing;
305 let client_settings;
306 {
307 let mut s = ctx.shared_state.lock().await;
308 let c = s.get_or_create_client(&client_id, &hello.host_name, &hello.mac);
309 c.connected = true;
310 client_settings = ServerSettings {
311 buffer_ms: ctx.buffer_ms,
312 latency: c.config.latency,
313 volume: c.config.volume.percent,
314 muted: c.config.volume.muted,
315 };
316 s.group_for_client(&client_id, &ctx.default_stream);
317
318 initial_routing =
319 SessionContext::build_routing(&s, &client_id).unwrap_or_else(|| SessionRouting {
320 stream_id: ctx.default_stream.clone(),
321 client_muted: false,
322 group_muted: false,
323 });
324 initial_stream_id = initial_routing.stream_id.clone();
325 }
326
327 let (routing_tx, routing_rx) = watch::channel(initial_routing);
328 ctx.routing_senders
329 .lock()
330 .await
331 .insert(client_id.clone(), routing_tx);
332
333 let _ = event_tx
334 .send(ServerEvent::ClientConnected {
335 id: client_id.clone(),
336 hello: hello.clone(),
337 })
338 .await;
339
340 let ss_frame = serialize_msg(
342 MessageType::ServerSettings,
343 &MessagePayload::ServerSettings(client_settings),
344 hello_id,
345 )?;
346 stream
347 .write_all(&ss_frame)
348 .await
349 .context("write server settings")?;
350
351 match ctx.codec_header_for(&initial_stream_id).await {
353 Some(info) => {
354 send_msg(
355 &mut stream,
356 MessageType::CodecHeader,
357 &MessagePayload::CodecHeader(CodecHeader {
358 codec: info.codec,
359 payload: info.header,
360 }),
361 )
362 .await?;
363 }
364 None => {
365 tracing::warn!(stream = %initial_stream_id, client = %client_id, "No codec header registered for stream");
366 }
367 }
368
369 let result = session_loop(
371 &mut stream,
372 chunk_rx,
373 settings_rx,
374 routing_rx,
375 #[cfg(feature = "custom-protocol")]
376 custom_rx,
377 event_tx.clone(),
378 client_id.clone(),
379 ctx,
380 )
381 .await;
382
383 ctx.settings_senders.lock().await.remove(&client_id);
385 ctx.routing_senders.lock().await.remove(&client_id);
386 #[cfg(feature = "custom-protocol")]
387 ctx.custom_senders.lock().await.remove(&client_id);
388 {
389 let mut s = ctx.shared_state.lock().await;
390 if let Some(c) = s.clients.get_mut(&client_id) {
391 c.connected = false;
392 }
393 }
394 let _ = event_tx
395 .send(ServerEvent::ClientDisconnected { id: client_id })
396 .await;
397
398 result
399}
400
401#[allow(clippy::too_many_arguments)]
409async fn session_loop(
410 stream: &mut TcpStream,
411 mut chunk_rx: broadcast::Receiver<WireChunkData>,
412 mut settings_rx: mpsc::Receiver<ClientSettingsUpdate>,
413 mut routing_rx: watch::Receiver<SessionRouting>,
414 #[cfg(feature = "custom-protocol")] mut custom_rx: mpsc::Receiver<CustomOutbound>,
415 event_tx: mpsc::Sender<ServerEvent>,
416 client_id: String,
417 ctx: &SessionContext,
418) -> Result<()> {
419 let (mut reader, mut writer) = stream.split();
420 let mut routing = routing_rx.borrow().clone();
421
422 loop {
423 #[cfg(feature = "custom-protocol")]
428 while let Ok(msg) = custom_rx.try_recv() {
429 let frame = serialize_msg(
430 MessageType::Custom(msg.type_id),
431 &MessagePayload::Custom(msg.payload),
432 0,
433 )?;
434 writer.write_all(&frame).await.context("write custom")?;
435 }
436
437 tokio::select! {
438 chunk = chunk_rx.recv() => {
439 let chunk = match chunk {
440 Ok(c) => c,
441 Err(broadcast::error::RecvError::Lagged(n)) => {
442 tracing::warn!(skipped = n, "Broadcast lagged");
443 continue;
444 }
445 Err(broadcast::error::RecvError::Closed) => {
446 tracing::warn!("Broadcast closed");
447 anyhow::bail!("broadcast closed");
448 }
449 };
450 if !should_send_chunk(&chunk, &routing, ctx.send_audio_to_muted) {
451 continue;
452 }
453 write_chunk(&mut writer, chunk).await?;
454 }
455 Ok(()) = routing_rx.changed() => {
456 let new = routing_rx.borrow().clone();
457 if new.stream_id != routing.stream_id {
458 tracing::debug!(old = %routing.stream_id, new = %new.stream_id, "Stream switch");
459 if let Some(info) = ctx.codec_header_for(&new.stream_id).await {
460 let frame = serialize_msg(
461 MessageType::CodecHeader,
462 &MessagePayload::CodecHeader(CodecHeader {
463 codec: info.codec,
464 payload: info.header,
465 }),
466 0,
467 )?;
468 writer.write_all(&frame).await.context("write codec header")?;
469 }
470 }
471 routing = new;
472 }
473 msg = read_frame_from(&mut reader) => {
474 let msg = msg?;
475 match msg.payload {
476 MessagePayload::Time(_t) => {
477 let latency = msg.base.received - msg.base.sent;
479 let frame = serialize_msg(
480 MessageType::Time,
481 &MessagePayload::Time(Time { latency }),
482 msg.base.id,
483 )?;
484 writer.write_all(&frame).await.context("write time")?;
485 }
486 MessagePayload::ClientInfo(info) => {
487 {
488 let mut s = ctx.shared_state.lock().await;
489 if let Some(c) = s.clients.get_mut(&client_id) {
490 c.config.volume.percent = info.volume;
491 c.config.volume.muted = info.muted;
492 }
493 }
494 let _ = event_tx.send(ServerEvent::ClientVolumeChanged {
495 client_id: client_id.clone(),
496 volume: info.volume,
497 muted: info.muted,
498 }).await;
499 }
500 #[cfg(feature = "custom-protocol")]
501 MessagePayload::Custom(payload) => {
502 if let MessageType::Custom(type_id) = msg.base.msg_type {
503 let _ = event_tx.send(ServerEvent::CustomMessage {
504 client_id: client_id.clone(),
505 message: snapcast_proto::CustomMessage::new(type_id, payload),
506 }).await;
507 }
508 }
509 _ => {}
510 }
511 }
512 update = settings_rx.recv() => {
513 let Some(update) = update else { continue };
514 write_settings(&mut writer, update).await?;
515 }
516 }
517 }
518}
519
520#[inline]
524fn should_send_chunk(
525 chunk: &WireChunkData,
526 routing: &SessionRouting,
527 send_audio_to_muted: bool,
528) -> bool {
529 if chunk.stream_id != routing.stream_id {
530 return false;
531 }
532 if !send_audio_to_muted && (routing.client_muted || routing.group_muted) {
533 return false;
534 }
535 true
536}
537
538async fn write_chunk<W: AsyncWriteExt + Unpin>(writer: &mut W, chunk: WireChunkData) -> Result<()> {
539 let wc = WireChunk {
540 timestamp: Timeval::from_usec(chunk.timestamp_usec),
541 payload: chunk.data,
542 };
543 let frame = serialize_msg(MessageType::WireChunk, &MessagePayload::WireChunk(wc), 0)?;
544 writer.write_all(&frame).await.context("write chunk")
545}
546
547async fn write_settings<W: AsyncWriteExt + Unpin>(
548 writer: &mut W,
549 update: ClientSettingsUpdate,
550) -> Result<()> {
551 let ss = ServerSettings {
552 buffer_ms: update.buffer_ms,
553 latency: update.latency,
554 volume: update.volume,
555 muted: update.muted,
556 };
557 let frame = serialize_msg(
558 MessageType::ServerSettings,
559 &MessagePayload::ServerSettings(ss),
560 0,
561 )?;
562 writer.write_all(&frame).await.context("write settings")?;
563 tracing::debug!(
564 volume = update.volume,
565 latency = update.latency,
566 "Pushed settings"
567 );
568 Ok(())
569}
570
571async fn validate_auth(
572 validator: &dyn crate::auth::AuthValidator,
573 hello: &snapcast_proto::message::hello::Hello,
574 stream: &mut TcpStream,
575 client_id: &str,
576) -> Result<()> {
577 let auth_result = match &hello.auth {
578 Some(a) => validator.validate(&a.scheme, &a.param),
579 None => Err(crate::auth::AuthError::Unauthorized(
580 "Authentication required".into(),
581 )),
582 };
583 match auth_result {
584 Ok(result) => {
585 if !result
586 .permissions
587 .iter()
588 .any(|p| p == crate::auth::PERM_STREAMING)
589 {
590 let err = snapcast_proto::message::error::Error {
591 code: 403,
592 message: "Forbidden".into(),
593 error: "Permission 'Streaming' missing".into(),
594 };
595 send_msg(stream, MessageType::Error, &MessagePayload::Error(err)).await?;
596 anyhow::bail!("Client {client_id}: missing Streaming permission");
597 }
598 tracing::info!(id = %client_id, user = %result.username, "Authenticated");
599 Ok(())
600 }
601 Err(e) => {
602 let err = snapcast_proto::message::error::Error {
603 code: e.code() as u32,
604 message: e.message().to_string(),
605 error: e.message().to_string(),
606 };
607 send_msg(stream, MessageType::Error, &MessagePayload::Error(err)).await?;
608 anyhow::bail!("Client {client_id}: {e}");
609 }
610 }
611}
612
613fn serialize_msg(
614 msg_type: MessageType,
615 payload: &MessagePayload,
616 refers_to: u16,
617) -> Result<Vec<u8>> {
618 let mut base = BaseMessage {
619 msg_type,
620 id: 0,
621 refers_to,
622 sent: now_timeval(),
623 received: Timeval::default(),
624 size: 0,
625 };
626 factory::serialize(&mut base, payload).map_err(|e| anyhow::anyhow!("serialize: {e}"))
627}
628
629async fn send_msg(
630 stream: &mut TcpStream,
631 msg_type: MessageType,
632 payload: &MessagePayload,
633) -> Result<()> {
634 let frame = serialize_msg(msg_type, payload, 0)?;
635 stream.write_all(&frame).await.context("write message")
636}
637
638async fn read_frame_from<R: AsyncReadExt + Unpin>(reader: &mut R) -> Result<TypedMessage> {
639 const MAX_PAYLOAD_SIZE: u32 = 2 * 1024 * 1024; let mut header_buf = [0u8; BaseMessage::HEADER_SIZE];
642 reader
643 .read_exact(&mut header_buf)
644 .await
645 .context("read header")?;
646 let mut base =
647 BaseMessage::read_from(&mut &header_buf[..]).map_err(|e| anyhow::anyhow!("parse: {e}"))?;
648 base.received = now_timeval();
649 anyhow::ensure!(
650 base.size <= MAX_PAYLOAD_SIZE,
651 "payload too large: {} bytes",
652 base.size
653 );
654 let mut payload_buf = vec![0u8; base.size as usize];
655 if !payload_buf.is_empty() {
656 reader
657 .read_exact(&mut payload_buf)
658 .await
659 .context("read payload")?;
660 }
661 factory::deserialize(base, &payload_buf).map_err(|e| anyhow::anyhow!("deserialize: {e}"))
662}
663
664fn now_timeval() -> Timeval {
665 Timeval::from_usec(now_usec())
666}
667
668#[cfg(test)]
671mod tests {
672 use super::*;
673
674 fn chunk(stream_id: &str) -> WireChunkData {
675 WireChunkData {
676 stream_id: stream_id.to_string(),
677 timestamp_usec: 0,
678 data: vec![0u8; 64],
679 }
680 }
681
682 fn routing(stream_id: &str, client_muted: bool, group_muted: bool) -> SessionRouting {
683 SessionRouting {
684 stream_id: stream_id.to_string(),
685 client_muted,
686 group_muted,
687 }
688 }
689
690 #[test]
693 fn matching_stream_unmuted_sends() {
694 assert!(should_send_chunk(
695 &chunk("z1"),
696 &routing("z1", false, false),
697 false
698 ));
699 }
700
701 #[test]
702 fn wrong_stream_skips() {
703 assert!(!should_send_chunk(
704 &chunk("z2"),
705 &routing("z1", false, false),
706 false
707 ));
708 }
709
710 #[test]
711 fn client_muted_skips() {
712 assert!(!should_send_chunk(
713 &chunk("z1"),
714 &routing("z1", true, false),
715 false
716 ));
717 }
718
719 #[test]
720 fn group_muted_skips() {
721 assert!(!should_send_chunk(
722 &chunk("z1"),
723 &routing("z1", false, true),
724 false
725 ));
726 }
727
728 #[test]
729 fn send_audio_to_muted_overrides() {
730 assert!(should_send_chunk(
731 &chunk("z1"),
732 &routing("z1", true, true),
733 true
734 ));
735 }
736
737 #[test]
738 fn wrong_stream_ignores_send_audio_to_muted() {
739 assert!(!should_send_chunk(
740 &chunk("z2"),
741 &routing("z1", false, false),
742 true
743 ));
744 }
745
746 #[test]
749 fn build_routing_finds_client_in_group() {
750 let mut state = crate::state::ServerState::default();
751 state.get_or_create_client("c1", "host", "mac");
752 state.group_for_client("c1", "stream1");
753 let r = SessionContext::build_routing(&state, "c1").unwrap();
754 assert_eq!(r.stream_id, "stream1");
755 assert!(!r.client_muted);
756 assert!(!r.group_muted);
757 }
758
759 #[test]
760 fn build_routing_reflects_mute() {
761 let mut state = crate::state::ServerState::default();
762 let c = state.get_or_create_client("c1", "host", "mac");
763 c.config.volume.muted = true;
764 state.group_for_client("c1", "stream1");
765 if let Some(g) = state
766 .groups
767 .iter_mut()
768 .find(|g| g.clients.contains(&"c1".to_string()))
769 {
770 g.muted = true;
771 }
772 let r = SessionContext::build_routing(&state, "c1").unwrap();
773 assert!(r.client_muted);
774 assert!(r.group_muted);
775 }
776
777 #[test]
778 fn build_routing_returns_none_for_unknown_client() {
779 let state = crate::state::ServerState::default();
780 assert!(SessionContext::build_routing(&state, "unknown").is_none());
781 }
782
783 #[test]
786 fn routing_watch_delivers_updates() {
787 let (tx, rx) = watch::channel(routing("z1", false, false));
788 assert_eq!(rx.borrow().stream_id, "z1");
789 tx.send(routing("z2", true, false)).unwrap();
790 assert_eq!(rx.borrow().stream_id, "z2");
791 assert!(rx.borrow().client_muted);
792 }
793
794 #[test]
795 fn unmute_cycle() {
796 let r_muted = routing("z1", true, false);
797 let r_unmuted = routing("z1", false, false);
798 assert!(!should_send_chunk(&chunk("z1"), &r_muted, false));
799 assert!(should_send_chunk(&chunk("z1"), &r_unmuted, false));
800 }
801
802 #[test]
803 fn stream_switch_changes_filter() {
804 let r1 = routing("z1", false, false);
805 let r2 = routing("z2", false, false);
806 assert!(should_send_chunk(&chunk("z1"), &r1, false));
807 assert!(!should_send_chunk(&chunk("z1"), &r2, false));
808 assert!(should_send_chunk(&chunk("z2"), &r2, false));
809 }
810}