ios_core/services/power_assertion/
mod.rs1use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
6
7pub const SERVICE_NAME: &str = "com.apple.mobile.assertion_agent";
8
9#[derive(Debug, thiserror::Error)]
10pub enum PowerAssertionError {
11 #[error("IO error: {0}")]
12 Io(#[from] std::io::Error),
13 #[error("plist error: {0}")]
14 Plist(String),
15 #[error("protocol error: {0}")]
16 Protocol(String),
17}
18
19pub struct PowerAssertionClient<S> {
20 stream: S,
21}
22
23impl<S: AsyncRead + AsyncWrite + Unpin> PowerAssertionClient<S> {
24 pub fn new(stream: S) -> Self {
25 Self { stream }
26 }
27
28 pub async fn create_assertion(
29 &mut self,
30 assertion_type: &str,
31 name: &str,
32 timeout_seconds: f64,
33 details: Option<&str>,
34 ) -> Result<plist::Dictionary, PowerAssertionError> {
35 let mut request = plist::Dictionary::from_iter([
36 (
37 "CommandKey".to_string(),
38 plist::Value::String("CommandCreateAssertion".into()),
39 ),
40 (
41 "AssertionTypeKey".to_string(),
42 plist::Value::String(assertion_type.to_string()),
43 ),
44 (
45 "AssertionNameKey".to_string(),
46 plist::Value::String(name.to_string()),
47 ),
48 (
49 "AssertionTimeoutKey".to_string(),
50 plist::Value::Real(timeout_seconds),
51 ),
52 ]);
53 if let Some(details) = details {
54 request.insert(
55 "AssertionDetailKey".to_string(),
56 plist::Value::String(details.to_string()),
57 );
58 }
59
60 send_plist(&mut self.stream, &plist::Value::Dictionary(request)).await?;
61 let response = recv_plist(&mut self.stream).await?;
62 if let Some(error) = response.get("Error").and_then(plist::Value::as_string) {
63 return Err(PowerAssertionError::Protocol(error.to_string()));
64 }
65 Ok(response)
66 }
67}
68
69async fn send_plist<S: AsyncWrite + Unpin>(
70 stream: &mut S,
71 value: &plist::Value,
72) -> Result<(), PowerAssertionError> {
73 let mut buf = Vec::new();
74 plist::to_writer_xml(&mut buf, value).map_err(|e| PowerAssertionError::Plist(e.to_string()))?;
75 stream.write_all(&(buf.len() as u32).to_be_bytes()).await?;
76 stream.write_all(&buf).await?;
77 stream.flush().await?;
78 Ok(())
79}
80
81async fn recv_plist<S: AsyncRead + Unpin>(
82 stream: &mut S,
83) -> Result<plist::Dictionary, PowerAssertionError> {
84 let mut len_buf = [0u8; 4];
85 stream.read_exact(&mut len_buf).await?;
86 let len = u32::from_be_bytes(len_buf) as usize;
87 const MAX_PLIST_SIZE: usize = 1024 * 1024;
88 if len > MAX_PLIST_SIZE {
89 return Err(PowerAssertionError::Protocol(format!(
90 "plist length {len} exceeds max {MAX_PLIST_SIZE}"
91 )));
92 }
93 let mut buf = vec![0u8; len];
94 stream.read_exact(&mut buf).await?;
95 plist::from_bytes(&buf).map_err(|e| PowerAssertionError::Plist(e.to_string()))
96}
97
98#[cfg(test)]
99mod tests {
100 use std::pin::Pin;
101 use std::task::{Context, Poll};
102
103 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
104
105 use super::*;
106
107 struct MockStream {
108 read_buf: Vec<u8>,
109 written: Vec<u8>,
110 read_pos: usize,
111 }
112
113 impl MockStream {
114 fn with_response(value: plist::Value) -> Self {
115 let mut payload = Vec::new();
116 plist::to_writer_xml(&mut payload, &value).unwrap();
117 let mut read_buf = Vec::new();
118 read_buf.extend_from_slice(&(payload.len() as u32).to_be_bytes());
119 read_buf.extend_from_slice(&payload);
120 Self {
121 read_buf,
122 written: Vec::new(),
123 read_pos: 0,
124 }
125 }
126 }
127
128 impl AsyncRead for MockStream {
129 fn poll_read(
130 mut self: Pin<&mut Self>,
131 _cx: &mut Context<'_>,
132 buf: &mut ReadBuf<'_>,
133 ) -> Poll<std::io::Result<()>> {
134 let remaining = self.read_buf.len().saturating_sub(self.read_pos);
135 if remaining == 0 {
136 return Poll::Ready(Err(std::io::Error::new(
137 std::io::ErrorKind::UnexpectedEof,
138 "no more test data",
139 )));
140 }
141 let to_copy = remaining.min(buf.remaining());
142 let start = self.read_pos;
143 let end = start + to_copy;
144 buf.put_slice(&self.read_buf[start..end]);
145 self.read_pos = end;
146 Poll::Ready(Ok(()))
147 }
148 }
149
150 impl AsyncWrite for MockStream {
151 fn poll_write(
152 mut self: Pin<&mut Self>,
153 _cx: &mut Context<'_>,
154 buf: &[u8],
155 ) -> Poll<std::io::Result<usize>> {
156 self.written.extend_from_slice(buf);
157 Poll::Ready(Ok(buf.len()))
158 }
159
160 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
161 Poll::Ready(Ok(()))
162 }
163
164 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
165 Poll::Ready(Ok(()))
166 }
167 }
168
169 #[tokio::test]
170 async fn create_assertion_sends_expected_payload() {
171 let response = plist::Value::Dictionary(plist::Dictionary::new());
172 let mut stream = MockStream::with_response(response);
173 let mut client = PowerAssertionClient::new(&mut stream);
174
175 client
176 .create_assertion("PreventUserIdleSystemSleep", "ios-cli", 30.0, Some("test"))
177 .await
178 .unwrap();
179
180 let len = u32::from_be_bytes(stream.written[..4].try_into().unwrap()) as usize;
181 let payload = &stream.written[4..4 + len];
182 let dict: plist::Dictionary = plist::from_bytes(payload).unwrap();
183 assert_eq!(
184 dict.get("CommandKey").and_then(plist::Value::as_string),
185 Some("CommandCreateAssertion")
186 );
187 assert_eq!(
188 dict.get("AssertionTypeKey")
189 .and_then(plist::Value::as_string),
190 Some("PreventUserIdleSystemSleep")
191 );
192 assert_eq!(
193 dict.get("AssertionNameKey")
194 .and_then(plist::Value::as_string),
195 Some("ios-cli")
196 );
197 assert_eq!(
198 dict.get("AssertionDetailKey")
199 .and_then(plist::Value::as_string),
200 Some("test")
201 );
202 }
203}