ios-core 0.1.7

High-level device API, pairing transport, and discovery for iOS devices
Documentation
//! Companion proxy service client.
//!
//! Service: `com.apple.companion_proxy`

use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};

pub const SERVICE_NAME: &str = "com.apple.companion_proxy";

service_error!(CompanionError);

pub struct CompanionProxyClient<S> {
    stream: S,
}

impl<S: AsyncRead + AsyncWrite + Unpin> CompanionProxyClient<S> {
    pub fn new(stream: S) -> Self {
        Self { stream }
    }

    pub async fn list(&mut self) -> Result<Vec<plist::Value>, CompanionError> {
        let request = plist::Dictionary::from_iter([(
            "Command".to_string(),
            plist::Value::String("GetDeviceRegistry".into()),
        )]);
        send_plist(&mut self.stream, &plist::Value::Dictionary(request)).await?;
        let response = recv_plist(&mut self.stream).await?;
        Ok(response
            .get("PairedDevicesArray")
            .and_then(plist::Value::as_array)
            .cloned()
            .unwrap_or_default())
    }

    pub async fn get_value(
        &mut self,
        udid: &str,
        key: &str,
    ) -> Result<plist::Value, CompanionError> {
        let request = plist::Dictionary::from_iter([
            (
                "Command".to_string(),
                plist::Value::String("GetValueFromRegistry".into()),
            ),
            (
                "GetValueGizmoUDIDKey".to_string(),
                plist::Value::String(udid.to_string()),
            ),
            (
                "GetValueKeyKey".to_string(),
                plist::Value::String(key.to_string()),
            ),
        ]);
        send_plist(&mut self.stream, &plist::Value::Dictionary(request)).await?;
        let response = recv_plist(&mut self.stream).await?;
        if let Some(value) = response.get("RetrievedValueDictionary") {
            return Ok(value.clone());
        }
        let error = response
            .get("Error")
            .and_then(plist::Value::as_string)
            .unwrap_or("missing RetrievedValueDictionary");
        Err(CompanionError::Protocol(error.to_string()))
    }
}

async fn send_plist<S: AsyncWrite + Unpin>(
    stream: &mut S,
    value: &plist::Value,
) -> Result<(), CompanionError> {
    let mut buf = Vec::new();
    plist::to_writer_xml(&mut buf, value)?;
    stream.write_all(&(buf.len() as u32).to_be_bytes()).await?;
    stream.write_all(&buf).await?;
    stream.flush().await?;
    Ok(())
}

async fn recv_plist<S: AsyncRead + Unpin>(
    stream: &mut S,
) -> Result<plist::Dictionary, CompanionError> {
    let mut len_buf = [0u8; 4];
    stream.read_exact(&mut len_buf).await?;
    let len = u32::from_be_bytes(len_buf) as usize;
    const MAX_PLIST_SIZE: usize = 1024 * 1024;
    if len > MAX_PLIST_SIZE {
        return Err(CompanionError::Protocol(format!(
            "plist length {len} exceeds max {MAX_PLIST_SIZE}"
        )));
    }
    let mut buf = vec![0u8; len];
    stream.read_exact(&mut buf).await?;
    Ok(plist::from_bytes(&buf)?)
}

#[cfg(test)]
mod tests {
    use crate::test_util::MockStream;

    use super::*;

    #[tokio::test]
    async fn list_sends_get_device_registry_command() {
        let response = plist::Value::Dictionary(plist::Dictionary::from_iter([(
            "PairedDevicesArray".to_string(),
            plist::Value::Array(vec![plist::Value::String("watch".into())]),
        )]));
        let mut stream = MockStream::with_response(response);
        let mut client = CompanionProxyClient::new(&mut stream);

        let devices = client.list().await.unwrap();
        assert_eq!(devices.len(), 1);

        let len = u32::from_be_bytes(stream.written[..4].try_into().unwrap()) as usize;
        let payload = &stream.written[4..4 + len];
        let dict: plist::Dictionary = plist::from_bytes(payload).unwrap();
        assert_eq!(dict["Command"].as_string(), Some("GetDeviceRegistry"));
    }

    #[tokio::test]
    async fn get_value_sends_registry_lookup_request() {
        let response = plist::Value::Dictionary(plist::Dictionary::from_iter([(
            "RetrievedValueDictionary".to_string(),
            plist::Value::String("AppleWatch".into()),
        )]));
        let mut stream = MockStream::with_response(response);
        let mut client = CompanionProxyClient::new(&mut stream);

        let value = client.get_value("watch-udid", "name").await.unwrap();
        assert_eq!(value.as_string(), Some("AppleWatch"));

        let len = u32::from_be_bytes(stream.written[..4].try_into().unwrap()) as usize;
        let payload = &stream.written[4..4 + len];
        let dict: plist::Dictionary = plist::from_bytes(payload).unwrap();
        assert_eq!(dict["Command"].as_string(), Some("GetValueFromRegistry"));
        assert_eq!(dict["GetValueGizmoUDIDKey"].as_string(), Some("watch-udid"));
        assert_eq!(dict["GetValueKeyKey"].as_string(), Some("name"));
    }
}