Skip to main content

ios_core/services/notificationproxy/
mod.rs

1//! Minimal notification proxy client.
2//!
3//! Service: `com.apple.mobile.notification_proxy`
4//! Reference: go-ios/ios/notificationproxy/notificationproxy.go
5
6use std::collections::HashSet;
7use std::time::Duration;
8
9use serde::{Deserialize, Serialize};
10use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
11use tokio::time::Instant;
12
13pub const SERVICE_NAME: &str = "com.apple.mobile.notification_proxy";
14pub const SPRINGBOARD_FINISHED_STARTUP: &str = "com.apple.springboard.finishedstartup";
15
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub enum NotificationEvent {
18    Notification(String),
19    ProxyDeath,
20}
21
22#[derive(Debug, thiserror::Error)]
23pub enum NotificationProxyError {
24    #[error("IO error: {0}")]
25    Io(#[from] std::io::Error),
26    #[error("plist error: {0}")]
27    Plist(String),
28    #[error("protocol error: {0}")]
29    Protocol(String),
30    #[error("proxy closed before notification arrived")]
31    ProxyDeath,
32    #[error("timed out waiting for notification")]
33    Timeout,
34}
35
36#[derive(Debug, Clone, PartialEq, Eq)]
37pub enum NotificationProxyEvent {
38    Notification(String),
39    ProxyDeath,
40}
41
42#[derive(Debug)]
43pub struct NotificationProxyClient<S> {
44    stream: S,
45    observing: HashSet<String>,
46}
47
48impl<S: AsyncRead + AsyncWrite + Unpin> NotificationProxyClient<S> {
49    pub fn new(stream: S) -> Self {
50        Self {
51            stream,
52            observing: HashSet::new(),
53        }
54    }
55
56    pub async fn observe(&mut self, notification: &str) -> Result<(), NotificationProxyError> {
57        if self.observing.contains(notification) {
58            return Ok(());
59        }
60
61        self.send_request(NotificationProxyRequest {
62            command: "ObserveNotification",
63            name: Some(notification),
64        })
65        .await?;
66        self.observing.insert(notification.to_string());
67        Ok(())
68    }
69
70    pub async fn post(&mut self, notification: &str) -> Result<(), NotificationProxyError> {
71        self.send_request(NotificationProxyRequest {
72            command: "PostNotification",
73            name: Some(notification),
74        })
75        .await
76    }
77
78    pub async fn wait_for(
79        &mut self,
80        notification: &str,
81        timeout: Duration,
82    ) -> Result<(), NotificationProxyError> {
83        self.observe(notification).await?;
84
85        let deadline = Instant::now() + timeout;
86        loop {
87            let remaining = deadline.saturating_duration_since(Instant::now());
88            if remaining.is_zero() {
89                return Err(NotificationProxyError::Timeout);
90            }
91
92            let event = tokio::time::timeout(remaining, self.recv_event())
93                .await
94                .map_err(|_| NotificationProxyError::Timeout)??;
95
96            match event {
97                NotificationEvent::Notification(name) if name == notification => return Ok(()),
98                NotificationEvent::ProxyDeath => return Err(NotificationProxyError::ProxyDeath),
99                NotificationEvent::Notification(_) => {}
100            }
101        }
102    }
103
104    pub async fn wait_for_springboard(
105        &mut self,
106        timeout: Duration,
107    ) -> Result<(), NotificationProxyError> {
108        self.wait_for(SPRINGBOARD_FINISHED_STARTUP, timeout).await
109    }
110
111    pub async fn next_event(
112        &mut self,
113        timeout: Duration,
114    ) -> Result<NotificationProxyEvent, NotificationProxyError> {
115        let message = tokio::time::timeout(timeout, self.recv_message())
116            .await
117            .map_err(|_| NotificationProxyError::Timeout)??;
118
119        match message.command.as_deref() {
120            Some("RelayNotification") => message
121                .name
122                .map(NotificationProxyEvent::Notification)
123                .ok_or_else(|| {
124                    NotificationProxyError::Protocol("RelayNotification missing Name field".into())
125                }),
126            Some("ProxyDeath") => Ok(NotificationProxyEvent::ProxyDeath),
127            other => Err(NotificationProxyError::Protocol(format!(
128                "unexpected notification proxy command: {}",
129                other.unwrap_or("<missing>")
130            ))),
131        }
132    }
133
134    pub async fn shutdown(&mut self) -> Result<(), NotificationProxyError> {
135        self.send_request(NotificationProxyRequest {
136            command: "Shutdown",
137            name: None,
138        })
139        .await
140    }
141
142    pub async fn recv_event(&mut self) -> Result<NotificationEvent, NotificationProxyError> {
143        let message = self.recv_message().await?;
144        match message.command.as_deref() {
145            Some("RelayNotification") => Ok(NotificationEvent::Notification(
146                message.name.ok_or_else(|| {
147                    NotificationProxyError::Protocol("RelayNotification missing Name".to_string())
148                })?,
149            )),
150            Some("ProxyDeath") => Ok(NotificationEvent::ProxyDeath),
151            Some(other) => Err(NotificationProxyError::Protocol(format!(
152                "unexpected notification proxy command: {other}"
153            ))),
154            None => Err(NotificationProxyError::Protocol(
155                "notification proxy message missing Command".to_string(),
156            )),
157        }
158    }
159
160    async fn send_request(
161        &mut self,
162        request: NotificationProxyRequest<'_>,
163    ) -> Result<(), NotificationProxyError> {
164        let mut buf = Vec::new();
165        plist::to_writer_xml(&mut buf, &request)
166            .map_err(|e| NotificationProxyError::Plist(e.to_string()))?;
167        self.stream
168            .write_all(&(buf.len() as u32).to_be_bytes())
169            .await?;
170        self.stream.write_all(&buf).await?;
171        self.stream.flush().await?;
172        Ok(())
173    }
174
175    async fn recv_message(&mut self) -> Result<NotificationProxyMessage, NotificationProxyError> {
176        let mut len_buf = [0u8; 4];
177        self.stream.read_exact(&mut len_buf).await?;
178        let len = u32::from_be_bytes(len_buf) as usize;
179        const MAX_PLIST_SIZE: usize = 4 * 1024 * 1024;
180        if len > MAX_PLIST_SIZE {
181            return Err(NotificationProxyError::Protocol(format!(
182                "plist length {len} exceeds max {MAX_PLIST_SIZE}"
183            )));
184        }
185        let mut buf = vec![0u8; len];
186        self.stream.read_exact(&mut buf).await?;
187        plist::from_bytes(&buf).map_err(|e| NotificationProxyError::Plist(e.to_string()))
188    }
189}
190
191#[derive(Serialize)]
192#[serde(rename_all = "PascalCase")]
193struct NotificationProxyRequest<'a> {
194    command: &'static str,
195    #[serde(skip_serializing_if = "Option::is_none")]
196    name: Option<&'a str>,
197}
198
199#[derive(Debug, Deserialize)]
200#[serde(rename_all = "PascalCase")]
201struct NotificationProxyMessage {
202    #[serde(default)]
203    command: Option<String>,
204    #[serde(default)]
205    name: Option<String>,
206}
207
208#[cfg(test)]
209mod tests {
210    use std::pin::Pin;
211    use std::task::{Context, Poll};
212
213    use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
214
215    use super::*;
216
217    #[derive(Default)]
218    struct MockStream {
219        read_buf: Vec<u8>,
220        written: Vec<u8>,
221        read_pos: usize,
222    }
223
224    impl MockStream {
225        fn with_frames(frames: Vec<Vec<u8>>) -> Self {
226            let mut read_buf = Vec::new();
227            for frame in frames {
228                read_buf.extend_from_slice(&(frame.len() as u32).to_be_bytes());
229                read_buf.extend_from_slice(&frame);
230            }
231            Self {
232                read_buf,
233                written: Vec::new(),
234                read_pos: 0,
235            }
236        }
237    }
238
239    impl AsyncRead for MockStream {
240        fn poll_read(
241            mut self: Pin<&mut Self>,
242            _cx: &mut Context<'_>,
243            buf: &mut ReadBuf<'_>,
244        ) -> Poll<std::io::Result<()>> {
245            let remaining = self.read_buf.len().saturating_sub(self.read_pos);
246            if remaining == 0 {
247                return Poll::Ready(Err(std::io::Error::new(
248                    std::io::ErrorKind::UnexpectedEof,
249                    "no more test data",
250                )));
251            }
252
253            let to_copy = remaining.min(buf.remaining());
254            let start = self.read_pos;
255            let end = start + to_copy;
256            buf.put_slice(&self.read_buf[start..end]);
257            self.read_pos = end;
258            Poll::Ready(Ok(()))
259        }
260    }
261
262    impl AsyncWrite for MockStream {
263        fn poll_write(
264            mut self: Pin<&mut Self>,
265            _cx: &mut Context<'_>,
266            buf: &[u8],
267        ) -> Poll<std::io::Result<usize>> {
268            self.written.extend_from_slice(buf);
269            Poll::Ready(Ok(buf.len()))
270        }
271
272        fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
273            Poll::Ready(Ok(()))
274        }
275
276        fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
277            Poll::Ready(Ok(()))
278        }
279    }
280
281    fn plist_frame(value: plist::Value) -> Vec<u8> {
282        let mut buf = Vec::new();
283        plist::to_writer_xml(&mut buf, &value).unwrap();
284        buf
285    }
286
287    #[tokio::test]
288    async fn observe_encodes_notification_request() {
289        let mut stream = MockStream::default();
290        let mut client = NotificationProxyClient::new(&mut stream);
291        client.observe("com.apple.example.ready").await.unwrap();
292
293        let len = u32::from_be_bytes(stream.written[..4].try_into().unwrap()) as usize;
294        let payload = &stream.written[4..4 + len];
295        let dict: plist::Dictionary = plist::from_bytes(payload).unwrap();
296        assert_eq!(dict["Command"].as_string(), Some("ObserveNotification"));
297        assert_eq!(dict["Name"].as_string(), Some("com.apple.example.ready"));
298    }
299
300    #[tokio::test]
301    async fn post_encodes_notification_request() {
302        let mut stream = MockStream::default();
303        let mut client = NotificationProxyClient::new(&mut stream);
304        client.post("com.apple.example.trigger").await.unwrap();
305
306        let len = u32::from_be_bytes(stream.written[..4].try_into().unwrap()) as usize;
307        let payload = &stream.written[4..4 + len];
308        let dict: plist::Dictionary = plist::from_bytes(payload).unwrap();
309        assert_eq!(dict["Command"].as_string(), Some("PostNotification"));
310        assert_eq!(dict["Name"].as_string(), Some("com.apple.example.trigger"));
311    }
312
313    #[tokio::test]
314    async fn wait_for_matches_relay_notification() {
315        let frame = plist_frame(plist::Value::Dictionary(plist::Dictionary::from_iter([
316            (
317                "Command".to_string(),
318                plist::Value::String("RelayNotification".into()),
319            ),
320            (
321                "Name".to_string(),
322                plist::Value::String("com.apple.example.ready".into()),
323            ),
324        ])));
325        let mut stream = MockStream::with_frames(vec![frame]);
326        let mut client = NotificationProxyClient::new(&mut stream);
327
328        client
329            .wait_for("com.apple.example.ready", Duration::from_millis(100))
330            .await
331            .unwrap();
332    }
333
334    #[tokio::test]
335    async fn recv_event_decodes_relay_notification() {
336        let frame = plist_frame(plist::Value::Dictionary(plist::Dictionary::from_iter([
337            (
338                "Command".to_string(),
339                plist::Value::String("RelayNotification".into()),
340            ),
341            (
342                "Name".to_string(),
343                plist::Value::String("com.apple.example.ready".into()),
344            ),
345        ])));
346        let mut stream = MockStream::with_frames(vec![frame]);
347        let mut client = NotificationProxyClient::new(&mut stream);
348
349        let event = client.recv_event().await.unwrap();
350        assert_eq!(
351            event,
352            NotificationEvent::Notification("com.apple.example.ready".into())
353        );
354    }
355
356    #[tokio::test]
357    async fn recv_event_decodes_proxy_death() {
358        let frame = plist_frame(plist::Value::Dictionary(plist::Dictionary::from_iter([(
359            "Command".to_string(),
360            plist::Value::String("ProxyDeath".into()),
361        )])));
362        let mut stream = MockStream::with_frames(vec![frame]);
363        let mut client = NotificationProxyClient::new(&mut stream);
364
365        let event = client.recv_event().await.unwrap();
366        assert_eq!(event, NotificationEvent::ProxyDeath);
367    }
368
369    #[tokio::test]
370    async fn wait_for_springboard_uses_expected_name() {
371        let frame = plist_frame(plist::Value::Dictionary(plist::Dictionary::from_iter([
372            (
373                "Command".to_string(),
374                plist::Value::String("RelayNotification".into()),
375            ),
376            (
377                "Name".to_string(),
378                plist::Value::String(SPRINGBOARD_FINISHED_STARTUP.into()),
379            ),
380        ])));
381        let mut stream = MockStream::with_frames(vec![frame]);
382        let mut client = NotificationProxyClient::new(&mut stream);
383
384        client
385            .wait_for_springboard(Duration::from_millis(100))
386            .await
387            .unwrap();
388    }
389
390    #[tokio::test]
391    async fn next_event_returns_notification_name() {
392        let frame = plist_frame(plist::Value::Dictionary(plist::Dictionary::from_iter([
393            (
394                "Command".to_string(),
395                plist::Value::String("RelayNotification".into()),
396            ),
397            (
398                "Name".to_string(),
399                plist::Value::String("com.apple.example.stream".into()),
400            ),
401        ])));
402        let mut stream = MockStream::with_frames(vec![frame]);
403        let mut client = NotificationProxyClient::new(&mut stream);
404
405        let event = client.next_event(Duration::from_millis(100)).await.unwrap();
406        assert_eq!(
407            event,
408            NotificationProxyEvent::Notification("com.apple.example.stream".into())
409        );
410    }
411
412    #[tokio::test]
413    async fn next_event_maps_proxy_death() {
414        let frame = plist_frame(plist::Value::Dictionary(plist::Dictionary::from_iter([(
415            "Command".to_string(),
416            plist::Value::String("ProxyDeath".into()),
417        )])));
418        let mut stream = MockStream::with_frames(vec![frame]);
419        let mut client = NotificationProxyClient::new(&mut stream);
420
421        let event = client.next_event(Duration::from_millis(100)).await.unwrap();
422        assert_eq!(event, NotificationProxyEvent::ProxyDeath);
423    }
424
425    #[tokio::test]
426    async fn wait_for_times_out_when_no_notification_arrives() {
427        let (client_side, _server_side) = tokio::io::duplex(1024);
428        let mut client = NotificationProxyClient::new(client_side);
429
430        let err = client
431            .wait_for("com.apple.example.ready", Duration::from_millis(10))
432            .await
433            .unwrap_err();
434        assert!(matches!(err, NotificationProxyError::Timeout));
435    }
436}