rtmp_rs/protocol/
chunk.rs

1//! RTMP chunk stream codec
2//!
3//! RTMP messages are split into chunks for multiplexing. Each chunk has a header
4//! that identifies the chunk stream and message being sent.
5//!
6//! ```text
7//! Chunk Format:
8//! +-------------+----------------+-------------------+
9//! | Basic Header| Message Header | Chunk Data        |
10//! | (1-3 bytes) | (0,3,7,11 bytes)| (variable)       |
11//! +-------------+----------------+-------------------+
12//!
13//! Basic Header formats:
14//! - 1 byte:  fmt(2) + csid(6)        for csid 2-63
15//! - 2 bytes: fmt(2) + 0 + csid(8)    for csid 64-319
16//! - 3 bytes: fmt(2) + 1 + csid(16)   for csid 64-65599
17//!
18//! Message Header formats (based on fmt):
19//! - Type 0 (11 bytes): timestamp(3) + length(3) + type(1) + stream_id(4)
20//! - Type 1 (7 bytes):  timestamp_delta(3) + length(3) + type(1)
21//! - Type 2 (3 bytes):  timestamp_delta(3)
22//! - Type 3 (0 bytes):  (use previous chunk's values)
23//!
24//! Extended timestamp (4 bytes) is appended when timestamp >= 0xFFFFFF
25//! ```
26//!
27//! Reference: RTMP Specification Section 5.3
28
29use bytes::{Buf, BufMut, Bytes, BytesMut};
30use std::collections::HashMap;
31
32use crate::error::{ProtocolError, Result};
33use crate::protocol::constants::*;
34
35/// A complete RTMP message (reassembled from chunks)
36#[derive(Debug, Clone)]
37pub struct RtmpChunk {
38    /// Chunk stream ID (for multiplexing)
39    pub csid: u32,
40    /// Message timestamp (milliseconds)
41    pub timestamp: u32,
42    /// Message type ID
43    pub message_type: u8,
44    /// Message stream ID
45    pub stream_id: u32,
46    /// Message payload
47    pub payload: Bytes,
48}
49
50/// Per-chunk-stream state for reassembly
51#[derive(Debug, Clone, Default)]
52struct ChunkStreamState {
53    /// Last timestamp (absolute)
54    timestamp: u32,
55    /// Last timestamp delta
56    timestamp_delta: u32,
57    /// Last message length
58    message_length: u32,
59    /// Last message type
60    message_type: u8,
61    /// Last message stream ID
62    stream_id: u32,
63    /// Whether we've received extended timestamp
64    has_extended_timestamp: bool,
65    /// Buffer for partial message reassembly
66    partial_message: BytesMut,
67    /// Expected total length of current message
68    expected_length: u32,
69}
70
71/// Chunk stream decoder
72///
73/// Handles chunk demultiplexing and message reassembly.
74pub struct ChunkDecoder {
75    /// Maximum incoming chunk size
76    chunk_size: u32,
77    /// Per-chunk-stream state
78    streams: HashMap<u32, ChunkStreamState>,
79    /// Maximum message size (sanity limit)
80    max_message_size: u32,
81}
82
83impl ChunkDecoder {
84    /// Create a new decoder with default chunk size
85    pub fn new() -> Self {
86        Self {
87            chunk_size: DEFAULT_CHUNK_SIZE,
88            streams: HashMap::new(),
89            max_message_size: MAX_MESSAGE_SIZE,
90        }
91    }
92
93    /// Set the chunk size (called when receiving SetChunkSize message)
94    pub fn set_chunk_size(&mut self, size: u32) {
95        self.chunk_size = size.min(MAX_CHUNK_SIZE);
96    }
97
98    /// Get current chunk size
99    pub fn chunk_size(&self) -> u32 {
100        self.chunk_size
101    }
102
103    /// Try to decode a complete message from the buffer
104    ///
105    /// Returns Ok(Some(chunk)) if a complete message was decoded,
106    /// Ok(None) if more data is needed, or Err on protocol error.
107    pub fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<RtmpChunk>> {
108        if buf.is_empty() {
109            return Ok(None);
110        }
111
112        // Parse basic header to get csid and fmt (peek only, don't advance)
113        let (fmt, csid, header_len) = match self.parse_basic_header(buf)? {
114            Some(v) => v,
115            None => return Ok(None),
116        };
117
118        tracing::trace!(
119            fmt = fmt,
120            csid = csid,
121            header_len = header_len,
122            first_byte = format!("0x{:02x}", buf[0]),
123            "Parsing chunk"
124        );
125
126        // Get or create chunk stream state
127        let state = self.streams.entry(csid).or_default();
128
129        // Calculate message header size based on fmt
130        let msg_header_size = match fmt {
131            0 => 11,
132            1 => 7,
133            2 => 3,
134            3 => 0,
135            _ => return Err(ProtocolError::InvalidChunkHeader.into()),
136        };
137
138        // Check if we have enough data for headers
139        let needs_extended = if fmt == 3 {
140            state.has_extended_timestamp
141        } else if buf.len() > header_len + 2 {
142            // Peek at timestamp field to check for extended timestamp
143            let ts_bytes = &buf[header_len..header_len + 3];
144            let ts =
145                ((ts_bytes[0] as u32) << 16) | ((ts_bytes[1] as u32) << 8) | (ts_bytes[2] as u32);
146            ts >= EXTENDED_TIMESTAMP_THRESHOLD
147        } else {
148            false
149        };
150
151        let extended_size = if needs_extended { 4 } else { 0 };
152        let total_header_size = header_len + msg_header_size + extended_size;
153
154        if buf.len() < total_header_size {
155            return Ok(None); // Need more header data
156        }
157
158        // PEEK at message header to determine chunk data length BEFORE consuming anything
159        // For fmt 3, we use state values; for others, we peek at the buffer
160        let (_peeked_message_length, peeked_expected_length) = match fmt {
161            0 | 1 => {
162                // Message length is at offset: header_len + 3 (timestamp bytes)
163                let len_offset = header_len + 3;
164                let len_bytes = &buf[len_offset..len_offset + 3];
165                let len = ((len_bytes[0] as u32) << 16)
166                    | ((len_bytes[1] as u32) << 8)
167                    | (len_bytes[2] as u32);
168                (len, len)
169            }
170            2 | 3 => {
171                // Use state for message length
172                let msg_len = state.message_length;
173                let expected = if state.partial_message.is_empty() {
174                    msg_len
175                } else {
176                    state.expected_length
177                };
178                (msg_len, expected)
179            }
180            _ => unreachable!(),
181        };
182
183        // Calculate chunk data length
184        let partial_len = state.partial_message.len() as u32;
185        let remaining = peeked_expected_length.saturating_sub(partial_len);
186        let chunk_data_len = remaining.min(self.chunk_size) as usize;
187
188        // Now check if we have enough data for header + payload
189        let total_chunk_size = total_header_size + chunk_data_len;
190        if buf.len() < total_chunk_size {
191            return Ok(None); // Need more data - don't consume anything
192        }
193
194        // NOW we can safely consume the data since we have enough for the entire chunk
195        buf.advance(header_len);
196
197        let (timestamp_field, message_length, message_type, stream_id) = match fmt {
198            0 => {
199                // Full header
200                let ts = buf.get_uint(3) as u32;
201                let len = buf.get_uint(3) as u32;
202                let typ = buf.get_u8();
203                let sid = buf.get_u32_le(); // Stream ID is little-endian!
204                (ts, len, typ, sid)
205            }
206            1 => {
207                // No stream ID
208                let ts = buf.get_uint(3) as u32;
209                let len = buf.get_uint(3) as u32;
210                let typ = buf.get_u8();
211                (ts, len, typ, state.stream_id)
212            }
213            2 => {
214                // Timestamp delta only
215                let ts = buf.get_uint(3) as u32;
216                (
217                    ts,
218                    state.message_length,
219                    state.message_type,
220                    state.stream_id,
221                )
222            }
223            3 => {
224                // No header, use previous values
225                (
226                    state.timestamp_delta,
227                    state.message_length,
228                    state.message_type,
229                    state.stream_id,
230                )
231            }
232            _ => unreachable!(),
233        };
234
235        // Handle extended timestamp (we already checked we have enough bytes)
236        let timestamp = if timestamp_field >= EXTENDED_TIMESTAMP_THRESHOLD
237            || (fmt == 3 && state.has_extended_timestamp)
238        {
239            state.has_extended_timestamp = true;
240            buf.get_u32()
241        } else {
242            state.has_extended_timestamp = false;
243            timestamp_field
244        };
245
246        // Update state
247        let absolute_timestamp = if fmt == 0 {
248            timestamp
249        } else if fmt == 3 && !state.partial_message.is_empty() {
250            // continuation chunk, timestamp stays the same
251            state.timestamp
252        } else {
253            state.timestamp.wrapping_add(timestamp)
254        };
255
256        state.timestamp_delta = timestamp;
257        state.message_length = message_length;
258        state.message_type = message_type;
259        state.stream_id = stream_id;
260        state.timestamp = absolute_timestamp;
261
262        // Validate message size
263        if message_length > self.max_message_size {
264            return Err(ProtocolError::MessageTooLarge {
265                size: message_length,
266                max: self.max_message_size,
267            }
268            .into());
269        }
270
271        // Initialize reassembly buffer if this is a new message
272        if state.partial_message.is_empty() {
273            state.expected_length = message_length;
274            state.partial_message.reserve(message_length as usize);
275        }
276
277        // Read chunk data (we already verified we have enough)
278        state.partial_message.put_slice(&buf[..chunk_data_len]);
279        buf.advance(chunk_data_len);
280
281        // Check if message is complete
282        if state.partial_message.len() as u32 >= state.expected_length {
283            let payload = state.partial_message.split().freeze();
284            state.expected_length = 0;
285
286            Ok(Some(RtmpChunk {
287                csid,
288                timestamp: state.timestamp,
289                message_type: state.message_type,
290                stream_id: state.stream_id,
291                payload,
292            }))
293        } else {
294            Ok(None) // Message not yet complete
295        }
296    }
297
298    /// Parse basic header and return (fmt, csid, header_length)
299    fn parse_basic_header(&self, buf: &[u8]) -> Result<Option<(u8, u32, usize)>> {
300        if buf.is_empty() {
301            return Ok(None);
302        }
303
304        let first = buf[0];
305        let fmt = (first >> 6) & 0x03;
306        let csid_low = first & 0x3F;
307
308        match csid_low {
309            0 => {
310                // 2-byte header: csid = 64 + second byte
311                if buf.len() < 2 {
312                    return Ok(None);
313                }
314                let csid = 64 + buf[1] as u32;
315                Ok(Some((fmt, csid, 2)))
316            }
317            1 => {
318                // 3-byte header: csid = 64 + second + third*256
319                if buf.len() < 3 {
320                    return Ok(None);
321                }
322                let csid = 64 + buf[1] as u32 + (buf[2] as u32) * 256;
323                Ok(Some((fmt, csid, 3)))
324            }
325            _ => {
326                // 1-byte header: csid = 2-63
327                Ok(Some((fmt, csid_low as u32, 1)))
328            }
329        }
330    }
331
332    /// Abort a message on a chunk stream (when receiving Abort message)
333    pub fn abort(&mut self, csid: u32) {
334        if let Some(state) = self.streams.get_mut(&csid) {
335            state.partial_message.clear();
336            state.expected_length = 0;
337        }
338    }
339}
340
341impl Default for ChunkDecoder {
342    fn default() -> Self {
343        Self::new()
344    }
345}
346
347/// Chunk stream encoder
348///
349/// Encodes messages into chunks for transmission.
350pub struct ChunkEncoder {
351    /// Outgoing chunk size
352    chunk_size: u32,
353    /// Per-chunk-stream state for compression
354    streams: HashMap<u32, ChunkStreamState>,
355}
356
357impl ChunkEncoder {
358    /// Create a new encoder with default chunk size
359    pub fn new() -> Self {
360        Self {
361            chunk_size: DEFAULT_CHUNK_SIZE,
362            streams: HashMap::new(),
363        }
364    }
365
366    /// Set the chunk size (call before encoding to use larger chunks)
367    pub fn set_chunk_size(&mut self, size: u32) {
368        self.chunk_size = size.min(MAX_CHUNK_SIZE);
369    }
370
371    /// Get current chunk size
372    pub fn chunk_size(&self) -> u32 {
373        self.chunk_size
374    }
375
376    /// Encode a message into chunks
377    pub fn encode(&mut self, chunk: &RtmpChunk, buf: &mut BytesMut) {
378        let csid = chunk.csid;
379        let chunk_size = self.chunk_size;
380
381        // Get or create state, and compute format based on current state
382        let state = self.streams.entry(csid).or_default();
383
384        // Compute format based on state comparison
385        let fmt = select_format(chunk, state);
386
387        // Determine if we need extended timestamp
388        let needs_extended = chunk.timestamp >= EXTENDED_TIMESTAMP_THRESHOLD;
389        let timestamp_field = if needs_extended {
390            EXTENDED_TIMESTAMP_THRESHOLD
391        } else {
392            chunk.timestamp
393        };
394
395        let timestamp_delta = chunk.timestamp.wrapping_sub(state.timestamp);
396        let delta_field = if needs_extended {
397            EXTENDED_TIMESTAMP_THRESHOLD
398        } else {
399            timestamp_delta
400        };
401
402        let had_extended_timestamp = state.has_extended_timestamp;
403
404        // Update state before encoding
405        state.timestamp = chunk.timestamp;
406        state.timestamp_delta = timestamp_delta;
407        state.message_length = chunk.payload.len() as u32;
408        state.message_type = chunk.message_type;
409        state.stream_id = chunk.stream_id;
410        state.has_extended_timestamp = needs_extended;
411
412        // Encode chunks
413        let mut offset = 0;
414        let payload_len = chunk.payload.len();
415        let mut first_chunk = true;
416
417        while offset < payload_len {
418            let chunk_data_len = (payload_len - offset).min(chunk_size as usize);
419
420            // Write basic header
421            write_basic_header(csid, if first_chunk { fmt } else { 3 }, buf);
422
423            // Write message header based on format
424            if first_chunk {
425                match fmt {
426                    0 => {
427                        // Full header
428                        write_u24(timestamp_field, buf);
429                        write_u24(payload_len as u32, buf);
430                        buf.put_u8(chunk.message_type);
431                        buf.put_u32_le(chunk.stream_id);
432                    }
433                    1 => {
434                        // No stream ID
435                        write_u24(delta_field, buf);
436                        write_u24(payload_len as u32, buf);
437                        buf.put_u8(chunk.message_type);
438                    }
439                    2 => {
440                        // Timestamp delta only
441                        write_u24(delta_field, buf);
442                    }
443                    3 => {
444                        // No header
445                    }
446                    _ => unreachable!(),
447                }
448            }
449
450            // Write extended timestamp if needed
451            if needs_extended && (first_chunk || had_extended_timestamp) {
452                buf.put_u32(chunk.timestamp);
453            }
454
455            // Write chunk data
456            buf.put_slice(&chunk.payload[offset..offset + chunk_data_len]);
457            offset += chunk_data_len;
458            first_chunk = false;
459        }
460    }
461}
462
463/// Select the best header format for compression
464fn select_format(chunk: &RtmpChunk, state: &ChunkStreamState) -> u8 {
465    // First message on this stream must use format 0
466    if state.message_type == 0 && state.stream_id == 0 {
467        return 0;
468    }
469
470    // If stream ID differs, must use format 0
471    if chunk.stream_id != state.stream_id {
472        return 0;
473    }
474
475    // If message type or length differs, use format 1
476    if chunk.message_type != state.message_type
477        || chunk.payload.len() as u32 != state.message_length
478    {
479        return 1;
480    }
481
482    // If timestamp delta matches previous, use format 3
483    let delta = chunk.timestamp.wrapping_sub(state.timestamp);
484    if delta == state.timestamp_delta {
485        return 3;
486    }
487
488    // Otherwise use format 2 (timestamp delta only)
489    2
490}
491
492/// Write basic header
493fn write_basic_header(csid: u32, fmt: u8, buf: &mut BytesMut) {
494    if csid >= 64 + 256 {
495        // 3-byte header
496        buf.put_u8((fmt << 6) | 1);
497        let csid_offset = csid - 64;
498        buf.put_u8((csid_offset & 0xFF) as u8);
499        buf.put_u8(((csid_offset >> 8) & 0xFF) as u8);
500    } else if csid >= 64 {
501        // 2-byte header
502        buf.put_u8((fmt << 6) | 0);
503        buf.put_u8((csid - 64) as u8);
504    } else {
505        // 1-byte header
506        buf.put_u8((fmt << 6) | (csid as u8));
507    }
508}
509
510/// Write 24-bit big-endian value
511fn write_u24(value: u32, buf: &mut BytesMut) {
512    buf.put_u8(((value >> 16) & 0xFF) as u8);
513    buf.put_u8(((value >> 8) & 0xFF) as u8);
514    buf.put_u8((value & 0xFF) as u8);
515}
516
517impl Default for ChunkEncoder {
518    fn default() -> Self {
519        Self::new()
520    }
521}
522
523#[cfg(test)]
524mod tests {
525    use super::*;
526
527    #[test]
528    fn test_basic_header_parsing() {
529        let decoder = ChunkDecoder::new();
530
531        // 1-byte header (csid 2-63)
532        let buf = [0x03]; // fmt=0, csid=3
533        let result = decoder.parse_basic_header(&buf).unwrap().unwrap();
534        assert_eq!(result, (0, 3, 1));
535
536        // 2-byte header (csid 64-319)
537        let buf = [0x00, 0x00]; // fmt=0, csid=64
538        let result = decoder.parse_basic_header(&buf).unwrap().unwrap();
539        assert_eq!(result, (0, 64, 2));
540
541        // 3-byte header (csid 64-65599)
542        let buf = [0x01, 0x00, 0x01]; // fmt=0, csid=64+256
543        let result = decoder.parse_basic_header(&buf).unwrap().unwrap();
544        assert_eq!(result, (0, 320, 3));
545    }
546
547    #[test]
548    fn test_encode_decode_roundtrip() {
549        let original = RtmpChunk {
550            csid: CSID_COMMAND,
551            timestamp: 1000,
552            message_type: MSG_COMMAND_AMF0,
553            stream_id: 0,
554            payload: Bytes::from_static(b"test payload data"),
555        };
556
557        let mut encoder = ChunkEncoder::new();
558        let mut decoder = ChunkDecoder::new();
559
560        let mut encoded = BytesMut::new();
561        encoder.encode(&original, &mut encoded);
562
563        let decoded = decoder.decode(&mut encoded).unwrap().unwrap();
564
565        assert_eq!(decoded.csid, original.csid);
566        assert_eq!(decoded.timestamp, original.timestamp);
567        assert_eq!(decoded.message_type, original.message_type);
568        assert_eq!(decoded.stream_id, original.stream_id);
569        assert_eq!(decoded.payload, original.payload);
570    }
571
572    #[test]
573    fn test_large_message_chunking() {
574        let large_payload = vec![0u8; 500]; // Larger than default chunk size (128)
575
576        let original = RtmpChunk {
577            csid: CSID_VIDEO,
578            timestamp: 0,
579            message_type: MSG_VIDEO,
580            stream_id: 1,
581            payload: Bytes::from(large_payload.clone()),
582        };
583
584        let mut encoder = ChunkEncoder::new();
585        let mut decoder = ChunkDecoder::new();
586
587        let mut encoded = BytesMut::new();
588        encoder.encode(&original, &mut encoded);
589
590        // Should produce multiple chunks (500 bytes / 128 = ~4 chunks)
591        assert!(encoded.len() > 500);
592
593        // Decode all chunks until we get a complete message
594        let decoded = loop {
595            if let Some(chunk) = decoder.decode(&mut encoded).unwrap() {
596                break chunk;
597            }
598        };
599        assert_eq!(decoded.payload.len(), 500);
600    }
601}