mod connection;
use std::{
fmt,
hash::{Hash, Hasher},
io,
sync::Arc,
};
use bytes::BytesMut;
use mpd_protocol::{
command::{Command as RawCommand, CommandList as RawCommandList},
response::{Error, Frame, Response as RawResponse},
AsyncConnection, MpdProtocolError,
};
use tokio::{
io::{AsyncRead, AsyncWrite},
sync::{
mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
oneshot,
},
};
use tracing::{debug, error, span, trace, warn, Instrument, Level};
use crate::{
commands::{self as cmds, Command, CommandList},
responses::TypedResponseError,
};
type CommandResponder = oneshot::Sender<Result<RawResponse, CommandError>>;
pub type Connection = (Client, ConnectionEvents);
#[derive(Clone)]
pub struct Client {
commands_sender: UnboundedSender<(RawCommandList, CommandResponder)>,
protocol_version: Arc<str>,
}
impl Client {
pub async fn connect<C>(connection: C) -> Result<Connection, MpdProtocolError>
where
C: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
do_connect(connection, None).await.map_err(|e| match e {
ConnectWithPasswordError::ProtocolError(e) => e,
ConnectWithPasswordError::IncorrectPassword => unreachable!(),
})
}
pub async fn connect_with_password<C>(
connection: C,
password: &str,
) -> Result<Connection, ConnectWithPasswordError>
where
C: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
do_connect(connection, Some(password)).await
}
pub async fn connect_with_password_opt<C>(
connection: C,
password: Option<&str>,
) -> Result<Connection, ConnectWithPasswordError>
where
C: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
do_connect(connection, password).await
}
pub async fn command<C>(&self, cmd: C) -> Result<C::Response, CommandError>
where
C: Command,
{
let command = cmd.command();
let frame = self.raw_command(command).await?;
let response = cmd.response(frame)?;
Ok(response)
}
pub async fn command_list<L>(&self, list: L) -> Result<L::Response, CommandError>
where
L: CommandList,
{
let frames = match list.command_list() {
Some(cmds) => self.raw_command_list(cmds).await?,
None => Vec::new(),
};
list.responses(frames).map_err(Into::into)
}
pub async fn raw_command(&self, command: RawCommand) -> Result<Frame, CommandError> {
self.do_send(RawCommandList::new(command))
.await?
.into_single_frame()
.map_err(|error| CommandError::ErrorResponse {
error,
succesful_frames: Vec::new(),
})
}
pub async fn raw_command_list(
&self,
commands: RawCommandList,
) -> Result<Vec<Frame>, CommandError> {
debug!(?commands, "sending command");
let res = self.do_send(commands).await?;
let mut frames = Vec::with_capacity(res.successful_frames());
for frame in res {
match frame {
Ok(f) => frames.push(f),
Err(error) => {
return Err(CommandError::ErrorResponse {
error,
succesful_frames: frames,
});
}
}
}
Ok(frames)
}
#[tracing::instrument(skip(self))]
pub async fn album_art(
&self,
uri: &str,
) -> Result<Option<(BytesMut, Option<String>)>, CommandError> {
debug!("loading album art");
let mut out = BytesMut::new();
let mut expected_size = 0;
let mut embedded = false;
let mut mime = None;
match self.command(cmds::AlbumArtEmbedded::new(uri)).await {
Ok(Some(resp)) => {
out = resp.data;
expected_size = resp.size;
out.reserve(expected_size);
embedded = true;
mime = resp.mime;
debug!(length = resp.size, ?mime, "found embedded album art");
}
Ok(None) => {
debug!("readpicture command gave no result, falling back");
}
Err(e) => match e {
CommandError::ErrorResponse { error, .. } if error.code == 5 => {
debug!("readpicture command unsupported, falling back");
}
e => return Err(e),
},
}
if !embedded {
if let Some(resp) = self.command(cmds::AlbumArt::new(uri)).await? {
out = resp.data;
expected_size = resp.size;
out.reserve(expected_size);
debug!(length = expected_size, "found separate file album art");
} else {
debug!("no embedded or separate album art found");
return Ok(None);
}
}
while out.len() < expected_size {
let resp = if embedded {
self.command(cmds::AlbumArtEmbedded::new(uri).offset(out.len()))
.await?
} else {
self.command(cmds::AlbumArt::new(uri).offset(out.len()))
.await?
};
if let Some(resp) = resp {
trace!(received = resp.data.len(), progress = out.len());
out.extend_from_slice(&resp.data);
} else {
warn!(progress = out.len(), "incomplete cover art response");
return Ok(None);
}
}
debug!(length = expected_size, "finished loading");
Ok(Some((out, mime)))
}
pub fn protocol_version(&self) -> &str {
self.protocol_version.as_ref()
}
pub fn is_connection_closed(&self) -> bool {
self.commands_sender.is_closed()
}
async fn do_send(&self, commands: RawCommandList) -> Result<RawResponse, CommandError> {
let (tx, rx) = oneshot::channel();
self.commands_sender
.send((commands, tx))
.map_err(|_| CommandError::ConnectionClosed)?;
rx.await.map_err(|_| CommandError::ConnectionClosed)?
}
}
impl fmt::Debug for Client {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Client")
.field("protocol_version", &self.protocol_version)
.finish_non_exhaustive()
}
}
async fn do_connect<IO: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
io: IO,
password: Option<&str>,
) -> Result<Connection, ConnectWithPasswordError> {
let span = span!(Level::DEBUG, "client connection");
let (state_changes_sender, state_changes) = unbounded_channel();
let (commands_sender, commands_receiver) = unbounded_channel();
let mut connection = match AsyncConnection::connect(io).instrument(span.clone()).await {
Ok(c) => c,
Err(e) => {
error!(error = ?e, "failed to perform initial handshake");
return Err(e.into());
}
};
let protocol_version = Arc::from(connection.protocol_version());
if let Some(password) = password {
trace!(parent: &span, "sending password");
if let Err(e) = connection
.send(RawCommand::new("password").argument(password.to_owned()))
.instrument(span.clone())
.await
{
error!(parent: &span, error = ?e, "failed to send password");
return Err(e.into());
}
match connection.receive().instrument(span.clone()).await {
Err(e) => {
error!(parent: &span, error = ?e, "failed to receive reply to password");
return Err(e.into());
}
Ok(None) => {
error!(
parent: &span,
"unexpected end of stream after sending password"
);
return Err(MpdProtocolError::Io(io::Error::new(
io::ErrorKind::UnexpectedEof,
"connection closed while waiting for reply to password",
))
.into());
}
Ok(Some(response)) if response.is_error() => {
error!(parent: &span, "incorrect password");
return Err(ConnectWithPasswordError::IncorrectPassword);
}
Ok(Some(_)) => {
trace!(parent: &span, "password accepted");
}
}
}
tokio::spawn(
connection::run_loop(connection, commands_receiver, state_changes_sender)
.instrument(span!(parent: &span, Level::TRACE, "run loop")),
);
let state_changes = ConnectionEvents(state_changes);
let client = Client {
commands_sender,
protocol_version,
};
Ok((client, state_changes))
}
#[derive(Debug)]
pub enum CommandError {
ConnectionClosed,
Protocol(MpdProtocolError),
ErrorResponse {
error: Error,
succesful_frames: Vec<Frame>,
},
InvalidTypedResponse(TypedResponseError),
}
impl fmt::Display for CommandError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
CommandError::ConnectionClosed => write!(f, "the connection is closed"),
CommandError::Protocol(_) => write!(f, "protocol error"),
CommandError::InvalidTypedResponse(_) => {
write!(f, "response was invalid for typed command")
}
CommandError::ErrorResponse {
error,
succesful_frames,
} => {
write!(
f,
"command returned an error [code {}]: {}",
error.code, error.message,
)?;
if !succesful_frames.is_empty() {
write!(f, " (after {} succesful frames)", succesful_frames.len())?;
}
Ok(())
}
}
}
}
impl std::error::Error for CommandError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
CommandError::Protocol(e) => Some(e),
CommandError::InvalidTypedResponse(e) => Some(e),
_ => None,
}
}
}
#[doc(hidden)]
impl From<MpdProtocolError> for CommandError {
fn from(e: MpdProtocolError) -> Self {
CommandError::Protocol(e)
}
}
#[doc(hidden)]
impl From<TypedResponseError> for CommandError {
fn from(e: TypedResponseError) -> Self {
CommandError::InvalidTypedResponse(e)
}
}
#[derive(Debug)]
pub enum ConnectWithPasswordError {
IncorrectPassword,
ProtocolError(MpdProtocolError),
}
impl fmt::Display for ConnectWithPasswordError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ConnectWithPasswordError::IncorrectPassword => write!(f, "incorrect password"),
ConnectWithPasswordError::ProtocolError(_) => write!(f, "protocol error"),
}
}
}
impl std::error::Error for ConnectWithPasswordError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
ConnectWithPasswordError::ProtocolError(e) => Some(e),
ConnectWithPasswordError::IncorrectPassword => None,
}
}
}
#[doc(hidden)]
impl From<MpdProtocolError> for ConnectWithPasswordError {
fn from(e: MpdProtocolError) -> Self {
ConnectWithPasswordError::ProtocolError(e)
}
}
#[derive(Debug)]
pub struct ConnectionEvents(pub(crate) UnboundedReceiver<ConnectionEvent>);
impl ConnectionEvents {
pub async fn next(&mut self) -> Option<ConnectionEvent> {
self.0.recv().await
}
}
#[derive(Debug)]
pub enum ConnectionEvent {
SubsystemChange(Subsystem),
ConnectionClosed(ConnectionError),
}
#[allow(missing_docs)]
#[non_exhaustive]
#[derive(Clone, Debug)]
pub enum Subsystem {
Database,
Message,
Mixer,
Options,
Output,
Partition,
Player,
Queue,
Sticker,
StoredPlaylist,
Subscription,
Update,
Neighbor,
Mount,
Other(Box<str>),
}
impl Subsystem {
fn from_frame(mut r: Frame) -> Option<Subsystem> {
r.get("changed").map(|raw| match &*raw {
"database" => Subsystem::Database,
"message" => Subsystem::Message,
"mixer" => Subsystem::Mixer,
"options" => Subsystem::Options,
"output" => Subsystem::Output,
"partition" => Subsystem::Partition,
"player" => Subsystem::Player,
"playlist" => Subsystem::Queue,
"sticker" => Subsystem::Sticker,
"stored_playlist" => Subsystem::StoredPlaylist,
"subscription" => Subsystem::Subscription,
"update" => Subsystem::Update,
"neighbor" => Subsystem::Neighbor,
"mount" => Subsystem::Mount,
_ => Subsystem::Other(raw.into()),
})
}
pub fn as_str(&self) -> &str {
match self {
Subsystem::Database => "database",
Subsystem::Message => "message",
Subsystem::Mixer => "mixer",
Subsystem::Options => "options",
Subsystem::Output => "output",
Subsystem::Partition => "partition",
Subsystem::Player => "player",
Subsystem::Queue => "playlist",
Subsystem::Sticker => "sticker",
Subsystem::StoredPlaylist => "stored_playlist",
Subsystem::Subscription => "subscription",
Subsystem::Update => "update",
Subsystem::Neighbor => "neighbor",
Subsystem::Mount => "mount",
Subsystem::Other(r) => r,
}
}
}
impl PartialEq for Subsystem {
fn eq(&self, other: &Self) -> bool {
self.as_str() == other.as_str()
}
}
impl Eq for Subsystem {}
impl Hash for Subsystem {
fn hash<H: Hasher>(&self, state: &mut H) {
self.as_str().hash(state);
}
}
#[derive(Debug)]
pub enum ConnectionError {
Protocol(MpdProtocolError),
InvalidResponse,
}
impl fmt::Display for ConnectionError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ConnectionError::Protocol(_) => write!(f, "protocol error"),
ConnectionError::InvalidResponse => write!(f, "invalid response"),
}
}
}
impl std::error::Error for ConnectionError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
ConnectionError::Protocol(e) => Some(e),
ConnectionError::InvalidResponse => None,
}
}
}
impl From<MpdProtocolError> for ConnectionError {
fn from(e: MpdProtocolError) -> Self {
ConnectionError::Protocol(e)
}
}
#[cfg(test)]
mod tests {
use std::collections::hash_map::DefaultHasher;
use assert_matches::assert_matches;
use tokio_test::io::Builder as MockBuilder;
use super::*;
static GREETING: &[u8] = b"OK MPD 0.21.11\n";
#[tokio::test]
async fn single_state_change() {
let io = MockBuilder::new()
.read(GREETING)
.write(b"idle\n")
.read(b"changed: player\nOK\n")
.write(b"idle\n")
.build();
let (_client, mut state_changes) = Client::connect(io).await.expect("connect failed");
assert_matches!(
state_changes.next().await,
Some(ConnectionEvent::SubsystemChange(Subsystem::Player))
);
}
#[tokio::test]
async fn command() {
let io = MockBuilder::new()
.read(GREETING)
.write(b"idle\n")
.write(b"noidle\n")
.read(b"changed: playlist\nOK\n")
.write(b"hello\n")
.read(b"foo: bar\nOK\n")
.write(b"idle\n")
.build();
let (client, mut state_changes) = Client::connect(io).await.expect("connect failed");
let response = client
.raw_command(RawCommand::new("hello"))
.await
.expect("command failed");
assert_eq!(response.find("foo"), Some("bar"));
assert_matches!(
state_changes.next().await,
Some(ConnectionEvent::SubsystemChange(Subsystem::Queue))
);
assert!(state_changes.next().await.is_none());
}
#[tokio::test]
async fn incomplete_response() {
let io = MockBuilder::new()
.read(GREETING)
.write(b"idle\n")
.write(b"noidle\n")
.read(b"OK\n")
.write(b"hello\n")
.read(b"foo: bar\n")
.read(b"baz: qux\nOK\n")
.write(b"idle\n")
.build();
let (client, _state_changes) = Client::connect(io).await.expect("connect failed");
let response = client
.raw_command(RawCommand::new("hello"))
.await
.expect("command failed");
assert_eq!(response.find("foo"), Some("bar"));
}
#[tokio::test]
async fn command_list() {
let io = MockBuilder::new()
.read(GREETING)
.write(b"idle\n")
.write(b"noidle\n")
.read(b"OK\n")
.write(b"command_list_ok_begin\nfoo\nbar\ncommand_list_end\n")
.read(b"foo: asdf\nlist_OK\n")
.read(b"baz: qux\nlist_OK\nOK\n")
.write(b"idle\n")
.build();
let (client, _state_changes) = Client::connect(io).await.expect("connect failed");
let mut commands = RawCommandList::new(RawCommand::new("foo"));
commands.add(RawCommand::new("bar"));
let responses = client
.raw_command_list(commands)
.await
.expect("command failed");
assert_eq!(responses.len(), 2);
assert_eq!(responses[0].find("foo"), Some("asdf"));
}
#[tokio::test]
async fn dropping_client() {
let io = MockBuilder::new().read(GREETING).write(b"idle\n").build();
let (client, mut state_changes) = Client::connect(io).await.expect("connect failed");
drop(client);
assert!(state_changes.next().await.is_none());
}
#[tokio::test]
async fn album_art() {
let io = MockBuilder::new()
.read(GREETING)
.write(b"idle\n")
.write(b"noidle\n")
.read(b"OK\n")
.write(b"readpicture foo/bar.mp3 0\n")
.read(b"size: 6\ntype: image/jpeg\nbinary: 3\nFOO\nOK\n")
.write(b"readpicture foo/bar.mp3 3\n")
.read(b"size: 6\ntype: image/jpeg\nbinary: 3\nBAR\nOK\n")
.build();
let (client, _) = Client::connect(io).await.expect("connect failed");
let x = client
.album_art("foo/bar.mp3")
.await
.expect("command failed");
assert_eq!(
x,
Some((BytesMut::from("FOOBAR"), Some(String::from("image/jpeg"))))
);
}
#[tokio::test]
async fn album_art_fallback() {
let io = MockBuilder::new()
.read(GREETING)
.write(b"idle\n")
.write(b"noidle\n")
.read(b"OK\n")
.write(b"readpicture foo/bar.mp3 0\n")
.read(b"OK\n")
.write(b"albumart foo/bar.mp3 0\n")
.read(b"size: 6\nbinary: 3\nFOO\nOK\n")
.write(b"albumart foo/bar.mp3 3\n")
.read(b"size: 6\nbinary: 3\nBAR\nOK\n")
.build();
let (client, _) = Client::connect(io).await.expect("connect failed");
let x = client
.album_art("foo/bar.mp3")
.await
.expect("command failed");
assert_eq!(x, Some((BytesMut::from("FOOBAR"), None)));
}
#[tokio::test]
async fn album_art_fallback_error() {
let io = MockBuilder::new()
.read(GREETING)
.write(b"idle\n")
.write(b"noidle\n")
.read(b"OK\n")
.write(b"readpicture foo/bar.mp3 0\n")
.read(b"ACK [5@0] {} unknown command \"readpicture\"\n")
.write(b"albumart foo/bar.mp3 0\n")
.read(b"size: 6\nbinary: 3\nFOO\nOK\n")
.write(b"albumart foo/bar.mp3 3\n")
.read(b"size: 6\nbinary: 3\nBAR\nOK\n")
.build();
let (client, _) = Client::connect(io).await.expect("connect failed");
let x = client
.album_art("foo/bar.mp3")
.await
.expect("command failed");
assert_eq!(x, Some((BytesMut::from("FOOBAR"), None)));
}
#[tokio::test]
async fn album_art_none() {
let io = MockBuilder::new()
.read(GREETING)
.write(b"idle\n")
.write(b"noidle\n")
.read(b"OK\n")
.write(b"readpicture foo/bar.mp3 0\n")
.read(b"OK\n")
.write(b"albumart foo/bar.mp3 0\n")
.read(b"OK\n")
.build();
let (client, _) = Client::connect(io).await.expect("connect failed");
let x = client
.album_art("foo/bar.mp3")
.await
.expect("command failed");
assert_eq!(x, None);
}
#[tokio::test]
async fn protocol_version() {
let io = MockBuilder::new().read(GREETING).write(b"idle\n").build();
let (client, _state_changes) = Client::connect(io).await.expect("connect failed");
assert_eq!(client.protocol_version(), "0.21.11");
}
#[test]
fn subsystem_equality() {
assert_eq!(Subsystem::Player, Subsystem::Other("player".into()));
let mut a = DefaultHasher::new();
Subsystem::Player.hash(&mut a);
let mut b = DefaultHasher::new();
Subsystem::Other("player".into()).hash(&mut b);
assert_eq!(a.finish(), b.finish());
}
}