mpd_utils/
persistent_client.rs

1use crate::socket::try_get_connection;
2use mpd_client::client::{CommandError, ConnectionEvent};
3use mpd_client::commands::Command;
4use mpd_client::responses::{SongInQueue, Status};
5use mpd_client::{commands, Client};
6use std::future::Future;
7use std::sync::{Arc, RwLock};
8use std::time::Duration;
9use tokio::spawn;
10use tokio::sync::broadcast;
11use tokio::sync::broadcast::error::RecvError;
12use tokio::time::sleep;
13use tracing::{debug, error, info};
14
15#[derive(Debug, Clone)]
16enum State {
17    Disconnected,
18    Connected(Arc<Client>),
19}
20
21type Channel<T> = (broadcast::Sender<T>, broadcast::Receiver<T>);
22
23/// MPD client which automatically attempts to reconnect
24/// if the connection cannot be established or is lost.
25///
26/// Commands sent to a disconnected client are queued.
27#[derive(Debug)]
28pub struct PersistentClient {
29    host: String,
30    retry_interval: Duration,
31    state: Arc<RwLock<State>>,
32    channel: Channel<Arc<ConnectionEvent>>,
33    connection_channel: Channel<Arc<Client>>,
34}
35
36impl PersistentClient {
37    pub fn new(host: String, retry_interval: Duration) -> Self {
38        let channel = broadcast::channel(1024);
39        let connection_channel = broadcast::channel(8);
40
41        Self {
42            host,
43            retry_interval,
44            state: Arc::new(RwLock::new(State::Disconnected)),
45            channel,
46            connection_channel,
47        }
48    }
49
50    /// Attempts to connect to the MPD host
51    /// and begins listening to server events.
52    pub fn init(&self) {
53        let host = self.host.clone();
54        let retry_interval = self.retry_interval;
55        let state = self.state.clone();
56        let tx = self.channel.0.clone();
57        let conn_tx = self.connection_channel.0.clone();
58
59        spawn(async move {
60            loop {
61                debug!("Attempting to connect to {host}");
62                let connection = try_get_connection(&host).await;
63
64                match connection {
65                    Ok(connection) => {
66                        info!("Connected to '{host}'");
67
68                        let client = Arc::new(connection.0);
69
70                        {
71                            *state.write().expect("Failed to get lock on state") =
72                                State::Connected(client.clone());
73                            conn_tx.send(client).expect("Failed to send event");
74                        }
75
76                        let mut events = connection.1;
77
78                        while let Some(event) = events.next().await {
79                            if let ConnectionEvent::ConnectionClosed(err) = event {
80                                error!("Lost connection to '{host}': {err:?}");
81                                *state.write().expect("Failed to get lock on state") =
82                                    State::Disconnected;
83
84                                break;
85                            }
86
87                            debug!("Sending event: {event:?}");
88
89                            // Wrap in `Arc` because `ConnectionEvent` isn't `Clone`.
90                            tx.send(Arc::new(event)).expect("Failed to send event");
91                        }
92                    }
93                    Err(err) => {
94                        error!("Failed to connect to '{host}': {err:?}");
95                        *state.write().expect("Failed to get lock on state") = State::Disconnected;
96                    }
97                }
98
99                sleep(retry_interval).await;
100            }
101        });
102    }
103
104    /// Gets the client host address or path
105    pub fn host(&self) -> &str {
106        &self.host
107    }
108
109    /// Gets whether there is a valid connection to the server
110    pub fn is_connected(&self) -> bool {
111        matches!(
112            *self.state.read().expect("Failed to get lock on state"),
113            State::Connected(_)
114        )
115    }
116
117    /// Waits for a valid connection to the server to be established.
118    /// If already connected, resolves immediately.
119    pub async fn wait_for_client(&self) -> Arc<Client> {
120        {
121            let state = self.state.read().expect("Failed to get lock on state");
122
123            if let State::Connected(client) = &*state {
124                return client.clone();
125            }
126        }
127
128        let mut rx = self.connection_channel.0.subscribe();
129        rx.recv().await.unwrap()
130    }
131
132    /// Runs the provided callback as soon as the connected client is available.
133    pub async fn with_client<F, Fut, T>(&self, f: F) -> T
134    where
135        F: FnOnce(Arc<Client>) -> Fut,
136        Fut: Future<Output = T>,
137    {
138        let client = self.wait_for_client().await;
139        f(client).await
140    }
141
142    /// Receives an event from the MPD server.
143    pub async fn recv(&mut self) -> Result<Arc<ConnectionEvent>, RecvError> {
144        let rx = &mut self.channel.1;
145        rx.recv().await
146    }
147
148    /// Creates a new receiver to be able to receive events
149    /// outside the context of `&self`.
150    ///
151    /// When you have access to the client instance, prefer` recv()` instead.
152    pub fn subscribe(&self) -> broadcast::Receiver<Arc<ConnectionEvent>> {
153        self.channel.0.subscribe()
154    }
155
156    /// Runs the provided command on the MPD server.
157    ///
158    /// Waits for a valid connection and response before the future is completed.
159    pub async fn command<C: Command>(&self, cmd: C) -> Result<C::Response, CommandError> {
160        self.with_client(|client| async move { client.command(cmd).await })
161            .await
162    }
163
164    /// Runs the `status` command on the MPD server.
165    ///
166    /// Waits for a valid connection and response before the future is completed.
167    pub async fn status(&self) -> Result<Status, CommandError> {
168        self.command(commands::Status).await
169    }
170
171    /// Runs the `currentsong` command on the MPD server.
172    ///
173    /// Waits for a valid connection and response before the future is completed.
174    pub async fn current_song(&self) -> Result<Option<SongInQueue>, CommandError> {
175        self.command(commands::CurrentSong).await
176    }
177}
178
179/// Creates a new client on the default localhost TCP address
180/// with a connection retry of 5 seconds.
181impl Default for PersistentClient {
182    fn default() -> Self {
183        PersistentClient::new("localhost:6600".to_string(), Duration::from_secs(5))
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use crate::*;
190    use mpd_client::commands;
191
192    #[tokio::test]
193    async fn test() {
194        let client = PersistentClient::default();
195        client.init();
196
197        let status = client
198            .with_client(|client| async move { client.command(commands::Status).await })
199            .await
200            .unwrap();
201
202        println!("{:?}", status);
203    }
204}