Skip to main content

ios_core/services/
device_link.rs

1use serde::Serialize;
2use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
3
4service_error!(DeviceLinkError);
5
6pub struct DeviceLinkClient<S> {
7    stream: S,
8}
9
10impl<S> DeviceLinkClient<S> {
11    pub fn new(stream: S) -> Self {
12        Self { stream }
13    }
14
15    pub fn into_inner(self) -> S {
16        self.stream
17    }
18
19    pub fn stream_mut(&mut self) -> &mut S {
20        &mut self.stream
21    }
22}
23
24impl<S> DeviceLinkClient<S>
25where
26    S: AsyncRead + AsyncWrite + Unpin,
27{
28    pub async fn version_exchange(&mut self) -> Result<u64, DeviceLinkError> {
29        let response = self.recv_message().await?;
30        let message = response.as_array().ok_or_else(|| {
31            DeviceLinkError::Protocol(format!(
32                "device link version exchange expected array, got {response:?}"
33            ))
34        })?;
35
36        let message_type = message
37            .first()
38            .and_then(plist::Value::as_string)
39            .ok_or_else(|| {
40                DeviceLinkError::Protocol(format!(
41                    "device link version exchange missing message type: {response:?}"
42                ))
43            })?;
44        if message_type != "DLMessageVersionExchange" {
45            return Err(DeviceLinkError::Protocol(format!(
46                "expected DLMessageVersionExchange, got {message_type}"
47            )));
48        }
49
50        let version = message
51            .get(1)
52            .and_then(|value| match value {
53                plist::Value::Integer(value) => value.as_unsigned(),
54                _ => None,
55            })
56            .ok_or_else(|| {
57                DeviceLinkError::Protocol(format!(
58                    "device link version exchange missing major version: {response:?}"
59                ))
60            })?;
61
62        self.send_message(&vec![
63            plist::Value::String("DLMessageVersionExchange".into()),
64            plist::Value::String("DLVersionsOk".into()),
65            plist::Value::Integer(version.into()),
66        ])
67        .await?;
68
69        let ready = self.recv_message().await?;
70        let ready_message = ready.as_array().ok_or_else(|| {
71            DeviceLinkError::Protocol(format!("device ready expected array, got {ready:?}"))
72        })?;
73        let ready_type = ready_message
74            .first()
75            .and_then(plist::Value::as_string)
76            .ok_or_else(|| {
77                DeviceLinkError::Protocol(format!("device ready missing message type: {ready:?}"))
78            })?;
79        if ready_type != "DLMessageDeviceReady" {
80            return Err(DeviceLinkError::Protocol(format!(
81                "expected DLMessageDeviceReady, got {ready_type}"
82            )));
83        }
84
85        Ok(version)
86    }
87
88    pub async fn send_process_message<T>(&mut self, message: &T) -> Result<(), DeviceLinkError>
89    where
90        T: Serialize,
91    {
92        self.send_message(&("DLMessageProcessMessage", message))
93            .await
94    }
95
96    pub async fn recv_process_message(&mut self) -> Result<plist::Dictionary, DeviceLinkError> {
97        let response = self.recv_message().await?;
98        let message = response.as_array().ok_or_else(|| {
99            DeviceLinkError::Protocol(format!("process message expected array, got {response:?}"))
100        })?;
101
102        let message_type = message
103            .first()
104            .and_then(plist::Value::as_string)
105            .ok_or_else(|| {
106                DeviceLinkError::Protocol(format!(
107                    "process message missing message type: {response:?}"
108                ))
109            })?;
110        if message_type != "DLMessageProcessMessage" {
111            return Err(DeviceLinkError::Protocol(format!(
112                "expected DLMessageProcessMessage, got {message_type}"
113            )));
114        }
115
116        message
117            .get(1)
118            .and_then(plist::Value::as_dictionary)
119            .cloned()
120            .ok_or_else(|| {
121                DeviceLinkError::Protocol(format!(
122                    "process message missing dictionary payload: {response:?}"
123                ))
124            })
125    }
126
127    pub async fn send_message<T>(&mut self, message: &T) -> Result<(), DeviceLinkError>
128    where
129        T: Serialize,
130    {
131        let mut payload = Vec::new();
132        plist::to_writer_xml(&mut payload, message)
133            .map_err(|e| DeviceLinkError::Plist(e.to_string()))?;
134        self.stream
135            .write_all(&(payload.len() as u32).to_be_bytes())
136            .await?;
137        self.stream.write_all(&payload).await?;
138        self.stream.flush().await?;
139        Ok(())
140    }
141
142    pub async fn recv_message(&mut self) -> Result<plist::Value, DeviceLinkError> {
143        let mut len_buf = [0u8; 4];
144        self.stream.read_exact(&mut len_buf).await?;
145        let len = u32::from_be_bytes(len_buf) as usize;
146        const MAX_PLIST_SIZE: usize = 4 * 1024 * 1024;
147        if len > MAX_PLIST_SIZE {
148            return Err(DeviceLinkError::Protocol(format!(
149                "plist length {len} exceeds maximum of {MAX_PLIST_SIZE}"
150            )));
151        }
152
153        let mut payload = vec![0u8; len];
154        self.stream.read_exact(&mut payload).await?;
155        plist::from_bytes(&payload).map_err(|e| DeviceLinkError::Plist(e.to_string()))
156    }
157
158    pub async fn disconnect(&mut self) -> Result<(), DeviceLinkError> {
159        self.send_message(&vec![
160            plist::Value::String("DLMessageDisconnect".into()),
161            plist::Value::String("___EmptyParameterString___".into()),
162        ])
163        .await
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
170
171    use super::*;
172
173    fn encode_frame(value: &plist::Value) -> Vec<u8> {
174        let mut payload = Vec::new();
175        plist::to_writer_xml(&mut payload, value).expect("plist serialization");
176        let mut frame = Vec::with_capacity(payload.len() + 4);
177        frame.extend_from_slice(&(payload.len() as u32).to_be_bytes());
178        frame.extend_from_slice(&payload);
179        frame
180    }
181
182    async fn read_frame(stream: &mut tokio::io::DuplexStream) -> plist::Value {
183        let mut len_buf = [0u8; 4];
184        stream.read_exact(&mut len_buf).await.expect("frame length");
185        let len = u32::from_be_bytes(len_buf) as usize;
186        let mut payload = vec![0u8; len];
187        stream
188            .read_exact(&mut payload)
189            .await
190            .expect("frame payload");
191        plist::from_bytes(&payload).expect("plist decode")
192    }
193
194    #[tokio::test]
195    async fn version_exchange_sends_versions_ok_and_returns_major_version() {
196        let (client_stream, mut server_stream) = duplex(4096);
197        let task = tokio::spawn(async move {
198            let mut client = DeviceLinkClient::new(client_stream);
199            client.version_exchange().await.unwrap()
200        });
201
202        server_stream
203            .write_all(&encode_frame(&plist::Value::Array(vec![
204                plist::Value::String("DLMessageVersionExchange".into()),
205                plist::Value::Integer(300u64.into()),
206            ])))
207            .await
208            .unwrap();
209
210        let versions_ok = read_frame(&mut server_stream).await;
211        assert_eq!(
212            versions_ok.as_array(),
213            Some(&vec![
214                plist::Value::String("DLMessageVersionExchange".into()),
215                plist::Value::String("DLVersionsOk".into()),
216                plist::Value::Integer(300u64.into()),
217            ])
218        );
219
220        server_stream
221            .write_all(&encode_frame(&plist::Value::Array(vec![
222                plist::Value::String("DLMessageDeviceReady".into()),
223            ])))
224            .await
225            .unwrap();
226
227        assert_eq!(task.await.unwrap(), 300);
228    }
229
230    #[tokio::test]
231    async fn recv_process_message_requires_dictionary_payload() {
232        let (client_stream, mut server_stream) = duplex(4096);
233        let task = tokio::spawn(async move {
234            let mut client = DeviceLinkClient::new(client_stream);
235            client.recv_process_message().await
236        });
237
238        server_stream
239            .write_all(&encode_frame(&plist::Value::Array(vec![
240                plist::Value::String("DLMessageProcessMessage".into()),
241                plist::Value::String("not-a-dict".into()),
242            ])))
243            .await
244            .unwrap();
245
246        let err = task
247            .await
248            .unwrap()
249            .expect_err("non-dictionary payload must fail");
250        assert!(err
251            .to_string()
252            .contains("process message missing dictionary payload"));
253    }
254
255    #[tokio::test]
256    async fn disconnect_sends_expected_message() {
257        let (client_stream, mut server_stream) = duplex(4096);
258        let task = tokio::spawn(async move {
259            let mut client = DeviceLinkClient::new(client_stream);
260            client.disconnect().await.unwrap();
261        });
262
263        let disconnect = read_frame(&mut server_stream).await;
264        assert_eq!(
265            disconnect.as_array(),
266            Some(&vec![
267                plist::Value::String("DLMessageDisconnect".into()),
268                plist::Value::String("___EmptyParameterString___".into()),
269            ])
270        );
271
272        task.await.unwrap();
273    }
274}