ios-core 0.1.7

High-level device API, pairing transport, and discovery for iOS devices
Documentation
//! Minimal notification proxy client.
//!
//! Service: `com.apple.mobile.notification_proxy`
//! Reference: go-ios/ios/notificationproxy/notificationproxy.go

use std::collections::HashSet;
use std::time::Duration;

use serde::{Deserialize, Serialize};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::time::Instant;

pub const SERVICE_NAME: &str = "com.apple.mobile.notification_proxy";
pub const SPRINGBOARD_FINISHED_STARTUP: &str = "com.apple.springboard.finishedstartup";

#[derive(Debug, Clone, PartialEq, Eq)]
pub enum NotificationEvent {
    Notification(String),
    ProxyDeath,
}

service_error!(
    NotificationProxyError,
    #[error("proxy closed before notification arrived")]
    ProxyDeath,
    #[error("timed out waiting for notification")]
    Timeout,
);

#[derive(Debug, Clone, PartialEq, Eq)]
pub enum NotificationProxyEvent {
    Notification(String),
    ProxyDeath,
}

#[derive(Debug)]
pub struct NotificationProxyClient<S> {
    stream: S,
    observing: HashSet<String>,
}

impl<S: AsyncRead + AsyncWrite + Unpin> NotificationProxyClient<S> {
    pub fn new(stream: S) -> Self {
        Self {
            stream,
            observing: HashSet::new(),
        }
    }

    pub async fn observe(&mut self, notification: &str) -> Result<(), NotificationProxyError> {
        if self.observing.contains(notification) {
            return Ok(());
        }

        self.send_request(NotificationProxyRequest {
            command: "ObserveNotification",
            name: Some(notification),
        })
        .await?;
        self.observing.insert(notification.to_string());
        Ok(())
    }

    pub async fn post(&mut self, notification: &str) -> Result<(), NotificationProxyError> {
        self.send_request(NotificationProxyRequest {
            command: "PostNotification",
            name: Some(notification),
        })
        .await
    }

    pub async fn wait_for(
        &mut self,
        notification: &str,
        timeout: Duration,
    ) -> Result<(), NotificationProxyError> {
        self.observe(notification).await?;

        let deadline = Instant::now() + timeout;
        loop {
            let remaining = deadline.saturating_duration_since(Instant::now());
            if remaining.is_zero() {
                return Err(NotificationProxyError::Timeout);
            }

            let event = tokio::time::timeout(remaining, self.recv_event())
                .await
                .map_err(|_| NotificationProxyError::Timeout)??;

            match event {
                NotificationEvent::Notification(name) if name == notification => return Ok(()),
                NotificationEvent::ProxyDeath => return Err(NotificationProxyError::ProxyDeath),
                NotificationEvent::Notification(_) => {}
            }
        }
    }

    pub async fn wait_for_springboard(
        &mut self,
        timeout: Duration,
    ) -> Result<(), NotificationProxyError> {
        self.wait_for(SPRINGBOARD_FINISHED_STARTUP, timeout).await
    }

    pub async fn next_event(
        &mut self,
        timeout: Duration,
    ) -> Result<NotificationProxyEvent, NotificationProxyError> {
        let message = tokio::time::timeout(timeout, self.recv_message())
            .await
            .map_err(|_| NotificationProxyError::Timeout)??;

        match message.command.as_deref() {
            Some("RelayNotification") => message
                .name
                .map(NotificationProxyEvent::Notification)
                .ok_or_else(|| {
                    NotificationProxyError::Protocol("RelayNotification missing Name field".into())
                }),
            Some("ProxyDeath") => Ok(NotificationProxyEvent::ProxyDeath),
            other => Err(NotificationProxyError::Protocol(format!(
                "unexpected notification proxy command: {}",
                other.unwrap_or("<missing>")
            ))),
        }
    }

    pub async fn shutdown(&mut self) -> Result<(), NotificationProxyError> {
        self.send_request(NotificationProxyRequest {
            command: "Shutdown",
            name: None,
        })
        .await
    }

