Skip to main content

snapcast_control/
communication.rs

1use std::sync::Arc;
2
3use stubborn_io::{ReconnectOptions, StubbornTcpStream};
4use uuid::Uuid;
5
6use crate::{
7  Message, Method, ValidMessage, errors,
8  protocol::{
9    self, ConnectionStatus, Request, RequestMethod, SentRequests, SnapcastDeserializer, client, group, server, stream,
10  },
11  state::WrappedState,
12};
13
14type Sender =
15  futures::stream::SplitSink<tokio_util::codec::Framed<StubbornTcpStream<std::net::SocketAddr>, Communication>, Method>;
16type Receiver =
17  futures::stream::SplitStream<tokio_util::codec::Framed<StubbornTcpStream<std::net::SocketAddr>, Communication>>;
18
19/// callback function type for connection status changes
20pub type ConnectionCallback = Arc<dyn Fn(ConnectionStatus) + Send + Sync>;
21
22/// builder for creating a [SnapcastConnection] with optional connection status callbacks
23#[derive(Default)]
24pub struct SnapcastConnectionBuilder {
25  on_connect: Option<Arc<dyn Fn() + Send + Sync>>,
26  on_disconnect: Option<Arc<dyn Fn() + Send + Sync>>,
27  on_reconnect_failed: Option<Arc<dyn Fn() + Send + Sync>>,
28}
29
30impl SnapcastConnectionBuilder {
31  /// create a new builder with no callbacks configured
32  pub fn new() -> Self {
33    Self::default()
34  }
35
36  /// set a callback to be invoked when a connection is established
37  ///
38  /// this is called both on initial connection and after successful reconnection
39  pub fn on_connect<F>(mut self, callback: F) -> Self
40  where
41    F: Fn() + Send + Sync + 'static,
42  {
43    self.on_connect = Some(Arc::new(callback));
44    self
45  }
46
47  /// set a callback to be invoked when the connection is lost
48  ///
49  /// after this callback is invoked, the client will automatically attempt to reconnect
50  pub fn on_disconnect<F>(mut self, callback: F) -> Self
51  where
52    F: Fn() + Send + Sync + 'static,
53  {
54    self.on_disconnect = Some(Arc::new(callback));
55    self
56  }
57
58  /// set a callback to be invoked when a reconnection attempt fails
59  ///
60  /// this may be called multiple times as the client retries with exponential backoff
61  pub fn on_reconnect_failed<F>(mut self, callback: F) -> Self
62  where
63    F: Fn() + Send + Sync + 'static,
64  {
65    self.on_reconnect_failed = Some(Arc::new(callback));
66    self
67  }
68
69  /// set a single callback to handle all connection status changes
70  ///
71  /// this is a convenience method that sets all three callbacks to invoke the same
72  /// handler with the appropriate [ConnectionStatus] variant
73  pub fn on_status_change<F>(self, callback: F) -> Self
74  where
75    F: Fn(ConnectionStatus) + Send + Sync + 'static,
76  {
77    let callback = Arc::new(callback);
78    let connect_cb = callback.clone();
79    let disconnect_cb = callback.clone();
80    let fail_cb = callback;
81
82    Self {
83      on_connect: Some(Arc::new(move || connect_cb(ConnectionStatus::Connected))),
84      on_disconnect: Some(Arc::new(move || disconnect_cb(ConnectionStatus::Disconnected))),
85      on_reconnect_failed: Some(Arc::new(move || fail_cb(ConnectionStatus::ReconnectFailed))),
86    }
87  }
88
89  /// connect to the Snapcast server at the given address
90  ///
91  /// # args
92  /// `address`: [std::net::SocketAddr] - the address of the Snapcast server
93  ///
94  /// # returns
95  /// a new [SnapcastConnection] struct
96  pub async fn connect(self, address: std::net::SocketAddr) -> Result<SnapcastConnection, std::io::Error> {
97    let state = WrappedState::default();
98    let (sender, receiver) =
99      Communication::init(address, self.on_connect, self.on_disconnect, self.on_reconnect_failed).await?;
100
101    Ok(SnapcastConnection {
102      state,
103      sender,
104      receiver,
105    })
106  }
107}
108
109/// Struct representing a connection to a Snapcast server.
110/// Contains the current state of the server and methods to interact with it.
111///
112/// call `SnapcastConnection::open` to create a new connection, or use
113/// `SnapcastConnection::builder` to configure connection status callbacks.
114pub struct SnapcastConnection {
115  /// The current state of the server. The state is Send + Sync, so it can be shared between threads.
116  pub state: WrappedState,
117
118  // internal
119  sender: Sender,
120  receiver: Receiver,
121}
122
123impl SnapcastConnection {
124  /// create a builder for configuring connection callbacks
125  ///
126  /// # returns
127  /// a new [SnapcastConnectionBuilder] struct
128  pub fn builder() -> SnapcastConnectionBuilder {
129    SnapcastConnectionBuilder::new()
130  }
131
132  /// open a new connection to a Snapcast server
133  ///
134  /// for connection status notifications, use [SnapcastConnection::builder] instead
135  ///
136  /// # args
137  /// `address`: [std::net::SocketAddr] - the address of the Snapcast server
138  ///
139  /// # returns
140  /// a new [SnapcastConnection] struct
141  ///
142  /// # example
143  /// ```no_run
144  /// let mut client = SnapcastConnection::open("127.0.0.1:1705".parse().expect("could not parse socket address")).await.expect("could not connect to server");
145  /// ```
146  pub async fn open(address: std::net::SocketAddr) -> Result<Self, std::io::Error> {
147    SnapcastConnectionBuilder::new().connect(address).await
148  }
149
150  /// send a raw command to the Snapcast server
151  ///
152  /// # args
153  /// `command`: [Method] - the command to send
154  ///
155  /// # returns
156  /// an empty [Ok] if the command was sent successfully, or a [ClientError] if there was an error
157  ///
158  /// # example
159  /// ```no_run
160  /// client.send(Method::ServerGetStatus).await.expect("could not send command");
161  /// ```
162  pub async fn send(&mut self, command: Method) -> Result<(), ClientError> {
163    use futures::SinkExt;
164
165    self.sender.send(command).await
166  }
167
168  /// receive messages from the Snapcast server
169  ///
170  /// uses a [futures::stream::Next] under the hood, so: \
171  /// creates a future that resolves to the next batch of messages in the stream
172  ///
173  /// # returns
174  /// an [Option] containing a [Vec] of [Result]s, one for each message in the batch, \
175  /// or [None] if the stream has ended. Transport-level errors result in a single-element
176  /// vec containing the error.
177  ///
178  /// # example
179  /// ```ignore
180  /// if let Some(messages) = client.recv().await {
181  ///   for result in messages {
182  ///     match result {
183  ///       Ok(message) => { /* handle message */ }
184  ///       Err(err) => { /* handle error */ }
185  ///     }
186  ///   }
187  /// }
188  /// ```
189  pub async fn recv(&mut self) -> Option<Vec<Result<ValidMessage, ClientError>>> {
190    use futures::StreamExt;
191
192    let messages = self.receiver.next().await;
193
194    match messages {
195      Some(Ok(messages)) => {
196        let mut results = Vec::with_capacity(messages.len());
197
198        for message in messages {
199          match &message {
200            Message::Error { error, .. } => {
201              results.push(Err(error.clone().into()));
202            }
203            Message::Result { result, .. } => {
204              self.state.handle_result(*result.clone());
205              results.push(Ok(
206                message.try_into().expect("Result can always convert to ValidMessage"),
207              ));
208            }
209            Message::Notification { method, .. } => {
210              self.state.handle_notification(*method.clone());
211              results.push(Ok(
212                message
213                  .try_into()
214                  .expect("Notification can always convert to ValidMessage"),
215              ));
216            }
217          }
218        }
219
220        Some(results)
221      }
222      Some(Err(err)) => Some(vec![Err(err)]),
223      None => None,
224    }
225  }
226
227  // client methods
228  /// request the current status of a client from the Snapcast server
229  ///
230  /// wrapper for sending a [ClientGetStatus](Method::ClientGetStatus) command
231  ///
232  /// # args
233  /// `id`: [String] - the id of the client
234  ///
235  /// # returns
236  /// an empty [Ok] if the command was sent successfully, or a [ClientError] if there was an error
237  ///
238  /// # example
239  /// ```no_run
240  /// client.client_get_status("client_id".to_string()).await.expect("could not get client status");
241  /// ```
242  pub async fn client_get_status(&mut self, id: String) -> Result<(), ClientError> {
243    self
244      .send(Method::ClientGetStatus {
245        params: client::GetStatusParams { id },
246      })
247      .await
248  }
249
250  /// set the volume and mute status of a client
251  ///
252  /// wrapper for sending a [ClientSetVolume](Method::ClientSetVolume) command
253  ///
254  /// # args
255  /// `id`: [String] - the id of the client
256  /// `volume`: [client::ClientVolume] - the volume and mute status to set
257  ///
258  /// # returns
259  /// an empty [Ok] if the command was sent successfully, or a [ClientError] if there was an error
260  ///
261  /// # example
262  /// ```no_run
263  /// client.client_set_mute("client_id".to_string(), client::ClientVolume { mute: false, volume: 50 }).await.expect("could not set client mute");
264  /// ```
265  pub async fn client_set_volume(&mut self, id: String, volume: client::ClientVolume) -> Result<(), ClientError> {
266    self
267      .send(Method::ClientSetVolume {
268        params: client::SetVolumeParams { id, volume },
269      })
270      .await
271  }
272
273  /// set the latency of a client
274  ///
275  /// wrapper for sending a [ClientSetLatency](Method::ClientSetLatency) command
276  ///
277  /// # args
278  /// `id`: [String] - the id of the client
279  /// `latency`: [usize] - the latency to set
280  ///
281  /// # returns
282  /// an empty [Ok] if the command was sent successfully, or a [ClientError] if there was an error
283  ///
284  /// # example
285  /// ```no_run
286  /// client.client_set_latency("client_id".to_string(), 100).await.expect("could not set client latency");
287  /// ```
288  pub async fn client_set_latency(&mut self, id: String, latency: usize) -> Result<(), ClientError> {
289    self
290      .send(Method::ClientSetLatency {
291        params: client::SetLatencyParams { id, latency },
292      })
293      .await
294  }
295
296  /// set the name of a client
297  ///
298  /// wrapper for sending a [ClientSetName](Method::ClientSetName) command
299  ///
300  /// # args
301  /// `id`: [String] - the id of the client
302  /// `name`: [String] - the name to set
303  ///
304  /// # returns
305  /// an empty [Ok] if the command was sent successfully, or a [ClientError] if there was an error
306  ///
307  /// # example
308  /// ```no_run
309  /// client.client_set_name("client_id".to_string(), "new_name".to_string()).await.expect("could not set client name");
310  /// ```
311  pub async fn client_set_name(&mut self, id: String, name: String) -> Result<(), ClientError> {
312    self
313      .send(Method::ClientSetName {
314        params: client::SetNameParams { id, name },
315      })
316      .await
317  }
318
319  // group methods
320  /// request the current status of a group from the Snapcast server
321  ///
322  /// wrapper for sending a [GroupGetStatus](Method::GroupGetStatus) command
323  ///
324  /// # args
325  /// `id`: [String] - the id of the group
326  ///
327  /// # returns
328  /// an empty [Ok] if the command was sent successfully, or a [ClientError] if there was an error
329  ///
330  /// # example
331  /// ```no_run
332  /// client.group_get_status("group_id".to_string()).await.expect("could not get group status");
333  /// ```
334  pub async fn group_get_status(&mut self, id: String) -> Result<(), ClientError> {
335    self
336      .send(Method::GroupGetStatus {
337        params: group::GetStatusParams { id },
338      })
339      .await
340  }
341
342  /// set the mute status of a group
343  ///
344  /// wrapper for sending a [GroupSetMute](Method::GroupSetMute) command
345  ///
346  /// # args
347  /// `id`: [String] - the id of the group
348  /// `mute`: [bool] - the mute status to set
349  ///
350  /// # returns
351  /// an empty [Ok] if the command was sent successfully, or a [ClientError] if there was an error
352  ///
353  /// # example
354  /// ```no_run
355  /// client.group_set_mute("group_id".to_string(), true).await.expect("could not set group mute");
356  /// ```
357  pub async fn group_set_mute(&mut self, id: String, mute: bool) -> Result<(), ClientError> {
358    self
359      .send(Method::GroupSetMute {
360        params: group::SetMuteParams { id, mute },
361      })
362      .await
363  }
364
365  /// set the stream of a group
366  ///
367  /// wrapper for sending a [GroupSetStream](Method::GroupSetStream) command
368  ///
369  /// # args
370  /// `id`: [String] - the id of the group
371  /// `stream_id`: [String] - the id of the stream to set
372  ///
373  /// # returns
374  /// an empty [Ok] if the command was sent successfully, or a [ClientError] if there was an error
375  ///
376  /// # example
377  /// ```no_run
378  /// client.group_set_stream("group_id".to_string(), "stream_id".to_string()).await.expect("could not set group stream");
379  /// ```
380  pub async fn group_set_stream(&mut self, id: String, stream_id: String) -> Result<(), ClientError> {
381    self
382      .send(Method::GroupSetStream {
383        params: group::SetStreamParams { id, stream_id },
384      })
385      .await
386  }
387
388  /// set the clients of a group
389  ///
390  /// wrapper for sending a [GroupSetClients](Method::GroupSetClients) command
391  ///
392  /// # args
393  /// `id`: [String] - the id of the group
394  /// `clients`: [Vec]<[String]> - the ids of the clients to set
395  ///
396  /// # returns
397  /// an empty [Ok] if the command was sent successfully, or a [ClientError] if there was an error
398  ///
399  /// # example
400  /// ```no_run
401  /// client.group_set_clients("group_id".to_string(), vec!["client_id".to_string()]).await.expect("could not set group clients");
402  /// ```
403  pub async fn group_set_clients(&mut self, id: String, clients: Vec<String>) -> Result<(), ClientError> {
404    self
405      .send(Method::GroupSetClients {
406        params: group::SetClientsParams { id, clients },
407      })
408      .await
409  }
410
411  /// set the name of a group
412  ///
413  /// wrapper for sending a [GroupSetName](Method::GroupSetName) command
414  ///
415  /// # args
416  /// `id`: [String] - the id of the group
417  /// `name`: [String] - the name to set
418  ///
419  /// # returns
420  /// an empty [Ok] if the command was sent successfully, or a [ClientError] if there was an error
421  ///
422  /// # example
423  /// ```no_run
424  /// client.group_set_name("group_id".to_string(), "new_name".to_string()).await.expect("could not set group name");
425  /// ```
426  pub async fn group_set_name(&mut self, id: String, name: String) -> Result<(), ClientError> {
427    self
428      .send(Method::GroupSetName {
429        params: group::SetNameParams { id, name },
430      })
431      .await
432  }
433
434  // server methods
435  /// request the rpc version of the Snapcast server
436  ///
437  /// wrapper for sending a [ServerGetStatus](Method::ServerGetStatus) command
438  ///
439  /// # returns
440  /// an empty [Ok] if the command was sent successfully, or a [ClientError] if there was an error
441  ///
442  /// # example
443  /// ```no_run
444  /// client.server_get_rpc_version().await.expect("could not get server rpc version");
445  /// ```
446  pub async fn server_get_rpc_version(&mut self) -> Result<(), ClientError> {
447    self.send(Method::ServerGetRPCVersion).await
448  }
449
450  /// request the current status of the Snapcast server, this is a full refresh for state
451  ///
452  /// wrapper for sending a [ServerGetStatus](Method::ServerGetStatus) command
453  ///
454  /// # returns
455  /// an empty [Ok] if the command was sent successfully, or a [ClientError] if there was an error
456  ///
457  /// # example
458  /// ```no_run
459  /// client.server_get_status().await.expect("could not get server status");
460  /// ```
461  pub async fn server_get_status(&mut self) -> Result<(), ClientError> {
462    self.send(Method::ServerGetStatus).await
463  }
464
465  /// forcefully delete a client from the Snapcast server
466  ///
467  /// wrapper for sending a [ServerDeleteClient](Method::ServerDeleteClient) command
468  ///
469  /// # args
470  /// `id`: [String] - the id of the client to delete
471  ///
472  /// # returns
473  /// an empty [Ok] if the command was sent successfully, or a [ClientError] if there was an error
474  ///
475  /// # example
476  /// ```no_run
477  /// client.server_delete_client("client_id".to_string()).await.expect("could not delete client");
478  /// ```
479  pub async fn server_delete_client(&mut self, id: String) -> Result<(), ClientError> {
480    self
481      .send(Method::ServerDeleteClient {
482        params: server::DeleteClientParams { id },
483      })
484      .await
485  }
486
487  // stream methods
488  /// add a new stream to the Snapcast server
489  ///
490  /// wrapper for sending a [StreamAddStream](Method::StreamAddStream) command
491  ///
492  /// # args
493  /// `stream_uri`: [String] - the uri of the stream to add
494  ///
495  /// # returns
496  /// an empty [Ok] if the command was sent successfully, or a [ClientError] if there was an error
497  ///
498  /// # example
499  /// ```no_run
500  /// client.stream_add_stream("librespot:///usr/bin/librespot?name=Spotify&...".to_string()).await.expect("could not add stream");
501  /// ```
502  pub async fn stream_add_stream(&mut self, stream_uri: String) -> Result<(), ClientError> {
503    self
504      .send(Method::StreamAddStream {
505        params: stream::AddStreamParams { stream_uri },
506      })
507      .await
508  }
509
510  /// remove a stream from the Snapcast server
511  ///
512  /// wrapper for sending a [StreamRemoveStream](Method::StreamRemoveStream) command
513  ///
514  /// # args
515  /// `id`: [String] - the id of the stream to remove
516  ///
517  /// # returns
518  /// an empty [Ok] if the command was sent successfully, or a [ClientError] if there was an error
519  ///
520  /// # example
521  /// ```no_run
522  /// client.stream_remove_stream("stream_id".to_string()).await.expect("could not remove stream");
523  /// ```
524  pub async fn stream_remove_stream(&mut self, id: String) -> Result<(), ClientError> {
525    self
526      .send(Method::StreamRemoveStream {
527        params: stream::RemoveStreamParams { id },
528      })
529      .await
530  }
531
532  /// control a stream on the Snapcast server
533  ///
534  /// wrapper for sending a [StreamControl](Method::StreamControl) command
535  ///
536  /// # args
537  /// `id`: [String] - the id of the stream to control
538  /// `command`: [stream::ControlCommand] - the command to send to the stream
539  ///
540  /// # returns
541  /// an empty [Ok] if the command was sent successfully, or a [ClientError] if there was an error
542  ///
543  /// # example
544  /// ```no_run
545  /// client.stream_control("stream_id".to_string(), stream::ControlCommand::Pause).await.expect("could not control stream");
546  /// ```
547  pub async fn stream_control(&mut self, id: String, command: stream::ControlCommand) -> Result<(), ClientError> {
548    self
549      .send(Method::StreamControl {
550        params: stream::ControlParams { id, command },
551      })
552      .await
553  }
554
555  /// set the property of a stream on the Snapcast server
556  ///
557  /// wrapper for sending a [StreamSetProperty](Method::StreamSetProperty) command
558  ///
559  /// # args
560  /// `id`: [String] - the id of the stream to control
561  /// `properties`: [stream::SetPropertyProperties] - the properties to set on the stream
562  ///
563  /// # returns
564  /// an empty [Ok] if the command was sent successfully, or a [ClientError] if there was an error
565  ///
566  /// # example
567  /// ```no_run
568  /// client.stream_set_property("stream_id".to_string(), stream::SetPropertyProperties::Shuffle(true)).await.expect("could not set stream property");
569  /// ```
570  pub async fn stream_set_property(
571    &mut self,
572    id: String,
573    properties: stream::SetPropertyProperties,
574  ) -> Result<(), ClientError> {
575    self
576      .send(Method::StreamSetProperty {
577        params: stream::SetPropertyParams { id, properties },
578      })
579      .await
580  }
581}
582
583#[derive(Debug, Clone, Default)]
584struct Communication {
585  purgatory: SentRequests,
586}
587
588impl Communication {
589  async fn init(
590    address: std::net::SocketAddr,
591    on_connect: Option<Arc<dyn Fn() + Send + Sync>>,
592    on_disconnect: Option<Arc<dyn Fn() + Send + Sync>>,
593    on_reconnect_failed: Option<Arc<dyn Fn() + Send + Sync>>,
594  ) -> Result<(Sender, Receiver), std::io::Error> {
595    use futures::stream::StreamExt;
596    use tokio_util::codec::Decoder;
597
598    let client = Self::default();
599    let options = create_reconnect_options(on_connect, on_disconnect, on_reconnect_failed);
600
601    tracing::info!("connecting to snapcast server at {}", address);
602    let stream = StubbornTcpStream::connect_with_options(address, options).await?;
603    let (writer, reader) = client.framed(stream).split();
604
605    Ok((writer, reader))
606  }
607}
608
609fn create_reconnect_options(
610  on_connect: Option<Arc<dyn Fn() + Send + Sync>>,
611  on_disconnect: Option<Arc<dyn Fn() + Send + Sync>>,
612  on_reconnect_failed: Option<Arc<dyn Fn() + Send + Sync>>,
613) -> ReconnectOptions {
614  let mut options = ReconnectOptions::new();
615
616  if let Some(cb) = on_connect {
617    options = options.with_on_connect_callback(move || cb());
618  }
619
620  if let Some(cb) = on_disconnect {
621    options = options.with_on_disconnect_callback(move || cb());
622  }
623
624  if let Some(cb) = on_reconnect_failed {
625    options = options.with_on_connect_fail_callback(move || cb());
626  }
627
628  options
629}
630
631impl tokio_util::codec::Decoder for Communication {
632  type Item = Vec<Message>;
633  type Error = ClientError;
634
635  fn decode(&mut self, src: &mut tokio_util::bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
636    use tokio_util::bytes::Buf;
637
638    if src.is_empty() {
639      return Ok(None);
640    }
641
642    let lf_pos = src.as_ref().iter().position(|b| *b == b'\n');
643    if let Some(lf_pos) = lf_pos {
644      let data = src.split_to(lf_pos);
645      src.advance(1);
646
647      tracing::debug!("received complete message with length: {}", data.len());
648      let message = std::str::from_utf8(&data)?;
649      tracing::trace!("completed json message: {:?}", message);
650
651      let messages = SnapcastDeserializer::de(message, &self.purgatory)?;
652      tracing::trace!("completed deserialized messages: {:?}", messages);
653
654      if messages.is_empty() {
655        return Ok(None);
656      }
657
658      return Ok(Some(messages));
659    }
660
661    Ok(None)
662  }
663}
664
665impl tokio_util::codec::Encoder<Method> for Communication {
666  type Error = ClientError;
667
668  fn encode(&mut self, method: Method, dst: &mut tokio_util::bytes::BytesMut) -> Result<(), Self::Error> {
669    tracing::trace!("encoding: {:?}", method);
670
671    let id = Uuid::new_v4();
672    let command: RequestMethod = (&method).into();
673    tracing::debug!("sending command: {:?}", command);
674    self.purgatory.insert(id, command);
675
676    let data = Request {
677      id,
678      jsonrpc: "2.0".to_string(),
679      method,
680    };
681
682    let string: String = data.try_into()?;
683    let string = format!("{}\n", string);
684    tracing::trace!("sending: {:?}", string);
685
686    dst.extend_from_slice(string.as_bytes());
687
688    Ok(())
689  }
690}
691
692/// Error type for the Snapcast client
693#[derive(Debug, thiserror::Error)]
694pub enum ClientError {
695  /// An error returned by the Snapcast server
696  #[error("Snapcast error: {0}")]
697  Snapcast(#[from] errors::SnapcastError),
698  /// An error communicating with the Snapcast server
699  #[error("Communication error: {0}")]
700  Io(#[from] std::io::Error),
701  /// An error decoding a UTF-8 string from the Snapcast server
702  #[error("UTF-8 decoding error: {0}")]
703  Utf8(#[from] std::str::Utf8Error),
704  /// An error deserializing a message from the Snapcast server
705  #[error("Deserialization error: {0}")]
706  Deserialization(#[from] protocol::DeserializationError),
707  /// An error deserializing the json from the Snapcast server
708  #[error("JSON Deserialization error: {0}")]
709  JsonDeserialization(#[from] serde_json::Error),
710  /// An unknown error
711  #[error("Unknown error: {0}")]
712  Unknown(String),
713}