Skip to main content

oxideav_rtmp/
chunk.rs

1//! RTMP chunk stream reader + writer.
2//!
3//! RTMP splits every message (command / audio / video / control) into
4//! one or more fixed-size **chunks**. Each chunk has a 1..=3 byte basic
5//! header carrying `(fmt, csid)`, a 0..=11 byte message header whose
6//! shape depends on `fmt`, an optional 4-byte extended timestamp, and
7//! the payload. Peers negotiate the max per-chunk payload via the
8//! `SetChunkSize` protocol control message (default 128, typically
9//! bumped to 4096+ after connect).
10//!
11//! This module is the minimum needed to round-trip real RTMP traffic:
12//!
13//! * Reader coalesces chunks back into whole `Message`s, auto-handling
14//!   fmt 0 / 1 / 2 / 3, extended timestamps, and the per-csid state
15//!   tables.
16//! * Writer emits messages as a sequence of chunks, preferring the
17//!   densest fmt that carries the same `(msg_length, msg_type_id,
18//!   msg_stream_id)` as the previous message on that csid (fmt 3 if
19//!   nothing changed, else fmt 2 / 1 / 0).
20//! * `SetChunkSize` is applied live on the reader when observed, and
21//!   exposed as a setter on the writer so callers can acknowledge it
22//!   to the peer.
23//!
24//! Chunk stream id (csid) conventions we use:
25//! * `2` — protocol control (SetChunkSize, Ack, WindowAckSize, SetPeerBandwidth, UserControl)
26//! * `3` — AMF command messages (connect, createStream, publish, …)
27//! * `4` — audio messages
28//! * `5` — video messages
29//! * `6` — data messages (@setDataFrame / onMetaData)
30
31use std::collections::HashMap;
32use std::io::{Read, Write};
33
34use crate::error::{Error, Result};
35
36/// Default max per-chunk payload size, applied until either side
37/// sends a `SetChunkSize` message (RTMP spec §5.4.1).
38pub const DEFAULT_CHUNK_SIZE: usize = 128;
39
40/// Max value RTMP's `SetChunkSize` field can legally carry (spec:
41/// 1..=16_777_215, top bit reserved).
42pub const MAX_CHUNK_SIZE: usize = 0x00FF_FFFF;
43
44/// One fully reassembled RTMP message, after chunk-header removal.
45#[derive(Debug, Clone)]
46pub struct Message {
47    pub msg_type_id: u8,
48    pub msg_stream_id: u32,
49    /// Absolute timestamp in the message-type-specific unit (ms for
50    /// audio / video / data / command).
51    pub timestamp: u32,
52    pub payload: Vec<u8>,
53}
54
55/// Typed classification of a message's `msg_stream_id` per Message
56/// Formats spec §5 ("Protocol Control Messages MUST have message
57/// stream ID 0 (called as control stream)") and §4.1 (3-byte stream
58/// ID field).
59///
60/// The numeric NetStream ids 1..=`0x00FF_FFFF` are the values a server
61/// returns from `_result(createStream)`; per the RTMP Commands Messages
62/// spec §4.1.3 a freshly created NetStream receives "a stream ID" that
63/// the publisher then stamps into every subsequent A/V / metadata
64/// message header. The chunk message-stream-id field on the wire is
65/// 32-bit little-endian (RTMP Chunk Stream §6.1.2.1) but the §4.1
66/// message header layout only allocates 3 bytes for it; values whose
67/// top byte is non-zero are reserved and surface here as
68/// [`MessageStreamKind::Reserved`] so a caller can refuse them.
69#[derive(Debug, Clone, Copy, PartialEq, Eq)]
70pub enum MessageStreamKind {
71    /// `msg_stream_id == 0` — the "control stream" carrying
72    /// NetConnection commands (`connect`, `createStream`, `_result`,
73    /// `_error`, `call`) and the protocol-control / user-control
74    /// messages (types 1..=6).
75    Control,
76    /// A NetStream handle (1..=`0x00FF_FFFF`) — the value returned by
77    /// `_result(createStream)`, stamped into every audio / video /
78    /// data / aggregate message that flows on that NetStream.
79    NetStream(u32),
80    /// `msg_stream_id` has bit(s) set in the top byte, outside the
81    /// §4.1 3-byte field. RTMP Chunk Stream §6.1.2.1 carries the
82    /// field as a 32-bit value so receivers see it on the wire, but
83    /// the Message Formats spec §4.1 reserves the high byte. Surfaced
84    /// so a strict consumer can refuse the message.
85    Reserved(u32),
86}
87
88impl Message {
89    /// Classify [`Self::msg_stream_id`] per Message Formats spec §4.1 /
90    /// §5 — `0` is the "control stream", `1..=0x00FF_FFFF` is a
91    /// NetStream handle, anything with bits set above the §4.1 3-byte
92    /// field is reserved.
93    pub fn stream_kind(&self) -> MessageStreamKind {
94        match self.msg_stream_id {
95            0 => MessageStreamKind::Control,
96            id if id & 0xFF00_0000 == 0 => MessageStreamKind::NetStream(id),
97            other => MessageStreamKind::Reserved(other),
98        }
99    }
100
101    /// True iff this message rides the control stream (`msg_stream_id
102    /// == 0`). All protocol-control / user-control / NetConnection
103    /// command traffic does, per Message Formats spec §5.
104    pub fn is_control_stream(&self) -> bool {
105        matches!(self.stream_kind(), MessageStreamKind::Control)
106    }
107
108    /// Validate the spec §5 mandate that "Protocol control messages
109    /// MUST have message stream ID 0 (called as control stream)". The
110    /// protocol-control range is message type IDs 1..=6 (Set Chunk
111    /// Size / Abort / Acknowledgement / User Control / Window Ack Size
112    /// / Set Peer Bandwidth) — type id 4 (User Control) is grouped
113    /// with the §5 protocol-control set in the same spec section. The
114    /// §6.1.2.1 reserved top-byte rule on `msg_stream_id` is also
115    /// enforced here so a single accessor catches both spec invariants.
116    pub fn validate_protocol_control_invariants(&self) -> Result<()> {
117        if matches!(self.stream_kind(), MessageStreamKind::Reserved(_)) {
118            return Err(Error::ProtocolViolation(format!(
119                "message stream id {:#010x} sets reserved high byte (spec §4.1: 3-byte field)",
120                self.msg_stream_id
121            )));
122        }
123        if matches!(self.msg_type_id, 1..=6) && self.msg_stream_id != 0 {
124            return Err(Error::ProtocolViolation(format!(
125                "protocol-control message type {} carries non-zero msg_stream_id {} (spec §5 requires 0)",
126                self.msg_type_id, self.msg_stream_id
127            )));
128        }
129        Ok(())
130    }
131}
132
133// ---------------------------------------------------------------------------
134// Per-csid state (reader)
135// ---------------------------------------------------------------------------
136
137#[derive(Default, Debug, Clone)]
138struct InState {
139    msg_type_id: u8,
140    msg_stream_id: u32,
141    msg_length: u32,
142    /// Absolute timestamp in ms. fmt 1/2 add a delta to this; fmt 3
143    /// re-uses it verbatim (unless the previous message had an
144    /// extended timestamp, in which case we re-read 4 bytes — see
145    /// [`ExtendedTsMode`]).
146    timestamp: u32,
147    last_delta: u32,
148    /// Whether the last completed fmt 0/1/2 chunk for this csid used
149    /// an extended timestamp. fmt 3 follow-ups must then also read
150    /// the 4-byte extended timestamp even though they normally don't.
151    last_had_ext_ts: bool,
152    /// Partial payload while a multi-chunk message is being received.
153    partial: Vec<u8>,
154}
155
156pub struct ChunkReader<R: Read> {
157    stream: R,
158    chunk_size: usize,
159    states: HashMap<u32, InState>,
160    /// Cumulative number of bytes consumed off the wire — basic
161    /// headers, message headers, extended timestamps, and payload all
162    /// count. This is the running "sequence number" the §5.3
163    /// Acknowledgement reports. It wraps at `u32::MAX` like the wire
164    /// field; we keep it as `u32` so [`ack_due`](Self::ack_due) emits
165    /// the value verbatim.
166    received_bytes: u32,
167    /// The window size the *peer* asked us to use (its §5.5 Window
168    /// Acknowledgement Size). We owe an Acknowledgement once we have
169    /// received this many bytes since the last one. `0` disables the
170    /// obligation (no window negotiated yet / peer asked for none).
171    window_ack_size: u32,
172    /// Value of [`received_bytes`](Self::received_bytes) at the moment
173    /// we last emitted an Acknowledgement. The next ack is due once
174    /// `received_bytes - last_ack_bytes >= window_ack_size`.
175    last_ack_bytes: u32,
176}
177
178impl<R: Read> ChunkReader<R> {
179    pub fn new(stream: R) -> Self {
180        Self {
181            stream,
182            chunk_size: DEFAULT_CHUNK_SIZE,
183            states: HashMap::new(),
184            received_bytes: 0,
185            window_ack_size: 0,
186            last_ack_bytes: 0,
187        }
188    }
189
190    /// Read exactly `buf.len()` bytes off the wire, charging them to
191    /// the §5.3 received-byte sequence counter. Every wire read in the
192    /// reader funnels through here so the Acknowledgement sequence
193    /// number stays exact regardless of chunk framing.
194    fn read_exact_counted(&mut self, buf: &mut [u8]) -> Result<()> {
195        self.stream.read_exact(buf)?;
196        self.received_bytes = self.received_bytes.wrapping_add(buf.len() as u32);
197        Ok(())
198    }
199
200    /// Total bytes consumed off the wire so far (the §5.3 sequence
201    /// number). Wraps at `u32::MAX` like the wire field.
202    pub fn received_bytes(&self) -> u32 {
203        self.received_bytes
204    }
205
206    /// The peer's current §5.5 Window Acknowledgement Size, or `0` if
207    /// none has been negotiated.
208    pub fn window_ack_size(&self) -> u32 {
209        self.window_ack_size
210    }
211
212    /// Record the peer's §5.5 Window Acknowledgement Size. Callers
213    /// dispatch the 4-byte big-endian body of an inbound
214    /// [`MSG_WINDOW_ACK_SIZE`](crate::message::MSG_WINDOW_ACK_SIZE)
215    /// message here; per §5.6 a Set Peer Bandwidth also carries an
216    /// output-bandwidth value equal to the window size, so the same
217    /// setter applies. Setting it to `0` disables the ack obligation.
218    ///
219    /// Resetting the window re-bases the "bytes since last ack"
220    /// accounting to the current sequence number so a freshly-shrunk
221    /// window doesn't make a single already-counted byte instantly
222    /// owe an Acknowledgement.
223    pub fn set_window_ack_size(&mut self, size: u32) {
224        self.window_ack_size = size;
225        self.last_ack_bytes = self.received_bytes;
226    }
227
228    /// Return the §5.3 Acknowledgement sequence number to emit if one
229    /// is now due, advancing the internal "last acked" mark so the
230    /// next ack only fires after another full window of bytes.
231    ///
232    /// Per §5.3 a peer "sends the acknowledgment to the peer after
233    /// receiving bytes equal to the window size". Returns `None` when
234    /// no window is negotiated (`window_ack_size == 0`) or fewer than
235    /// `window_ack_size` bytes have arrived since the last ack. The
236    /// caller is expected to call this after each
237    /// [`read_message`](Self::read_message) and, when it yields
238    /// `Some(seq)`, write a [`build_ack(seq)`](crate::message::build_ack)
239    /// back to the peer.
240    pub fn ack_due(&mut self) -> Option<u32> {
241        if self.window_ack_size == 0 {
242            return None;
243        }
244        // Use wrapping subtraction so a `received_bytes` that wrapped
245        // past u32::MAX since the last ack still measures the true
246        // gap (the wire counter is defined modulo 2^32).
247        let since = self.received_bytes.wrapping_sub(self.last_ack_bytes);
248        if since >= self.window_ack_size {
249            self.last_ack_bytes = self.received_bytes;
250            Some(self.received_bytes)
251        } else {
252            None
253        }
254    }
255
256    /// Override the current max chunk payload. Callers typically react
257    /// to an incoming `SetChunkSize` control message by propagating
258    /// the new value here; the reader itself does NOT auto-apply
259    /// SetChunkSize because the control flow lives at the message
260    /// layer one level up.
261    pub fn set_chunk_size(&mut self, size: usize) {
262        self.chunk_size = size.clamp(1, MAX_CHUNK_SIZE);
263    }
264
265    pub fn chunk_size(&self) -> usize {
266        self.chunk_size
267    }
268
269    /// React to an inbound Abort Message (RTMP 1.0 §5.2) by discarding the
270    /// partially-received message on the named chunk stream id.
271    ///
272    /// Per §5.2, the Abort Message tells a receiver that "is waiting for
273    /// chunks to complete a message" to "discard the partially received
274    /// message over a chunk stream and abort processing of that message."
275    /// The sender uses it after transmitting part of a message it has
276    /// decided not to finish, so the receiver must drop the half-filled
277    /// reassembly buffer rather than splice the abandoned bytes onto the
278    /// next message that arrives on the same csid.
279    ///
280    /// This clears only the in-flight payload bytes; the csid's header
281    /// state (last timestamp / type / length / extended-timestamp latch)
282    /// is left intact, because a subsequent fmt-1/2/3 chunk on the csid
283    /// still relies on it per §5.3.2, and a fmt-0 chunk would overwrite
284    /// it anyway. An Abort for a csid that has no in-flight message (or
285    /// one this reader has never seen) is a no-op, matching the spec's
286    /// "if it is waiting for chunks" precondition. Returns `true` when a
287    /// non-empty partial buffer was actually discarded.
288    ///
289    /// The control flow lives at the message layer one level up — like
290    /// [`ChunkReader::set_chunk_size`], the reader does not auto-apply an
291    /// inbound Abort; the caller dispatches a
292    /// [`MSG_ABORT`](crate::message::MSG_ABORT) message's 4-byte
293    /// big-endian chunk stream id here.
294    pub fn abort_partial(&mut self, chunk_stream_id: u32) -> bool {
295        match self.states.get_mut(&chunk_stream_id) {
296            Some(st) if !st.partial.is_empty() => {
297                st.partial.clear();
298                true
299            }
300            _ => false,
301        }
302    }
303
304    /// Borrow the underlying reader (for splitting, timeout config, …).
305    pub fn inner_mut(&mut self) -> &mut R {
306        &mut self.stream
307    }
308
309    /// Read chunks off the wire until one full message is reassembled.
310    /// Blocks until at least one complete message is available.
311    pub fn read_message(&mut self) -> Result<Message> {
312        loop {
313            let (csid, fmt) = self.read_basic_header()?;
314            match fmt {
315                0 => self.read_fmt0_header(csid)?,
316                1 => self.read_fmt1_header(csid)?,
317                2 => self.read_fmt2_header(csid)?,
318                3 => self.read_fmt3_header(csid)?,
319                _ => unreachable!("fmt is 2 bits"),
320            }
321            // Read up to `chunk_size` bytes of payload (or the
322            // remaining message length, whichever is smaller). Compute
323            // the take size from a short immutable lookup, do the
324            // counted wire read into a scratch buffer (so the
325            // `&mut self` ack-accounting borrow doesn't overlap the
326            // per-csid state borrow), then append to the partial.
327            let take = {
328                let state = self.states.get(&csid).ok_or_else(|| {
329                    Error::InvalidChunk(format!(
330                        "fmt {fmt} chunk on csid {csid} without prior fmt-0 state"
331                    ))
332                })?;
333                let need = state.msg_length as usize - state.partial.len();
334                need.min(self.chunk_size)
335            };
336            let mut buf = vec![0u8; take];
337            self.read_exact_counted(&mut buf)?;
338            let state = self
339                .states
340                .get_mut(&csid)
341                .expect("csid state present (checked immediately above)");
342            state.partial.extend_from_slice(&buf);
343
344            if state.partial.len() as u32 >= state.msg_length {
345                let payload = std::mem::take(&mut state.partial);
346                let msg = Message {
347                    msg_type_id: state.msg_type_id,
348                    msg_stream_id: state.msg_stream_id,
349                    timestamp: state.timestamp,
350                    payload,
351                };
352                return Ok(msg);
353            }
354        }
355    }
356
357    fn read_basic_header(&mut self) -> Result<(u32, u8)> {
358        let mut b = [0u8; 1];
359        self.read_exact_counted(&mut b)?;
360        let fmt = (b[0] >> 6) & 0x03;
361        let low = b[0] & 0x3F;
362        let csid = match low {
363            0 => {
364                let mut b1 = [0u8; 1];
365                self.read_exact_counted(&mut b1)?;
366                b1[0] as u32 + 64
367            }
368            1 => {
369                let mut b2 = [0u8; 2];
370                self.read_exact_counted(&mut b2)?;
371                // spec: second byte is high order, third byte low — but
372                // commodity peers in the wild interpret it the other way;
373                // the official spec reads "2nd + 3rd byte * 256" which is
374                // little-endian.
375                b2[0] as u32 + (b2[1] as u32) * 256 + 64
376            }
377            other => other as u32,
378        };
379        Ok((csid, fmt))
380    }
381
382    fn read_u24_be(&mut self) -> Result<u32> {
383        let mut b = [0u8; 3];
384        self.read_exact_counted(&mut b)?;
385        Ok(((b[0] as u32) << 16) | ((b[1] as u32) << 8) | (b[2] as u32))
386    }
387
388    fn read_u32_le_stream_id(&mut self) -> Result<u32> {
389        let mut b = [0u8; 4];
390        self.read_exact_counted(&mut b)?;
391        // msg_stream_id is explicitly little-endian — only field in
392        // the RTMP wire format that is.
393        Ok(u32::from_le_bytes(b))
394    }
395
396    fn read_fmt0_header(&mut self, csid: u32) -> Result<()> {
397        let mut ts = self.read_u24_be()?;
398        let len = self.read_u24_be()?;
399        let mut t = [0u8; 1];
400        self.read_exact_counted(&mut t)?;
401        let ty = t[0];
402        let stream_id = self.read_u32_le_stream_id()?;
403        let had_ext_ts = ts == 0x00FF_FFFF;
404        if had_ext_ts {
405            ts = self.read_u32_be()?;
406        }
407        let st = self.states.entry(csid).or_default();
408        // fmt 0 wipes any half-received message on this csid.
409        st.partial.clear();
410        st.msg_type_id = ty;
411        st.msg_stream_id = stream_id;
412        st.msg_length = len;
413        st.timestamp = ts;
414        st.last_delta = ts;
415        st.last_had_ext_ts = had_ext_ts;
416        Ok(())
417    }
418
419    fn read_fmt1_header(&mut self, csid: u32) -> Result<()> {
420        let mut delta = self.read_u24_be()?;
421        let len = self.read_u24_be()?;
422        let mut t = [0u8; 1];
423        self.read_exact_counted(&mut t)?;
424        let ty = t[0];
425        let had_ext_ts = delta == 0x00FF_FFFF;
426        if had_ext_ts {
427            delta = self.read_u32_be()?;
428        }
429        let st = self
430            .states
431            .get_mut(&csid)
432            .ok_or_else(|| Error::InvalidChunk("fmt 1 without prior fmt 0".into()))?;
433        st.msg_type_id = ty;
434        st.msg_length = len;
435        st.timestamp = st.timestamp.wrapping_add(delta);
436        st.last_delta = delta;
437        st.last_had_ext_ts = had_ext_ts;
438        // fmt 1 also starts a new message.
439        st.partial.clear();
440        Ok(())
441    }
442
443    fn read_fmt2_header(&mut self, csid: u32) -> Result<()> {
444        let mut delta = self.read_u24_be()?;
445        let had_ext_ts = delta == 0x00FF_FFFF;
446        if had_ext_ts {
447            delta = self.read_u32_be()?;
448        }
449        let st = self
450            .states
451            .get_mut(&csid)
452            .ok_or_else(|| Error::InvalidChunk("fmt 2 without prior fmt 0/1".into()))?;
453        st.timestamp = st.timestamp.wrapping_add(delta);
454        st.last_delta = delta;
455        st.last_had_ext_ts = had_ext_ts;
456        st.partial.clear();
457        Ok(())
458    }
459
460    fn read_fmt3_header(&mut self, csid: u32) -> Result<()> {
461        // fmt 3 reads no message header normally — but if the last
462        // message on this csid used an extended timestamp, the
463        // extended-timestamp field is repeated here too. This is the
464        // most common source of chunk-stream desync bugs in RTMP
465        // implementations.
466        let (had_ext_ts, partial_empty, last_delta) = {
467            let st = self
468                .states
469                .get(&csid)
470                .ok_or_else(|| Error::InvalidChunk("fmt 3 without prior fmt 0/1/2".into()))?;
471            (st.last_had_ext_ts, st.partial.is_empty(), st.last_delta)
472        };
473        if had_ext_ts {
474            let _dup = self.read_u32_be()?;
475        }
476        // If this is a continuation of a multi-chunk message (partial
477        // buffer non-empty), we keep appending; otherwise we start a
478        // new message with the same metadata as the previous one and
479        // extend the timestamp by the last recorded delta.
480        if partial_empty {
481            let st = self.states.get_mut(&csid).unwrap();
482            st.timestamp = st.timestamp.wrapping_add(last_delta);
483        }
484        Ok(())
485    }
486
487    fn read_u32_be(&mut self) -> Result<u32> {
488        let mut b = [0u8; 4];
489        self.read_exact_counted(&mut b)?;
490        Ok(u32::from_be_bytes(b))
491    }
492}
493
494// ---------------------------------------------------------------------------
495// Writer
496// ---------------------------------------------------------------------------
497
498#[derive(Default, Debug, Clone)]
499struct OutState {
500    msg_type_id: u8,
501    msg_stream_id: u32,
502    msg_length: u32,
503    timestamp: u32,
504    last_delta: u32,
505    last_had_ext_ts: bool,
506    /// True once the first fmt-0 chunk has been emitted on this csid.
507    primed: bool,
508}
509
510pub struct ChunkWriter<W: Write> {
511    stream: W,
512    chunk_size: usize,
513    states: HashMap<u32, OutState>,
514}
515
516impl<W: Write> ChunkWriter<W> {
517    pub fn new(stream: W) -> Self {
518        Self {
519            stream,
520            chunk_size: DEFAULT_CHUNK_SIZE,
521            states: HashMap::new(),
522        }
523    }
524
525    pub fn set_chunk_size(&mut self, size: usize) {
526        self.chunk_size = size.clamp(1, MAX_CHUNK_SIZE);
527    }
528
529    pub fn chunk_size(&self) -> usize {
530        self.chunk_size
531    }
532
533    pub fn inner_mut(&mut self) -> &mut W {
534        &mut self.stream
535    }
536
537    pub fn flush(&mut self) -> Result<()> {
538        self.stream.flush()?;
539        Ok(())
540    }
541
542    /// Emit `msg` on `csid`, splitting into chunks if the payload
543    /// exceeds the current per-chunk limit. Picks the densest valid
544    /// chunk header format:
545    ///
546    /// * first message on the csid → fmt 0 (full header);
547    /// * same (stream, type, length) as previous + monotonic
548    ///   timestamp delta → fmt 2 or 3;
549    /// * same stream but different (type, length) → fmt 1;
550    /// * anything else → fmt 0.
551    pub fn write_message(&mut self, csid: u32, msg: &Message) -> Result<()> {
552        let payload_len = msg.payload.len() as u32;
553
554        // Snapshot the state fields we need for the header decision
555        // and the continuation-chunk timestamp repeat. Dropping the
556        // borrow here lets us reach for `self.stream` / `self.chunk_size`
557        // below without NLL complaints.
558        let (prev_primed, prev_stream_id, prev_type_id, prev_length, prev_timestamp) = {
559            let st = self.states.entry(csid).or_default();
560            (
561                st.primed,
562                st.msg_stream_id,
563                st.msg_type_id,
564                st.msg_length,
565                st.timestamp,
566            )
567        };
568
569        // Pick the most compact fmt we can get away with.
570        let fmt = if !prev_primed || prev_stream_id != msg.msg_stream_id {
571            0
572        } else if prev_type_id != msg.msg_type_id || prev_length != payload_len {
573            1
574        } else if msg.timestamp < prev_timestamp {
575            // Timestamp went backwards (rare, e.g. seek in the source).
576            // Re-prime with fmt 0.
577            0
578        } else if msg.timestamp == prev_timestamp {
579            3
580        } else {
581            2
582        };
583        let delta = msg.timestamp.wrapping_sub(prev_timestamp);
584        let ext_ts_needed = (fmt == 0 && msg.timestamp >= 0x00FF_FFFF)
585            || (fmt != 0 && fmt != 3 && delta >= 0x00FF_FFFF);
586
587        // Query the sticky "previous chunk used ext ts" bit we'll need
588        // to mirror on any fmt-3 continuation.
589        let prev_last_had_ext_ts = self
590            .states
591            .get(&csid)
592            .map(|s| s.last_had_ext_ts)
593            .unwrap_or(false);
594
595        let chunk_size = self.chunk_size;
596        let mut first_chunk_done = false;
597        let mut cursor = 0usize;
598        while cursor < msg.payload.len() || !first_chunk_done {
599            let chunk_fmt = if !first_chunk_done { fmt } else { 3 };
600            self.write_basic_header(chunk_fmt, csid)?;
601            match chunk_fmt {
602                0 => {
603                    let ts_field = if ext_ts_needed {
604                        0x00FF_FFFF
605                    } else {
606                        msg.timestamp
607                    };
608                    self.write_u24_be(ts_field)?;
609                    self.write_u24_be(payload_len)?;
610                    self.stream.write_all(&[msg.msg_type_id])?;
611                    self.stream.write_all(&msg.msg_stream_id.to_le_bytes())?;
612                    if ext_ts_needed {
613                        self.stream.write_all(&msg.timestamp.to_be_bytes())?;
614                    }
615                }
616                1 => {
617                    let ts_field = if ext_ts_needed { 0x00FF_FFFF } else { delta };
618                    self.write_u24_be(ts_field)?;
619                    self.write_u24_be(payload_len)?;
620                    self.stream.write_all(&[msg.msg_type_id])?;
621                    if ext_ts_needed {
622                        self.stream.write_all(&msg.timestamp.to_be_bytes())?;
623                    }
624                }
625                2 => {
626                    let ts_field = if ext_ts_needed { 0x00FF_FFFF } else { delta };
627                    self.write_u24_be(ts_field)?;
628                    if ext_ts_needed {
629                        self.stream.write_all(&msg.timestamp.to_be_bytes())?;
630                    }
631                }
632                3 => {
633                    // Continuation chunk must repeat the extended
634                    // timestamp iff the head chunk used one. Use the
635                    // current-message decision for the first round,
636                    // fall back to the previous message's sticky bit
637                    // otherwise.
638                    let ext_repeat = if !first_chunk_done {
639                        ext_ts_needed
640                    } else {
641                        prev_last_had_ext_ts && cursor == 0
642                    };
643                    if ext_repeat {
644                        self.stream.write_all(&msg.timestamp.to_be_bytes())?;
645                    }
646                }
647                _ => unreachable!(),
648            }
649            let end = (cursor + chunk_size).min(msg.payload.len());
650            self.stream.write_all(&msg.payload[cursor..end])?;
651            cursor = end;
652            first_chunk_done = true;
653        }
654
655        // Commit the updated state.
656        let st = self.states.entry(csid).or_default();
657        st.msg_type_id = msg.msg_type_id;
658        st.msg_stream_id = msg.msg_stream_id;
659        st.msg_length = payload_len;
660        st.timestamp = msg.timestamp;
661        st.last_delta = if fmt == 0 { msg.timestamp } else { delta };
662        st.last_had_ext_ts = ext_ts_needed;
663        st.primed = true;
664        Ok(())
665    }
666
667    fn write_basic_header(&mut self, fmt: u8, csid: u32) -> Result<()> {
668        match csid {
669            2..=63 => {
670                self.stream.write_all(&[(fmt << 6) | (csid as u8)])?;
671            }
672            64..=319 => {
673                self.stream.write_all(&[fmt << 6, (csid - 64) as u8])?;
674            }
675            320..=65_599 => {
676                let v = (csid - 64) as u16;
677                self.stream
678                    .write_all(&[(fmt << 6) | 1, (v & 0xFF) as u8, (v >> 8) as u8])?;
679            }
680            other => {
681                return Err(Error::ProtocolViolation(format!(
682                    "chunk stream id {other} out of range"
683                )))
684            }
685        }
686        Ok(())
687    }
688
689    fn write_u24_be(&mut self, v: u32) -> Result<()> {
690        let v = v & 0x00FF_FFFF;
691        self.stream
692            .write_all(&[(v >> 16) as u8, (v >> 8) as u8, v as u8])?;
693        Ok(())
694    }
695}
696
697#[cfg(test)]
698mod tests {
699    use super::*;
700    use std::io::Cursor;
701
702    /// Send one small message through ChunkWriter → ChunkReader and
703    /// make sure the payload, timestamp, type id, and stream id all
704    /// survive the round-trip.
705    #[test]
706    fn chunk_roundtrip_short_message() {
707        let mut buf = Vec::new();
708        {
709            let mut w = ChunkWriter::new(&mut buf);
710            w.write_message(
711                3,
712                &Message {
713                    msg_type_id: 20,
714                    msg_stream_id: 0,
715                    timestamp: 12345,
716                    payload: b"hello world".to_vec(),
717                },
718            )
719            .unwrap();
720        }
721        let mut r = ChunkReader::new(Cursor::new(&buf));
722        let msg = r.read_message().unwrap();
723        assert_eq!(msg.msg_type_id, 20);
724        assert_eq!(msg.timestamp, 12345);
725        assert_eq!(msg.payload, b"hello world");
726    }
727
728    /// Force a multi-chunk payload (small chunk size, bigger message)
729    /// and check it still reassembles correctly.
730    #[test]
731    fn chunk_roundtrip_multi_chunk_message() {
732        let payload: Vec<u8> = (0..4096u16).map(|i| (i & 0xFF) as u8).collect();
733        let mut buf = Vec::new();
734        {
735            let mut w = ChunkWriter::new(&mut buf);
736            w.set_chunk_size(128);
737            w.write_message(
738                3,
739                &Message {
740                    msg_type_id: 9,
741                    msg_stream_id: 1,
742                    timestamp: 7000,
743                    payload: payload.clone(),
744                },
745            )
746            .unwrap();
747        }
748        let mut r = ChunkReader::new(Cursor::new(&buf));
749        r.set_chunk_size(128);
750        let msg = r.read_message().unwrap();
751        assert_eq!(msg.payload, payload);
752        assert_eq!(msg.msg_type_id, 9);
753        assert_eq!(msg.timestamp, 7000);
754    }
755
756    /// §5.3 Acknowledgement accounting: `received_bytes` counts every
757    /// byte the reader consumes off the wire — basic header, message
758    /// header, extended timestamp, and payload — so the sequence
759    /// number it reports matches the peer's view of "bytes sent".
760    #[test]
761    fn received_bytes_counts_full_wire_size() {
762        let mut buf = Vec::new();
763        {
764            let mut w = ChunkWriter::new(&mut buf);
765            w.write_message(
766                3,
767                &Message {
768                    msg_type_id: 20,
769                    msg_stream_id: 0,
770                    timestamp: 1000,
771                    payload: b"abcdef".to_vec(),
772                },
773            )
774            .unwrap();
775        }
776        let wire_len = buf.len() as u32;
777        let mut r = ChunkReader::new(Cursor::new(&buf));
778        assert_eq!(r.received_bytes(), 0);
779        let _ = r.read_message().unwrap();
780        // The whole single-chunk frame was consumed; nothing left over.
781        assert_eq!(r.received_bytes(), wire_len);
782    }
783
784    /// With no window negotiated, `ack_due` never fires regardless of
785    /// how many bytes flow.
786    #[test]
787    fn ack_not_due_without_window() {
788        let mut buf = Vec::new();
789        {
790            let mut w = ChunkWriter::new(&mut buf);
791            w.write_message(
792                3,
793                &Message {
794                    msg_type_id: 20,
795                    msg_stream_id: 0,
796                    timestamp: 0,
797                    payload: vec![0u8; 500],
798                },
799            )
800            .unwrap();
801        }
802        let mut r = ChunkReader::new(Cursor::new(&buf));
803        let _ = r.read_message().unwrap();
804        assert_eq!(r.window_ack_size(), 0);
805        assert_eq!(r.ack_due(), None);
806    }
807
808    /// §5.3 / §5.5: once a window is set, `ack_due` fires the first
809    /// time the received-byte count crosses it, reports the running
810    /// sequence number, and re-arms only after another full window.
811    #[test]
812    fn ack_due_fires_once_per_window() {
813        // Two messages, each ~big enough that the first crosses a
814        // small window and the second crosses the next one.
815        let mut buf = Vec::new();
816        {
817            let mut w = ChunkWriter::new(&mut buf);
818            for ts in [10u32, 20] {
819                w.write_message(
820                    4,
821                    &Message {
822                        msg_type_id: 8,
823                        msg_stream_id: 1,
824                        timestamp: ts,
825                        payload: vec![0xAB; 200],
826                    },
827                )
828                .unwrap();
829            }
830        }
831        let mut r = ChunkReader::new(Cursor::new(&buf));
832        r.set_window_ack_size(150);
833        assert_eq!(r.window_ack_size(), 150);
834
835        let _ = r.read_message().unwrap();
836        // First message (≈ 211 bytes ≥ 150) owes an ack at the current
837        // sequence number, and it is not due a second time immediately.
838        let first = r.ack_due().expect("first ack due after window crossed");
839        assert_eq!(first, r.received_bytes());
840        assert_eq!(r.ack_due(), None, "ack must not re-fire within a window");
841
842        let _ = r.read_message().unwrap();
843        // Second message pushes another full window past the last ack.
844        let second = r.ack_due().expect("second ack due after second window");
845        assert!(second > first);
846        assert_eq!(second, r.received_bytes());
847    }
848
849    /// §5.5: shrinking / resetting the window re-bases the
850    /// "bytes since last ack" mark to the current sequence so a
851    /// single already-counted byte doesn't instantly owe an ack.
852    #[test]
853    fn set_window_rebases_accounting() {
854        let mut buf = Vec::new();
855        {
856            let mut w = ChunkWriter::new(&mut buf);
857            w.write_message(
858                4,
859                &Message {
860                    msg_type_id: 8,
861                    msg_stream_id: 1,
862                    timestamp: 0,
863                    payload: vec![0u8; 400],
864                },
865            )
866            .unwrap();
867        }
868        let mut r = ChunkReader::new(Cursor::new(&buf));
869        let _ = r.read_message().unwrap();
870        // Set the window AFTER 400+ bytes already arrived: because the
871        // setter re-bases, no ack is immediately due even though the
872        // total received exceeds the new window.
873        r.set_window_ack_size(100);
874        assert_eq!(r.ack_due(), None);
875    }
876
877    /// RTMP 1.0 §5.2 Abort Message: after a publisher sends part of a
878    /// multi-chunk message and then aborts it, the receiver must discard
879    /// the half-filled reassembly buffer for that csid. We build a
880    /// two-chunk message, hand the reader only the first chunk (so it is
881    /// "waiting for chunks to complete a message" and `read_message`
882    /// surfaces `UnexpectedEof`), then assert `abort_partial` reports it
883    /// discarded a non-empty buffer.
884    #[test]
885    fn abort_partial_discards_in_flight_message() {
886        let payload: Vec<u8> = (0..200u16).map(|i| (i & 0xFF) as u8).collect();
887        let mut full = Vec::new();
888        {
889            let mut w = ChunkWriter::new(&mut full);
890            w.set_chunk_size(128);
891            w.write_message(
892                5,
893                &Message {
894                    msg_type_id: 9,
895                    msg_stream_id: 1,
896                    timestamp: 1000,
897                    payload: payload.clone(),
898                },
899            )
900            .unwrap();
901        }
902        // First chunk = fmt-0 header (12 bytes) + 128 payload bytes. Hand
903        // the reader only that prefix so the second chunk never arrives.
904        let first_chunk = &full[..12 + 128];
905        let mut r = ChunkReader::new(Cursor::new(first_chunk));
906        r.set_chunk_size(128);
907        // Reading blocks for the missing chunk, hitting EOF — the csid-5
908        // partial buffer now holds the first 128 bytes.
909        let err = r.read_message().unwrap_err();
910        assert!(matches!(err, Error::Io(_) | Error::UnexpectedEof));
911        // §5.2 discard: a non-empty partial exists, so abort returns true
912        // and clears it; a second abort on the now-empty csid is a no-op.
913        assert!(r.abort_partial(5), "first abort should discard 128 bytes");
914        assert!(!r.abort_partial(5), "second abort has nothing to discard");
915        // An abort for a csid the reader never saw is also a no-op.
916        assert!(!r.abort_partial(9));
917    }
918
919    /// Two back-to-back messages on the same csid should use fmt 3 for
920    /// the second when every field matches, keeping the wire compact.
921    #[test]
922    fn back_to_back_same_message_uses_fmt3() {
923        let msg = Message {
924            msg_type_id: 9,
925            msg_stream_id: 1,
926            timestamp: 1000,
927            payload: vec![0xAA; 32],
928        };
929        let mut buf = Vec::new();
930        {
931            let mut w = ChunkWriter::new(&mut buf);
932            w.write_message(5, &msg).unwrap();
933            w.write_message(5, &msg).unwrap();
934        }
935        // First byte of the second basic header: fmt=3 (bits 7-6 = 11),
936        // csid=5 (bits 5-0 = 000101) → 0xC5.
937        let first_headers_len = 1 + 11 + 32; // fmt 0 on csid 5
938        assert_eq!(buf[first_headers_len], 0xC5);
939    }
940}