use serde::Serialize;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
service_error!(DeviceLinkError);
pub struct DeviceLinkClient<S> {
stream: S,
}
impl<S> DeviceLinkClient<S> {
pub fn new(stream: S) -> Self {
Self { stream }
}
pub fn into_inner(self) -> S {
self.stream
}
pub fn stream_mut(&mut self) -> &mut S {
&mut self.stream
}
}
impl<S> DeviceLinkClient<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
pub async fn version_exchange(&mut self) -> Result<u64, DeviceLinkError> {
let response = self.recv_message().await?;
let message = response.as_array().ok_or_else(|| {
DeviceLinkError::Protocol(format!(
"device link version exchange expected array, got {response:?}"
))
})?;
let message_type = message
.first()
.and_then(plist::Value::as_string)
.ok_or_else(|| {
DeviceLinkError::Protocol(format!(
"device link version exchange missing message type: {response:?}"
))
})?;
if message_type != "DLMessageVersionExchange" {
return Err(DeviceLinkError::Protocol(format!(
"expected DLMessageVersionExchange, got {message_type}"
)));
}
let version = message
.get(1)
.and_then(|value| match value {
plist::Value::Integer(value) => value.as_unsigned(),
_ => None,
})
.ok_or_else(|| {
DeviceLinkError::Protocol(format!(
"device link version exchange missing major version: {response:?}"
))
})?;
self.send_message(&vec![
plist::Value::String("DLMessageVersionExchange".into()),
plist::Value::String("DLVersionsOk".into()),
plist::Value::Integer(version.into()),
])
.await?;
let ready = self.recv_message().await?;
let ready_message = ready.as_array().ok_or_else(|| {
DeviceLinkError::Protocol(format!("device ready expected array, got {ready:?}"))
})?;
let ready_type = ready_message
.first()
.and_then(plist::Value::as_string)
.ok_or_else(|| {
DeviceLinkError::Protocol(format!("device ready missing message type: {ready:?}"))
})?;
if ready_type != "DLMessageDeviceReady" {
return Err(DeviceLinkError::Protocol(format!(
"expected DLMessageDeviceReady, got {ready_type}"
)));
}
Ok(version)
}
pub async fn send_process_message<T>(&mut self, message: &T) -> Result<(), DeviceLinkError>
where
T: Serialize,
{
self.send_message(&("DLMessageProcessMessage", message))
.await
}
pub async fn recv_process_message(&mut self) -> Result<plist::Dictionary, DeviceLinkError> {
let response = self.recv_message().await?;
let message = response.as_array().ok_or_else(|| {
DeviceLinkError::Protocol(format!("process message expected array, got {response:?}"))
})?;
let message_type = message
.first()
.and_then(plist::Value::as_string)
.ok_or_else(|| {
DeviceLinkError::Protocol(format!(
"process message missing message type: {response:?}"
))
})?;
if message_type != "DLMessageProcessMessage" {
return Err(DeviceLinkError::Protocol(format!(
"expected DLMessageProcessMessage, got {message_type}"
)));
}
message
.get(1)
.and_then(plist::Value::as_dictionary)
.cloned()
.ok_or_else(|| {
DeviceLinkError::Protocol(format!(
"process message missing dictionary payload: {response:?}"
))
})
}
pub async fn send_message<T>(&mut self, message: &T) -> Result<(), DeviceLinkError>
where
T: Serialize,
{
let mut payload = Vec::new();
plist::to_writer_xml(&mut payload, message)
.map_err(|e| DeviceLinkError::Plist(e.to_string()))?;
self.stream
.write_all(&(payload.len() as u32).to_be_bytes())
.await?;
self.stream.write_all(&payload).await?;
self.stream.flush().await?;
Ok(())
}
pub async fn recv_message(&mut self) -> Result<plist::Value, DeviceLinkError> {
let mut len_buf = [0u8; 4];
self.stream.read_exact(&mut len_buf).await?;
let len = u32::from_be_bytes(len_buf) as usize;
const MAX_PLIST_SIZE: usize = 4 * 1024 * 1024;
if len > MAX_PLIST_SIZE {
return Err(DeviceLinkError::Protocol(format!(
"plist length {len} exceeds maximum of {MAX_PLIST_SIZE}"
)));
}
let mut payload = vec![0u8; len];
self.stream.read_exact(&mut payload).await?;
plist::from_bytes(&payload).map_err(|e| DeviceLinkError::Plist(e.to_string()))
}
pub async fn disconnect(&mut self) -> Result<(), DeviceLinkError> {
self.send_message(&vec![
plist::Value::String("DLMessageDisconnect".into()),
plist::Value::String("___EmptyParameterString___".into()),
])
.await
}
}
#[cfg(test)]
mod tests {
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
use super::*;
fn encode_frame(value: &plist::Value) -> Vec<u8> {
let mut payload = Vec::new();
plist::to_writer_xml(&mut payload, value).expect("plist serialization");
let mut frame = Vec::with_capacity(payload.len() + 4);
frame.extend_from_slice(&(payload.len() as u32).to_be_bytes());
frame.extend_from_slice(&payload);
frame
}
async fn read_frame(stream: &mut tokio::io::DuplexStream) -> plist::Value {
let mut len_buf = [0u8; 4];
stream.read_exact(&mut len_buf).await.expect("frame length");
let len = u32::from_be_bytes(len_buf) as usize;
let mut payload = vec![0u8; len];
stream
.read_exact(&mut payload)
.await
.expect("frame payload");
plist::from_bytes(&payload).expect("plist decode")
}
#[tokio::test]
async fn version_exchange_sends_versions_ok_and_returns_major_version() {
let (client_stream, mut server_stream) = duplex(4096);
let task = tokio::spawn(async move {
let mut client = DeviceLinkClient::new(client_stream);
client.version_exchange().await.unwrap()
});
server_stream
.write_all(&encode_frame(&plist::Value::Array(vec![
plist::Value::String("DLMessageVersionExchange".into()),
plist::Value::Integer(300u64.into()),
])))
.await
.unwrap();
let versions_ok = read_frame(&mut server_stream).await;
assert_eq!(
versions_ok.as_array(),
Some(&vec![
plist::Value::String("DLMessageVersionExchange".into()),
plist::Value::String("DLVersionsOk".into()),
plist::Value::Integer(300u64.into()),
])
);
server_stream
.write_all(&encode_frame(&plist::Value::Array(vec![
plist::Value::String("DLMessageDeviceReady".into()),
])))
.await
.unwrap();
assert_eq!(task.await.unwrap(), 300);
}
#[tokio::test]
async fn recv_process_message_requires_dictionary_payload() {
let (client_stream, mut server_stream) = duplex(4096);
let task = tokio::spawn(async move {
let mut client = DeviceLinkClient::new(client_stream);
client.recv_process_message().await
});
server_stream
.write_all(&encode_frame(&plist::Value::Array(vec![
plist::Value::String("DLMessageProcessMessage".into()),
plist::Value::String("not-a-dict".into()),
])))
.await
.unwrap();
let err = task
.await
.unwrap()
.expect_err("non-dictionary payload must fail");
assert!(err
.to_string()
.contains("process message missing dictionary payload"));
}
#[tokio::test]
async fn disconnect_sends_expected_message() {
let (client_stream, mut server_stream) = duplex(4096);
let task = tokio::spawn(async move {
let mut client = DeviceLinkClient::new(client_stream);
client.disconnect().await.unwrap();
});
let disconnect = read_frame(&mut server_stream).await;
assert_eq!(
disconnect.as_array(),
Some(&vec![
plist::Value::String("DLMessageDisconnect".into()),
plist::Value::String("___EmptyParameterString___".into()),
])
);
task.await.unwrap();
}
}