1use std::net::{Ipv6Addr, SocketAddr};
4
5use tokio::io::{AsyncRead, AsyncWrite};
6use tokio::net::TcpStream;
7
8use crate::xpc::h2_raw::H2Framer;
9use crate::xpc::message::{flags, XpcMessage, XpcValue};
10use crate::xpc::rsd::{initialize_xpc_connection_on_framer, XpcConnection};
11use crate::xpc::XpcError;
12
13trait XpcStream: AsyncRead + AsyncWrite + Unpin + Send {}
14impl<T> XpcStream for T where T: AsyncRead + AsyncWrite + Unpin + Send {}
15
16type DynStream = Box<dyn XpcStream>;
17
18pub struct XpcClient {
20 inner: XpcConnection<DynStream>,
21}
22
23impl XpcClient {
24 pub async fn connect(addr: Ipv6Addr, port: u16) -> Result<Self, XpcError> {
26 let sock_addr = SocketAddr::new(addr.into(), port);
27 let stream = TcpStream::connect(sock_addr).await?;
28 Self::connect_stream(stream).await
29 }
30
31 pub async fn connect_stream<S>(stream: S) -> Result<Self, XpcError>
33 where
34 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
35 {
36 let stream: DynStream = Box::new(stream);
37 let mut framer = H2Framer::connect(stream)
38 .await
39 .map_err(|e| XpcError::Tls(format!("H2: {e}")))?;
40 initialize_xpc_connection_on_framer(&mut framer).await?;
41 Ok(Self {
42 inner: XpcConnection::new(framer),
43 })
44 }
45
46 pub async fn call(&mut self, body: XpcValue) -> Result<XpcMessage, XpcError> {
48 self.inner
49 .send_with_flags(body, flags::WANTING_REPLY)
50 .await?;
51 self.inner.recv().await
52 }
53
54 pub async fn send(&mut self, body: XpcValue) -> Result<(), XpcError> {
56 self.inner.send(body).await
57 }
58
59 pub async fn recv(&mut self) -> Result<XpcMessage, XpcError> {
61 self.inner.recv().await
62 }
63}
64
65#[cfg(test)]
66mod tests {
67 use bytes::Bytes;
68 use indexmap::IndexMap;
69 use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
70 use tokio::time::{timeout, Duration};
71
72 use super::*;
73 use crate::xpc::message::{encode_message, flags, XpcMessage, XpcValue};
74
75 const FRAME_DATA: u8 = 0x00;
76 const FRAME_HEADERS: u8 = 0x01;
77 const FRAME_SETTINGS: u8 = 0x04;
78 const FLAG_END_HEADERS: u8 = 0x04;
79 const FLAG_SETTINGS_ACK: u8 = 0x01;
80 const STREAM_INIT: u32 = 0;
81 const STREAM_CLIENT_SERVER: u32 = 1;
82 const STREAM_SERVER_CLIENT: u32 = 3;
83
84 fn build_frame(frame_type: u8, flags: u8, stream_id: u32, payload: &[u8]) -> Vec<u8> {
85 let len = payload.len();
86 let mut out = Vec::with_capacity(9 + len);
87 out.push(((len >> 16) & 0xFF) as u8);
88 out.push(((len >> 8) & 0xFF) as u8);
89 out.push((len & 0xFF) as u8);
90 out.push(frame_type);
91 out.push(flags);
92 out.extend_from_slice(&(stream_id & 0x7FFF_FFFF).to_be_bytes());
93 out.extend_from_slice(payload);
94 out
95 }
96
97 fn settings_frame() -> Vec<u8> {
98 build_frame(FRAME_SETTINGS, 0, STREAM_INIT, &[])
99 }
100
101 fn settings_ack_frame() -> Vec<u8> {
102 build_frame(FRAME_SETTINGS, FLAG_SETTINGS_ACK, STREAM_INIT, &[])
103 }
104
105 fn headers_frame(stream_id: u32) -> Vec<u8> {
106 build_frame(FRAME_HEADERS, FLAG_END_HEADERS, stream_id, &[])
107 }
108
109 fn data_frame(stream_id: u32, payload: &[u8]) -> Vec<u8> {
110 build_frame(FRAME_DATA, 0, stream_id, payload)
111 }
112
113 fn empty_message(flags: u32) -> Bytes {
114 encode_message(&XpcMessage {
115 flags,
116 msg_id: 0,
117 body: Some(XpcValue::Dictionary(IndexMap::new()))
118 .filter(|_| flags == flags::ALWAYS_SET),
119 })
120 .expect("message should encode")
121 }
122
123 #[tokio::test]
124 async fn connect_stream_bootstraps_remote_xpc_before_returning() {
125 let (client, mut server) = duplex(4096);
126
127 let msg1 = empty_message(flags::ALWAYS_SET);
128 let msg2 = encode_message(&XpcMessage {
129 flags: flags::ALWAYS_SET,
130 msg_id: 0,
131 body: None,
132 })
133 .expect("message should encode");
134 let msg3 = encode_message(&XpcMessage {
135 flags: flags::ALWAYS_SET,
136 msg_id: 0,
137 body: None,
138 })
139 .expect("message should encode");
140
141 let server_task = tokio::spawn(async move {
142 let mut preface = [0u8; 24];
143 server.read_exact(&mut preface).await.unwrap();
144 assert_eq!(&preface, crate::xpc::h2_raw::H2_PREFACE);
145
146 let mut settings = [0u8; 21];
147 server.read_exact(&mut settings).await.unwrap();
148 assert_eq!(settings[3], FRAME_SETTINGS);
149
150 let mut window_update = [0u8; 13];
151 server.read_exact(&mut window_update).await.unwrap();
152 assert_eq!(window_update[3], 0x08);
153
154 server.write_all(&settings_frame()).await.unwrap();
155 server.flush().await.unwrap();
156
157 let mut ack = [0u8; 9];
158 server.read_exact(&mut ack).await.unwrap();
159 assert_eq!(ack, settings_ack_frame().as_slice());
160
161 let mut cs_headers = [0u8; 9];
162 server.read_exact(&mut cs_headers).await.unwrap();
163 assert_eq!(cs_headers, headers_frame(STREAM_CLIENT_SERVER).as_slice());
164
165 let mut cs_msg1_header = [0u8; 9];
166 server.read_exact(&mut cs_msg1_header).await.unwrap();
167 assert_eq!(cs_msg1_header[3], FRAME_DATA);
168 let cs_msg1_len = ((cs_msg1_header[0] as usize) << 16)
169 | ((cs_msg1_header[1] as usize) << 8)
170 | (cs_msg1_header[2] as usize);
171 let mut cs_msg1 = vec![0u8; cs_msg1_len];
172 server.read_exact(&mut cs_msg1).await.unwrap();
173 assert_eq!(cs_msg1, msg1);
174
175 server
176 .write_all(&data_frame(STREAM_CLIENT_SERVER, &msg2))
177 .await
178 .unwrap();
179 server.flush().await.unwrap();
180
181 let mut sc_headers = [0u8; 9];
182 server.read_exact(&mut sc_headers).await.unwrap();
183 assert_eq!(sc_headers, headers_frame(STREAM_SERVER_CLIENT).as_slice());
184
185 let mut sc_msg2_header = [0u8; 9];
186 server.read_exact(&mut sc_msg2_header).await.unwrap();
187 assert_eq!(sc_msg2_header[3], FRAME_DATA);
188 let sc_msg2_len = ((sc_msg2_header[0] as usize) << 16)
189 | ((sc_msg2_header[1] as usize) << 8)
190 | (sc_msg2_header[2] as usize);
191 let mut sc_msg2 = vec![0u8; sc_msg2_len];
192 server.read_exact(&mut sc_msg2).await.unwrap();
193 assert_eq!(
194 decode_message_payload(&sc_msg2),
195 (flags::INIT_HANDSHAKE | flags::ALWAYS_SET, 0)
196 );
197
198 server
199 .write_all(&data_frame(STREAM_SERVER_CLIENT, &msg2))
200 .await
201 .unwrap();
202 server.flush().await.unwrap();
203
204 let mut cs_msg3_header = [0u8; 9];
205 server.read_exact(&mut cs_msg3_header).await.unwrap();
206 assert_eq!(cs_msg3_header[3], FRAME_DATA);
207 let cs_msg3_len = ((cs_msg3_header[0] as usize) << 16)
208 | ((cs_msg3_header[1] as usize) << 8)
209 | (cs_msg3_header[2] as usize);
210 let mut cs_msg3 = vec![0u8; cs_msg3_len];
211 server.read_exact(&mut cs_msg3).await.unwrap();
212 assert_eq!(
213 decode_message_payload(&cs_msg3),
214 (flags::ALWAYS_SET | 0x200, 0)
215 );
216
217 server
218 .write_all(&data_frame(STREAM_CLIENT_SERVER, &msg3))
219 .await
220 .unwrap();
221 server.flush().await.unwrap();
222 });
223
224 timeout(Duration::from_secs(1), XpcClient::connect_stream(client))
225 .await
226 .expect("connect timed out")
227 .expect("connect should succeed");
228
229 server_task.await.unwrap();
230 }
231
232 #[tokio::test]
233 async fn call_sets_wanting_reply_on_outgoing_request() {
234 let (client, mut server) = duplex(4096);
235
236 let empty = encode_message(&XpcMessage {
237 flags: flags::ALWAYS_SET,
238 msg_id: 0,
239 body: None,
240 })
241 .expect("message should encode");
242 let reply = encode_message(&XpcMessage {
243 flags: flags::ALWAYS_SET | flags::REPLY | flags::DATA,
244 msg_id: 1,
245 body: Some(XpcValue::Dictionary(IndexMap::new())),
246 })
247 .expect("message should encode");
248
249 let server_task = tokio::spawn(async move {
250 let mut preface = [0u8; 24];
251 server.read_exact(&mut preface).await.unwrap();
252 assert_eq!(&preface, crate::xpc::h2_raw::H2_PREFACE);
253
254 let mut settings = [0u8; 21];
255 server.read_exact(&mut settings).await.unwrap();
256 assert_eq!(settings[3], FRAME_SETTINGS);
257
258 let mut window_update = [0u8; 13];
259 server.read_exact(&mut window_update).await.unwrap();
260 assert_eq!(window_update[3], 0x08);
261
262 server.write_all(&settings_frame()).await.unwrap();
263 server.flush().await.unwrap();
264
265 let mut ack = [0u8; 9];
266 server.read_exact(&mut ack).await.unwrap();
267 assert_eq!(ack, settings_ack_frame().as_slice());
268
269 let mut cs_headers = [0u8; 9];
270 server.read_exact(&mut cs_headers).await.unwrap();
271 assert_eq!(cs_headers, headers_frame(STREAM_CLIENT_SERVER).as_slice());
272
273 let mut cs_msg1_header = [0u8; 9];
274 server.read_exact(&mut cs_msg1_header).await.unwrap();
275 let cs_msg1_len = ((cs_msg1_header[0] as usize) << 16)
276 | ((cs_msg1_header[1] as usize) << 8)
277 | (cs_msg1_header[2] as usize);
278 let mut cs_msg1 = vec![0u8; cs_msg1_len];
279 server.read_exact(&mut cs_msg1).await.unwrap();
280 assert_eq!(
281 cs_msg1.as_slice(),
282 empty_message(flags::ALWAYS_SET).as_ref()
283 );
284
285 server
286 .write_all(&data_frame(STREAM_CLIENT_SERVER, &empty))
287 .await
288 .unwrap();
289 server.flush().await.unwrap();
290
291 let mut sc_headers = [0u8; 9];
292 server.read_exact(&mut sc_headers).await.unwrap();
293 assert_eq!(sc_headers, headers_frame(STREAM_SERVER_CLIENT).as_slice());
294
295 let mut sc_msg2_header = [0u8; 9];
296 server.read_exact(&mut sc_msg2_header).await.unwrap();
297 let sc_msg2_len = ((sc_msg2_header[0] as usize) << 16)
298 | ((sc_msg2_header[1] as usize) << 8)
299 | (sc_msg2_header[2] as usize);
300 let mut sc_msg2 = vec![0u8; sc_msg2_len];
301 server.read_exact(&mut sc_msg2).await.unwrap();
302 assert_eq!(
303 decode_message_payload(&sc_msg2),
304 (flags::INIT_HANDSHAKE | flags::ALWAYS_SET, 0)
305 );
306
307 server
308 .write_all(&data_frame(STREAM_SERVER_CLIENT, &empty))
309 .await
310 .unwrap();
311 server.flush().await.unwrap();
312
313 let mut cs_msg3_header = [0u8; 9];
314 server.read_exact(&mut cs_msg3_header).await.unwrap();
315 let cs_msg3_len = ((cs_msg3_header[0] as usize) << 16)
316 | ((cs_msg3_header[1] as usize) << 8)
317 | (cs_msg3_header[2] as usize);
318 let mut cs_msg3 = vec![0u8; cs_msg3_len];
319 server.read_exact(&mut cs_msg3).await.unwrap();
320 assert_eq!(
321 decode_message_payload(&cs_msg3),
322 (flags::ALWAYS_SET | 0x200, 0)
323 );
324
325 server
326 .write_all(&data_frame(STREAM_CLIENT_SERVER, &empty))
327 .await
328 .unwrap();
329 server.flush().await.unwrap();
330
331 let mut request_header = [0u8; 9];
332 server.read_exact(&mut request_header).await.unwrap();
333 assert_eq!(request_header[3], FRAME_DATA);
334 let request_len = ((request_header[0] as usize) << 16)
335 | ((request_header[1] as usize) << 8)
336 | (request_header[2] as usize);
337 let mut request = vec![0u8; request_len];
338 server.read_exact(&mut request).await.unwrap();
339 assert_eq!(
340 decode_message_payload(&request),
341 (flags::ALWAYS_SET | flags::DATA | flags::WANTING_REPLY, 1)
342 );
343
344 server
345 .write_all(&data_frame(STREAM_SERVER_CLIENT, &reply))
346 .await
347 .unwrap();
348 server.flush().await.unwrap();
349 });
350
351 let mut client = timeout(Duration::from_secs(1), XpcClient::connect_stream(client))
352 .await
353 .expect("connect timed out")
354 .expect("connect should succeed");
355
356 let response = timeout(
357 Duration::from_secs(1),
358 client.call(XpcValue::Dictionary(IndexMap::new())),
359 )
360 .await
361 .expect("call timed out")
362 .expect("call should succeed");
363
364 assert_eq!(
365 response.flags,
366 flags::ALWAYS_SET | flags::REPLY | flags::DATA
367 );
368 assert_eq!(response.msg_id, 1);
369
370 server_task.await.unwrap();
371 }
372
373 fn decode_message_payload(bytes: &[u8]) -> (u32, u64) {
374 let msg = crate::xpc::message::decode_message(Bytes::copy_from_slice(bytes)).unwrap();
375 (msg.flags, msg.msg_id)
376 }
377}