1use std::collections::{HashMap, HashSet, VecDeque};
13
14use crate::proto::dtx::DTX_MAGIC;
15use crate::proto::nskeyedarchiver_encode;
16use bytes::{BufMut, Bytes, BytesMut};
17use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
18
19use super::primitive_enc::{encode_primitive_dict, PrimArg};
20use super::types::{DtxMessage, DtxPayload, NSObject};
21
22const MAX_DTX_MESSAGE_SIZE: usize = 128 * 1024 * 1024;
25
26const MSG_OK: u32 = 0;
27const MSG_UNKNOWN_TYPE_ONE: u32 = 1; const MSG_METHOD_INVOCATION: u32 = 2;
29const MSG_RESPONSE: u32 = 3;
30const MSG_ERROR: u32 = 4;
31const MSG_BARRIER: u32 = 5;
32const _MSG_LZ4_COMPRESSED: u32 = 0x0707; #[derive(Debug, thiserror::Error)]
37pub enum DtxError {
38 #[error("IO error: {0}")]
39 Io(#[from] std::io::Error),
40 #[error("bad magic: 0x{0:08X}")]
41 BadMagic(u32),
42 #[error("protocol error: {0}")]
43 Protocol(String),
44}
45
46pub fn encode_dtx(
50 identifier: u32,
51 conv_idx: u32,
52 channel_code: i32,
53 expects_reply: bool,
54 msg_type: u32,
55 payload: &[u8],
56 aux_bytes: &[u8],
57) -> Bytes {
58 let aux_len = aux_bytes.len();
59 let payload_len = payload.len();
60
61 let aux_with_hdr = if aux_len > 0 { aux_len + 16 } else { 0 };
63 let total_payload = aux_with_hdr + payload_len;
64 let msg_len = 16 + aux_with_hdr + payload_len;
65
66 let mut out = BytesMut::with_capacity(32 + msg_len);
67
68 out.put_u32(DTX_MAGIC); out.put_u32_le(32); out.put_u16_le(0); out.put_u16_le(1); out.put_u32_le(msg_len as u32); out.put_u32_le(identifier); out.put_u32_le(conv_idx); out.put_u32_le(channel_code as u32); out.put_u32_le(if expects_reply { 1 } else { 0 }); out.put_u32_le(msg_type);
81 out.put_u32_le(aux_with_hdr as u32);
82 out.put_u32_le(total_payload as u32);
83 out.put_u32_le(0); if aux_len > 0 {
86 out.put_u32_le(496);
88 out.put_u32_le(0);
89 out.put_u32_le(aux_len as u32);
90 out.put_u32_le(0);
91 out.put_slice(aux_bytes);
92 }
93 out.put_slice(payload);
94
95 out.freeze()
96}
97
98pub fn encode_ack(msg: &DtxMessage) -> Bytes {
100 let mut out = BytesMut::with_capacity(48);
101 out.put_u32(DTX_MAGIC);
102 out.put_u32_le(32);
103 out.put_u16_le(0);
104 out.put_u16_le(1);
105 out.put_u32_le(16); out.put_u32_le(msg.identifier);
107 out.put_u32_le(msg.conversation_idx + 1);
108 out.put_u32_le(msg.channel_code as u32);
109 out.put_u32_le(0); out.put_u32_le(MSG_OK);
112 out.put_u32_le(0);
113 out.put_u32_le(0);
114 out.put_u32_le(0);
115 out.freeze()
116}
117
118struct DtxHeader {
122 header_len: usize,
123 frag_idx: u16,
124 frag_cnt: u16,
125 msg_len: usize,
126 identifier: u32,
127 conv_idx: u32,
128 channel_code: i32,
129 expects_reply: bool,
130}
131
132fn parse_dtx_header(hdr: &[u8; 32]) -> Result<DtxHeader, DtxError> {
134 let magic = u32::from_be_bytes(hdr[0..4].try_into().unwrap());
135 if magic != DTX_MAGIC {
136 return Err(DtxError::BadMagic(magic));
137 }
138 let header_len = u32::from_le_bytes(hdr[4..8].try_into().unwrap()) as usize;
139 if header_len < 32 {
140 return Err(DtxError::Protocol(format!(
141 "invalid DTX header length: {header_len}"
142 )));
143 }
144 let frag_idx = u16::from_le_bytes(hdr[8..10].try_into().unwrap());
145 let frag_cnt = u16::from_le_bytes(hdr[10..12].try_into().unwrap());
146 if frag_cnt == 0 {
147 return Err(DtxError::Protocol("invalid DTX fragment count: 0".into()));
148 }
149 if frag_idx >= frag_cnt {
150 return Err(DtxError::Protocol(format!(
151 "invalid DTX fragment index {frag_idx} for count {frag_cnt}"
152 )));
153 }
154 let msg_len = u32::from_le_bytes(hdr[12..16].try_into().unwrap()) as usize;
155 if msg_len > MAX_DTX_MESSAGE_SIZE {
156 return Err(DtxError::Protocol(format!(
157 "DTX message length {msg_len} exceeds max {MAX_DTX_MESSAGE_SIZE}"
158 )));
159 }
160 if frag_cnt > 1 && frag_idx == 0 && msg_len == 0 {
161 return Err(DtxError::Protocol(
162 "multi-fragment first header declares zero total size".into(),
163 ));
164 }
165 Ok(DtxHeader {
166 header_len,
167 frag_idx,
168 frag_cnt,
169 msg_len,
170 identifier: u32::from_le_bytes(hdr[16..20].try_into().unwrap()),
171 conv_idx: u32::from_le_bytes(hdr[20..24].try_into().unwrap()),
172 channel_code: i32::from_le_bytes(hdr[24..28].try_into().unwrap()),
173 expects_reply: u32::from_le_bytes(hdr[28..32].try_into().unwrap()) != 0,
174 })
175}
176
177async fn read_dtx_header<R: AsyncRead + Unpin>(reader: &mut R) -> Result<DtxHeader, DtxError> {
178 let mut hdr = [0u8; 32];
179 reader.read_exact(&mut hdr).await?;
180 let header = parse_dtx_header(&hdr)?;
181 if header.header_len > 32 {
182 let mut extra = vec![0u8; header.header_len - 32];
183 reader.read_exact(&mut extra).await?;
184 }
185 Ok(header)
186}
187
188fn decode_dtx_body_from_slice(h: &DtxHeader, body_slice: &[u8]) -> Result<DtxMessage, DtxError> {
189 if body_slice.len() < 16 {
190 return Ok(DtxMessage {
191 identifier: h.identifier,
192 conversation_idx: h.conv_idx,
193 channel_code: h.channel_code,
194 expects_reply: h.expects_reply,
195 payload: DtxPayload::Empty,
196 });
197 }
198
199 let ph = &body_slice[0..16];
200 let msg_type = u32::from_le_bytes(ph[0..4].try_into().unwrap());
201 let aux_len = u32::from_le_bytes(ph[4..8].try_into().unwrap()) as usize;
202 let total_pay = u32::from_le_bytes(ph[8..12].try_into().unwrap()) as usize;
203
204 if aux_len > total_pay {
205 return Err(DtxError::Protocol(format!(
206 "aux_len ({aux_len}) exceeds total_pay ({total_pay})"
207 )));
208 }
209 let pay_len = total_pay - aux_len;
210 let rest = &body_slice[16..];
211
212 let aux_data = if aux_len > 0 {
213 if rest.len() < 16 {
214 return Err(DtxError::Protocol("aux header truncated".into()));
215 }
216 let actual_aux = u32::from_le_bytes(rest[8..12].try_into().unwrap()) as usize;
217 if actual_aux > aux_len.saturating_sub(16) {
218 return Err(DtxError::Protocol(format!(
219 "auxiliary data size ({actual_aux}) exceeds available space ({})",
220 aux_len.saturating_sub(16)
221 )));
222 }
223 let aux_start = 16;
224 let aux_end = aux_start + actual_aux;
225 if rest.len() < aux_end {
226 return Err(DtxError::Protocol("aux data truncated".into()));
227 }
228 Some(Bytes::copy_from_slice(&rest[aux_start..aux_end]))
229 } else {
230 None
231 };
232
233 let pay_start = aux_len;
234 let pay_end = pay_start + pay_len;
235 let payload_bytes = if pay_len > 0 && rest.len() >= pay_end {
236 Bytes::copy_from_slice(&rest[pay_start..pay_end])
237 } else {
238 Bytes::new()
239 };
240
241 let payload = decode_payload(msg_type, payload_bytes, aux_data);
242 Ok(DtxMessage {
243 identifier: h.identifier,
244 conversation_idx: h.conv_idx,
245 channel_code: h.channel_code,
246 expects_reply: h.expects_reply,
247 payload,
248 })
249}
250
251pub fn decode_dtx_message_from_bytes(data: &[u8]) -> Result<Option<(DtxMessage, usize)>, DtxError> {
252 if data.len() < 32 {
253 return Ok(None);
254 }
255
256 let header_bytes: &[u8; 32] = data[..32]
257 .try_into()
258 .map_err(|_| DtxError::Protocol("DTX header truncated".into()))?;
259 let header = parse_dtx_header(header_bytes)?;
260 let total_len = header
261 .header_len
262 .checked_add(header.msg_len)
263 .ok_or_else(|| DtxError::Protocol("DTX frame length overflow".into()))?;
264 if data.len() < total_len {
265 return Ok(None);
266 }
267
268 let body = &data[header.header_len..total_len];
269 let message = decode_dtx_body_from_slice(&header, body)?;
270 Ok(Some((message, total_len)))
271}
272
273async fn read_dtx_body<R: AsyncRead + Unpin>(
276 reader: &mut R,
277 h: &DtxHeader,
278 body: &[u8], ) -> Result<DtxMessage, DtxError> {
280 let body_owned: Vec<u8>;
282 let body_slice: &[u8] = if body.is_empty() && h.msg_len > 0 {
283 body_owned = {
284 let mut b = vec![0u8; h.msg_len];
285 reader.read_exact(&mut b).await?;
286 b
287 };
288 &body_owned
289 } else {
290 body
291 };
292
293 decode_dtx_body_from_slice(h, body_slice)
294}
295
296pub async fn read_dtx_frame<R: AsyncRead + Unpin>(reader: &mut R) -> Result<DtxMessage, DtxError> {
297 let h = read_dtx_header(reader).await?;
298 tracing::trace!(
299 "read_dtx_frame: frag_idx={} frag_cnt={} msg_len={} id={}",
300 h.frag_idx,
301 h.frag_cnt,
302 h.msg_len,
303 h.identifier
304 );
305
306 if h.frag_cnt > 1 && h.frag_idx == 0 {
308 return Ok(DtxMessage {
309 identifier: h.identifier,
310 conversation_idx: h.conv_idx,
311 channel_code: h.channel_code,
312 expects_reply: h.expects_reply,
313 payload: DtxPayload::Empty,
314 });
315 }
316
317 if h.msg_len == 0 {
318 return Ok(DtxMessage {
319 identifier: h.identifier,
320 conversation_idx: h.conv_idx,
321 channel_code: h.channel_code,
322 expects_reply: h.expects_reply,
323 payload: DtxPayload::Empty,
324 });
325 }
326
327 read_dtx_body(reader, &h, &[]).await
328}
329
330fn decode_payload(msg_type: u32, payload: Bytes, aux: Option<Bytes>) -> DtxPayload {
331 tracing::trace!(
332 "decode_payload: msg_type={msg_type} payload_len={} aux={}",
333 payload.len(),
334 aux.is_some()
335 );
336 match msg_type {
337 MSG_OK => DtxPayload::Empty,
338 MSG_METHOD_INVOCATION => {
339 let mut args = aux
340 .map(super::primitive::decode_auxiliary)
341 .unwrap_or_default();
342 let selector = if payload.is_empty() {
343 String::new()
344 } else {
345 match crate::proto::nskeyedarchiver::unarchive(&payload)
346 .ok()
347 .and_then(|v| v.as_str().map(String::from))
348 {
349 Some(selector) => selector,
350 None => {
351 tracing::debug!(
352 "decode_payload: method invocation payload decode failed, preserving {} raw bytes",
353 payload.len()
354 );
355 args.insert(0, NSObject::Data(payload));
356 String::new()
357 }
358 }
359 };
360 DtxPayload::MethodInvocation { selector, args }
361 }
362 MSG_RESPONSE | MSG_ERROR => {
363 if payload.is_empty() {
364 DtxPayload::Response(NSObject::Null)
365 } else {
366 let obj = crate::proto::nskeyedarchiver::unarchive(&payload)
367 .map(archive_to_ns)
368 .unwrap_or(NSObject::Data(payload));
369 DtxPayload::Response(obj)
370 }
371 }
372 MSG_BARRIER => DtxPayload::Empty,
373 MSG_UNKNOWN_TYPE_ONE => match aux {
374 Some(aux) => DtxPayload::RawWithAux {
375 payload,
376 aux: super::primitive::decode_auxiliary(aux),
377 },
378 None => DtxPayload::Raw(payload),
379 },
380 _ => {
381 if payload.is_empty() {
382 DtxPayload::Empty
383 } else {
384 DtxPayload::Raw(payload)
385 }
386 }
387 }
388}
389
390fn archive_to_ns(v: crate::proto::nskeyedarchiver::ArchiveValue) -> NSObject {
391 use crate::proto::nskeyedarchiver::ArchiveValue;
392 match v {
393 ArchiveValue::Null => NSObject::Null,
394 ArchiveValue::Bool(b) => NSObject::Bool(b),
395 ArchiveValue::Int(n) => NSObject::Int(n),
396 ArchiveValue::Float(f) => NSObject::Double(f),
397 ArchiveValue::String(s) => NSObject::String(s),
398 ArchiveValue::Data(d) => NSObject::Data(d),
399 ArchiveValue::Array(a) => NSObject::Array(a.into_iter().map(archive_to_ns).collect()),
400 ArchiveValue::Dict(d) => {
401 NSObject::Dict(d.into_iter().map(|(k, v)| (k, archive_to_ns(v))).collect())
402 }
403 ArchiveValue::Unknown(s) => NSObject::String(format!("<{s}>")),
404 }
405}
406
407struct FragmentAccum {
411 header: DtxHeader,
413 fragments: Vec<Option<Vec<u8>>>,
415 remaining: u16,
417}
418
419pub struct DtxConnection<S> {
423 stream: S,
424 identifier: u32,
430 channel_counter: i32,
431 pending_replies: HashMap<u32, DtxMessage>,
433 outstanding_reply_ids: HashSet<u32>,
435 queued_messages: VecDeque<DtxMessage>,
437 fragments: HashMap<u32, FragmentAccum>,
439}
440
441impl<S: AsyncRead + AsyncWrite + Unpin + Send> DtxConnection<S> {
442 pub fn new(stream: S) -> Self {
443 Self {
445 stream,
446 identifier: 5,
447 channel_counter: 1,
448 pending_replies: HashMap::new(),
449 outstanding_reply_ids: HashSet::new(),
450 queued_messages: VecDeque::new(),
451 fragments: HashMap::new(),
452 }
453 }
454
455 fn next_id(&mut self) -> u32 {
456 let id = self.identifier;
457 self.identifier += 1;
458 id
459 }
460
461 fn next_channel_code(&mut self) -> i32 {
462 let code = self.channel_counter;
463 self.channel_counter += 1;
464 code
465 }
466
467 pub async fn send_raw(&mut self, data: &[u8]) -> Result<(), DtxError> {
468 self.stream.write_all(data).await?;
469 self.stream.flush().await?;
470 Ok(())
471 }
472
473 pub async fn send_ack(&mut self, msg: &DtxMessage) -> Result<(), DtxError> {
474 self.send_raw(&encode_ack(msg)).await
475 }
476
477 fn buffer_reply(&mut self, msg: DtxMessage) {
478 if let Some(previous) = self.pending_replies.insert(msg.identifier, msg.clone()) {
479 tracing::trace!(
480 "buffer_reply: replacing pending reply id={} old_conv={} new_conv={}",
481 previous.identifier,
482 previous.conversation_idx,
483 msg.conversation_idx
484 );
485 }
486 }
487
488 fn is_reply_message(&self, msg: &DtxMessage) -> bool {
489 msg.conversation_idx > 0 && self.outstanding_reply_ids.contains(&msg.identifier)
490 }
491
492 async fn recv_from_stream(&mut self) -> Result<DtxMessage, DtxError> {
493 loop {
494 let h = read_dtx_header(&mut self.stream).await?;
495 tracing::trace!(
496 "recv: frag_idx={} frag_cnt={} msg_len={} id={}",
497 h.frag_idx,
498 h.frag_cnt,
499 h.msg_len,
500 h.identifier
501 );
502
503 if h.frag_cnt <= 1 {
504 if h.msg_len == 0 {
506 return Ok(DtxMessage {
507 identifier: h.identifier,
508 conversation_idx: h.conv_idx,
509 channel_code: normalize_incoming_channel_code(h.channel_code, h.conv_idx),
510 expects_reply: h.expects_reply,
511 payload: DtxPayload::Empty,
512 });
513 }
514 let mut msg = read_dtx_body(&mut self.stream, &h, &[]).await?;
515 msg.channel_code =
516 normalize_incoming_channel_code(msg.channel_code, msg.conversation_idx);
517 return Ok(msg);
518 }
519
520 if h.frag_idx == 0 {
521 if self.fragments.contains_key(&h.identifier) {
523 return Err(DtxError::Protocol(format!(
524 "duplicate first fragment for id={}",
525 h.identifier
526 )));
527 }
528 self.fragments.insert(
529 h.identifier,
530 FragmentAccum {
531 fragments: vec![None; (h.frag_cnt - 1) as usize],
532 remaining: h.frag_cnt - 1,
533 header: h,
534 },
535 );
536 continue;
537 }
538
539 let mut frag_body = vec![0u8; h.msg_len];
541 self.stream.read_exact(&mut frag_body).await?;
542
543 let id = h.identifier;
544 if let Some(accum) = self.fragments.get_mut(&id) {
545 if h.frag_cnt != accum.header.frag_cnt {
546 return Err(DtxError::Protocol(format!(
547 "fragment count mismatch for id={id}: got={} expected={}",
548 h.frag_cnt, accum.header.frag_cnt
549 )));
550 }
551 if h.conv_idx != accum.header.conv_idx
552 || h.channel_code != accum.header.channel_code
553 || h.expects_reply != accum.header.expects_reply
554 {
555 return Err(DtxError::Protocol(format!(
556 "fragment metadata mismatch for id={id}"
557 )));
558 }
559 let slot_idx = h
560 .frag_idx
561 .checked_sub(1)
562 .map(|idx| idx as usize)
563 .ok_or_else(|| {
564 DtxError::Protocol(format!(
565 "invalid fragment index {} for id={id}",
566 h.frag_idx
567 ))
568 })?;
569 let slot = accum.fragments.get_mut(slot_idx).ok_or_else(|| {
570 DtxError::Protocol(format!(
571 "fragment index {} out of range for id={id}",
572 h.frag_idx
573 ))
574 })?;
575 if slot.is_some() {
576 return Err(DtxError::Protocol(format!(
577 "duplicate fragment {} for id={id}",
578 h.frag_idx
579 )));
580 }
581 *slot = Some(frag_body);
582 accum.remaining -= 1;
583 if accum.remaining == 0 {
584 let accum = self.fragments.remove(&id).ok_or_else(|| {
585 DtxError::Protocol(format!("missing fragment accumulator for id={id}"))
586 })?;
587 let mut body = Vec::with_capacity(accum.header.msg_len);
588 for (index, fragment) in accum.fragments.into_iter().enumerate() {
589 let fragment = fragment.ok_or_else(|| {
590 DtxError::Protocol(format!(
591 "missing fragment {} for id={id}",
592 index + 1
593 ))
594 })?;
595 body.extend_from_slice(&fragment);
596 }
597 if body.len() != accum.header.msg_len {
598 return Err(DtxError::Protocol(format!(
599 "fragmented body size mismatch for id={id}: assembled={} expected={}",
600 body.len(),
601 accum.header.msg_len
602 )));
603 }
604 let mut msg = read_dtx_body(&mut self.stream, &accum.header, &body).await?;
605 msg.channel_code =
606 normalize_incoming_channel_code(msg.channel_code, msg.conversation_idx);
607 return Ok(msg);
608 }
609 } else {
610 return Err(DtxError::Protocol(format!(
611 "fragment id={id} frag_idx={} without first fragment",
612 h.frag_idx
613 )));
614 }
615 }
616 }
617
618 async fn wait_for_reply(&mut self, id: u32) -> Result<DtxMessage, DtxError> {
619 if let Some(msg) = self.pending_replies.remove(&id) {
620 self.outstanding_reply_ids.remove(&id);
621 return Ok(msg);
622 }
623
624 loop {
625 let msg = self.recv_from_stream().await?;
626 tracing::trace!(
627 "wait_for_reply: target_id={} recv id={} conv_idx={} ch={} expects_reply={}",
628 id,
629 msg.identifier,
630 msg.conversation_idx,
631 msg.channel_code,
632 msg.expects_reply
633 );
634
635 if self.is_reply_message(&msg) {
636 if msg.identifier == id {
637 self.outstanding_reply_ids.remove(&id);
638 return Ok(msg);
639 }
640 self.buffer_reply(msg);
641 continue;
642 }
643
644 if msg.expects_reply {
645 self.send_ack(&msg).await?;
646 }
647 self.queued_messages.push_back(msg);
648 }
649 }
650
651 pub async fn recv(&mut self) -> Result<DtxMessage, DtxError> {
653 if let Some(msg) = self.queued_messages.pop_front() {
654 return Ok(msg);
655 }
656
657 loop {
658 let msg = self.recv_from_stream().await?;
659 if self.is_reply_message(&msg) {
660 self.buffer_reply(msg);
661 continue;
662 }
663 return Ok(msg);
664 }
665 }
666
667 pub async fn request_channel(&mut self, service_name: &str) -> Result<i32, DtxError> {
670 let channel_code = self.next_channel_code();
671 let id = self.next_id();
672
673 let selector =
674 nskeyedarchiver_encode::archive_string("_requestChannelWithCode:identifier:");
675 let arg_name = nskeyedarchiver_encode::archive_string(service_name);
676
677 let aux = encode_primitive_dict(&[
679 PrimArg::Int32(channel_code),
680 PrimArg::Bytes(Bytes::from(arg_name)),
681 ]);
682
683 let frame = encode_dtx(id, 0, 0, true, MSG_METHOD_INVOCATION, &selector, &aux);
684 self.send_raw(&frame).await?;
685 self.outstanding_reply_ids.insert(id);
686
687 let msg = self.wait_for_reply(id).await?;
689 tracing::debug!(
690 "request_channel recv: id={} conv_idx={} ch={} expects_reply={}",
691 msg.identifier,
692 msg.conversation_idx,
693 msg.channel_code,
694 msg.expects_reply
695 );
696 Ok(channel_code)
697 }
698
699 pub async fn method_call(
701 &mut self,
702 channel_code: i32,
703 selector: &str,
704 args: &[PrimArg],
705 ) -> Result<DtxMessage, DtxError> {
706 let id = self.next_id();
707 let sel_bytes = nskeyedarchiver_encode::archive_string(selector);
708 let aux = if args.is_empty() {
709 Bytes::new()
710 } else {
711 encode_primitive_dict(args)
712 };
713 let frame = encode_dtx(
714 id,
715 0,
716 channel_code,
717 true,
718 MSG_METHOD_INVOCATION,
719 &sel_bytes,
720 &aux,
721 );
722 self.send_raw(&frame).await?;
723 self.outstanding_reply_ids.insert(id);
724 tracing::debug!("method_call '{selector}' id={id} ch={channel_code}");
725
726 let msg = self.wait_for_reply(id).await?;
727 tracing::debug!(
728 "method_call recv: id={} conv_idx={} ch={}",
729 msg.identifier,
730 msg.conversation_idx,
731 msg.channel_code
732 );
733 Ok(msg)
734 }
735
736 pub async fn method_call_async(
738 &mut self,
739 channel_code: i32,
740 selector: &str,
741 args: &[PrimArg],
742 ) -> Result<(), DtxError> {
743 let id = self.next_id();
744 let sel_bytes = nskeyedarchiver_encode::archive_string(selector);
745 let aux = if args.is_empty() {
746 Bytes::new()
747 } else {
748 encode_primitive_dict(args)
749 };
750 let frame = encode_dtx(
751 id,
752 0,
753 channel_code,
754 false,
755 MSG_METHOD_INVOCATION,
756 &sel_bytes,
757 &aux,
758 );
759 self.send_raw(&frame).await
760 }
761}
762
763fn normalize_incoming_channel_code(channel_code: i32, conversation_idx: u32) -> i32 {
764 if conversation_idx % 2 == 0 {
765 -channel_code
766 } else {
767 channel_code
768 }
769}
770
771#[cfg(test)]
772mod tests {
773 use bytes::BufMut;
774
775 use super::*;
776
777 #[test]
778 fn test_encode_dtx_layout() {
779 let sel = nskeyedarchiver_encode::archive_string("test");
780 let frame = encode_dtx(1, 0, 1, true, MSG_METHOD_INVOCATION, &sel, &[]);
781 assert_eq!(
782 u32::from_be_bytes(frame[0..4].try_into().unwrap()),
783 DTX_MAGIC
784 );
785 assert_eq!(u32::from_le_bytes(frame[4..8].try_into().unwrap()), 32);
786 assert_eq!(u32::from_le_bytes(frame[28..32].try_into().unwrap()), 1); assert_eq!(
788 u32::from_le_bytes(frame[32..36].try_into().unwrap()),
789 MSG_METHOD_INVOCATION
790 );
791 }
792
793 #[test]
794 fn test_encode_ack_length() {
795 let msg = DtxMessage {
796 identifier: 5,
797 conversation_idx: 0,
798 channel_code: 1,
799 expects_reply: true,
800 payload: DtxPayload::Empty,
801 };
802 let ack = encode_ack(&msg);
803 assert_eq!(ack.len(), 48);
804 assert_eq!(u32::from_le_bytes(ack[32..36].try_into().unwrap()), MSG_OK);
805 }
806
807 #[tokio::test]
808 async fn test_dtx_encode_decode_roundtrip() {
809 let sel = nskeyedarchiver_encode::archive_string("setConfig:");
810 let frame = encode_dtx(7, 0, 2, true, MSG_METHOD_INVOCATION, &sel, &[]);
811 let mut cur = std::io::Cursor::new(frame);
812 let msg = read_dtx_frame(&mut cur).await.unwrap();
813 assert_eq!(msg.identifier, 7);
814 assert_eq!(msg.channel_code, 2);
815 assert!(msg.expects_reply);
816 if let DtxPayload::MethodInvocation { selector, .. } = &msg.payload {
818 assert_eq!(selector, "setConfig:");
819 } else {
820 panic!("expected MethodInvocation");
821 }
822 }
823
824 #[tokio::test]
825 async fn test_data_frame_keeps_raw_payload() {
826 let payload = b"trace-binary-payload";
827 let frame = encode_dtx(11, 0, 4, false, MSG_UNKNOWN_TYPE_ONE, payload, &[]);
828 let mut cur = std::io::Cursor::new(frame);
829 let msg = read_dtx_frame(&mut cur).await.unwrap();
830 match msg.payload {
831 DtxPayload::Raw(bytes) => assert_eq!(bytes.as_ref(), payload),
832 other => panic!("expected raw payload, got {other:?}"),
833 }
834 }
835
836 #[tokio::test]
837 async fn test_data_frame_preserves_auxiliary_arguments() {
838 let payload = b"trace-binary-payload";
839 let aux =
840 encode_primitive_dict(&[PrimArg::Bytes(Bytes::from_static(b"kperf-aux-payload"))]);
841 let frame = encode_dtx(13, 0, 4, false, MSG_UNKNOWN_TYPE_ONE, payload, &aux);
842 let mut cur = std::io::Cursor::new(frame);
843 let msg = read_dtx_frame(&mut cur).await.unwrap();
844
845 match msg.payload {
846 DtxPayload::RawWithAux { payload: body, aux } => {
847 assert_eq!(body.as_ref(), payload);
848 assert!(matches!(
849 aux.first(),
850 Some(NSObject::Data(bytes)) if bytes.as_ref() == b"kperf-aux-payload"
851 ));
852 }
853 other => panic!("expected raw payload with aux, got {other:?}"),
854 }
855 }
856
857 #[tokio::test]
858 async fn test_method_invocation_preserves_raw_payload_when_selector_decode_fails() {
859 let payload = b"not-a-selector";
860 let frame = encode_dtx(12, 0, 4, false, MSG_METHOD_INVOCATION, payload, &[]);
861 let mut cur = std::io::Cursor::new(frame);
862 let msg = read_dtx_frame(&mut cur).await.unwrap();
863
864 match msg.payload {
865 DtxPayload::MethodInvocation { selector, args } => {
866 assert!(selector.is_empty());
867 assert!(
868 matches!(args.first(), Some(NSObject::Data(bytes)) if bytes.as_ref() == payload)
869 );
870 }
871 other => panic!("expected method invocation, got {other:?}"),
872 }
873 }
874
875 #[tokio::test]
876 async fn test_method_call_buffers_unrelated_notifications() {
877 let (client, mut server) = tokio::io::duplex(4096);
878 let mut conn = DtxConnection::new(client);
879
880 let call = tokio::spawn(async move {
881 conn.method_call(2, "startSampling", &[])
882 .await
883 .map(|reply| (conn, reply))
884 });
885
886 let outbound = read_dtx_frame(&mut server).await.unwrap();
887 assert_eq!(outbound.identifier, 5);
888 assert!(outbound.expects_reply);
889
890 let notify_selector = nskeyedarchiver_encode::archive_string("note:");
891 let notify = encode_dtx(77, 0, 1, true, MSG_METHOD_INVOCATION, ¬ify_selector, &[]);
892 server.write_all(¬ify).await.unwrap();
893
894 let ack = read_dtx_frame(&mut server).await.unwrap();
895 assert_eq!(ack.identifier, 77);
896 assert_eq!(ack.conversation_idx, 1);
897
898 let reply = encode_dtx(5, 1, 2, false, MSG_RESPONSE, &[], &[]);
899 server.write_all(&reply).await.unwrap();
900
901 let (mut conn, reply) = call.await.unwrap().unwrap();
902 assert_eq!(reply.identifier, 5);
903 assert_eq!(reply.conversation_idx, 1);
904
905 let queued = conn.recv().await.unwrap();
906 assert_eq!(queued.identifier, 77);
907 assert_eq!(queued.channel_code, -1);
908 assert!(queued.expects_reply);
909 }
910
911 #[tokio::test]
912 async fn test_recv_normalizes_even_conversation_channel_codes() {
913 let (client, mut server) = tokio::io::duplex(256);
914 let mut conn = DtxConnection::new(client);
915
916 let recv_task = tokio::spawn(async move { conn.recv().await });
917
918 let payload = b"trace-binary-payload";
919 let frame = encode_dtx(42, 0, -2, false, MSG_UNKNOWN_TYPE_ONE, payload, &[]);
920 server.write_all(&frame).await.unwrap();
921
922 let msg = recv_task.await.unwrap().unwrap();
923 assert_eq!(msg.identifier, 42);
924 assert_eq!(msg.conversation_idx, 0);
925 assert_eq!(msg.channel_code, 2);
926 match msg.payload {
927 DtxPayload::Raw(bytes) => assert_eq!(bytes.as_ref(), payload),
928 other => panic!("expected raw payload, got {other:?}"),
929 }
930 }
931
932 #[tokio::test]
933 async fn test_wait_for_reply_returns_buffered_reply_immediately() {
934 let (client, _server) = tokio::io::duplex(64);
935 let mut conn = DtxConnection::new(client);
936
937 conn.buffer_reply(DtxMessage {
938 identifier: 9,
939 conversation_idx: 1,
940 channel_code: 3,
941 expects_reply: false,
942 payload: DtxPayload::Empty,
943 });
944
945 let reply = conn.wait_for_reply(9).await.unwrap();
946 assert_eq!(reply.identifier, 9);
947 assert_eq!(reply.conversation_idx, 1);
948 assert_eq!(reply.channel_code, 3);
949 }
950
951 #[tokio::test]
952 async fn test_recv_treats_unsolicited_conversation_message_as_live_event() {
953 let (client, mut server) = tokio::io::duplex(256);
954 let mut conn = DtxConnection::new(client);
955
956 let recv_task = tokio::spawn(async move { conn.recv().await });
957
958 let payload = b"trace-binary-payload";
959 let frame = encode_dtx(42, 1, 2, false, MSG_UNKNOWN_TYPE_ONE, payload, &[]);
960 server.write_all(&frame).await.unwrap();
961
962 let msg = recv_task.await.unwrap().unwrap();
963 assert_eq!(msg.identifier, 42);
964 assert_eq!(msg.conversation_idx, 1);
965 assert_eq!(msg.channel_code, 2);
966 match msg.payload {
967 DtxPayload::Raw(bytes) => assert_eq!(bytes.as_ref(), payload),
968 other => panic!("expected raw payload, got {other:?}"),
969 }
970 }
971
972 #[tokio::test]
973 async fn test_read_dtx_frame_skips_extended_header_bytes() {
974 let payload = b"trace-binary-payload";
975 let mut frame = encode_dtx(21, 0, 4, false, MSG_UNKNOWN_TYPE_ONE, payload, &[]).to_vec();
976
977 frame[4..8].copy_from_slice(&36u32.to_le_bytes());
978 frame.splice(32..32, [0xAA, 0xBB, 0xCC, 0xDD]);
979
980 let mut cur = std::io::Cursor::new(frame);
981 let msg = read_dtx_frame(&mut cur)
982 .await
983 .expect("extended headers should be skipped before parsing payload");
984
985 match msg.payload {
986 DtxPayload::Raw(bytes) => assert_eq!(bytes.as_ref(), payload),
987 other => panic!("expected raw payload, got {other:?}"),
988 }
989 }
990
991 #[tokio::test]
992 async fn test_ok_reply_decodes_as_empty_payload() {
993 let frame = encode_dtx(23, 1, 4, false, MSG_OK, &[], &[]);
994 let mut cur = std::io::Cursor::new(frame);
995 let msg = read_dtx_frame(&mut cur).await.unwrap();
996 assert!(matches!(msg.payload, DtxPayload::Empty));
997 }
998
999 #[tokio::test]
1000 async fn test_error_reply_decodes_like_response_object() {
1001 let payload = nskeyedarchiver_encode::archive_string("selector failed");
1002 let frame = encode_dtx(24, 1, 4, false, MSG_ERROR, &payload, &[]);
1003 let mut cur = std::io::Cursor::new(frame);
1004 let msg = read_dtx_frame(&mut cur).await.unwrap();
1005 assert!(matches!(
1006 msg.payload,
1007 DtxPayload::Response(NSObject::String(ref value)) if value == "selector failed"
1008 ));
1009 }
1010
1011 fn encode_fragment(
1012 identifier: u32,
1013 frag_idx: u16,
1014 frag_cnt: u16,
1015 channel_code: i32,
1016 expects_reply: bool,
1017 msg_len: usize,
1018 body: &[u8],
1019 ) -> Bytes {
1020 let mut out = BytesMut::with_capacity(32 + body.len());
1021 out.put_u32(DTX_MAGIC);
1022 out.put_u32_le(32);
1023 out.put_u16_le(frag_idx);
1024 out.put_u16_le(frag_cnt);
1025 out.put_u32_le(msg_len as u32);
1026 out.put_u32_le(identifier);
1027 out.put_u32_le(0);
1028 out.put_u32_le(channel_code as u32);
1029 out.put_u32_le(if expects_reply { 1 } else { 0 });
1030 out.extend_from_slice(body);
1031 out.freeze()
1032 }
1033
1034 #[tokio::test]
1035 async fn test_recv_reassembles_out_of_order_fragments_by_index() {
1036 let payload = b"fragmented-trace-payload";
1037 let mut body = BytesMut::with_capacity(16 + payload.len());
1038 body.put_u32_le(MSG_UNKNOWN_TYPE_ONE);
1039 body.put_u32_le(0);
1040 body.put_u32_le(payload.len() as u32);
1041 body.put_u32_le(0);
1042 body.extend_from_slice(payload);
1043 let body = body.freeze();
1044
1045 let split_at = 10;
1046 let first = encode_fragment(31, 0, 3, 4, false, body.len(), &[]);
1047 let second = encode_fragment(31, 1, 3, 4, false, split_at, &body[..split_at]);
1048 let third = encode_fragment(31, 2, 3, 4, false, body.len() - split_at, &body[split_at..]);
1049
1050 let (client, mut server) = tokio::io::duplex(512);
1051 let mut conn = DtxConnection::new(client);
1052
1053 let recv_task = tokio::spawn(async move { conn.recv_from_stream().await });
1054 server.write_all(&first).await.unwrap();
1055 server.write_all(&third).await.unwrap();
1056 server.write_all(&second).await.unwrap();
1057
1058 let msg = recv_task
1059 .await
1060 .unwrap()
1061 .expect("fragment order should not affect reassembly");
1062
1063 match msg.payload {
1064 DtxPayload::Raw(bytes) => assert_eq!(bytes.as_ref(), payload),
1065 other => panic!("expected raw payload, got {other:?}"),
1066 }
1067 }
1068
1069 #[tokio::test]
1070 async fn test_recv_rejects_duplicate_first_fragment() {
1071 let payload = b"fragmented-trace-payload";
1072 let mut body = BytesMut::with_capacity(16 + payload.len());
1073 body.put_u32_le(MSG_UNKNOWN_TYPE_ONE);
1074 body.put_u32_le(0);
1075 body.put_u32_le(payload.len() as u32);
1076 body.put_u32_le(0);
1077 body.extend_from_slice(payload);
1078 let body = body.freeze();
1079
1080 let first = encode_fragment(41, 0, 2, 4, false, body.len(), &[]);
1081 let duplicate_first = encode_fragment(41, 0, 2, 4, false, body.len(), &[]);
1082
1083 let (client, mut server) = tokio::io::duplex(512);
1084 let mut conn = DtxConnection::new(client);
1085
1086 let recv_task = tokio::spawn(async move { conn.recv_from_stream().await });
1087 server.write_all(&first).await.unwrap();
1088 server.write_all(&duplicate_first).await.unwrap();
1089
1090 let err = recv_task.await.unwrap().unwrap_err();
1091 assert!(matches!(
1092 err,
1093 DtxError::Protocol(message) if message.contains("duplicate first fragment")
1094 ));
1095 }
1096
1097 #[tokio::test]
1098 async fn test_recv_rejects_fragment_without_first_fragment() {
1099 let payload = b"fragmented-trace-payload";
1100 let mut body = BytesMut::with_capacity(16 + payload.len());
1101 body.put_u32_le(MSG_UNKNOWN_TYPE_ONE);
1102 body.put_u32_le(0);
1103 body.put_u32_le(payload.len() as u32);
1104 body.put_u32_le(0);
1105 body.extend_from_slice(payload);
1106 let body = body.freeze();
1107
1108 let stray = encode_fragment(43, 1, 2, 4, false, body.len(), &body);
1109
1110 let (client, mut server) = tokio::io::duplex(512);
1111 let mut conn = DtxConnection::new(client);
1112
1113 let recv_task = tokio::spawn(async move { conn.recv_from_stream().await });
1114 server.write_all(&stray).await.unwrap();
1115
1116 let err = recv_task.await.unwrap().unwrap_err();
1117 assert!(matches!(
1118 err,
1119 DtxError::Protocol(message) if message.contains("without first fragment")
1120 ));
1121 }
1122
1123 #[tokio::test]
1124 async fn test_recv_rejects_fragment_metadata_mismatch() {
1125 let payload = b"fragmented-trace-payload";
1126 let mut body = BytesMut::with_capacity(16 + payload.len());
1127 body.put_u32_le(MSG_UNKNOWN_TYPE_ONE);
1128 body.put_u32_le(0);
1129 body.put_u32_le(payload.len() as u32);
1130 body.put_u32_le(0);
1131 body.extend_from_slice(payload);
1132 let body = body.freeze();
1133
1134 let split_at = 10;
1135 let first = encode_fragment(45, 0, 3, 4, false, body.len(), &[]);
1136 let bad_second = encode_fragment(45, 1, 3, 5, false, split_at, &body[..split_at]);
1137
1138 let (client, mut server) = tokio::io::duplex(512);
1139 let mut conn = DtxConnection::new(client);
1140
1141 let recv_task = tokio::spawn(async move { conn.recv_from_stream().await });
1142 server.write_all(&first).await.unwrap();
1143 server.write_all(&bad_second).await.unwrap();
1144
1145 let err = recv_task.await.unwrap().unwrap_err();
1146 assert!(matches!(
1147 err,
1148 DtxError::Protocol(message) if message.contains("fragment metadata mismatch")
1149 ));
1150 }
1151}