Skip to main content

ios_core/services/arbitration/
mod.rs

1//! Device arbitration service client.
2//!
3//! Service: `com.apple.dt.devicearbitration`
4
5use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
6
7pub const SERVICE_NAME: &str = "com.apple.dt.devicearbitration";
8const MAX_PLIST_SIZE: usize = 4 * 1024 * 1024;
9
10#[derive(Debug, thiserror::Error)]
11pub enum ArbitrationError {
12    #[error("IO error: {0}")]
13    Io(#[from] std::io::Error),
14    #[error("plist error: {0}")]
15    Plist(String),
16    #[error("protocol error: {0}")]
17    Protocol(String),
18}
19
20pub struct ArbitrationClient<S> {
21    stream: S,
22}
23
24impl<S: AsyncRead + AsyncWrite + Unpin> ArbitrationClient<S> {
25    pub fn new(stream: S) -> Self {
26        Self { stream }
27    }
28
29    pub async fn version(&mut self) -> Result<plist::Dictionary, ArbitrationError> {
30        self.send_command(plist::Dictionary::from_iter([(
31            "command".to_string(),
32            plist::Value::String("version".into()),
33        )]))
34        .await
35    }
36
37    pub async fn check_in(&mut self, hostname: &str, force: bool) -> Result<(), ArbitrationError> {
38        let response = self
39            .send_command(plist::Dictionary::from_iter([
40                (
41                    "command".to_string(),
42                    plist::Value::String(if force { "force-check-in" } else { "check-in" }.into()),
43                ),
44                (
45                    "hostname".to_string(),
46                    plist::Value::String(hostname.to_string()),
47                ),
48            ]))
49            .await?;
50        ensure_success(&response)
51    }
52
53    pub async fn check_out(&mut self) -> Result<(), ArbitrationError> {
54        let response = self
55            .send_command(plist::Dictionary::from_iter([(
56                "command".to_string(),
57                plist::Value::String("check-out".into()),
58            )]))
59            .await?;
60        ensure_success(&response)
61    }
62
63    async fn send_command(
64        &mut self,
65        request: plist::Dictionary,
66    ) -> Result<plist::Dictionary, ArbitrationError> {
67        send_plist(&mut self.stream, &plist::Value::Dictionary(request)).await?;
68        recv_plist(&mut self.stream).await
69    }
70}
71
72fn ensure_success(response: &plist::Dictionary) -> Result<(), ArbitrationError> {
73    match response.get("result").and_then(plist::Value::as_string) {
74        Some("success") => Ok(()),
75        Some(other) => Err(ArbitrationError::Protocol(other.to_string())),
76        None => Err(ArbitrationError::Protocol(
77            "device arbitration response missing result".into(),
78        )),
79    }
80}
81
82async fn send_plist<S: AsyncWrite + Unpin>(
83    stream: &mut S,
84    value: &plist::Value,
85) -> Result<(), ArbitrationError> {
86    let mut buf = Vec::new();
87    plist::to_writer_xml(&mut buf, value).map_err(|e| ArbitrationError::Plist(e.to_string()))?;
88    stream.write_all(&(buf.len() as u32).to_be_bytes()).await?;
89    stream.write_all(&buf).await?;
90    stream.flush().await?;
91    Ok(())
92}
93
94async fn recv_plist<S: AsyncRead + Unpin>(
95    stream: &mut S,
96) -> Result<plist::Dictionary, ArbitrationError> {
97    let mut len_buf = [0u8; 4];
98    stream.read_exact(&mut len_buf).await?;
99    let len = u32::from_be_bytes(len_buf) as usize;
100    if len > MAX_PLIST_SIZE {
101        return Err(ArbitrationError::Protocol(format!(
102            "plist length {len} exceeds max {MAX_PLIST_SIZE}"
103        )));
104    }
105    let mut buf = vec![0u8; len];
106    stream.read_exact(&mut buf).await?;
107    plist::from_bytes(&buf).map_err(|e| ArbitrationError::Plist(e.to_string()))
108}
109
110#[cfg(test)]
111mod tests {
112    use std::pin::Pin;
113    use std::task::{Context, Poll};
114
115    use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
116
117    use super::*;
118
119    struct MockStream {
120        read_buf: Vec<u8>,
121        written: Vec<u8>,
122        read_pos: usize,
123    }
124
125    impl MockStream {
126        fn with_response(value: plist::Value) -> Self {
127            let mut payload = Vec::new();
128            plist::to_writer_xml(&mut payload, &value).unwrap();
129            let mut read_buf = Vec::new();
130            read_buf.extend_from_slice(&(payload.len() as u32).to_be_bytes());
131            read_buf.extend_from_slice(&payload);
132            Self {
133                read_buf,
134                written: Vec::new(),
135                read_pos: 0,
136            }
137        }
138    }
139
140    impl AsyncRead for MockStream {
141        fn poll_read(
142            mut self: Pin<&mut Self>,
143            _cx: &mut Context<'_>,
144            buf: &mut ReadBuf<'_>,
145        ) -> Poll<std::io::Result<()>> {
146            let remaining = self.read_buf.len().saturating_sub(self.read_pos);
147            if remaining == 0 {
148                return Poll::Ready(Err(std::io::Error::new(
149                    std::io::ErrorKind::UnexpectedEof,
150                    "no more test data",
151                )));
152            }
153            let to_copy = remaining.min(buf.remaining());
154            let start = self.read_pos;
155            let end = start + to_copy;
156            buf.put_slice(&self.read_buf[start..end]);
157            self.read_pos = end;
158            Poll::Ready(Ok(()))
159        }
160    }
161
162    impl AsyncWrite for MockStream {
163        fn poll_write(
164            mut self: Pin<&mut Self>,
165            _cx: &mut Context<'_>,
166            buf: &[u8],
167        ) -> Poll<std::io::Result<usize>> {
168            self.written.extend_from_slice(buf);
169            Poll::Ready(Ok(buf.len()))
170        }
171
172        fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
173            Poll::Ready(Ok(()))
174        }
175
176        fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
177            Poll::Ready(Ok(()))
178        }
179    }
180
181    #[tokio::test]
182    async fn check_in_sends_hostname() {
183        let response = plist::Value::Dictionary(plist::Dictionary::from_iter([(
184            "result".to_string(),
185            plist::Value::String("success".into()),
186        )]));
187        let mut stream = MockStream::with_response(response);
188        let mut client = ArbitrationClient::new(&mut stream);
189
190        client.check_in("host", false).await.unwrap();
191
192        let len = u32::from_be_bytes(stream.written[..4].try_into().unwrap()) as usize;
193        let payload = &stream.written[4..4 + len];
194        let dict: plist::Dictionary = plist::from_bytes(payload).unwrap();
195        assert_eq!(dict["command"].as_string(), Some("check-in"));
196        assert_eq!(dict["hostname"].as_string(), Some("host"));
197    }
198
199    #[tokio::test]
200    async fn recv_plist_rejects_oversized_frame() {
201        let mut read_buf = ((MAX_PLIST_SIZE as u32) + 1).to_be_bytes().to_vec();
202        read_buf.extend_from_slice(b"ignored");
203        let mut stream = MockStream {
204            read_buf,
205            written: Vec::new(),
206            read_pos: 0,
207        };
208
209        let err = recv_plist(&mut stream).await.unwrap_err();
210        assert!(
211            matches!(err, ArbitrationError::Protocol(message) if message.contains("exceeds max"))
212        );
213    }
214}