Skip to main content

ios_core/services/
device_link.rs

1use serde::Serialize;
2use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
3
4#[derive(Debug, thiserror::Error)]
5pub enum DeviceLinkError {
6    #[error("IO error: {0}")]
7    Io(#[from] std::io::Error),
8    #[error("plist error: {0}")]
9    Plist(String),
10    #[error("protocol error: {0}")]
11    Protocol(String),
12}
13
14pub struct DeviceLinkClient<S> {
15    stream: S,
16}
17
18impl<S> DeviceLinkClient<S> {
19    pub fn new(stream: S) -> Self {
20        Self { stream }
21    }
22
23    pub fn into_inner(self) -> S {
24        self.stream
25    }
26
27    pub fn stream_mut(&mut self) -> &mut S {
28        &mut self.stream
29    }
30}
31
32impl<S> DeviceLinkClient<S>
33where
34    S: AsyncRead + AsyncWrite + Unpin,
35{
36    pub async fn version_exchange(&mut self) -> Result<u64, DeviceLinkError> {
37        let response = self.recv_message().await?;
38        let message = response.as_array().ok_or_else(|| {
39            DeviceLinkError::Protocol(format!(
40                "device link version exchange expected array, got {response:?}"
41            ))
42        })?;
43
44        let message_type = message
45            .first()
46            .and_then(plist::Value::as_string)
47            .ok_or_else(|| {
48                DeviceLinkError::Protocol(format!(
49                    "device link version exchange missing message type: {response:?}"
50                ))
51            })?;
52        if message_type != "DLMessageVersionExchange" {
53            return Err(DeviceLinkError::Protocol(format!(
54                "expected DLMessageVersionExchange, got {message_type}"
55            )));
56        }
57
58        let version = message
59            .get(1)
60            .and_then(|value| match value {
61                plist::Value::Integer(value) => value.as_unsigned(),
62                _ => None,
63            })
64            .ok_or_else(|| {
65                DeviceLinkError::Protocol(format!(
66                    "device link version exchange missing major version: {response:?}"
67                ))
68            })?;
69
70        self.send_message(&vec![
71            plist::Value::String("DLMessageVersionExchange".into()),
72            plist::Value::String("DLVersionsOk".into()),
73            plist::Value::Integer(version.into()),
74        ])
75        .await?;
76
77        let ready = self.recv_message().await?;
78        let ready_message = ready.as_array().ok_or_else(|| {
79            DeviceLinkError::Protocol(format!("device ready expected array, got {ready:?}"))
80        })?;
81        let ready_type = ready_message
82            .first()
83            .and_then(plist::Value::as_string)
84            .ok_or_else(|| {
85                DeviceLinkError::Protocol(format!("device ready missing message type: {ready:?}"))
86            })?;
87        if ready_type != "DLMessageDeviceReady" {
88            return Err(DeviceLinkError::Protocol(format!(
89                "expected DLMessageDeviceReady, got {ready_type}"
90            )));
91        }
92
93        Ok(version)
94    }
95
96    pub async fn send_process_message<T>(&mut self, message: &T) -> Result<(), DeviceLinkError>
97    where
98        T: Serialize,
99    {
100        self.send_message(&("DLMessageProcessMessage", message))
101            .await
102    }
103
104    pub async fn recv_process_message(&mut self) -> Result<plist::Dictionary, DeviceLinkError> {
105        let response = self.recv_message().await?;
106        let message = response.as_array().ok_or_else(|| {
107            DeviceLinkError::Protocol(format!("process message expected array, got {response:?}"))
108        })?;
109
110        let message_type = message
111            .first()
112            .and_then(plist::Value::as_string)
113            .ok_or_else(|| {
114                DeviceLinkError::Protocol(format!(
115                    "process message missing message type: {response:?}"
116                ))
117            })?;
118        if message_type != "DLMessageProcessMessage" {
119            return Err(DeviceLinkError::Protocol(format!(
120                "expected DLMessageProcessMessage, got {message_type}"
121            )));
122        }
123
124        message
125            .get(1)
126            .and_then(plist::Value::as_dictionary)
127            .cloned()
128            .ok_or_else(|| {
129                DeviceLinkError::Protocol(format!(
130                    "process message missing dictionary payload: {response:?}"
131                ))
132            })
133    }
134
135    pub async fn send_message<T>(&mut self, message: &T) -> Result<(), DeviceLinkError>
136    where
137        T: Serialize,
138    {
139        let mut payload = Vec::new();
140        plist::to_writer_xml(&mut payload, message)
141            .map_err(|e| DeviceLinkError::Plist(e.to_string()))?;
142        self.stream
143            .write_all(&(payload.len() as u32).to_be_bytes())
144            .await?;
145        self.stream.write_all(&payload).await?;
146        self.stream.flush().await?;
147        Ok(())
148    }
149
150    pub async fn recv_message(&mut self) -> Result<plist::Value, DeviceLinkError> {
151        let mut len_buf = [0u8; 4];
152        self.stream.read_exact(&mut len_buf).await?;
153        let len = u32::from_be_bytes(len_buf) as usize;
154        const MAX_PLIST_SIZE: usize = 4 * 1024 * 1024;
155        if len > MAX_PLIST_SIZE {
156            return Err(DeviceLinkError::Protocol(format!(
157                "plist length {len} exceeds maximum of {MAX_PLIST_SIZE}"
158            )));
159        }
160
161        let mut payload = vec![0u8; len];
162        self.stream.read_exact(&mut payload).await?;
163        plist::from_bytes(&payload).map_err(|e| DeviceLinkError::Plist(e.to_string()))
164    }
165
166    pub async fn disconnect(&mut self) -> Result<(), DeviceLinkError> {
167        self.send_message(&vec![
168            plist::Value::String("DLMessageDisconnect".into()),
169            plist::Value::String("___EmptyParameterString___".into()),
170        ])
171        .await
172    }
173}
174
175#[cfg(test)]
176mod tests {
177    use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
178
179    use super::*;
180
181    fn encode_frame(value: &plist::Value) -> Vec<u8> {
182        let mut payload = Vec::new();
183        plist::to_writer_xml(&mut payload, value).expect("plist serialization");
184        let mut frame = Vec::with_capacity(payload.len() + 4);
185        frame.extend_from_slice(&(payload.len() as u32).to_be_bytes());
186        frame.extend_from_slice(&payload);
187        frame
188    }
189
190    async fn read_frame(stream: &mut tokio::io::DuplexStream) -> plist::Value {
191        let mut len_buf = [0u8; 4];
192        stream.read_exact(&mut len_buf).await.expect("frame length");
193        let len = u32::from_be_bytes(len_buf) as usize;
194        let mut payload = vec![0u8; len];
195        stream
196            .read_exact(&mut payload)
197            .await
198            .expect("frame payload");
199        plist::from_bytes(&payload).expect("plist decode")
200    }
201
202    #[tokio::test]
203    async fn version_exchange_sends_versions_ok_and_returns_major_version() {
204        let (client_stream, mut server_stream) = duplex(4096);
205        let task = tokio::spawn(async move {
206            let mut client = DeviceLinkClient::new(client_stream);
207            client.version_exchange().await.unwrap()
208        });
209
210        server_stream
211            .write_all(&encode_frame(&plist::Value::Array(vec![
212                plist::Value::String("DLMessageVersionExchange".into()),
213                plist::Value::Integer(300u64.into()),
214            ])))
215            .await
216            .unwrap();
217
218        let versions_ok = read_frame(&mut server_stream).await;
219        assert_eq!(
220            versions_ok.as_array(),
221            Some(&vec![
222                plist::Value::String("DLMessageVersionExchange".into()),
223                plist::Value::String("DLVersionsOk".into()),
224                plist::Value::Integer(300u64.into()),
225            ])
226        );
227
228        server_stream
229            .write_all(&encode_frame(&plist::Value::Array(vec![
230                plist::Value::String("DLMessageDeviceReady".into()),
231            ])))
232            .await
233            .unwrap();
234
235        assert_eq!(task.await.unwrap(), 300);
236    }
237
238    #[tokio::test]
239    async fn recv_process_message_requires_dictionary_payload() {
240        let (client_stream, mut server_stream) = duplex(4096);
241        let task = tokio::spawn(async move {
242            let mut client = DeviceLinkClient::new(client_stream);
243            client.recv_process_message().await
244        });
245
246        server_stream
247            .write_all(&encode_frame(&plist::Value::Array(vec![
248                plist::Value::String("DLMessageProcessMessage".into()),
249                plist::Value::String("not-a-dict".into()),
250            ])))
251            .await
252            .unwrap();
253
254        let err = task
255            .await
256            .unwrap()
257            .expect_err("non-dictionary payload must fail");
258        assert!(err
259            .to_string()
260            .contains("process message missing dictionary payload"));
261    }
262
263    #[tokio::test]
264    async fn disconnect_sends_expected_message() {
265        let (client_stream, mut server_stream) = duplex(4096);
266        let task = tokio::spawn(async move {
267            let mut client = DeviceLinkClient::new(client_stream);
268            client.disconnect().await.unwrap();
269        });
270
271        let disconnect = read_frame(&mut server_stream).await;
272        assert_eq!(
273            disconnect.as_array(),
274            Some(&vec![
275                plist::Value::String("DLMessageDisconnect".into()),
276                plist::Value::String("___EmptyParameterString___".into()),
277            ])
278        );
279
280        task.await.unwrap();
281    }
282}