Skip to main content

ios_core/services/arbitration/
mod.rs

1//! Device arbitration service client.
2//!
3//! Service: `com.apple.dt.devicearbitration`
4
5use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
6
7pub const SERVICE_NAME: &str = "com.apple.dt.devicearbitration";
8
9#[derive(Debug, thiserror::Error)]
10pub enum ArbitrationError {
11    #[error("IO error: {0}")]
12    Io(#[from] std::io::Error),
13    #[error("plist error: {0}")]
14    Plist(String),
15    #[error("protocol error: {0}")]
16    Protocol(String),
17}
18
19pub struct ArbitrationClient<S> {
20    stream: S,
21}
22
23impl<S: AsyncRead + AsyncWrite + Unpin> ArbitrationClient<S> {
24    pub fn new(stream: S) -> Self {
25        Self { stream }
26    }
27
28    pub async fn version(&mut self) -> Result<plist::Dictionary, ArbitrationError> {
29        self.send_command(plist::Dictionary::from_iter([(
30            "command".to_string(),
31            plist::Value::String("version".into()),
32        )]))
33        .await
34    }
35
36    pub async fn check_in(&mut self, hostname: &str, force: bool) -> Result<(), ArbitrationError> {
37        let response = self
38            .send_command(plist::Dictionary::from_iter([
39                (
40                    "command".to_string(),
41                    plist::Value::String(if force { "force-check-in" } else { "check-in" }.into()),
42                ),
43                (
44                    "hostname".to_string(),
45                    plist::Value::String(hostname.to_string()),
46                ),
47            ]))
48            .await?;
49        ensure_success(&response)
50    }
51
52    pub async fn check_out(&mut self) -> Result<(), ArbitrationError> {
53        let response = self
54            .send_command(plist::Dictionary::from_iter([(
55                "command".to_string(),
56                plist::Value::String("check-out".into()),
57            )]))
58            .await?;
59        ensure_success(&response)
60    }
61
62    async fn send_command(
63        &mut self,
64        request: plist::Dictionary,
65    ) -> Result<plist::Dictionary, ArbitrationError> {
66        send_plist(&mut self.stream, &plist::Value::Dictionary(request)).await?;
67        recv_plist(&mut self.stream).await
68    }
69}
70
71fn ensure_success(response: &plist::Dictionary) -> Result<(), ArbitrationError> {
72    match response.get("result").and_then(plist::Value::as_string) {
73        Some("success") => Ok(()),
74        Some(other) => Err(ArbitrationError::Protocol(other.to_string())),
75        None => Err(ArbitrationError::Protocol(
76            "device arbitration response missing result".into(),
77        )),
78    }
79}
80
81async fn send_plist<S: AsyncWrite + Unpin>(
82    stream: &mut S,
83    value: &plist::Value,
84) -> Result<(), ArbitrationError> {
85    let mut buf = Vec::new();
86    plist::to_writer_xml(&mut buf, value).map_err(|e| ArbitrationError::Plist(e.to_string()))?;
87    stream.write_all(&(buf.len() as u32).to_be_bytes()).await?;
88    stream.write_all(&buf).await?;
89    stream.flush().await?;
90    Ok(())
91}
92
93async fn recv_plist<S: AsyncRead + Unpin>(
94    stream: &mut S,
95) -> Result<plist::Dictionary, ArbitrationError> {
96    let mut len_buf = [0u8; 4];
97    stream.read_exact(&mut len_buf).await?;
98    let len = u32::from_be_bytes(len_buf) as usize;
99    let mut buf = vec![0u8; len];
100    stream.read_exact(&mut buf).await?;
101    plist::from_bytes(&buf).map_err(|e| ArbitrationError::Plist(e.to_string()))
102}
103
104#[cfg(test)]
105mod tests {
106    use std::pin::Pin;
107    use std::task::{Context, Poll};
108
109    use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
110
111    use super::*;
112
113    struct MockStream {
114        read_buf: Vec<u8>,
115        written: Vec<u8>,
116        read_pos: usize,
117    }
118
119    impl MockStream {
120        fn with_response(value: plist::Value) -> Self {
121            let mut payload = Vec::new();
122            plist::to_writer_xml(&mut payload, &value).unwrap();
123            let mut read_buf = Vec::new();
124            read_buf.extend_from_slice(&(payload.len() as u32).to_be_bytes());
125            read_buf.extend_from_slice(&payload);
126            Self {
127                read_buf,
128                written: Vec::new(),
129                read_pos: 0,
130            }
131        }
132    }
133
134    impl AsyncRead for MockStream {
135        fn poll_read(
136            mut self: Pin<&mut Self>,
137            _cx: &mut Context<'_>,
138            buf: &mut ReadBuf<'_>,
139        ) -> Poll<std::io::Result<()>> {
140            let remaining = self.read_buf.len().saturating_sub(self.read_pos);
141            if remaining == 0 {
142                return Poll::Ready(Err(std::io::Error::new(
143                    std::io::ErrorKind::UnexpectedEof,
144                    "no more test data",
145                )));
146            }
147            let to_copy = remaining.min(buf.remaining());
148            let start = self.read_pos;
149            let end = start + to_copy;
150            buf.put_slice(&self.read_buf[start..end]);
151            self.read_pos = end;
152            Poll::Ready(Ok(()))
153        }
154    }
155
156    impl AsyncWrite for MockStream {
157        fn poll_write(
158            mut self: Pin<&mut Self>,
159            _cx: &mut Context<'_>,
160            buf: &[u8],
161        ) -> Poll<std::io::Result<usize>> {
162            self.written.extend_from_slice(buf);
163            Poll::Ready(Ok(buf.len()))
164        }
165
166        fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
167            Poll::Ready(Ok(()))
168        }
169
170        fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
171            Poll::Ready(Ok(()))
172        }
173    }
174
175    #[tokio::test]
176    async fn check_in_sends_hostname() {
177        let response = plist::Value::Dictionary(plist::Dictionary::from_iter([(
178            "result".to_string(),
179            plist::Value::String("success".into()),
180        )]));
181        let mut stream = MockStream::with_response(response);
182        let mut client = ArbitrationClient::new(&mut stream);
183
184        client.check_in("host", false).await.unwrap();
185
186        let len = u32::from_be_bytes(stream.written[..4].try_into().unwrap()) as usize;
187        let payload = &stream.written[4..4 + len];
188        let dict: plist::Dictionary = plist::from_bytes(payload).unwrap();
189        assert_eq!(dict["command"].as_string(), Some("check-in"));
190        assert_eq!(dict["hostname"].as_string(), Some("host"));
191    }
192}