Skip to main content

ios_core/services/arbitration/
mod.rs

1//! Device arbitration service client.
2//!
3//! Service: `com.apple.dt.devicearbitration`
4
5use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
6
7pub const SERVICE_NAME: &str = "com.apple.dt.devicearbitration";
8const MAX_PLIST_SIZE: usize = 4 * 1024 * 1024;
9
10service_error!(ArbitrationError);
11
12pub struct ArbitrationClient<S> {
13    stream: S,
14}
15
16impl<S: AsyncRead + AsyncWrite + Unpin> ArbitrationClient<S> {
17    pub fn new(stream: S) -> Self {
18        Self { stream }
19    }
20
21    pub async fn version(&mut self) -> Result<plist::Dictionary, ArbitrationError> {
22        self.send_command(plist::Dictionary::from_iter([(
23            "command".to_string(),
24            plist::Value::String("version".into()),
25        )]))
26        .await
27    }
28
29    pub async fn check_in(&mut self, hostname: &str, force: bool) -> Result<(), ArbitrationError> {
30        let response = self
31            .send_command(plist::Dictionary::from_iter([
32                (
33                    "command".to_string(),
34                    plist::Value::String(if force { "force-check-in" } else { "check-in" }.into()),
35                ),
36                (
37                    "hostname".to_string(),
38                    plist::Value::String(hostname.to_string()),
39                ),
40            ]))
41            .await?;
42        ensure_success(&response)
43    }
44
45    pub async fn check_out(&mut self) -> Result<(), ArbitrationError> {
46        let response = self
47            .send_command(plist::Dictionary::from_iter([(
48                "command".to_string(),
49                plist::Value::String("check-out".into()),
50            )]))
51            .await?;
52        ensure_success(&response)
53    }
54
55    async fn send_command(
56        &mut self,
57        request: plist::Dictionary,
58    ) -> Result<plist::Dictionary, ArbitrationError> {
59        send_plist(&mut self.stream, &plist::Value::Dictionary(request)).await?;
60        recv_plist(&mut self.stream).await
61    }
62}
63
64fn ensure_success(response: &plist::Dictionary) -> Result<(), ArbitrationError> {
65    match response.get("result").and_then(plist::Value::as_string) {
66        Some("success") => Ok(()),
67        Some(other) => Err(ArbitrationError::Protocol(other.to_string())),
68        None => Err(ArbitrationError::Protocol(
69            "device arbitration response missing result".into(),
70        )),
71    }
72}
73
74async fn send_plist<S: AsyncWrite + Unpin>(
75    stream: &mut S,
76    value: &plist::Value,
77) -> Result<(), ArbitrationError> {
78    let mut buf = Vec::new();
79    plist::to_writer_xml(&mut buf, value)?;
80    stream.write_all(&(buf.len() as u32).to_be_bytes()).await?;
81    stream.write_all(&buf).await?;
82    stream.flush().await?;
83    Ok(())
84}
85
86async fn recv_plist<S: AsyncRead + Unpin>(
87    stream: &mut S,
88) -> Result<plist::Dictionary, ArbitrationError> {
89    let mut len_buf = [0u8; 4];
90    stream.read_exact(&mut len_buf).await?;
91    let len = u32::from_be_bytes(len_buf) as usize;
92    if len > MAX_PLIST_SIZE {
93        return Err(ArbitrationError::Protocol(format!(
94            "plist length {len} exceeds max {MAX_PLIST_SIZE}"
95        )));
96    }
97    let mut buf = vec![0u8; len];
98    stream.read_exact(&mut buf).await?;
99    Ok(plist::from_bytes(&buf)?)
100}
101
102#[cfg(test)]
103mod tests {
104    use crate::test_util::MockStream;
105
106    use super::*;
107
108    #[tokio::test]
109    async fn check_in_sends_hostname() {
110        let response = plist::Value::Dictionary(plist::Dictionary::from_iter([(
111            "result".to_string(),
112            plist::Value::String("success".into()),
113        )]));
114        let mut stream = MockStream::with_response(response);
115        let mut client = ArbitrationClient::new(&mut stream);
116
117        client.check_in("host", false).await.unwrap();
118
119        let len = u32::from_be_bytes(stream.written[..4].try_into().unwrap()) as usize;
120        let payload = &stream.written[4..4 + len];
121        let dict: plist::Dictionary = plist::from_bytes(payload).unwrap();
122        assert_eq!(dict["command"].as_string(), Some("check-in"));
123        assert_eq!(dict["hostname"].as_string(), Some("host"));
124    }
125
126    #[tokio::test]
127    async fn recv_plist_rejects_oversized_frame() {
128        let mut read_buf = ((MAX_PLIST_SIZE as u32) + 1).to_be_bytes().to_vec();
129        read_buf.extend_from_slice(b"ignored");
130        let mut stream = MockStream::new(read_buf);
131
132        let err = recv_plist(&mut stream).await.unwrap_err();
133        assert!(
134            matches!(err, ArbitrationError::Protocol(message) if message.contains("exceeds max"))
135        );
136    }
137}