    pub async fn recv_event(&mut self) -> Result<NotificationEvent, NotificationProxyError> {
        let message = self.recv_message().await?;
        match message.command.as_deref() {
            Some("RelayNotification") => Ok(NotificationEvent::Notification(
                message.name.ok_or_else(|| {
                    NotificationProxyError::Protocol("RelayNotification missing Name".to_string())
                })?,
            )),
            Some("ProxyDeath") => Ok(NotificationEvent::ProxyDeath),
            Some(other) => Err(NotificationProxyError::Protocol(format!(
                "unexpected notification proxy command: {other}"
            ))),
            None => Err(NotificationProxyError::Protocol(
                "notification proxy message missing Command".to_string(),
            )),
        }
    }

    async fn send_request(
        &mut self,
        request: NotificationProxyRequest<'_>,
    ) -> Result<(), NotificationProxyError> {
        let mut buf = Vec::new();
        plist::to_writer_xml(&mut buf, &request)?;
        self.stream
            .write_all(&(buf.len() as u32).to_be_bytes())
            .await?;
        self.stream.write_all(&buf).await?;
        self.stream.flush().await?;
        Ok(())
    }

    async fn recv_message(&mut self) -> Result<NotificationProxyMessage, NotificationProxyError> {
        let mut len_buf = [0u8; 4];
        self.stream.read_exact(&mut len_buf).await?;
        let len = u32::from_be_bytes(len_buf) as usize;
        const MAX_PLIST_SIZE: usize = 4 * 1024 * 1024;
        if len > MAX_PLIST_SIZE {
            return Err(NotificationProxyError::Protocol(format!(
                "plist length {len} exceeds max {MAX_PLIST_SIZE}"
            )));
        }
        let mut buf = vec![0u8; len];
        self.stream.read_exact(&mut buf).await?;
        Ok(plist::from_bytes(&buf)?)
    }
}

#[derive(Serialize)]
#[serde(rename_all = "PascalCase")]
struct NotificationProxyRequest<'a> {
    command: &'static str,
    #[serde(skip_serializing_if = "Option::is_none")]
    name: Option<&'a str>,
}

#[derive(Debug, Deserialize)]
#[serde(rename_all = "PascalCase")]
struct NotificationProxyMessage {
    #[serde(default)]
    command: Option<String>,
    #[serde(default)]
    name: Option<String>,
}

#[cfg(test)]
mod tests {
    use crate::test_util::MockStream;

    use super::*;

    fn plist_frame(value: plist::Value) -> Vec<u8> {
        let mut buf = Vec::new();
        plist::to_writer_xml(&mut buf, &value).unwrap();
        buf
    }

    #[tokio::test]
    async fn observe_encodes_notification_request() {
        let mut stream = MockStream::default();
        let mut client = NotificationProxyClient::new(&mut stream);
        client.observe("com.apple.example.ready").await.unwrap();

        let len = u32::from_be_bytes(stream.written[..4].try_into().unwrap()) as usize;
        let payload = &stream.written[4..4 + len];
        let dict: plist::Dictionary = plist::from_bytes(payload).unwrap();
        assert_eq!(dict["Command"].as_string(), Some("ObserveNotification"));
        assert_eq!(dict["Name"].as_string(), Some("com.apple.example.ready"));
    }

    #[tokio::test]
    async fn post_encodes_notification_request() {
        let mut stream = MockStream::default();
        let mut client = NotificationProxyClient::new(&mut stream);
        client.post("com.apple.example.trigger").await.unwrap();

        let len = u32::from_be_bytes(stream.written[..4].try_into().unwrap()) as usize;
        let payload = &stream.written[4..4 + len];
        let dict: plist::Dictionary = plist::from_bytes(payload).unwrap();
        assert_eq!(dict["Command"].as_string(), Some("PostNotification"));
        assert_eq!(dict["Name"].as_string(), Some("com.apple.example.trigger"));
    }

