Skip to main content

ios_core/services/file_relay/
mod.rs

1//! File relay service client.
2//!
3//! Service: `com.apple.mobile.file_relay`
4
5use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
6
7pub const SERVICE_NAME: &str = "com.apple.mobile.file_relay";
8
9#[derive(Debug, thiserror::Error)]
10pub enum FileRelayError {
11    #[error("IO error: {0}")]
12    Io(#[from] std::io::Error),
13    #[error("plist error: {0}")]
14    Plist(String),
15    #[error("protocol error: {0}")]
16    Protocol(String),
17}
18
19pub struct FileRelayClient<S> {
20    stream: S,
21}
22
23impl<S: AsyncRead + AsyncWrite + Unpin> FileRelayClient<S> {
24    pub fn new(stream: S) -> Self {
25        Self { stream }
26    }
27
28    pub async fn request_sources(&mut self, sources: &[&str]) -> Result<Vec<u8>, FileRelayError> {
29        let request = plist::Dictionary::from_iter([(
30            "Sources".to_string(),
31            plist::Value::Array(
32                sources
33                    .iter()
34                    .map(|source| plist::Value::String((*source).to_string()))
35                    .collect(),
36            ),
37        )]);
38        send_plist(&mut self.stream, &plist::Value::Dictionary(request)).await?;
39        let response = recv_plist(&mut self.stream).await?;
40        match response.get("Status").and_then(plist::Value::as_string) {
41            Some("Acknowledged") => {}
42            Some(other) => {
43                let error = response
44                    .get("Error")
45                    .and_then(plist::Value::as_string)
46                    .unwrap_or(other);
47                return Err(FileRelayError::Protocol(error.to_string()));
48            }
49            None => {
50                return Err(FileRelayError::Protocol(
51                    "file relay response missing Status".into(),
52                ));
53            }
54        }
55
56        let mut data = Vec::new();
57        self.stream.read_to_end(&mut data).await?;
58        Ok(data)
59    }
60}
61
62async fn send_plist<S: AsyncWrite + Unpin>(
63    stream: &mut S,
64    value: &plist::Value,
65) -> Result<(), FileRelayError> {
66    let mut buf = Vec::new();
67    plist::to_writer_xml(&mut buf, value).map_err(|e| FileRelayError::Plist(e.to_string()))?;
68    stream.write_all(&(buf.len() as u32).to_be_bytes()).await?;
69    stream.write_all(&buf).await?;
70    stream.flush().await?;
71    Ok(())
72}
73
74async fn recv_plist<S: AsyncRead + Unpin>(
75    stream: &mut S,
76) -> Result<plist::Dictionary, FileRelayError> {
77    let mut len_buf = [0u8; 4];
78    stream.read_exact(&mut len_buf).await?;
79    let len = u32::from_be_bytes(len_buf) as usize;
80    const MAX_PLIST_SIZE: usize = 1024 * 1024;
81    if len > MAX_PLIST_SIZE {
82        return Err(FileRelayError::Protocol(format!(
83            "plist length {len} exceeds max {MAX_PLIST_SIZE}"
84        )));
85    }
86    let mut buf = vec![0u8; len];
87    stream.read_exact(&mut buf).await?;
88    plist::from_bytes(&buf).map_err(|e| FileRelayError::Plist(e.to_string()))
89}
90
91#[cfg(test)]
92mod tests {
93    use std::pin::Pin;
94    use std::task::{Context, Poll};
95
96    use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
97
98    use super::*;
99
100    struct MockStream {
101        read_buf: Vec<u8>,
102        written: Vec<u8>,
103        read_pos: usize,
104    }
105
106    impl MockStream {
107        fn with_response(plist_value: plist::Value, raw: &[u8]) -> Self {
108            let mut payload = Vec::new();
109            plist::to_writer_xml(&mut payload, &plist_value).unwrap();
110            let mut read_buf = Vec::new();
111            read_buf.extend_from_slice(&(payload.len() as u32).to_be_bytes());
112            read_buf.extend_from_slice(&payload);
113            read_buf.extend_from_slice(raw);
114            Self {
115                read_buf,
116                written: Vec::new(),
117                read_pos: 0,
118            }
119        }
120    }
121
122    impl AsyncRead for MockStream {
123        fn poll_read(
124            mut self: Pin<&mut Self>,
125            _cx: &mut Context<'_>,
126            buf: &mut ReadBuf<'_>,
127        ) -> Poll<std::io::Result<()>> {
128            let remaining = self.read_buf.len().saturating_sub(self.read_pos);
129            if remaining == 0 {
130                return Poll::Ready(Ok(()));
131            }
132            let to_copy = remaining.min(buf.remaining());
133            let start = self.read_pos;
134            let end = start + to_copy;
135            buf.put_slice(&self.read_buf[start..end]);
136            self.read_pos = end;
137            Poll::Ready(Ok(()))
138        }
139    }
140
141    impl AsyncWrite for MockStream {
142        fn poll_write(
143            mut self: Pin<&mut Self>,
144            _cx: &mut Context<'_>,
145            buf: &[u8],
146        ) -> Poll<std::io::Result<usize>> {
147            self.written.extend_from_slice(buf);
148            Poll::Ready(Ok(buf.len()))
149        }
150
151        fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
152            Poll::Ready(Ok(()))
153        }
154
155        fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
156            Poll::Ready(Ok(()))
157        }
158    }
159
160    #[tokio::test]
161    async fn request_sources_reads_acknowledged_archive() {
162        let response = plist::Value::Dictionary(plist::Dictionary::from_iter([(
163            "Status".to_string(),
164            plist::Value::String("Acknowledged".into()),
165        )]));
166        let mut stream = MockStream::with_response(response, b"archive-bytes");
167        let mut client = FileRelayClient::new(&mut stream);
168
169        let archive = client.request_sources(&["Network"]).await.unwrap();
170        assert_eq!(archive, b"archive-bytes");
171
172        let len = u32::from_be_bytes(stream.written[..4].try_into().unwrap()) as usize;
173        let payload = &stream.written[4..4 + len];
174        let dict: plist::Dictionary = plist::from_bytes(payload).unwrap();
175        let sources = dict["Sources"].as_array().unwrap();
176        assert_eq!(sources[0].as_string(), Some("Network"));
177    }
178}