Skip to main content

ios_core/services/file_relay/
mod.rs

1//! File relay service client.
2//!
3//! Service: `com.apple.mobile.file_relay`
4
5use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
6
7pub const SERVICE_NAME: &str = "com.apple.mobile.file_relay";
8
9service_error!(FileRelayError);
10
11impl From<super::plist_frame::PlistFrameError> for FileRelayError {
12    fn from(error: super::plist_frame::PlistFrameError) -> Self {
13        match error {
14            super::plist_frame::PlistFrameError::Io(error) => Self::Io(error),
15            super::plist_frame::PlistFrameError::Plist(error) => Self::Plist(error),
16            super::plist_frame::PlistFrameError::Protocol(message) => Self::Protocol(message),
17        }
18    }
19}
20
21pub struct FileRelayClient<S> {
22    stream: S,
23}
24
25impl<S: AsyncRead + AsyncWrite + Unpin> FileRelayClient<S> {
26    pub fn new(stream: S) -> Self {
27        Self { stream }
28    }
29
30    pub async fn request_sources(&mut self, sources: &[&str]) -> Result<Vec<u8>, FileRelayError> {
31        self.send_request_sources(sources).await?;
32        let mut data = Vec::new();
33        self.stream.read_to_end(&mut data).await?;
34        Ok(data)
35    }
36
37    pub async fn request_sources_to_writer<W>(
38        &mut self,
39        sources: &[&str],
40        writer: &mut W,
41    ) -> Result<u64, FileRelayError>
42    where
43        W: AsyncWrite + Unpin,
44    {
45        self.send_request_sources(sources).await?;
46        tokio::io::copy(&mut self.stream, writer)
47            .await
48            .map_err(FileRelayError::from)
49    }
50
51    async fn send_request_sources(&mut self, sources: &[&str]) -> Result<(), FileRelayError> {
52        let request = plist::Dictionary::from_iter([(
53            "Sources".to_string(),
54            plist::Value::Array(
55                sources
56                    .iter()
57                    .map(|source| plist::Value::String((*source).to_string()))
58                    .collect(),
59            ),
60        )]);
61        send_plist(&mut self.stream, &plist::Value::Dictionary(request)).await?;
62        let response = recv_plist(&mut self.stream).await?;
63        match response.get("Status").and_then(plist::Value::as_string) {
64            Some("Acknowledged") => {}
65            Some(other) => {
66                let error = response
67                    .get("Error")
68                    .and_then(plist::Value::as_string)
69                    .unwrap_or(other);
70                return Err(FileRelayError::Protocol(error.to_string()));
71            }
72            None => {
73                return Err(FileRelayError::Protocol(
74                    "file relay response missing Status".into(),
75                ));
76            }
77        }
78        Ok(())
79    }
80}
81
82async fn send_plist<S: AsyncWrite + Unpin>(
83    stream: &mut S,
84    value: &plist::Value,
85) -> Result<(), FileRelayError> {
86    const MAX_PLIST_SIZE: usize = 1024 * 1024;
87    super::plist_frame::write_xml_plist_frame(stream, value, MAX_PLIST_SIZE)
88        .await
89        .map_err(FileRelayError::from)
90}
91
92async fn recv_plist<S: AsyncRead + Unpin>(
93    stream: &mut S,
94) -> Result<plist::Dictionary, FileRelayError> {
95    const MAX_PLIST_SIZE: usize = 1024 * 1024;
96    super::plist_frame::read_plist_frame(stream, MAX_PLIST_SIZE)
97        .await
98        .map_err(FileRelayError::from)
99}
100
101#[cfg(test)]
102mod tests {
103    use crate::test_util::MockStream;
104
105    use super::*;
106
107    #[tokio::test]
108    async fn request_sources_reads_acknowledged_archive() {
109        let response = plist::Value::Dictionary(plist::Dictionary::from_iter([(
110            "Status".to_string(),
111            plist::Value::String("Acknowledged".into()),
112        )]));
113        let mut stream =
114            MockStream::with_plist_response_and_trailing_bytes(response, b"archive-bytes");
115        let mut client = FileRelayClient::new(&mut stream);
116
117        let archive = client.request_sources(&["Network"]).await.unwrap();
118        assert_eq!(archive, b"archive-bytes");
119
120        let len = u32::from_be_bytes(stream.written[..4].try_into().unwrap()) as usize;
121        let payload = &stream.written[4..4 + len];
122        let dict: plist::Dictionary = plist::from_bytes(payload).unwrap();
123        let sources = dict["Sources"].as_array().unwrap();
124        assert_eq!(sources[0].as_string(), Some("Network"));
125    }
126
127    #[tokio::test]
128    async fn request_sources_to_writer_streams_acknowledged_archive() {
129        let response = plist::Value::Dictionary(plist::Dictionary::from_iter([(
130            "Status".to_string(),
131            plist::Value::String("Acknowledged".into()),
132        )]));
133        let mut stream =
134            MockStream::with_plist_response_and_trailing_bytes(response, b"archive-bytes");
135        let mut client = FileRelayClient::new(&mut stream);
136        let mut output = Vec::new();
137
138        let bytes = client
139            .request_sources_to_writer(&["Network"], &mut output)
140            .await
141            .unwrap();
142
143        assert_eq!(bytes, 13);
144        assert_eq!(output, b"archive-bytes");
145    }
146}