1use std::collections::{HashMap, HashSet, VecDeque};
13
14use crate::proto::dtx::DTX_MAGIC;
15use crate::proto::nskeyedarchiver_encode;
16use bytes::{BufMut, Bytes, BytesMut};
17use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
18
19use super::primitive_enc::{encode_primitive_dict, PrimArg};
20use super::types::{DtxMessage, DtxPayload, NSObject};
21
22const MAX_DTX_MESSAGE_SIZE: usize = 128 * 1024 * 1024;
25const MAX_DTX_FRAGMENTS: u16 = 1024;
26
27const MSG_OK: u32 = 0;
28const MSG_UNKNOWN_TYPE_ONE: u32 = 1; const MSG_METHOD_INVOCATION: u32 = 2;
30const MSG_RESPONSE: u32 = 3;
31const MSG_ERROR: u32 = 4;
32const MSG_BARRIER: u32 = 5;
33const _MSG_LZ4_COMPRESSED: u32 = 0x0707; #[derive(Debug, thiserror::Error)]
38pub enum DtxError {
39 #[error("IO error: {0}")]
40 Io(#[from] std::io::Error),
41 #[error("bad magic: 0x{0:08X}")]
42 BadMagic(u32),
43 #[error("protocol error: {0}")]
44 Protocol(String),
45}
46
47pub fn encode_dtx(
51 identifier: u32,
52 conv_idx: u32,
53 channel_code: i32,
54 expects_reply: bool,
55 msg_type: u32,
56 payload: &[u8],
57 aux_bytes: &[u8],
58) -> Bytes {
59 let aux_len = aux_bytes.len();
60 let payload_len = payload.len();
61
62 let aux_with_hdr = if aux_len > 0 { aux_len + 16 } else { 0 };
64 let total_payload = aux_with_hdr + payload_len;
65 let msg_len = 16 + aux_with_hdr + payload_len;
66
67 let mut out = BytesMut::with_capacity(32 + msg_len);
68
69 out.put_u32(DTX_MAGIC); out.put_u32_le(32); out.put_u16_le(0); out.put_u16_le(1); out.put_u32_le(msg_len as u32); out.put_u32_le(identifier); out.put_u32_le(conv_idx); out.put_u32_le(channel_code as u32); out.put_u32_le(if expects_reply { 1 } else { 0 }); out.put_u32_le(msg_type);
82 out.put_u32_le(aux_with_hdr as u32);
83 out.put_u32_le(total_payload as u32);
84 out.put_u32_le(0); if aux_len > 0 {
87 out.put_u32_le(496);
89 out.put_u32_le(0);
90 out.put_u32_le(aux_len as u32);
91 out.put_u32_le(0);
92 out.put_slice(aux_bytes);
93 }
94 out.put_slice(payload);
95
96 out.freeze()
97}
98
99pub fn encode_ack(msg: &DtxMessage) -> Bytes {
101 let mut out = BytesMut::with_capacity(48);
102 out.put_u32(DTX_MAGIC);
103 out.put_u32_le(32);
104 out.put_u16_le(0);
105 out.put_u16_le(1);
106 out.put_u32_le(16); out.put_u32_le(msg.identifier);
108 out.put_u32_le(msg.conversation_idx + 1);
109 out.put_u32_le(msg.channel_code as u32);
110 out.put_u32_le(0); out.put_u32_le(MSG_OK);
113 out.put_u32_le(0);
114 out.put_u32_le(0);
115 out.put_u32_le(0);
116 out.freeze()
117}
118
119struct DtxHeader {
123 header_len: usize,
124 frag_idx: u16,
125 frag_cnt: u16,
126 msg_len: usize,
127 identifier: u32,
128 conv_idx: u32,
129 channel_code: i32,
130 expects_reply: bool,
131}
132
133fn parse_dtx_header(hdr: &[u8; 32]) -> Result<DtxHeader, DtxError> {
135 let magic = u32::from_be_bytes(hdr[0..4].try_into().unwrap());
136 if magic != DTX_MAGIC {
137 return Err(DtxError::BadMagic(magic));
138 }
139 let header_len = u32::from_le_bytes(hdr[4..8].try_into().unwrap()) as usize;
140 if header_len < 32 {
141 return Err(DtxError::Protocol(format!(
142 "invalid DTX header length: {header_len}"
143 )));
144 }
145 let frag_idx = u16::from_le_bytes(hdr[8..10].try_into().unwrap());
146 let frag_cnt = u16::from_le_bytes(hdr[10..12].try_into().unwrap());
147 if frag_cnt == 0 {
148 return Err(DtxError::Protocol("invalid DTX fragment count: 0".into()));
149 }
150 if frag_cnt > MAX_DTX_FRAGMENTS {
151 return Err(DtxError::Protocol(format!(
152 "DTX message has too many fragments: {frag_cnt} exceeds {MAX_DTX_FRAGMENTS}"
153 )));
154 }
155 if frag_idx >= frag_cnt {
156 return Err(DtxError::Protocol(format!(
157 "invalid DTX fragment index {frag_idx} for count {frag_cnt}"
158 )));
159 }
160 let msg_len = u32::from_le_bytes(hdr[12..16].try_into().unwrap()) as usize;
161 if msg_len > MAX_DTX_MESSAGE_SIZE {
162 return Err(DtxError::Protocol(format!(
163 "DTX message length {msg_len} exceeds max {MAX_DTX_MESSAGE_SIZE}"
164 )));
165 }
166 if frag_cnt > 1 && frag_idx == 0 && msg_len == 0 {
167 return Err(DtxError::Protocol(
168 "multi-fragment first header declares zero total size".into(),
169 ));
170 }
171 Ok(DtxHeader {
172 header_len,
173 frag_idx,
174 frag_cnt,
175 msg_len,
176 identifier: u32::from_le_bytes(hdr[16..20].try_into().unwrap()),
177 conv_idx: u32::from_le_bytes(hdr[20..24].try_into().unwrap()),
178 channel_code: i32::from_le_bytes(hdr[24..28].try_into().unwrap()),
179 expects_reply: u32::from_le_bytes(hdr[28..32].try_into().unwrap()) != 0,
180 })
181}
182
183async fn read_dtx_header<R: AsyncRead + Unpin>(reader: &mut R) -> Result<DtxHeader, DtxError> {
184 let mut hdr = [0u8; 32];
185 reader.read_exact(&mut hdr).await?;
186 let header = parse_dtx_header(&hdr)?;
187 if header.header_len > 32 {
188 let mut extra = vec![0u8; header.header_len - 32];
189 reader.read_exact(&mut extra).await?;
190 }
191 Ok(header)
192}
193
194fn decode_dtx_body_from_slice(h: &DtxHeader, body_slice: &[u8]) -> Result<DtxMessage, DtxError> {
195 if body_slice.len() < 16 {
196 return Ok(DtxMessage {
197 identifier: h.identifier,
198 conversation_idx: h.conv_idx,
199 channel_code: h.channel_code,
200 expects_reply: h.expects_reply,
201 payload: DtxPayload::Empty,
202 });
203 }
204
205 let ph = &body_slice[0..16];
206 let msg_type = u32::from_le_bytes(ph[0..4].try_into().unwrap());
207 let aux_len = u32::from_le_bytes(ph[4..8].try_into().unwrap()) as usize;
208 let total_pay = u32::from_le_bytes(ph[8..12].try_into().unwrap()) as usize;
209
210 if aux_len > total_pay {
211 return Err(DtxError::Protocol(format!(
212 "aux_len ({aux_len}) exceeds total_pay ({total_pay})"
213 )));
214 }
215 let pay_len = total_pay - aux_len;
216 let rest = &body_slice[16..];
217
218 let aux_data = if aux_len > 0 {
219 if rest.len() < 16 {
220 return Err(DtxError::Protocol("aux header truncated".into()));
221 }
222 let actual_aux = u32::from_le_bytes(rest[8..12].try_into().unwrap()) as usize;
223 if actual_aux > aux_len.saturating_sub(16) {
224 return Err(DtxError::Protocol(format!(
225 "auxiliary data size ({actual_aux}) exceeds available space ({})",
226 aux_len.saturating_sub(16)
227 )));
228 }
229 let aux_start = 16;
230 let aux_end = aux_start + actual_aux;
231 if rest.len() < aux_end {
232 return Err(DtxError::Protocol("aux data truncated".into()));
233 }
234 Some(Bytes::copy_from_slice(&rest[aux_start..aux_end]))
235 } else {
236 None
237 };
238
239 let pay_start = aux_len;
240 let pay_end = pay_start + pay_len;
241 let payload_bytes = if pay_len > 0 && rest.len() >= pay_end {
242 Bytes::copy_from_slice(&rest[pay_start..pay_end])
243 } else {
244 Bytes::new()
245 };
246
247 let payload = decode_payload(msg_type, payload_bytes, aux_data);
248 Ok(DtxMessage {
249 identifier: h.identifier,
250 conversation_idx: h.conv_idx,
251 channel_code: h.channel_code,
252 expects_reply: h.expects_reply,
253 payload,
254 })
255}
256
257pub fn decode_dtx_message_from_bytes(data: &[u8]) -> Result<Option<(DtxMessage, usize)>, DtxError> {
258 if data.len() < 32 {
259 return Ok(None);
260 }
261
262 let header_bytes: &[u8; 32] = data[..32]
263 .try_into()
264 .map_err(|_| DtxError::Protocol("DTX header truncated".into()))?;
265 let header = parse_dtx_header(header_bytes)?;
266 let total_len = header
267 .header_len
268 .checked_add(header.msg_len)
269 .ok_or_else(|| DtxError::Protocol("DTX frame length overflow".into()))?;
270 if data.len() < total_len {
271 return Ok(None);
272 }
273
274 let body = &data[header.header_len..total_len];
275 let message = decode_dtx_body_from_slice(&header, body)?;
276 Ok(Some((message, total_len)))
277}
278
279async fn read_dtx_body<R: AsyncRead + Unpin>(
282 reader: &mut R,
283 h: &DtxHeader,
284 body: &[u8], ) -> Result<DtxMessage, DtxError> {
286 let body_owned: Vec<u8>;
288 let body_slice: &[u8] = if body.is_empty() && h.msg_len > 0 {
289 body_owned = {
290 let mut b = vec![0u8; h.msg_len];
291 reader.read_exact(&mut b).await?;
292 b
293 };
294 &body_owned
295 } else {
296 body
297 };
298
299 decode_dtx_body_from_slice(h, body_slice)
300}
301
302pub async fn read_dtx_frame<R: AsyncRead + Unpin>(reader: &mut R) -> Result<DtxMessage, DtxError> {
303 let h = read_dtx_header(reader).await?;
304 tracing::trace!(
305 "read_dtx_frame: frag_idx={} frag_cnt={} msg_len={} id={}",
306 h.frag_idx,
307 h.frag_cnt,
308 h.msg_len,
309 h.identifier
310 );
311
312 if h.frag_cnt > 1 && h.frag_idx == 0 {
314 return Ok(DtxMessage {
315 identifier: h.identifier,
316 conversation_idx: h.conv_idx,
317 channel_code: h.channel_code,
318 expects_reply: h.expects_reply,
319 payload: DtxPayload::Empty,
320 });
321 }
322
323 if h.msg_len == 0 {
324 return Ok(DtxMessage {
325 identifier: h.identifier,
326 conversation_idx: h.conv_idx,
327 channel_code: h.channel_code,
328 expects_reply: h.expects_reply,
329 payload: DtxPayload::Empty,
330 });
331 }
332
333 read_dtx_body(reader, &h, &[]).await
334}
335
336fn decode_payload(msg_type: u32, payload: Bytes, aux: Option<Bytes>) -> DtxPayload {
337 tracing::trace!(
338 "decode_payload: msg_type={msg_type} payload_len={} aux={}",
339 payload.len(),
340 aux.is_some()
341 );
342 match msg_type {
343 MSG_OK => DtxPayload::Empty,
344 MSG_METHOD_INVOCATION => {
345 let mut args = aux
346 .map(super::primitive::decode_auxiliary)
347 .unwrap_or_default();
348 let selector = if payload.is_empty() {
349 String::new()
350 } else {
351 match crate::proto::nskeyedarchiver::unarchive(&payload)
352 .ok()
353 .and_then(|v| v.as_str().map(String::from))
354 {
355 Some(selector) => selector,
356 None => {
357 tracing::debug!(
358 "decode_payload: method invocation payload decode failed, preserving {} raw bytes",
359 payload.len()
360 );
361 args.insert(0, NSObject::Data(payload));
362 String::new()
363 }
364 }
365 };
366 DtxPayload::MethodInvocation { selector, args }
367 }
368 MSG_RESPONSE | MSG_ERROR => {
369 if payload.is_empty() {
370 DtxPayload::Response(NSObject::Null)
371 } else {
372 let obj = crate::proto::nskeyedarchiver::unarchive(&payload)
373 .map(archive_to_ns)
374 .unwrap_or(NSObject::Data(payload));
375 DtxPayload::Response(obj)
376 }
377 }
378 MSG_BARRIER => DtxPayload::Empty,
379 MSG_UNKNOWN_TYPE_ONE => match aux {
380 Some(aux) => DtxPayload::RawWithAux {
381 payload,
382 aux: super::primitive::decode_auxiliary(aux),
383 },
384 None => DtxPayload::Raw(payload),
385 },
386 _ => {
387 if payload.is_empty() {
388 DtxPayload::Empty
389 } else {
390 DtxPayload::Raw(payload)
391 }
392 }
393 }
394}
395
396fn archive_to_ns(v: crate::proto::nskeyedarchiver::ArchiveValue) -> NSObject {
397 use crate::proto::nskeyedarchiver::ArchiveValue;
398 match v {
399 ArchiveValue::Null => NSObject::Null,
400 ArchiveValue::Bool(b) => NSObject::Bool(b),
401 ArchiveValue::Int(n) => NSObject::Int(n),
402 ArchiveValue::Float(f) => NSObject::Double(f),
403 ArchiveValue::String(s) => NSObject::String(s),
404 ArchiveValue::Data(d) => NSObject::Data(d),
405 ArchiveValue::Array(a) => NSObject::Array(a.into_iter().map(archive_to_ns).collect()),
406 ArchiveValue::Dict(d) => {
407 NSObject::Dict(d.into_iter().map(|(k, v)| (k, archive_to_ns(v))).collect())
408 }
409 ArchiveValue::Unknown(s) => NSObject::String(format!("<{s}>")),
410 }
411}
412
413struct FragmentAccum {
417 header: DtxHeader,
419 fragments: Vec<Option<Vec<u8>>>,
421 remaining: u16,
423}
424
425pub struct DtxConnection<S> {
429 stream: S,
430 identifier: u32,
436 channel_counter: i32,
437 pending_replies: HashMap<u32, DtxMessage>,
439 outstanding_reply_ids: HashSet<u32>,
441 queued_messages: VecDeque<DtxMessage>,
443 fragments: HashMap<u32, FragmentAccum>,
445}
446
447impl<S: AsyncRead + AsyncWrite + Unpin + Send> DtxConnection<S> {
448 pub fn new(stream: S) -> Self {
449 Self {
451 stream,
452 identifier: 5,
453 channel_counter: 1,
454 pending_replies: HashMap::new(),
455 outstanding_reply_ids: HashSet::new(),
456 queued_messages: VecDeque::new(),
457 fragments: HashMap::new(),
458 }
459 }
460
461 fn next_id(&mut self) -> u32 {
462 let id = self.identifier;
463 self.identifier += 1;
464 id
465 }
466
467 fn next_channel_code(&mut self) -> i32 {
468 let code = self.channel_counter;
469 self.channel_counter += 1;
470 code
471 }
472
473 pub async fn send_raw(&mut self, data: &[u8]) -> Result<(), DtxError> {
474 self.stream.write_all(data).await?;
475 self.stream.flush().await?;
476 Ok(())
477 }
478
479 pub async fn send_ack(&mut self, msg: &DtxMessage) -> Result<(), DtxError> {
480 self.send_raw(&encode_ack(msg)).await
481 }
482
483 fn buffer_reply(&mut self, msg: DtxMessage) {
484 if let Some(previous) = self.pending_replies.insert(msg.identifier, msg.clone()) {
485 tracing::trace!(
486 "buffer_reply: replacing pending reply id={} old_conv={} new_conv={}",
487 previous.identifier,
488 previous.conversation_idx,
489 msg.conversation_idx
490 );
491 }
492 }
493
494 fn is_reply_message(&self, msg: &DtxMessage) -> bool {
495 msg.conversation_idx > 0 && self.outstanding_reply_ids.contains(&msg.identifier)
496 }
497
498 async fn recv_from_stream(&mut self) -> Result<DtxMessage, DtxError> {
499 loop {
500 let h = read_dtx_header(&mut self.stream).await?;
501 tracing::trace!(
502 "recv: frag_idx={} frag_cnt={} msg_len={} id={}",
503 h.frag_idx,
504 h.frag_cnt,
505 h.msg_len,
506 h.identifier
507 );
508
509 if h.frag_cnt <= 1 {
510 if h.msg_len == 0 {
512 return Ok(DtxMessage {
513 identifier: h.identifier,
514 conversation_idx: h.conv_idx,
515 channel_code: normalize_incoming_channel_code(h.channel_code, h.conv_idx),
516 expects_reply: h.expects_reply,
517 payload: DtxPayload::Empty,
518 });
519 }
520 let mut msg = read_dtx_body(&mut self.stream, &h, &[]).await?;
521 msg.channel_code =
522 normalize_incoming_channel_code(msg.channel_code, msg.conversation_idx);
523 return Ok(msg);
524 }
525
526 if h.frag_idx == 0 {
527 if self.fragments.contains_key(&h.identifier) {
529 return Err(DtxError::Protocol(format!(
530 "duplicate first fragment for id={}",
531 h.identifier
532 )));
533 }
534 self.fragments.insert(
535 h.identifier,
536 FragmentAccum {
537 fragments: vec![None; (h.frag_cnt - 1) as usize],
538 remaining: h.frag_cnt - 1,
539 header: h,
540 },
541 );
542 continue;
543 }
544
545 let mut frag_body = vec![0u8; h.msg_len];
547 self.stream.read_exact(&mut frag_body).await?;
548
549 let id = h.identifier;
550 if let Some(accum) = self.fragments.get_mut(&id) {
551 if h.frag_cnt != accum.header.frag_cnt {
552 return Err(DtxError::Protocol(format!(
553 "fragment count mismatch for id={id}: got={} expected={}",
554 h.frag_cnt, accum.header.frag_cnt
555 )));
556 }
557 if h.conv_idx != accum.header.conv_idx
558 || h.channel_code != accum.header.channel_code
559 || h.expects_reply != accum.header.expects_reply
560 {
561 return Err(DtxError::Protocol(format!(
562 "fragment metadata mismatch for id={id}"
563 )));
564 }
565 let slot_idx = h
566 .frag_idx
567 .checked_sub(1)
568 .map(|idx| idx as usize)
569 .ok_or_else(|| {
570 DtxError::Protocol(format!(
571 "invalid fragment index {} for id={id}",
572 h.frag_idx
573 ))
574 })?;
575 let slot = accum.fragments.get_mut(slot_idx).ok_or_else(|| {
576 DtxError::Protocol(format!(
577 "fragment index {} out of range for id={id}",
578 h.frag_idx
579 ))
580 })?;
581 if slot.is_some() {
582 return Err(DtxError::Protocol(format!(
583 "duplicate fragment {} for id={id}",
584 h.frag_idx
585 )));
586 }
587 *slot = Some(frag_body);
588 accum.remaining -= 1;
589 if accum.remaining == 0 {
590 let accum = self.fragments.remove(&id).ok_or_else(|| {
591 DtxError::Protocol(format!("missing fragment accumulator for id={id}"))
592 })?;
593 let mut body = Vec::with_capacity(accum.header.msg_len);
594 for (index, fragment) in accum.fragments.into_iter().enumerate() {
595 let fragment = fragment.ok_or_else(|| {
596 DtxError::Protocol(format!(
597 "missing fragment {} for id={id}",
598 index + 1
599 ))
600 })?;
601 body.extend_from_slice(&fragment);
602 }
603 if body.len() != accum.header.msg_len {
604 return Err(DtxError::Protocol(format!(
605 "fragmented body size mismatch for id={id}: assembled={} expected={}",
606 body.len(),
607 accum.header.msg_len
608 )));
609 }
610 let mut msg = read_dtx_body(&mut self.stream, &accum.header, &body).await?;
611 msg.channel_code =
612 normalize_incoming_channel_code(msg.channel_code, msg.conversation_idx);
613 return Ok(msg);
614 }
615 } else {
616 return Err(DtxError::Protocol(format!(
617 "fragment id={id} frag_idx={} without first fragment",
618 h.frag_idx
619 )));
620 }
621 }
622 }
623
624 async fn wait_for_reply(&mut self, id: u32) -> Result<DtxMessage, DtxError> {
625 if let Some(msg) = self.pending_replies.remove(&id) {
626 self.outstanding_reply_ids.remove(&id);
627 return Ok(msg);
628 }
629
630 loop {
631 let msg = self.recv_from_stream().await?;
632 tracing::trace!(
633 "wait_for_reply: target_id={} recv id={} conv_idx={} ch={} expects_reply={}",
634 id,
635 msg.identifier,
636 msg.conversation_idx,
637 msg.channel_code,
638 msg.expects_reply
639 );
640
641 if self.is_reply_message(&msg) {
642 if msg.identifier == id {
643 self.outstanding_reply_ids.remove(&id);
644 return Ok(msg);
645 }
646 self.buffer_reply(msg);
647 continue;
648 }
649
650 if msg.expects_reply {
651 self.send_ack(&msg).await?;
652 }
653 self.queued_messages.push_back(msg);
654 }
655 }
656
657 pub async fn recv(&mut self) -> Result<DtxMessage, DtxError> {
659 if let Some(msg) = self.queued_messages.pop_front() {
660 return Ok(msg);
661 }
662
663 loop {
664 let msg = self.recv_from_stream().await?;
665 if self.is_reply_message(&msg) {
666 self.buffer_reply(msg);
667 continue;
668 }
669 return Ok(msg);
670 }
671 }
672
673 pub async fn request_channel(&mut self, service_name: &str) -> Result<i32, DtxError> {
676 let channel_code = self.next_channel_code();
677 let id = self.next_id();
678
679 let selector =
680 nskeyedarchiver_encode::archive_string("_requestChannelWithCode:identifier:");
681 let arg_name = nskeyedarchiver_encode::archive_string(service_name);
682
683 let aux = encode_primitive_dict(&[
685 PrimArg::Int32(channel_code),
686 PrimArg::Bytes(Bytes::from(arg_name)),
687 ]);
688
689 let frame = encode_dtx(id, 0, 0, true, MSG_METHOD_INVOCATION, &selector, &aux);
690 self.send_raw(&frame).await?;
691 self.outstanding_reply_ids.insert(id);
692
693 let msg = self.wait_for_reply(id).await?;
695 tracing::debug!(
696 "request_channel recv: id={} conv_idx={} ch={} expects_reply={}",
697 msg.identifier,
698 msg.conversation_idx,
699 msg.channel_code,
700 msg.expects_reply
701 );
702 Ok(channel_code)
703 }
704
705 pub async fn method_call(
707 &mut self,
708 channel_code: i32,
709 selector: &str,
710 args: &[PrimArg],
711 ) -> Result<DtxMessage, DtxError> {
712 let id = self.next_id();
713 let sel_bytes = nskeyedarchiver_encode::archive_string(selector);
714 let aux = if args.is_empty() {
715 Bytes::new()
716 } else {
717 encode_primitive_dict(args)
718 };
719 let frame = encode_dtx(
720 id,
721 0,
722 channel_code,
723 true,
724 MSG_METHOD_INVOCATION,
725 &sel_bytes,
726 &aux,
727 );
728 self.send_raw(&frame).await?;
729 self.outstanding_reply_ids.insert(id);
730 tracing::debug!("method_call '{selector}' id={id} ch={channel_code}");
731
732 let msg = self.wait_for_reply(id).await?;
733 tracing::debug!(
734 "method_call recv: id={} conv_idx={} ch={}",
735 msg.identifier,
736 msg.conversation_idx,
737 msg.channel_code
738 );
739 Ok(msg)
740 }
741
742 pub async fn method_call_async(
744 &mut self,
745 channel_code: i32,
746 selector: &str,
747 args: &[PrimArg],
748 ) -> Result<(), DtxError> {
749 let id = self.next_id();
750 let sel_bytes = nskeyedarchiver_encode::archive_string(selector);
751 let aux = if args.is_empty() {
752 Bytes::new()
753 } else {
754 encode_primitive_dict(args)
755 };
756 let frame = encode_dtx(
757 id,
758 0,
759 channel_code,
760 false,
761 MSG_METHOD_INVOCATION,
762 &sel_bytes,
763 &aux,
764 );
765 self.send_raw(&frame).await
766 }
767}
768
769fn normalize_incoming_channel_code(channel_code: i32, conversation_idx: u32) -> i32 {
770 if conversation_idx % 2 == 0 {
771 -channel_code
772 } else {
773 channel_code
774 }
775}
776
777#[cfg(test)]
778mod tests {
779 use bytes::BufMut;
780
781 use super::*;
782
783 #[test]
784 fn test_encode_dtx_layout() {
785 let sel = nskeyedarchiver_encode::archive_string("test");
786 let frame = encode_dtx(1, 0, 1, true, MSG_METHOD_INVOCATION, &sel, &[]);
787 assert_eq!(
788 u32::from_be_bytes(frame[0..4].try_into().unwrap()),
789 DTX_MAGIC
790 );
791 assert_eq!(u32::from_le_bytes(frame[4..8].try_into().unwrap()), 32);
792 assert_eq!(u32::from_le_bytes(frame[28..32].try_into().unwrap()), 1); assert_eq!(
794 u32::from_le_bytes(frame[32..36].try_into().unwrap()),
795 MSG_METHOD_INVOCATION
796 );
797 }
798
799 #[test]
800 fn test_encode_ack_length() {
801 let msg = DtxMessage {
802 identifier: 5,
803 conversation_idx: 0,
804 channel_code: 1,
805 expects_reply: true,
806 payload: DtxPayload::Empty,
807 };
808 let ack = encode_ack(&msg);
809 assert_eq!(ack.len(), 48);
810 assert_eq!(u32::from_le_bytes(ack[32..36].try_into().unwrap()), MSG_OK);
811 }
812
813 #[tokio::test]
814 async fn test_dtx_encode_decode_roundtrip() {
815 let sel = nskeyedarchiver_encode::archive_string("setConfig:");
816 let frame = encode_dtx(7, 0, 2, true, MSG_METHOD_INVOCATION, &sel, &[]);
817 let mut cur = std::io::Cursor::new(frame);
818 let msg = read_dtx_frame(&mut cur).await.unwrap();
819 assert_eq!(msg.identifier, 7);
820 assert_eq!(msg.channel_code, 2);
821 assert!(msg.expects_reply);
822 if let DtxPayload::MethodInvocation { selector, .. } = &msg.payload {
824 assert_eq!(selector, "setConfig:");
825 } else {
826 panic!("expected MethodInvocation");
827 }
828 }
829
830 #[tokio::test]
831 async fn test_data_frame_keeps_raw_payload() {
832 let payload = b"trace-binary-payload";
833 let frame = encode_dtx(11, 0, 4, false, MSG_UNKNOWN_TYPE_ONE, payload, &[]);
834 let mut cur = std::io::Cursor::new(frame);
835 let msg = read_dtx_frame(&mut cur).await.unwrap();
836 match msg.payload {
837 DtxPayload::Raw(bytes) => assert_eq!(bytes.as_ref(), payload),
838 other => panic!("expected raw payload, got {other:?}"),
839 }
840 }
841
842 #[tokio::test]
843 async fn test_data_frame_preserves_auxiliary_arguments() {
844 let payload = b"trace-binary-payload";
845 let aux =
846 encode_primitive_dict(&[PrimArg::Bytes(Bytes::from_static(b"kperf-aux-payload"))]);
847 let frame = encode_dtx(13, 0, 4, false, MSG_UNKNOWN_TYPE_ONE, payload, &aux);
848 let mut cur = std::io::Cursor::new(frame);
849 let msg = read_dtx_frame(&mut cur).await.unwrap();
850
851 match msg.payload {
852 DtxPayload::RawWithAux { payload: body, aux } => {
853 assert_eq!(body.as_ref(), payload);
854 assert!(matches!(
855 aux.first(),
856 Some(NSObject::Data(bytes)) if bytes.as_ref() == b"kperf-aux-payload"
857 ));
858 }
859 other => panic!("expected raw payload with aux, got {other:?}"),
860 }
861 }
862
863 #[tokio::test]
864 async fn test_method_invocation_preserves_raw_payload_when_selector_decode_fails() {
865 let payload = b"not-a-selector";
866 let frame = encode_dtx(12, 0, 4, false, MSG_METHOD_INVOCATION, payload, &[]);
867 let mut cur = std::io::Cursor::new(frame);
868 let msg = read_dtx_frame(&mut cur).await.unwrap();
869
870 match msg.payload {
871 DtxPayload::MethodInvocation { selector, args } => {
872 assert!(selector.is_empty());
873 assert!(
874 matches!(args.first(), Some(NSObject::Data(bytes)) if bytes.as_ref() == payload)
875 );
876 }
877 other => panic!("expected method invocation, got {other:?}"),
878 }
879 }
880
881 #[tokio::test]
882 async fn test_method_call_buffers_unrelated_notifications() {
883 let (client, mut server) = tokio::io::duplex(4096);
884 let mut conn = DtxConnection::new(client);
885
886 let call = tokio::spawn(async move {
887 conn.method_call(2, "startSampling", &[])
888 .await
889 .map(|reply| (conn, reply))
890 });
891
892 let outbound = read_dtx_frame(&mut server).await.unwrap();
893 assert_eq!(outbound.identifier, 5);
894 assert!(outbound.expects_reply);
895
896 let notify_selector = nskeyedarchiver_encode::archive_string("note:");
897 let notify = encode_dtx(77, 0, 1, true, MSG_METHOD_INVOCATION, ¬ify_selector, &[]);
898 server.write_all(¬ify).await.unwrap();
899
900 let ack = read_dtx_frame(&mut server).await.unwrap();
901 assert_eq!(ack.identifier, 77);
902 assert_eq!(ack.conversation_idx, 1);
903
904 let reply = encode_dtx(5, 1, 2, false, MSG_RESPONSE, &[], &[]);
905 server.write_all(&reply).await.unwrap();
906
907 let (mut conn, reply) = call.await.unwrap().unwrap();
908 assert_eq!(reply.identifier, 5);
909 assert_eq!(reply.conversation_idx, 1);
910
911 let queued = conn.recv().await.unwrap();
912 assert_eq!(queued.identifier, 77);
913 assert_eq!(queued.channel_code, -1);
914 assert!(queued.expects_reply);
915 }
916
917 #[tokio::test]
918 async fn test_recv_normalizes_even_conversation_channel_codes() {
919 let (client, mut server) = tokio::io::duplex(256);
920 let mut conn = DtxConnection::new(client);
921
922 let recv_task = tokio::spawn(async move { conn.recv().await });
923
924 let payload = b"trace-binary-payload";
925 let frame = encode_dtx(42, 0, -2, false, MSG_UNKNOWN_TYPE_ONE, payload, &[]);
926 server.write_all(&frame).await.unwrap();
927
928 let msg = recv_task.await.unwrap().unwrap();
929 assert_eq!(msg.identifier, 42);
930 assert_eq!(msg.conversation_idx, 0);
931 assert_eq!(msg.channel_code, 2);
932 match msg.payload {
933 DtxPayload::Raw(bytes) => assert_eq!(bytes.as_ref(), payload),
934 other => panic!("expected raw payload, got {other:?}"),
935 }
936 }
937
938 #[tokio::test]
939 async fn test_wait_for_reply_returns_buffered_reply_immediately() {
940 let (client, _server) = tokio::io::duplex(64);
941 let mut conn = DtxConnection::new(client);
942
943 conn.buffer_reply(DtxMessage {
944 identifier: 9,
945 conversation_idx: 1,
946 channel_code: 3,
947 expects_reply: false,
948 payload: DtxPayload::Empty,
949 });
950
951 let reply = conn.wait_for_reply(9).await.unwrap();
952 assert_eq!(reply.identifier, 9);
953 assert_eq!(reply.conversation_idx, 1);
954 assert_eq!(reply.channel_code, 3);
955 }
956
957 #[tokio::test]
958 async fn test_recv_treats_unsolicited_conversation_message_as_live_event() {
959 let (client, mut server) = tokio::io::duplex(256);
960 let mut conn = DtxConnection::new(client);
961
962 let recv_task = tokio::spawn(async move { conn.recv().await });
963
964 let payload = b"trace-binary-payload";
965 let frame = encode_dtx(42, 1, 2, false, MSG_UNKNOWN_TYPE_ONE, payload, &[]);
966 server.write_all(&frame).await.unwrap();
967
968 let msg = recv_task.await.unwrap().unwrap();
969 assert_eq!(msg.identifier, 42);
970 assert_eq!(msg.conversation_idx, 1);
971 assert_eq!(msg.channel_code, 2);
972 match msg.payload {
973 DtxPayload::Raw(bytes) => assert_eq!(bytes.as_ref(), payload),
974 other => panic!("expected raw payload, got {other:?}"),
975 }
976 }
977
978 #[tokio::test]
979 async fn test_read_dtx_frame_skips_extended_header_bytes() {
980 let payload = b"trace-binary-payload";
981 let mut frame = encode_dtx(21, 0, 4, false, MSG_UNKNOWN_TYPE_ONE, payload, &[]).to_vec();
982
983 frame[4..8].copy_from_slice(&36u32.to_le_bytes());
984 frame.splice(32..32, [0xAA, 0xBB, 0xCC, 0xDD]);
985
986 let mut cur = std::io::Cursor::new(frame);
987 let msg = read_dtx_frame(&mut cur)
988 .await
989 .expect("extended headers should be skipped before parsing payload");
990
991 match msg.payload {
992 DtxPayload::Raw(bytes) => assert_eq!(bytes.as_ref(), payload),
993 other => panic!("expected raw payload, got {other:?}"),
994 }
995 }
996
997 #[tokio::test]
998 async fn test_ok_reply_decodes_as_empty_payload() {
999 let frame = encode_dtx(23, 1, 4, false, MSG_OK, &[], &[]);
1000 let mut cur = std::io::Cursor::new(frame);
1001 let msg = read_dtx_frame(&mut cur).await.unwrap();
1002 assert!(matches!(msg.payload, DtxPayload::Empty));
1003 }
1004
1005 #[tokio::test]
1006 async fn test_error_reply_decodes_like_response_object() {
1007 let payload = nskeyedarchiver_encode::archive_string("selector failed");
1008 let frame = encode_dtx(24, 1, 4, false, MSG_ERROR, &payload, &[]);
1009 let mut cur = std::io::Cursor::new(frame);
1010 let msg = read_dtx_frame(&mut cur).await.unwrap();
1011 assert!(matches!(
1012 msg.payload,
1013 DtxPayload::Response(NSObject::String(ref value)) if value == "selector failed"
1014 ));
1015 }
1016
1017 fn encode_fragment(
1018 identifier: u32,
1019 frag_idx: u16,
1020 frag_cnt: u16,
1021 channel_code: i32,
1022 expects_reply: bool,
1023 msg_len: usize,
1024 body: &[u8],
1025 ) -> Bytes {
1026 let mut out = BytesMut::with_capacity(32 + body.len());
1027 out.put_u32(DTX_MAGIC);
1028 out.put_u32_le(32);
1029 out.put_u16_le(frag_idx);
1030 out.put_u16_le(frag_cnt);
1031 out.put_u32_le(msg_len as u32);
1032 out.put_u32_le(identifier);
1033 out.put_u32_le(0);
1034 out.put_u32_le(channel_code as u32);
1035 out.put_u32_le(if expects_reply { 1 } else { 0 });
1036 out.extend_from_slice(body);
1037 out.freeze()
1038 }
1039
1040 #[tokio::test]
1041 async fn test_recv_reassembles_out_of_order_fragments_by_index() {
1042 let payload = b"fragmented-trace-payload";
1043 let mut body = BytesMut::with_capacity(16 + payload.len());
1044 body.put_u32_le(MSG_UNKNOWN_TYPE_ONE);
1045 body.put_u32_le(0);
1046 body.put_u32_le(payload.len() as u32);
1047 body.put_u32_le(0);
1048 body.extend_from_slice(payload);
1049 let body = body.freeze();
1050
1051 let split_at = 10;
1052 let first = encode_fragment(31, 0, 3, 4, false, body.len(), &[]);
1053 let second = encode_fragment(31, 1, 3, 4, false, split_at, &body[..split_at]);
1054 let third = encode_fragment(31, 2, 3, 4, false, body.len() - split_at, &body[split_at..]);
1055
1056 let (client, mut server) = tokio::io::duplex(512);
1057 let mut conn = DtxConnection::new(client);
1058
1059 let recv_task = tokio::spawn(async move { conn.recv_from_stream().await });
1060 server.write_all(&first).await.unwrap();
1061 server.write_all(&third).await.unwrap();
1062 server.write_all(&second).await.unwrap();
1063
1064 let msg = recv_task
1065 .await
1066 .unwrap()
1067 .expect("fragment order should not affect reassembly");
1068
1069 match msg.payload {
1070 DtxPayload::Raw(bytes) => assert_eq!(bytes.as_ref(), payload),
1071 other => panic!("expected raw payload, got {other:?}"),
1072 }
1073 }
1074
1075 #[tokio::test]
1076 async fn test_recv_rejects_duplicate_first_fragment() {
1077 let payload = b"fragmented-trace-payload";
1078 let mut body = BytesMut::with_capacity(16 + payload.len());
1079 body.put_u32_le(MSG_UNKNOWN_TYPE_ONE);
1080 body.put_u32_le(0);
1081 body.put_u32_le(payload.len() as u32);
1082 body.put_u32_le(0);
1083 body.extend_from_slice(payload);
1084 let body = body.freeze();
1085
1086 let first = encode_fragment(41, 0, 2, 4, false, body.len(), &[]);
1087 let duplicate_first = encode_fragment(41, 0, 2, 4, false, body.len(), &[]);
1088
1089 let (client, mut server) = tokio::io::duplex(512);
1090 let mut conn = DtxConnection::new(client);
1091
1092 let recv_task = tokio::spawn(async move { conn.recv_from_stream().await });
1093 server.write_all(&first).await.unwrap();
1094 server.write_all(&duplicate_first).await.unwrap();
1095
1096 let err = recv_task.await.unwrap().unwrap_err();
1097 assert!(matches!(
1098 err,
1099 DtxError::Protocol(message) if message.contains("duplicate first fragment")
1100 ));
1101 }
1102
1103 #[tokio::test]
1104 async fn test_recv_rejects_fragment_without_first_fragment() {
1105 let payload = b"fragmented-trace-payload";
1106 let mut body = BytesMut::with_capacity(16 + payload.len());
1107 body.put_u32_le(MSG_UNKNOWN_TYPE_ONE);
1108 body.put_u32_le(0);
1109 body.put_u32_le(payload.len() as u32);
1110 body.put_u32_le(0);
1111 body.extend_from_slice(payload);
1112 let body = body.freeze();
1113
1114 let stray = encode_fragment(43, 1, 2, 4, false, body.len(), &body);
1115
1116 let (client, mut server) = tokio::io::duplex(512);
1117 let mut conn = DtxConnection::new(client);
1118
1119 let recv_task = tokio::spawn(async move { conn.recv_from_stream().await });
1120 server.write_all(&stray).await.unwrap();
1121
1122 let err = recv_task.await.unwrap().unwrap_err();
1123 assert!(matches!(
1124 err,
1125 DtxError::Protocol(message) if message.contains("without first fragment")
1126 ));
1127 }
1128
1129 #[tokio::test]
1130 async fn test_recv_rejects_fragment_metadata_mismatch() {
1131 let payload = b"fragmented-trace-payload";
1132 let mut body = BytesMut::with_capacity(16 + payload.len());
1133 body.put_u32_le(MSG_UNKNOWN_TYPE_ONE);
1134 body.put_u32_le(0);
1135 body.put_u32_le(payload.len() as u32);
1136 body.put_u32_le(0);
1137 body.extend_from_slice(payload);
1138 let body = body.freeze();
1139
1140 let split_at = 10;
1141 let first = encode_fragment(45, 0, 3, 4, false, body.len(), &[]);
1142 let bad_second = encode_fragment(45, 1, 3, 5, false, split_at, &body[..split_at]);
1143
1144 let (client, mut server) = tokio::io::duplex(512);
1145 let mut conn = DtxConnection::new(client);
1146
1147 let recv_task = tokio::spawn(async move { conn.recv_from_stream().await });
1148 server.write_all(&first).await.unwrap();
1149 server.write_all(&bad_second).await.unwrap();
1150
1151 let err = recv_task.await.unwrap().unwrap_err();
1152 assert!(matches!(
1153 err,
1154 DtxError::Protocol(message) if message.contains("fragment metadata mismatch")
1155 ));
1156 }
1157
1158 #[tokio::test]
1159 async fn test_recv_rejects_excessive_fragment_count_before_allocation() {
1160 let first = encode_fragment(72, 0, MAX_DTX_FRAGMENTS + 1, 4, false, 16, &[]);
1161 let mut cursor = std::io::Cursor::new(first);
1162
1163 let err = match read_dtx_header(&mut cursor).await {
1164 Ok(_) => panic!("excessive fragment count should be rejected"),
1165 Err(err) => err,
1166 };
1167
1168 assert!(matches!(
1169 err,
1170 DtxError::Protocol(message) if message.contains("too many fragments")
1171 ));
1172 }
1173}