Skip to main content

ios_core/services/power_assertion/
mod.rs

1//! Power assertion service client.
2//!
3//! Service: `com.apple.mobile.assertion_agent`
4
5use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
6
7pub const SERVICE_NAME: &str = "com.apple.mobile.assertion_agent";
8
9#[derive(Debug, thiserror::Error)]
10pub enum PowerAssertionError {
11    #[error("IO error: {0}")]
12    Io(#[from] std::io::Error),
13    #[error("plist error: {0}")]
14    Plist(String),
15    #[error("protocol error: {0}")]
16    Protocol(String),
17}
18
19pub struct PowerAssertionClient<S> {
20    stream: S,
21}
22
23impl<S: AsyncRead + AsyncWrite + Unpin> PowerAssertionClient<S> {
24    pub fn new(stream: S) -> Self {
25        Self { stream }
26    }
27
28    pub async fn create_assertion(
29        &mut self,
30        assertion_type: &str,
31        name: &str,
32        timeout_seconds: f64,
33        details: Option<&str>,
34    ) -> Result<plist::Dictionary, PowerAssertionError> {
35        let mut request = plist::Dictionary::from_iter([
36            (
37                "CommandKey".to_string(),
38                plist::Value::String("CommandCreateAssertion".into()),
39            ),
40            (
41                "AssertionTypeKey".to_string(),
42                plist::Value::String(assertion_type.to_string()),
43            ),
44            (
45                "AssertionNameKey".to_string(),
46                plist::Value::String(name.to_string()),
47            ),
48            (
49                "AssertionTimeoutKey".to_string(),
50                plist::Value::Real(timeout_seconds),
51            ),
52        ]);
53        if let Some(details) = details {
54            request.insert(
55                "AssertionDetailKey".to_string(),
56                plist::Value::String(details.to_string()),
57            );
58        }
59
60        send_plist(&mut self.stream, &plist::Value::Dictionary(request)).await?;
61        let response = recv_plist(&mut self.stream).await?;
62        if let Some(error) = response.get("Error").and_then(plist::Value::as_string) {
63            return Err(PowerAssertionError::Protocol(error.to_string()));
64        }
65        Ok(response)
66    }
67}
68
69async fn send_plist<S: AsyncWrite + Unpin>(
70    stream: &mut S,
71    value: &plist::Value,
72) -> Result<(), PowerAssertionError> {
73    let mut buf = Vec::new();
74    plist::to_writer_xml(&mut buf, value).map_err(|e| PowerAssertionError::Plist(e.to_string()))?;
75    stream.write_all(&(buf.len() as u32).to_be_bytes()).await?;
76    stream.write_all(&buf).await?;
77    stream.flush().await?;
78    Ok(())
79}
80
81async fn recv_plist<S: AsyncRead + Unpin>(
82    stream: &mut S,
83) -> Result<plist::Dictionary, PowerAssertionError> {
84    let mut len_buf = [0u8; 4];
85    stream.read_exact(&mut len_buf).await?;
86    let len = u32::from_be_bytes(len_buf) as usize;
87    const MAX_PLIST_SIZE: usize = 1024 * 1024;
88    if len > MAX_PLIST_SIZE {
89        return Err(PowerAssertionError::Protocol(format!(
90            "plist length {len} exceeds max {MAX_PLIST_SIZE}"
91        )));
92    }
93    let mut buf = vec![0u8; len];
94    stream.read_exact(&mut buf).await?;
95    plist::from_bytes(&buf).map_err(|e| PowerAssertionError::Plist(e.to_string()))
96}
97
98#[cfg(test)]
99mod tests {
100    use std::pin::Pin;
101    use std::task::{Context, Poll};
102
103    use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
104
105    use super::*;
106
107    struct MockStream {
108        read_buf: Vec<u8>,
109        written: Vec<u8>,
110        read_pos: usize,
111    }
112
113    impl MockStream {
114        fn with_response(value: plist::Value) -> Self {
115            let mut payload = Vec::new();
116            plist::to_writer_xml(&mut payload, &value).unwrap();
117            let mut read_buf = Vec::new();
118            read_buf.extend_from_slice(&(payload.len() as u32).to_be_bytes());
119            read_buf.extend_from_slice(&payload);
120            Self {
121                read_buf,
122                written: Vec::new(),
123                read_pos: 0,
124            }
125        }
126    }
127
128    impl AsyncRead for MockStream {
129        fn poll_read(
130            mut self: Pin<&mut Self>,
131            _cx: &mut Context<'_>,
132            buf: &mut ReadBuf<'_>,
133        ) -> Poll<std::io::Result<()>> {
134            let remaining = self.read_buf.len().saturating_sub(self.read_pos);
135            if remaining == 0 {
136                return Poll::Ready(Err(std::io::Error::new(
137                    std::io::ErrorKind::UnexpectedEof,
138                    "no more test data",
139                )));
140            }
141            let to_copy = remaining.min(buf.remaining());
142            let start = self.read_pos;
143            let end = start + to_copy;
144            buf.put_slice(&self.read_buf[start..end]);
145            self.read_pos = end;
146            Poll::Ready(Ok(()))
147        }
148    }
149
150    impl AsyncWrite for MockStream {
151        fn poll_write(
152            mut self: Pin<&mut Self>,
153            _cx: &mut Context<'_>,
154            buf: &[u8],
155        ) -> Poll<std::io::Result<usize>> {
156            self.written.extend_from_slice(buf);
157            Poll::Ready(Ok(buf.len()))
158        }
159
160        fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
161            Poll::Ready(Ok(()))
162        }
163
164        fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
165            Poll::Ready(Ok(()))
166        }
167    }
168
169    #[tokio::test]
170    async fn create_assertion_sends_expected_payload() {
171        let response = plist::Value::Dictionary(plist::Dictionary::new());
172        let mut stream = MockStream::with_response(response);
173        let mut client = PowerAssertionClient::new(&mut stream);
174
175        client
176            .create_assertion("PreventUserIdleSystemSleep", "ios-cli", 30.0, Some("test"))
177            .await
178            .unwrap();
179
180        let len = u32::from_be_bytes(stream.written[..4].try_into().unwrap()) as usize;
181        let payload = &stream.written[4..4 + len];
182        let dict: plist::Dictionary = plist::from_bytes(payload).unwrap();
183        assert_eq!(
184            dict.get("CommandKey").and_then(plist::Value::as_string),
185            Some("CommandCreateAssertion")
186        );
187        assert_eq!(
188            dict.get("AssertionTypeKey")
189                .and_then(plist::Value::as_string),
190            Some("PreventUserIdleSystemSleep")
191        );
192        assert_eq!(
193            dict.get("AssertionNameKey")
194                .and_then(plist::Value::as_string),
195            Some("ios-cli")
196        );
197        assert_eq!(
198            dict.get("AssertionDetailKey")
199                .and_then(plist::Value::as_string),
200            Some("test")
201        );
202    }
203}