ios-core 0.1.7

High-level device API, pairing transport, and discovery for iOS devices
Documentation
use serde::{de::DeserializeOwned, Serialize};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};

#[derive(Debug, thiserror::Error)]
pub(crate) enum PlistFrameError {
    #[error("IO error: {0}")]
    Io(#[from] std::io::Error),
    #[error("plist error: {0}")]
    Plist(#[from] plist::Error),
    #[error("protocol error: {0}")]
    Protocol(String),
}

pub(crate) async fn write_xml_plist_frame<W, T>(
    writer: &mut W,
    value: &T,
    max_len: usize,
) -> Result<(), PlistFrameError>
where
    W: AsyncWrite + Unpin,
    T: Serialize,
{
    let mut payload = Vec::new();
    plist::to_writer_xml(&mut payload, value)?;
    write_raw_plist_frame(writer, &payload, max_len).await
}

#[allow(dead_code)]
pub(crate) async fn write_binary_plist_frame<W, T>(
    writer: &mut W,
    value: &T,
    max_len: usize,
) -> Result<(), PlistFrameError>
where
    W: AsyncWrite + Unpin,
    T: Serialize,
{
    let mut payload = Vec::new();
    plist::to_writer_binary(&mut payload, value)?;
    write_raw_plist_frame(writer, &payload, max_len).await
}

pub(crate) async fn write_raw_plist_frame<W>(
    writer: &mut W,
    payload: &[u8],
    max_len: usize,
) -> Result<(), PlistFrameError>
where
    W: AsyncWrite + Unpin,
{
    if payload.len() > max_len {
        return Err(PlistFrameError::Protocol(format!(
            "plist frame length {} exceeds max {max_len}",
            payload.len()
        )));
    }
    let len = u32::try_from(payload.len()).map_err(|_| {
        PlistFrameError::Protocol(format!("plist frame too large: {}", payload.len()))
    })?;
    writer.write_all(&len.to_be_bytes()).await?;
    writer.write_all(payload).await?;
    writer.flush().await?;
    Ok(())
}

pub(crate) async fn read_plist_frame<R, T>(
    reader: &mut R,
    max_len: usize,
) -> Result<T, PlistFrameError>
where
    R: AsyncRead + Unpin,
    T: DeserializeOwned,
{
    let mut len_buf = [0u8; 4];
    reader.read_exact(&mut len_buf).await?;
    let len = u32::from_be_bytes(len_buf) as usize;
    if len > max_len {
        return Err(PlistFrameError::Protocol(format!(
            "plist frame length {len} exceeds max {max_len}"
        )));
    }
    let mut payload = vec![0u8; len];
    reader.read_exact(&mut payload).await?;
    plist::from_bytes(&payload).map_err(PlistFrameError::from)
}

#[cfg(test)]
mod tests {
    use super::*;
    use plist::Value;
    use tokio::io::AsyncReadExt;

    #[tokio::test]
    async fn write_xml_plist_frame_writes_big_endian_len_and_payload() {
        let value = Value::Dictionary(plist::Dictionary::from_iter([(
            "Status".to_string(),
            Value::String("Acknowledged".to_string()),
        )]));
        let mut output = Vec::new();

        write_xml_plist_frame(&mut output, &value, 1024)
            .await
            .unwrap();

        let len = u32::from_be_bytes(output[..4].try_into().unwrap()) as usize;
        let decoded: Value = plist::from_bytes(&output[4..4 + len]).unwrap();
        assert_eq!(decoded, value);
    }

    #[tokio::test]
    async fn write_xml_plist_frame_rejects_oversized_payload() {
        let value = Value::String("payload".repeat(64));
        let mut output = Vec::new();

        let err = write_xml_plist_frame(&mut output, &value, 8)
            .await
            .unwrap_err();

        assert!(err.to_string().contains("plist frame"));
        assert!(output.is_empty());
    }

    #[tokio::test]
    async fn read_plist_frame_rejects_oversized_len_before_allocating() {
        let mut input = std::io::Cursor::new((1025u32).to_be_bytes());

        let err = read_plist_frame::<_, Value>(&mut input, 1024)
            .await
            .unwrap_err();

        assert!(err.to_string().contains("exceeds max"));
    }

    #[tokio::test]
    async fn read_plist_frame_reads_exact_payload() {
        let value = Value::String("ok".to_string());
        let mut framed = Vec::new();
        write_xml_plist_frame(&mut framed, &value, 1024)
            .await
            .unwrap();
        framed.extend_from_slice(b"trailing");
        let mut input = std::io::Cursor::new(framed);

        let decoded: Value = read_plist_frame(&mut input, 1024).await.unwrap();

        assert_eq!(decoded, value);
        let mut trailing = Vec::new();
        input.read_to_end(&mut trailing).await.unwrap();
        assert_eq!(trailing, b"trailing");
    }
}