Skip to main content

ios_core/services/dtx/
codec.rs

1//! DTX message encoder/decoder and connection manager.
2//!
3//! Reference: go-ios/ios/dtx_codec/decoder.go + encoder.go + connection.go
4//!
5//! Key wire format details (from encoder.go):
6//! - Header: 32 bytes (magic BE + rest LE)
7//! - Payload header: 16 bytes LE
8//! - Aux header: 16 bytes LE (buffer_size=496, unknown=0, aux_size, unknown=0)
9//! - Message type 0 = OK/Ack, type 2 = MethodInvocation, type 3 = Response/Object,
10//!   type 4 = Error, type 5 = Barrier
11
12use std::collections::{HashMap, HashSet, VecDeque};
13
14use crate::proto::dtx::DTX_MAGIC;
15use crate::proto::nskeyedarchiver_encode;
16use bytes::{BufMut, Bytes, BytesMut};
17use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
18
19use super::primitive_enc::{encode_primitive_dict, PrimArg};
20use super::types::{DtxMessage, DtxPayload, NSObject};
21
22// ── DTX message type constants ────────────────────────────────────────────────
23
24const MAX_DTX_MESSAGE_SIZE: usize = 128 * 1024 * 1024;
25const MAX_DTX_FRAGMENTS: u16 = 1024;
26
27const MSG_OK: u32 = 0;
28const MSG_UNKNOWN_TYPE_ONE: u32 = 1; // sysmontap data messages
29const MSG_METHOD_INVOCATION: u32 = 2;
30const MSG_RESPONSE: u32 = 3;
31const MSG_ERROR: u32 = 4;
32const MSG_BARRIER: u32 = 5;
33const _MSG_LZ4_COMPRESSED: u32 = 0x0707; // LZ4 compressed payload
34
35// ── Error ─────────────────────────────────────────────────────────────────────
36
37#[derive(Debug, thiserror::Error)]
38pub enum DtxError {
39    #[error("IO error: {0}")]
40    Io(#[from] std::io::Error),
41    #[error("bad magic: 0x{0:08X}")]
42    BadMagic(u32),
43    #[error("protocol error: {0}")]
44    Protocol(String),
45}
46
47// ── DTX Encoder (matches go-ios encoder.go exactly) ──────────────────────────
48
49/// Encode a full DTX message to bytes.
50pub fn encode_dtx(
51    identifier: u32,
52    conv_idx: u32,
53    channel_code: i32,
54    expects_reply: bool,
55    msg_type: u32,
56    payload: &[u8],
57    aux_bytes: &[u8],
58) -> Bytes {
59    let aux_len = aux_bytes.len();
60    let payload_len = payload.len();
61
62    // aux_length_with_header = aux_len + 16 (header) if aux_len > 0
63    let aux_with_hdr = if aux_len > 0 { aux_len + 16 } else { 0 };
64    let total_payload = aux_with_hdr + payload_len;
65    let msg_len = 16 + aux_with_hdr + payload_len;
66
67    let mut out = BytesMut::with_capacity(32 + msg_len);
68
69    // Header (32 bytes)
70    out.put_u32(DTX_MAGIC); // magic (BE)
71    out.put_u32_le(32); // header_length
72    out.put_u16_le(0); // fragment_index
73    out.put_u16_le(1); // fragment_count
74    out.put_u32_le(msg_len as u32); // message_length
75    out.put_u32_le(identifier); // identifier
76    out.put_u32_le(conv_idx); // conversation_index
77    out.put_u32_le(channel_code as u32); // channel_code
78    out.put_u32_le(if expects_reply { 1 } else { 0 }); // expects_reply
79
80    // Payload header (16 bytes)
81    out.put_u32_le(msg_type);
82    out.put_u32_le(aux_with_hdr as u32);
83    out.put_u32_le(total_payload as u32);
84    out.put_u32_le(0); // flags
85
86    if aux_len > 0 {
87        // Aux header (16 bytes): buffer_size=496 as per go-ios writeAuxHeader
88        out.put_u32_le(496);
89        out.put_u32_le(0);
90        out.put_u32_le(aux_len as u32);
91        out.put_u32_le(0);
92        out.put_slice(aux_bytes);
93    }
94    out.put_slice(payload);
95
96    out.freeze()
97}
98
99/// Encode a DTX ack message (48 bytes total).
100pub fn encode_ack(msg: &DtxMessage) -> Bytes {
101    let mut out = BytesMut::with_capacity(48);
102    out.put_u32(DTX_MAGIC);
103    out.put_u32_le(32);
104    out.put_u16_le(0);
105    out.put_u16_le(1);
106    out.put_u32_le(16); // message_length = 16 (payload header only)
107    out.put_u32_le(msg.identifier);
108    out.put_u32_le(msg.conversation_idx + 1);
109    out.put_u32_le(msg.channel_code as u32);
110    out.put_u32_le(0); // expects_reply = false
111                       // Payload header: type=0 (OK/Ack)
112    out.put_u32_le(MSG_OK);
113    out.put_u32_le(0);
114    out.put_u32_le(0);
115    out.put_u32_le(0);
116    out.freeze()
117}
118
119// ── DTX frame reader ──────────────────────────────────────────────────────────
120
121/// Raw DTX header fields parsed from the 32-byte wire header.
122struct DtxHeader {
123    header_len: usize,
124    frag_idx: u16,
125    frag_cnt: u16,
126    msg_len: usize,
127    identifier: u32,
128    conv_idx: u32,
129    channel_code: i32,
130    expects_reply: bool,
131}
132
133// Safety: hdr is &[u8; 32], so all fixed-size slice conversions are infallible.
134fn parse_dtx_header(hdr: &[u8; 32]) -> Result<DtxHeader, DtxError> {
135    let magic = u32::from_be_bytes(hdr[0..4].try_into().unwrap());
136    if magic != DTX_MAGIC {
137        return Err(DtxError::BadMagic(magic));
138    }
139    let header_len = u32::from_le_bytes(hdr[4..8].try_into().unwrap()) as usize;
140    if header_len < 32 {
141        return Err(DtxError::Protocol(format!(
142            "invalid DTX header length: {header_len}"
143        )));
144    }
145    let frag_idx = u16::from_le_bytes(hdr[8..10].try_into().unwrap());
146    let frag_cnt = u16::from_le_bytes(hdr[10..12].try_into().unwrap());
147    if frag_cnt == 0 {
148        return Err(DtxError::Protocol("invalid DTX fragment count: 0".into()));
149    }
150    if frag_cnt > MAX_DTX_FRAGMENTS {
151        return Err(DtxError::Protocol(format!(
152            "DTX message has too many fragments: {frag_cnt} exceeds {MAX_DTX_FRAGMENTS}"
153        )));
154    }
155    if frag_idx >= frag_cnt {
156        return Err(DtxError::Protocol(format!(
157            "invalid DTX fragment index {frag_idx} for count {frag_cnt}"
158        )));
159    }
160    let msg_len = u32::from_le_bytes(hdr[12..16].try_into().unwrap()) as usize;
161    if msg_len > MAX_DTX_MESSAGE_SIZE {
162        return Err(DtxError::Protocol(format!(
163            "DTX message length {msg_len} exceeds max {MAX_DTX_MESSAGE_SIZE}"
164        )));
165    }
166    if frag_cnt > 1 && frag_idx == 0 && msg_len == 0 {
167        return Err(DtxError::Protocol(
168            "multi-fragment first header declares zero total size".into(),
169        ));
170    }
171    Ok(DtxHeader {
172        header_len,
173        frag_idx,
174        frag_cnt,
175        msg_len,
176        identifier: u32::from_le_bytes(hdr[16..20].try_into().unwrap()),
177        conv_idx: u32::from_le_bytes(hdr[20..24].try_into().unwrap()),
178        channel_code: i32::from_le_bytes(hdr[24..28].try_into().unwrap()),
179        expects_reply: u32::from_le_bytes(hdr[28..32].try_into().unwrap()) != 0,
180    })
181}
182
183async fn read_dtx_header<R: AsyncRead + Unpin>(reader: &mut R) -> Result<DtxHeader, DtxError> {
184    let mut hdr = [0u8; 32];
185    reader.read_exact(&mut hdr).await?;
186    let header = parse_dtx_header(&hdr)?;
187    if header.header_len > 32 {
188        let mut extra = vec![0u8; header.header_len - 32];
189        reader.read_exact(&mut extra).await?;
190    }
191    Ok(header)
192}
193
194fn decode_dtx_body_from_slice(h: &DtxHeader, body_slice: &[u8]) -> Result<DtxMessage, DtxError> {
195    if body_slice.len() < 16 {
196        return Ok(DtxMessage {
197            identifier: h.identifier,
198            conversation_idx: h.conv_idx,
199            channel_code: h.channel_code,
200            expects_reply: h.expects_reply,
201            payload: DtxPayload::Empty,
202        });
203    }
204
205    let ph = &body_slice[0..16];
206    let msg_type = u32::from_le_bytes(ph[0..4].try_into().unwrap());
207    let aux_len = u32::from_le_bytes(ph[4..8].try_into().unwrap()) as usize;
208    let total_pay = u32::from_le_bytes(ph[8..12].try_into().unwrap()) as usize;
209
210    if aux_len > total_pay {
211        return Err(DtxError::Protocol(format!(
212            "aux_len ({aux_len}) exceeds total_pay ({total_pay})"
213        )));
214    }
215    let pay_len = total_pay - aux_len;
216    let rest = &body_slice[16..];
217
218    let aux_data = if aux_len > 0 {
219        if rest.len() < 16 {
220            return Err(DtxError::Protocol("aux header truncated".into()));
221        }
222        let actual_aux = u32::from_le_bytes(rest[8..12].try_into().unwrap()) as usize;
223        if actual_aux > aux_len.saturating_sub(16) {
224            return Err(DtxError::Protocol(format!(
225                "auxiliary data size ({actual_aux}) exceeds available space ({})",
226                aux_len.saturating_sub(16)
227            )));
228        }
229        let aux_start = 16;
230        let aux_end = aux_start + actual_aux;
231        if rest.len() < aux_end {
232            return Err(DtxError::Protocol("aux data truncated".into()));
233        }
234        Some(Bytes::copy_from_slice(&rest[aux_start..aux_end]))
235    } else {
236        None
237    };
238
239    let pay_start = aux_len;
240    let pay_end = pay_start + pay_len;
241    let payload_bytes = if pay_len > 0 && rest.len() >= pay_end {
242        Bytes::copy_from_slice(&rest[pay_start..pay_end])
243    } else {
244        Bytes::new()
245    };
246
247    let payload = decode_payload(msg_type, payload_bytes, aux_data);
248    Ok(DtxMessage {
249        identifier: h.identifier,
250        conversation_idx: h.conv_idx,
251        channel_code: h.channel_code,
252        expects_reply: h.expects_reply,
253        payload,
254    })
255}
256
257pub fn decode_dtx_message_from_bytes(data: &[u8]) -> Result<Option<(DtxMessage, usize)>, DtxError> {
258    if data.len() < 32 {
259        return Ok(None);
260    }
261
262    let header_bytes: &[u8; 32] = data[..32]
263        .try_into()
264        .map_err(|_| DtxError::Protocol("DTX header truncated".into()))?;
265    let header = parse_dtx_header(header_bytes)?;
266    let total_len = header
267        .header_len
268        .checked_add(header.msg_len)
269        .ok_or_else(|| DtxError::Protocol("DTX frame length overflow".into()))?;
270    if data.len() < total_len {
271        return Ok(None);
272    }
273
274    let body = &data[header.header_len..total_len];
275    let message = decode_dtx_body_from_slice(&header, body)?;
276    Ok(Some((message, total_len)))
277}
278
279/// Read a single non-fragmented DTX message body (payload header + aux + payload).
280/// `msg_len` is the number of bytes after the 32-byte header.
281async fn read_dtx_body<R: AsyncRead + Unpin>(
282    reader: &mut R,
283    h: &DtxHeader,
284    body: &[u8], // pre-read body bytes (for reassembled fragments)
285) -> Result<DtxMessage, DtxError> {
286    // body may be pre-supplied (reassembled) or empty (read from stream)
287    let body_owned: Vec<u8>;
288    let body_slice: &[u8] = if body.is_empty() && h.msg_len > 0 {
289        body_owned = {
290            let mut b = vec![0u8; h.msg_len];
291            reader.read_exact(&mut b).await?;
292            b
293        };
294        &body_owned
295    } else {
296        body
297    };
298
299    decode_dtx_body_from_slice(h, body_slice)
300}
301
302pub async fn read_dtx_frame<R: AsyncRead + Unpin>(reader: &mut R) -> Result<DtxMessage, DtxError> {
303    let h = read_dtx_header(reader).await?;
304    tracing::trace!(
305        "read_dtx_frame: frag_idx={} frag_cnt={} msg_len={} id={}",
306        h.frag_idx,
307        h.frag_cnt,
308        h.msg_len,
309        h.identifier
310    );
311
312    // First fragment of multi-fragment message: no body, just a size announcement
313    if h.frag_cnt > 1 && h.frag_idx == 0 {
314        return Ok(DtxMessage {
315            identifier: h.identifier,
316            conversation_idx: h.conv_idx,
317            channel_code: h.channel_code,
318            expects_reply: h.expects_reply,
319            payload: DtxPayload::Empty,
320        });
321    }
322
323    if h.msg_len == 0 {
324        return Ok(DtxMessage {
325            identifier: h.identifier,
326            conversation_idx: h.conv_idx,
327            channel_code: h.channel_code,
328            expects_reply: h.expects_reply,
329            payload: DtxPayload::Empty,
330        });
331    }
332
333    read_dtx_body(reader, &h, &[]).await
334}
335
336fn decode_payload(msg_type: u32, payload: Bytes, aux: Option<Bytes>) -> DtxPayload {
337    tracing::trace!(
338        "decode_payload: msg_type={msg_type} payload_len={} aux={}",
339        payload.len(),
340        aux.is_some()
341    );
342    match msg_type {
343        MSG_OK => DtxPayload::Empty,
344        MSG_METHOD_INVOCATION => {
345            let mut args = aux
346                .map(super::primitive::decode_auxiliary)
347                .unwrap_or_default();
348            let selector = if payload.is_empty() {
349                String::new()
350            } else {
351                match crate::proto::nskeyedarchiver::unarchive(&payload)
352                    .ok()
353                    .and_then(|v| v.as_str().map(String::from))
354                {
355                    Some(selector) => selector,
356                    None => {
357                        tracing::debug!(
358                            "decode_payload: method invocation payload decode failed, preserving {} raw bytes",
359                            payload.len()
360                        );
361                        args.insert(0, NSObject::Data(payload));
362                        String::new()
363                    }
364                }
365            };
366            DtxPayload::MethodInvocation { selector, args }
367        }
368        MSG_RESPONSE | MSG_ERROR => {
369            if payload.is_empty() {
370                DtxPayload::Response(NSObject::Null)
371            } else {
372                let obj = crate::proto::nskeyedarchiver::unarchive(&payload)
373                    .map(archive_to_ns)
374                    .unwrap_or(NSObject::Data(payload));
375                DtxPayload::Response(obj)
376            }
377        }
378        MSG_BARRIER => DtxPayload::Empty,
379        MSG_UNKNOWN_TYPE_ONE => match aux {
380            Some(aux) => DtxPayload::RawWithAux {
381                payload,
382                aux: super::primitive::decode_auxiliary(aux),
383            },
384            None => DtxPayload::Raw(payload),
385        },
386        _ => {
387            if payload.is_empty() {
388                DtxPayload::Empty
389            } else {
390                DtxPayload::Raw(payload)
391            }
392        }
393    }
394}
395
396fn archive_to_ns(v: crate::proto::nskeyedarchiver::ArchiveValue) -> NSObject {
397    use crate::proto::nskeyedarchiver::ArchiveValue;
398    match v {
399        ArchiveValue::Null => NSObject::Null,
400        ArchiveValue::Bool(b) => NSObject::Bool(b),
401        ArchiveValue::Int(n) => NSObject::Int(n),
402        ArchiveValue::Float(f) => NSObject::Double(f),
403        ArchiveValue::String(s) => NSObject::String(s),
404        ArchiveValue::Data(d) => NSObject::Data(d),
405        ArchiveValue::Array(a) => NSObject::Array(a.into_iter().map(archive_to_ns).collect()),
406        ArchiveValue::Dict(d) => {
407            NSObject::Dict(d.into_iter().map(|(k, v)| (k, archive_to_ns(v))).collect())
408        }
409        ArchiveValue::Unknown(s) => NSObject::String(format!("<{s}>")),
410    }
411}
412
413// ── Fragment reassembly state ─────────────────────────────────────────────────
414
415/// In-progress multi-fragment message accumulator.
416struct FragmentAccum {
417    /// Header fields from the first fragment (index=0).
418    header: DtxHeader,
419    /// Body fragments keyed by fragment index - 1.
420    fragments: Vec<Option<Vec<u8>>>,
421    /// Number of body fragments still expected.
422    remaining: u16,
423}
424
425// ── DtxConnection ─────────────────────────────────────────────────────────────
426
427/// A managed DTX connection with channel multiplexing and method call support.
428pub struct DtxConnection<S> {
429    stream: S,
430    /// Connection-wide message identifier counter.
431    ///
432    /// We intentionally keep this global (instead of per-channel) to mirror
433    /// pymobiledevice3's reply-correlation model, where responses are matched
434    /// centrally by identifier and non-target traffic is buffered separately.
435    identifier: u32,
436    channel_counter: i32,
437    /// Replies buffered while another request is waiting on its own response.
438    pending_replies: HashMap<u32, DtxMessage>,
439    /// Identifiers for synchronous requests that are currently awaiting a reply.
440    outstanding_reply_ids: HashSet<u32>,
441    /// Non-reply messages buffered while a request is synchronously awaiting its reply.
442    queued_messages: VecDeque<DtxMessage>,
443    /// In-progress multi-fragment messages keyed by DTX identifier.
444    fragments: HashMap<u32, FragmentAccum>,
445}
446
447impl<S: AsyncRead + AsyncWrite + Unpin + Send> DtxConnection<S> {
448    pub fn new(stream: S) -> Self {
449        // Start identifier at 5 to match go-ios global channel messageIdentifier initial value
450        Self {
451            stream,
452            identifier: 5,
453            channel_counter: 1,
454            pending_replies: HashMap::new(),
455            outstanding_reply_ids: HashSet::new(),
456            queued_messages: VecDeque::new(),
457            fragments: HashMap::new(),
458        }
459    }
460
461    fn next_id(&mut self) -> u32 {
462        let id = self.identifier;
463        self.identifier += 1;
464        id
465    }
466
467    fn next_channel_code(&mut self) -> i32 {
468        let code = self.channel_counter;
469        self.channel_counter += 1;
470        code
471    }
472
473    pub async fn send_raw(&mut self, data: &[u8]) -> Result<(), DtxError> {
474        self.stream.write_all(data).await?;
475        self.stream.flush().await?;
476        Ok(())
477    }
478
479    pub async fn send_ack(&mut self, msg: &DtxMessage) -> Result<(), DtxError> {
480        self.send_raw(&encode_ack(msg)).await
481    }
482
483    fn buffer_reply(&mut self, msg: DtxMessage) {
484        if let Some(previous) = self.pending_replies.insert(msg.identifier, msg.clone()) {
485            tracing::trace!(
486                "buffer_reply: replacing pending reply id={} old_conv={} new_conv={}",
487                previous.identifier,
488                previous.conversation_idx,
489                msg.conversation_idx
490            );
491        }
492    }
493
494    fn is_reply_message(&self, msg: &DtxMessage) -> bool {
495        msg.conversation_idx > 0 && self.outstanding_reply_ids.contains(&msg.identifier)
496    }
497
498    async fn recv_from_stream(&mut self) -> Result<DtxMessage, DtxError> {
499        loop {
500            let h = read_dtx_header(&mut self.stream).await?;
501            tracing::trace!(
502                "recv: frag_idx={} frag_cnt={} msg_len={} id={}",
503                h.frag_idx,
504                h.frag_cnt,
505                h.msg_len,
506                h.identifier
507            );
508
509            if h.frag_cnt <= 1 {
510                // Single-fragment message
511                if h.msg_len == 0 {
512                    return Ok(DtxMessage {
513                        identifier: h.identifier,
514                        conversation_idx: h.conv_idx,
515                        channel_code: normalize_incoming_channel_code(h.channel_code, h.conv_idx),
516                        expects_reply: h.expects_reply,
517                        payload: DtxPayload::Empty,
518                    });
519                }
520                let mut msg = read_dtx_body(&mut self.stream, &h, &[]).await?;
521                msg.channel_code =
522                    normalize_incoming_channel_code(msg.channel_code, msg.conversation_idx);
523                return Ok(msg);
524            }
525
526            if h.frag_idx == 0 {
527                // First fragment: no body, just announces total size
528                if self.fragments.contains_key(&h.identifier) {
529                    return Err(DtxError::Protocol(format!(
530                        "duplicate first fragment for id={}",
531                        h.identifier
532                    )));
533                }
534                self.fragments.insert(
535                    h.identifier,
536                    FragmentAccum {
537                        fragments: vec![None; (h.frag_cnt - 1) as usize],
538                        remaining: h.frag_cnt - 1,
539                        header: h,
540                    },
541                );
542                continue;
543            }
544
545            // Subsequent fragment: read msg_len bytes of body
546            let mut frag_body = vec![0u8; h.msg_len];
547            self.stream.read_exact(&mut frag_body).await?;
548
549            let id = h.identifier;
550            if let Some(accum) = self.fragments.get_mut(&id) {
551                if h.frag_cnt != accum.header.frag_cnt {
552                    return Err(DtxError::Protocol(format!(
553                        "fragment count mismatch for id={id}: got={} expected={}",
554                        h.frag_cnt, accum.header.frag_cnt
555                    )));
556                }
557                if h.conv_idx != accum.header.conv_idx
558                    || h.channel_code != accum.header.channel_code
559                    || h.expects_reply != accum.header.expects_reply
560                {
561                    return Err(DtxError::Protocol(format!(
562                        "fragment metadata mismatch for id={id}"
563                    )));
564                }
565                let slot_idx = h
566                    .frag_idx
567                    .checked_sub(1)
568                    .map(|idx| idx as usize)
569                    .ok_or_else(|| {
570                        DtxError::Protocol(format!(
571                            "invalid fragment index {} for id={id}",
572                            h.frag_idx
573                        ))
574                    })?;
575                let slot = accum.fragments.get_mut(slot_idx).ok_or_else(|| {
576                    DtxError::Protocol(format!(
577                        "fragment index {} out of range for id={id}",
578                        h.frag_idx
579                    ))
580                })?;
581                if slot.is_some() {
582                    return Err(DtxError::Protocol(format!(
583                        "duplicate fragment {} for id={id}",
584                        h.frag_idx
585                    )));
586                }
587                *slot = Some(frag_body);
588                accum.remaining -= 1;
589                if accum.remaining == 0 {
590                    let accum = self.fragments.remove(&id).ok_or_else(|| {
591                        DtxError::Protocol(format!("missing fragment accumulator for id={id}"))
592                    })?;
593                    let mut body = Vec::with_capacity(accum.header.msg_len);
594                    for (index, fragment) in accum.fragments.into_iter().enumerate() {
595                        let fragment = fragment.ok_or_else(|| {
596                            DtxError::Protocol(format!(
597                                "missing fragment {} for id={id}",
598                                index + 1
599                            ))
600                        })?;
601                        body.extend_from_slice(&fragment);
602                    }
603                    if body.len() != accum.header.msg_len {
604                        return Err(DtxError::Protocol(format!(
605                            "fragmented body size mismatch for id={id}: assembled={} expected={}",
606                            body.len(),
607                            accum.header.msg_len
608                        )));
609                    }
610                    let mut msg = read_dtx_body(&mut self.stream, &accum.header, &body).await?;
611                    msg.channel_code =
612                        normalize_incoming_channel_code(msg.channel_code, msg.conversation_idx);
613                    return Ok(msg);
614                }
615            } else {
616                return Err(DtxError::Protocol(format!(
617                    "fragment id={id} frag_idx={} without first fragment",
618                    h.frag_idx
619                )));
620            }
621        }
622    }
623
624    async fn wait_for_reply(&mut self, id: u32) -> Result<DtxMessage, DtxError> {
625        if let Some(msg) = self.pending_replies.remove(&id) {
626            self.outstanding_reply_ids.remove(&id);
627            return Ok(msg);
628        }
629
630        loop {
631            let msg = self.recv_from_stream().await?;
632            tracing::trace!(
633                "wait_for_reply: target_id={} recv id={} conv_idx={} ch={} expects_reply={}",
634                id,
635                msg.identifier,
636                msg.conversation_idx,
637                msg.channel_code,
638                msg.expects_reply
639            );
640
641            if self.is_reply_message(&msg) {
642                if msg.identifier == id {
643                    self.outstanding_reply_ids.remove(&id);
644                    return Ok(msg);
645                }
646                self.buffer_reply(msg);
647                continue;
648            }
649
650            if msg.expects_reply {
651                self.send_ack(&msg).await?;
652            }
653            self.queued_messages.push_back(msg);
654        }
655    }
656
657    /// Receive the next fully-assembled DTX message, transparently reassembling fragments.
658    pub async fn recv(&mut self) -> Result<DtxMessage, DtxError> {
659        if let Some(msg) = self.queued_messages.pop_front() {
660            return Ok(msg);
661        }
662
663        loop {
664            let msg = self.recv_from_stream().await?;
665            if self.is_reply_message(&msg) {
666                self.buffer_reply(msg);
667                continue;
668            }
669            return Ok(msg);
670        }
671    }
672
673    /// Request a DTX channel by service name.
674    /// Returns the assigned channel code.
675    pub async fn request_channel(&mut self, service_name: &str) -> Result<i32, DtxError> {
676        let channel_code = self.next_channel_code();
677        let id = self.next_id();
678
679        let selector =
680            nskeyedarchiver_encode::archive_string("_requestChannelWithCode:identifier:");
681        let arg_name = nskeyedarchiver_encode::archive_string(service_name);
682
683        // channel_code is passed as raw Int32 (not NSKeyedArchiver), matching go-ios AddInt32()
684        let aux = encode_primitive_dict(&[
685            PrimArg::Int32(channel_code),
686            PrimArg::Bytes(Bytes::from(arg_name)),
687        ]);
688
689        let frame = encode_dtx(id, 0, 0, true, MSG_METHOD_INVOCATION, &selector, &aux);
690        self.send_raw(&frame).await?;
691        self.outstanding_reply_ids.insert(id);
692
693        // Read reply (skip unrelated notifications)
694        let msg = self.wait_for_reply(id).await?;
695        tracing::debug!(
696            "request_channel recv: id={} conv_idx={} ch={} expects_reply={}",
697            msg.identifier,
698            msg.conversation_idx,
699            msg.channel_code,
700            msg.expects_reply
701        );
702        Ok(channel_code)
703    }
704
705    /// Call a method on a channel and wait for the response.
706    pub async fn method_call(
707        &mut self,
708        channel_code: i32,
709        selector: &str,
710        args: &[PrimArg],
711    ) -> Result<DtxMessage, DtxError> {
712        let id = self.next_id();
713        let sel_bytes = nskeyedarchiver_encode::archive_string(selector);
714        let aux = if args.is_empty() {
715            Bytes::new()
716        } else {
717            encode_primitive_dict(args)
718        };
719        let frame = encode_dtx(
720            id,
721            0,
722            channel_code,
723            true,
724            MSG_METHOD_INVOCATION,
725            &sel_bytes,
726            &aux,
727        );
728        self.send_raw(&frame).await?;
729        self.outstanding_reply_ids.insert(id);
730        tracing::debug!("method_call '{selector}' id={id} ch={channel_code}");
731
732        let msg = self.wait_for_reply(id).await?;
733        tracing::debug!(
734            "method_call recv: id={} conv_idx={} ch={}",
735            msg.identifier,
736            msg.conversation_idx,
737            msg.channel_code
738        );
739        Ok(msg)
740    }
741
742    /// Fire-and-forget method call.
743    pub async fn method_call_async(
744        &mut self,
745        channel_code: i32,
746        selector: &str,
747        args: &[PrimArg],
748    ) -> Result<(), DtxError> {
749        let id = self.next_id();
750        let sel_bytes = nskeyedarchiver_encode::archive_string(selector);
751        let aux = if args.is_empty() {
752            Bytes::new()
753        } else {
754            encode_primitive_dict(args)
755        };
756        let frame = encode_dtx(
757            id,
758            0,
759            channel_code,
760            false,
761            MSG_METHOD_INVOCATION,
762            &sel_bytes,
763            &aux,
764        );
765        self.send_raw(&frame).await
766    }
767}
768
769fn normalize_incoming_channel_code(channel_code: i32, conversation_idx: u32) -> i32 {
770    if conversation_idx % 2 == 0 {
771        -channel_code
772    } else {
773        channel_code
774    }
775}
776
777#[cfg(test)]
778mod tests {
779    use bytes::BufMut;
780
781    use super::*;
782
783    #[test]
784    fn test_encode_dtx_layout() {
785        let sel = nskeyedarchiver_encode::archive_string("test");
786        let frame = encode_dtx(1, 0, 1, true, MSG_METHOD_INVOCATION, &sel, &[]);
787        assert_eq!(
788            u32::from_be_bytes(frame[0..4].try_into().unwrap()),
789            DTX_MAGIC
790        );
791        assert_eq!(u32::from_le_bytes(frame[4..8].try_into().unwrap()), 32);
792        assert_eq!(u32::from_le_bytes(frame[28..32].try_into().unwrap()), 1); // expects_reply
793        assert_eq!(
794            u32::from_le_bytes(frame[32..36].try_into().unwrap()),
795            MSG_METHOD_INVOCATION
796        );
797    }
798
799    #[test]
800    fn test_encode_ack_length() {
801        let msg = DtxMessage {
802            identifier: 5,
803            conversation_idx: 0,
804            channel_code: 1,
805            expects_reply: true,
806            payload: DtxPayload::Empty,
807        };
808        let ack = encode_ack(&msg);
809        assert_eq!(ack.len(), 48);
810        assert_eq!(u32::from_le_bytes(ack[32..36].try_into().unwrap()), MSG_OK);
811    }
812
813    #[tokio::test]
814    async fn test_dtx_encode_decode_roundtrip() {
815        let sel = nskeyedarchiver_encode::archive_string("setConfig:");
816        let frame = encode_dtx(7, 0, 2, true, MSG_METHOD_INVOCATION, &sel, &[]);
817        let mut cur = std::io::Cursor::new(frame);
818        let msg = read_dtx_frame(&mut cur).await.unwrap();
819        assert_eq!(msg.identifier, 7);
820        assert_eq!(msg.channel_code, 2);
821        assert!(msg.expects_reply);
822        // Selector should be recoverable
823        if let DtxPayload::MethodInvocation { selector, .. } = &msg.payload {
824            assert_eq!(selector, "setConfig:");
825        } else {
826            panic!("expected MethodInvocation");
827        }
828    }
829
830    #[tokio::test]
831    async fn test_data_frame_keeps_raw_payload() {
832        let payload = b"trace-binary-payload";
833        let frame = encode_dtx(11, 0, 4, false, MSG_UNKNOWN_TYPE_ONE, payload, &[]);
834        let mut cur = std::io::Cursor::new(frame);
835        let msg = read_dtx_frame(&mut cur).await.unwrap();
836        match msg.payload {
837            DtxPayload::Raw(bytes) => assert_eq!(bytes.as_ref(), payload),
838            other => panic!("expected raw payload, got {other:?}"),
839        }
840    }
841
842    #[tokio::test]
843    async fn test_data_frame_preserves_auxiliary_arguments() {
844        let payload = b"trace-binary-payload";
845        let aux =
846            encode_primitive_dict(&[PrimArg::Bytes(Bytes::from_static(b"kperf-aux-payload"))]);
847        let frame = encode_dtx(13, 0, 4, false, MSG_UNKNOWN_TYPE_ONE, payload, &aux);
848        let mut cur = std::io::Cursor::new(frame);
849        let msg = read_dtx_frame(&mut cur).await.unwrap();
850
851        match msg.payload {
852            DtxPayload::RawWithAux { payload: body, aux } => {
853                assert_eq!(body.as_ref(), payload);
854                assert!(matches!(
855                    aux.first(),
856                    Some(NSObject::Data(bytes)) if bytes.as_ref() == b"kperf-aux-payload"
857                ));
858            }
859            other => panic!("expected raw payload with aux, got {other:?}"),
860        }
861    }
862
863    #[tokio::test]
864    async fn test_method_invocation_preserves_raw_payload_when_selector_decode_fails() {
865        let payload = b"not-a-selector";
866        let frame = encode_dtx(12, 0, 4, false, MSG_METHOD_INVOCATION, payload, &[]);
867        let mut cur = std::io::Cursor::new(frame);
868        let msg = read_dtx_frame(&mut cur).await.unwrap();
869
870        match msg.payload {
871            DtxPayload::MethodInvocation { selector, args } => {
872                assert!(selector.is_empty());
873                assert!(
874                    matches!(args.first(), Some(NSObject::Data(bytes)) if bytes.as_ref() == payload)
875                );
876            }
877            other => panic!("expected method invocation, got {other:?}"),
878        }
879    }
880
881    #[tokio::test]
882    async fn test_method_call_buffers_unrelated_notifications() {
883        let (client, mut server) = tokio::io::duplex(4096);
884        let mut conn = DtxConnection::new(client);
885
886        let call = tokio::spawn(async move {
887            conn.method_call(2, "startSampling", &[])
888                .await
889                .map(|reply| (conn, reply))
890        });
891
892        let outbound = read_dtx_frame(&mut server).await.unwrap();
893        assert_eq!(outbound.identifier, 5);
894        assert!(outbound.expects_reply);
895
896        let notify_selector = nskeyedarchiver_encode::archive_string("note:");
897        let notify = encode_dtx(77, 0, 1, true, MSG_METHOD_INVOCATION, &notify_selector, &[]);
898        server.write_all(&notify).await.unwrap();
899
900        let ack = read_dtx_frame(&mut server).await.unwrap();
901        assert_eq!(ack.identifier, 77);
902        assert_eq!(ack.conversation_idx, 1);
903
904        let reply = encode_dtx(5, 1, 2, false, MSG_RESPONSE, &[], &[]);
905        server.write_all(&reply).await.unwrap();
906
907        let (mut conn, reply) = call.await.unwrap().unwrap();
908        assert_eq!(reply.identifier, 5);
909        assert_eq!(reply.conversation_idx, 1);
910
911        let queued = conn.recv().await.unwrap();
912        assert_eq!(queued.identifier, 77);
913        assert_eq!(queued.channel_code, -1);
914        assert!(queued.expects_reply);
915    }
916
917    #[tokio::test]
918    async fn test_recv_normalizes_even_conversation_channel_codes() {
919        let (client, mut server) = tokio::io::duplex(256);
920        let mut conn = DtxConnection::new(client);
921
922        let recv_task = tokio::spawn(async move { conn.recv().await });
923
924        let payload = b"trace-binary-payload";
925        let frame = encode_dtx(42, 0, -2, false, MSG_UNKNOWN_TYPE_ONE, payload, &[]);
926        server.write_all(&frame).await.unwrap();
927
928        let msg = recv_task.await.unwrap().unwrap();
929        assert_eq!(msg.identifier, 42);
930        assert_eq!(msg.conversation_idx, 0);
931        assert_eq!(msg.channel_code, 2);
932        match msg.payload {
933            DtxPayload::Raw(bytes) => assert_eq!(bytes.as_ref(), payload),
934            other => panic!("expected raw payload, got {other:?}"),
935        }
936    }
937
938    #[tokio::test]
939    async fn test_wait_for_reply_returns_buffered_reply_immediately() {
940        let (client, _server) = tokio::io::duplex(64);
941        let mut conn = DtxConnection::new(client);
942
943        conn.buffer_reply(DtxMessage {
944            identifier: 9,
945            conversation_idx: 1,
946            channel_code: 3,
947            expects_reply: false,
948            payload: DtxPayload::Empty,
949        });
950
951        let reply = conn.wait_for_reply(9).await.unwrap();
952        assert_eq!(reply.identifier, 9);
953        assert_eq!(reply.conversation_idx, 1);
954        assert_eq!(reply.channel_code, 3);
955    }
956
957    #[tokio::test]
958    async fn test_recv_treats_unsolicited_conversation_message_as_live_event() {
959        let (client, mut server) = tokio::io::duplex(256);
960        let mut conn = DtxConnection::new(client);
961
962        let recv_task = tokio::spawn(async move { conn.recv().await });
963
964        let payload = b"trace-binary-payload";
965        let frame = encode_dtx(42, 1, 2, false, MSG_UNKNOWN_TYPE_ONE, payload, &[]);
966        server.write_all(&frame).await.unwrap();
967
968        let msg = recv_task.await.unwrap().unwrap();
969        assert_eq!(msg.identifier, 42);
970        assert_eq!(msg.conversation_idx, 1);
971        assert_eq!(msg.channel_code, 2);
972        match msg.payload {
973            DtxPayload::Raw(bytes) => assert_eq!(bytes.as_ref(), payload),
974            other => panic!("expected raw payload, got {other:?}"),
975        }
976    }
977
978    #[tokio::test]
979    async fn test_read_dtx_frame_skips_extended_header_bytes() {
980        let payload = b"trace-binary-payload";
981        let mut frame = encode_dtx(21, 0, 4, false, MSG_UNKNOWN_TYPE_ONE, payload, &[]).to_vec();
982
983        frame[4..8].copy_from_slice(&36u32.to_le_bytes());
984        frame.splice(32..32, [0xAA, 0xBB, 0xCC, 0xDD]);
985
986        let mut cur = std::io::Cursor::new(frame);
987        let msg = read_dtx_frame(&mut cur)
988            .await
989            .expect("extended headers should be skipped before parsing payload");
990
991        match msg.payload {
992            DtxPayload::Raw(bytes) => assert_eq!(bytes.as_ref(), payload),
993            other => panic!("expected raw payload, got {other:?}"),
994        }
995    }
996
997    #[tokio::test]
998    async fn test_ok_reply_decodes_as_empty_payload() {
999        let frame = encode_dtx(23, 1, 4, false, MSG_OK, &[], &[]);
1000        let mut cur = std::io::Cursor::new(frame);
1001        let msg = read_dtx_frame(&mut cur).await.unwrap();
1002        assert!(matches!(msg.payload, DtxPayload::Empty));
1003    }
1004
1005    #[tokio::test]
1006    async fn test_error_reply_decodes_like_response_object() {
1007        let payload = nskeyedarchiver_encode::archive_string("selector failed");
1008        let frame = encode_dtx(24, 1, 4, false, MSG_ERROR, &payload, &[]);
1009        let mut cur = std::io::Cursor::new(frame);
1010        let msg = read_dtx_frame(&mut cur).await.unwrap();
1011        assert!(matches!(
1012            msg.payload,
1013            DtxPayload::Response(NSObject::String(ref value)) if value == "selector failed"
1014        ));
1015    }
1016
1017    fn encode_fragment(
1018        identifier: u32,
1019        frag_idx: u16,
1020        frag_cnt: u16,
1021        channel_code: i32,
1022        expects_reply: bool,
1023        msg_len: usize,
1024        body: &[u8],
1025    ) -> Bytes {
1026        let mut out = BytesMut::with_capacity(32 + body.len());
1027        out.put_u32(DTX_MAGIC);
1028        out.put_u32_le(32);
1029        out.put_u16_le(frag_idx);
1030        out.put_u16_le(frag_cnt);
1031        out.put_u32_le(msg_len as u32);
1032        out.put_u32_le(identifier);
1033        out.put_u32_le(0);
1034        out.put_u32_le(channel_code as u32);
1035        out.put_u32_le(if expects_reply { 1 } else { 0 });
1036        out.extend_from_slice(body);
1037        out.freeze()
1038    }
1039
1040    #[tokio::test]
1041    async fn test_recv_reassembles_out_of_order_fragments_by_index() {
1042        let payload = b"fragmented-trace-payload";
1043        let mut body = BytesMut::with_capacity(16 + payload.len());
1044        body.put_u32_le(MSG_UNKNOWN_TYPE_ONE);
1045        body.put_u32_le(0);
1046        body.put_u32_le(payload.len() as u32);
1047        body.put_u32_le(0);
1048        body.extend_from_slice(payload);
1049        let body = body.freeze();
1050
1051        let split_at = 10;
1052        let first = encode_fragment(31, 0, 3, 4, false, body.len(), &[]);
1053        let second = encode_fragment(31, 1, 3, 4, false, split_at, &body[..split_at]);
1054        let third = encode_fragment(31, 2, 3, 4, false, body.len() - split_at, &body[split_at..]);
1055
1056        let (client, mut server) = tokio::io::duplex(512);
1057        let mut conn = DtxConnection::new(client);
1058
1059        let recv_task = tokio::spawn(async move { conn.recv_from_stream().await });
1060        server.write_all(&first).await.unwrap();
1061        server.write_all(&third).await.unwrap();
1062        server.write_all(&second).await.unwrap();
1063
1064        let msg = recv_task
1065            .await
1066            .unwrap()
1067            .expect("fragment order should not affect reassembly");
1068
1069        match msg.payload {
1070            DtxPayload::Raw(bytes) => assert_eq!(bytes.as_ref(), payload),
1071            other => panic!("expected raw payload, got {other:?}"),
1072        }
1073    }
1074
1075    #[tokio::test]
1076    async fn test_recv_rejects_duplicate_first_fragment() {
1077        let payload = b"fragmented-trace-payload";
1078        let mut body = BytesMut::with_capacity(16 + payload.len());
1079        body.put_u32_le(MSG_UNKNOWN_TYPE_ONE);
1080        body.put_u32_le(0);
1081        body.put_u32_le(payload.len() as u32);
1082        body.put_u32_le(0);
1083        body.extend_from_slice(payload);
1084        let body = body.freeze();
1085
1086        let first = encode_fragment(41, 0, 2, 4, false, body.len(), &[]);
1087        let duplicate_first = encode_fragment(41, 0, 2, 4, false, body.len(), &[]);
1088
1089        let (client, mut server) = tokio::io::duplex(512);
1090        let mut conn = DtxConnection::new(client);
1091
1092        let recv_task = tokio::spawn(async move { conn.recv_from_stream().await });
1093        server.write_all(&first).await.unwrap();
1094        server.write_all(&duplicate_first).await.unwrap();
1095
1096        let err = recv_task.await.unwrap().unwrap_err();
1097        assert!(matches!(
1098            err,
1099            DtxError::Protocol(message) if message.contains("duplicate first fragment")
1100        ));
1101    }
1102
1103    #[tokio::test]
1104    async fn test_recv_rejects_fragment_without_first_fragment() {
1105        let payload = b"fragmented-trace-payload";
1106        let mut body = BytesMut::with_capacity(16 + payload.len());
1107        body.put_u32_le(MSG_UNKNOWN_TYPE_ONE);
1108        body.put_u32_le(0);
1109        body.put_u32_le(payload.len() as u32);
1110        body.put_u32_le(0);
1111        body.extend_from_slice(payload);
1112        let body = body.freeze();
1113
1114        let stray = encode_fragment(43, 1, 2, 4, false, body.len(), &body);
1115
1116        let (client, mut server) = tokio::io::duplex(512);
1117        let mut conn = DtxConnection::new(client);
1118
1119        let recv_task = tokio::spawn(async move { conn.recv_from_stream().await });
1120        server.write_all(&stray).await.unwrap();
1121
1122        let err = recv_task.await.unwrap().unwrap_err();
1123        assert!(matches!(
1124            err,
1125            DtxError::Protocol(message) if message.contains("without first fragment")
1126        ));
1127    }
1128
1129    #[tokio::test]
1130    async fn test_recv_rejects_fragment_metadata_mismatch() {
1131        let payload = b"fragmented-trace-payload";
1132        let mut body = BytesMut::with_capacity(16 + payload.len());
1133        body.put_u32_le(MSG_UNKNOWN_TYPE_ONE);
1134        body.put_u32_le(0);
1135        body.put_u32_le(payload.len() as u32);
1136        body.put_u32_le(0);
1137        body.extend_from_slice(payload);
1138        let body = body.freeze();
1139
1140        let split_at = 10;
1141        let first = encode_fragment(45, 0, 3, 4, false, body.len(), &[]);
1142        let bad_second = encode_fragment(45, 1, 3, 5, false, split_at, &body[..split_at]);
1143
1144        let (client, mut server) = tokio::io::duplex(512);
1145        let mut conn = DtxConnection::new(client);
1146
1147        let recv_task = tokio::spawn(async move { conn.recv_from_stream().await });
1148        server.write_all(&first).await.unwrap();
1149        server.write_all(&bad_second).await.unwrap();
1150
1151        let err = recv_task.await.unwrap().unwrap_err();
1152        assert!(matches!(
1153            err,
1154            DtxError::Protocol(message) if message.contains("fragment metadata mismatch")
1155        ));
1156    }
1157
1158    #[tokio::test]
1159    async fn test_recv_rejects_excessive_fragment_count_before_allocation() {
1160        let first = encode_fragment(72, 0, MAX_DTX_FRAGMENTS + 1, 4, false, 16, &[]);
1161        let mut cursor = std::io::Cursor::new(first);
1162
1163        let err = match read_dtx_header(&mut cursor).await {
1164            Ok(_) => panic!("excessive fragment count should be rejected"),
1165            Err(err) => err,
1166        };
1167
1168        assert!(matches!(
1169            err,
1170            DtxError::Protocol(message) if message.contains("too many fragments")
1171        ));
1172    }
1173}