    #[tokio::test]
    async fn wait_for_matches_relay_notification() {
        let frame = plist_frame(plist::Value::Dictionary(plist::Dictionary::from_iter([
            (
                "Command".to_string(),
                plist::Value::String("RelayNotification".into()),
            ),
            (
                "Name".to_string(),
                plist::Value::String("com.apple.example.ready".into()),
            ),
        ])));
        let mut stream = MockStream::with_frames(vec![frame]);
        let mut client = NotificationProxyClient::new(&mut stream);

        client
            .wait_for("com.apple.example.ready", Duration::from_millis(100))
            .await
            .unwrap();
    }

    #[tokio::test]
    async fn recv_event_decodes_relay_notification() {
        let frame = plist_frame(plist::Value::Dictionary(plist::Dictionary::from_iter([
            (
                "Command".to_string(),
                plist::Value::String("RelayNotification".into()),
            ),
            (
                "Name".to_string(),
                plist::Value::String("com.apple.example.ready".into()),
            ),
        ])));
        let mut stream = MockStream::with_frames(vec![frame]);
        let mut client = NotificationProxyClient::new(&mut stream);

        let event = client.recv_event().await.unwrap();
        assert_eq!(
            event,
            NotificationEvent::Notification("com.apple.example.ready".into())
        );
    }

    #[tokio::test]
    async fn recv_event_decodes_proxy_death() {
        let frame = plist_frame(plist::Value::Dictionary(plist::Dictionary::from_iter([(
            "Command".to_string(),
            plist::Value::String("ProxyDeath".into()),
        )])));
        let mut stream = MockStream::with_frames(vec![frame]);
        let mut client = NotificationProxyClient::new(&mut stream);

        let event = client.recv_event().await.unwrap();
        assert_eq!(event, NotificationEvent::ProxyDeath);
    }

    #[tokio::test]
    async fn wait_for_springboard_uses_expected_name() {
        let frame = plist_frame(plist::Value::Dictionary(plist::Dictionary::from_iter([
            (
                "Command".to_string(),
                plist::Value::String("RelayNotification".into()),
            ),
            (
                "Name".to_string(),
                plist::Value::String(SPRINGBOARD_FINISHED_STARTUP.into()),
            ),
        ])));
        let mut stream = MockStream::with_frames(vec![frame]);
        let mut client = NotificationProxyClient::new(&mut stream);

        client
            .wait_for_springboard(Duration::from_millis(100))
            .await
            .unwrap();
    }

    #[tokio::test]
    async fn next_event_returns_notification_name() {
        let frame = plist_frame(plist::Value::Dictionary(plist::Dictionary::from_iter([
            (
                "Command".to_string(),
                plist::Value::String("RelayNotification".into()),
            ),
            (
                "Name".to_string(),
                plist::Value::String("com.apple.example.stream".into()),
            ),
        ])));
        let mut stream = MockStream::with_frames(vec![frame]);
        let mut client = NotificationProxyClient::new(&mut stream);

        let event = client.next_event(Duration::from_millis(100)).await.unwrap();
        assert_eq!(
            event,
            NotificationProxyEvent::Notification("com.apple.example.stream".into())
        );
    }

    #[tokio::test]
    async fn next_event_maps_proxy_death() {
        let frame = plist_frame(plist::Value::Dictionary(plist::Dictionary::from_iter([(
            "Command".to_string(),
            plist::Value::String("ProxyDeath".into()),
        )])));
        let mut stream = MockStream::with_frames(vec![frame]);
        let mut client = NotificationProxyClient::new(&mut stream);

        let event = client.next_event(Duration::from_millis(100)).await.unwrap();
        assert_eq!(event, NotificationProxyEvent::ProxyDeath);
    }

    #[tokio::test]
    async fn wait_for_times_out_when_no_notification_arrives() {
        let (client_side, _server_side) = tokio::io::duplex(1024);
        let mut client = NotificationProxyClient::new(client_side);

        let err = client
            .wait_for("com.apple.example.ready", Duration::from_millis(10))
            .await
            .unwrap_err();
        assert!(matches!(err, NotificationProxyError::Timeout));
    }
}