use std::sync::{Arc, Mutex};
use std::time::Duration;
use anyhow::{Result, bail};
use snapcast_proto::message::codec_header::CodecHeader;
use snapcast_proto::message::factory::{MessagePayload, TypedMessage};
use snapcast_proto::message::hello::{Auth, Hello};
use snapcast_proto::message::server_settings::ServerSettings;
use snapcast_proto::message::time::Time;
use snapcast_proto::{MessageType, SampleFormat};
use tokio::sync::mpsc;
use crate::connection::TcpConnection;
use crate::decoder::{self, Decoder, PcmDecoder};
use crate::stream::{PcmChunk, Stream};
use crate::time_provider::TimeProvider;
use crate::{ClientCommand, ClientEvent};
const VERSION: &str = env!("CARGO_PKG_VERSION");
const MAX_RECONNECT_DELAY_SECS: u32 = 30;
#[cfg(feature = "mdns")]
const MDNS_TIMEOUT: Duration = Duration::from_secs(5);
const HELLO_TIMEOUT: Duration = Duration::from_secs(5);
const SYNC_INTERVAL: Duration = Duration::from_secs(1);
const QUICK_SYNC_INTERVAL: Duration = Duration::from_millis(100);
pub struct Controller {
settings: crate::ClientConfig,
connection: TcpConnection,
time_provider: Arc<Mutex<TimeProvider>>,
stream: Option<Arc<Mutex<Stream>>>,
decoder: Option<Box<dyn Decoder>>,
sample_format: SampleFormat,
server_settings: Option<ServerSettings>,
event_tx: mpsc::Sender<ClientEvent>,
command_rx: mpsc::Receiver<ClientCommand>,
}
impl Controller {
pub fn new(
settings: crate::ClientConfig,
event_tx: mpsc::Sender<ClientEvent>,
command_rx: mpsc::Receiver<ClientCommand>,
time_provider: Arc<Mutex<TimeProvider>>,
stream: Arc<Mutex<Stream>>,
) -> Self {
Self {
connection: TcpConnection::new(&settings.host, settings.port),
settings,
time_provider,
stream: Some(stream),
decoder: None,
sample_format: SampleFormat::default(),
server_settings: None,
event_tx,
command_rx,
}
}
pub async fn run(&mut self) -> Result<()> {
let mut attempts = 0u32;
loop {
match self.session().await {
Ok(()) => {
self.cleanup();
return Ok(());
}
Err(e) => {
if attempts == 0 {
tracing::warn!("Connection failed: {e}");
} else {
tracing::debug!("Reconnect attempt {attempts} failed: {e}");
}
self.emit(ClientEvent::Disconnected {
reason: e.to_string(),
});
attempts = attempts.saturating_add(1);
}
}
self.cleanup();
let delay = Duration::from_secs(attempts.min(MAX_RECONNECT_DELAY_SECS) as u64);
tokio::time::sleep(delay).await;
}
}
async fn session(&mut self) -> Result<()> {
if self.settings.host.is_empty() || self.settings.host.starts_with('_') {
#[cfg(feature = "mdns")]
{
tracing::info!(service = %self.settings.host, "Browsing mDNS...");
let endpoint =
crate::discovery::discover(MDNS_TIMEOUT, &self.settings.mdns_service_type)
.await?;
self.settings.host = endpoint.host;
self.settings.port = endpoint.port;
self.connection = TcpConnection::new(&self.settings.host, self.settings.port);
}
#[cfg(not(feature = "mdns"))]
bail!("mDNS not available — specify server URL");
}
self.connection.connect().await?;
tracing::info!(
host = %self.settings.host,
port = self.settings.port,
"Connected"
);
self.emit(ClientEvent::Connected {
host: self.settings.host.clone(),
port: self.settings.port,
});
self.send_hello().await?;
self.receive_loop().await
}
async fn send_hello(&mut self) -> Result<()> {
let mac = get_mac_address();
let host_id = if self.settings.host_id.is_empty() {
mac.clone()
} else {
self.settings.host_id.clone()
};
let auth = self.settings.auth.as_ref().map(|a| Auth {
scheme: a.scheme.clone(),
param: a.param.clone(),
});
let hello = Hello {
mac: mac.clone(),
host_name: hostname(),
version: VERSION.to_string(),
client_name: self.settings.client_name.clone(),
os: std::env::consts::OS.to_string(),
arch: std::env::consts::ARCH.to_string(),
instance: self.settings.instance,
id: host_id,
snap_stream_protocol_version: snapcast_proto::PROTOCOL_VERSION,
auth,
};
self.connection
.send(MessageType::Hello, &MessagePayload::Hello(hello))
.await?;
let response = self.recv_timeout(HELLO_TIMEOUT).await?;
match response.payload {
MessagePayload::ServerSettings(ss) => {
self.emit(ClientEvent::ServerSettings {
buffer_ms: ss.buffer_ms,
latency: ss.latency,
volume: ss.volume,
muted: ss.muted,
});
self.server_settings = Some(ss);
Ok(())
}
MessagePayload::Error(e) => {
bail!("Server rejected Hello: {} (code {})", e.error, e.code)
}
_ => bail!("Unexpected response to Hello: {:?}", response.base.msg_type),
}
}
async fn receive_loop(&mut self) -> Result<()> {
let mut sync_timer = tokio::time::interval(SYNC_INTERVAL);
const INITIAL_QUICK_SYNCS: u32 = 50;
let mut quick_syncs_remaining = INITIAL_QUICK_SYNCS;
let mut quick_sync_timer = tokio::time::interval(QUICK_SYNC_INTERVAL);
self.connection
.send(MessageType::Time, &MessagePayload::Time(Time::new()))
.await
.ok();
loop {
tokio::select! {
msg = self.connection.recv() => {
let msg = msg?;
self.handle_message(msg)?;
}
cmd = self.command_rx.recv() => {
match cmd {
Some(ClientCommand::Stop) | None => {
tracing::info!("Stop command received");
return Ok(());
}
Some(ClientCommand::SetVolume { volume, muted }) => {
tracing::debug!(volume, muted, "Volume change (applied by binary)");
}
#[cfg(feature = "custom-protocol")]
Some(ClientCommand::SendCustom(msg)) => {
self.connection
.send(
MessageType::Custom(msg.type_id),
&MessagePayload::Custom(msg.payload),
)
.await?;
}
}
}
_ = quick_sync_timer.tick(), if quick_syncs_remaining > 0 => {
quick_syncs_remaining -= 1;
self.connection
.send(MessageType::Time, &MessagePayload::Time(Time::new()))
.await
.ok();
if quick_syncs_remaining == 0 {
let diff = self.time_provider.lock().unwrap_or_else(|e| e.into_inner()).diff_to_server_usec();
let diff_ms = diff as f64 / 1000.0;
tracing::info!(diff_ms, "Time sync complete");
self.emit(ClientEvent::TimeSyncComplete { diff_ms });
}
}
_ = sync_timer.tick(), if quick_syncs_remaining == 0 => {
self.connection
.send(MessageType::Time, &MessagePayload::Time(Time::new()))
.await
.ok();
}
}
}
}
fn handle_message(&mut self, msg: TypedMessage) -> Result<()> {
match msg.payload {
MessagePayload::WireChunk(wc) => {
if let Some(ref mut dec) = self.decoder {
let mut data = wc.payload;
if dec.decode(&mut data)? {
let chunk = PcmChunk::new(wc.timestamp, data, self.sample_format);
if let Some(ref stream) = self.stream {
stream
.lock()
.unwrap_or_else(|e| e.into_inner())
.add_chunk(chunk);
}
}
}
}
MessagePayload::ServerSettings(ss) => {
self.emit(ClientEvent::ServerSettings {
buffer_ms: ss.buffer_ms,
latency: ss.latency,
volume: ss.volume,
muted: ss.muted,
});
self.emit(ClientEvent::VolumeChanged {
volume: ss.volume,
muted: ss.muted,
});
self.apply_server_settings(&ss);
self.server_settings = Some(ss);
}
MessagePayload::CodecHeader(ch) => {
self.init_audio_pipeline(&ch)?;
}
MessagePayload::Time(t) => {
let s2c = msg.base.received - msg.base.sent;
self.time_provider
.lock()
.unwrap()
.set_diff(&t.latency, &s2c);
}
MessagePayload::Error(e) => {
tracing::error!(code = e.code, error = %e.error, "Server error");
}
#[cfg(feature = "custom-protocol")]
MessagePayload::Custom(payload) => {
if let MessageType::Custom(type_id) = msg.base.msg_type {
self.emit(ClientEvent::CustomMessage(
snapcast_proto::CustomMessage::new(type_id, payload),
));
}
}
_ => {}
}
Ok(())
}
fn apply_server_settings(&mut self, ss: &ServerSettings) {
if let Some(ref stream) = self.stream {
let buf_ms = (ss.buffer_ms - ss.latency - self.settings.latency).max(0);
stream
.lock()
.unwrap_or_else(|e| e.into_inner())
.set_buffer_ms(buf_ms as i64);
}
}
fn init_audio_pipeline(&mut self, header: &CodecHeader) -> Result<()> {
let mut dec: Box<dyn Decoder> = match header.codec.as_str() {
"pcm" => Box::new(PcmDecoder::new()),
"flac" => Box::new(decoder::flac::create(header)?),
"ogg" => Box::new(decoder::vorbis::create(header)?),
"opus" => Box::new(decoder::opus::create(header)?),
#[cfg(all(feature = "f32lz4", feature = "encryption"))]
"f32lz4" => Box::new(decoder::f32lz4::create(
self.settings.encryption_psk.as_deref(),
)),
#[cfg(all(feature = "f32lz4", not(feature = "encryption")))]
"f32lz4" => Box::new(decoder::f32lz4::create()),
other => bail!("unsupported codec: {other}"),
};
self.sample_format = dec.set_header(header)?;
tracing::info!(codec = %header.codec, format = %self.sample_format, "Codec initialized");
self.emit(ClientEvent::StreamStarted {
codec: header.codec.clone(),
format: self.sample_format,
});
if let Some(ref stream) = self.stream {
let mut s = stream.lock().unwrap_or_else(|e| e.into_inner());
*s = Stream::new(self.sample_format);
if let Some(ref ss) = self.server_settings {
let buf_ms = (ss.buffer_ms - ss.latency - self.settings.latency).max(0);
s.set_buffer_ms(buf_ms as i64);
}
}
self.decoder = Some(dec);
Ok(())
}
async fn recv_timeout(&mut self, timeout: Duration) -> Result<TypedMessage> {
tokio::time::timeout(timeout, self.connection.recv())
.await
.map_err(|_| anyhow::anyhow!("receive timed out"))?
}
fn cleanup(&mut self) {
self.decoder = None;
self.connection.disconnect();
}
fn emit(&self, event: ClientEvent) {
let _ = self.event_tx.try_send(event);
}
}
fn hostname() -> String {
hostname::get()
.map(|h| h.to_string_lossy().into_owned())
.unwrap_or_else(|_| "unknown".to_string())
}
fn get_mac_address() -> String {
mac_address::get_mac_address()
.ok()
.flatten()
.map(|mac| mac.to_string().to_lowercase())
.unwrap_or_else(|| "00:00:00:00:00:00".to_string())
}