1use bytes::{Buf, BufMut, Bytes, BytesMut};
30use std::collections::HashMap;
31
32use crate::error::{ProtocolError, Result};
33use crate::protocol::constants::*;
34
35#[derive(Debug, Clone)]
37pub struct RtmpChunk {
38 pub csid: u32,
40 pub timestamp: u32,
42 pub message_type: u8,
44 pub stream_id: u32,
46 pub payload: Bytes,
48}
49
50#[derive(Debug, Clone, Default)]
52struct ChunkStreamState {
53 timestamp: u32,
55 timestamp_delta: u32,
57 message_length: u32,
59 message_type: u8,
61 stream_id: u32,
63 has_extended_timestamp: bool,
65 partial_message: BytesMut,
67 expected_length: u32,
69}
70
71pub struct ChunkDecoder {
75 chunk_size: u32,
77 streams: HashMap<u32, ChunkStreamState>,
79 max_message_size: u32,
81}
82
83impl ChunkDecoder {
84 pub fn new() -> Self {
86 Self {
87 chunk_size: DEFAULT_CHUNK_SIZE,
88 streams: HashMap::new(),
89 max_message_size: MAX_MESSAGE_SIZE,
90 }
91 }
92
93 pub fn set_chunk_size(&mut self, size: u32) {
95 self.chunk_size = size.min(MAX_CHUNK_SIZE);
96 }
97
98 pub fn chunk_size(&self) -> u32 {
100 self.chunk_size
101 }
102
103 pub fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<RtmpChunk>> {
108 if buf.is_empty() {
109 return Ok(None);
110 }
111
112 let (fmt, csid, header_len) = match self.parse_basic_header(buf)? {
114 Some(v) => v,
115 None => return Ok(None),
116 };
117
118 tracing::trace!(
119 fmt = fmt,
120 csid = csid,
121 header_len = header_len,
122 first_byte = format!("0x{:02x}", buf[0]),
123 "Parsing chunk"
124 );
125
126 let state = self.streams.entry(csid).or_default();
128
129 let msg_header_size = match fmt {
131 0 => 11,
132 1 => 7,
133 2 => 3,
134 3 => 0,
135 _ => return Err(ProtocolError::InvalidChunkHeader.into()),
136 };
137
138 let needs_extended = if fmt == 3 {
140 state.has_extended_timestamp
141 } else if buf.len() > header_len + 2 {
142 let ts_bytes = &buf[header_len..header_len + 3];
144 let ts =
145 ((ts_bytes[0] as u32) << 16) | ((ts_bytes[1] as u32) << 8) | (ts_bytes[2] as u32);
146 ts >= EXTENDED_TIMESTAMP_THRESHOLD
147 } else {
148 false
149 };
150
151 let extended_size = if needs_extended { 4 } else { 0 };
152 let total_header_size = header_len + msg_header_size + extended_size;
153
154 if buf.len() < total_header_size {
155 return Ok(None); }
157
158 let (_peeked_message_length, peeked_expected_length) = match fmt {
161 0 | 1 => {
162 let len_offset = header_len + 3;
164 let len_bytes = &buf[len_offset..len_offset + 3];
165 let len = ((len_bytes[0] as u32) << 16)
166 | ((len_bytes[1] as u32) << 8)
167 | (len_bytes[2] as u32);
168 (len, len)
169 }
170 2 | 3 => {
171 let msg_len = state.message_length;
173 let expected = if state.partial_message.is_empty() {
174 msg_len
175 } else {
176 state.expected_length
177 };
178 (msg_len, expected)
179 }
180 _ => unreachable!(),
181 };
182
183 let partial_len = state.partial_message.len() as u32;
185 let remaining = peeked_expected_length.saturating_sub(partial_len);
186 let chunk_data_len = remaining.min(self.chunk_size) as usize;
187
188 let total_chunk_size = total_header_size + chunk_data_len;
190 if buf.len() < total_chunk_size {
191 return Ok(None); }
193
194 buf.advance(header_len);
196
197 let (timestamp_field, message_length, message_type, stream_id) = match fmt {
198 0 => {
199 let ts = buf.get_uint(3) as u32;
201 let len = buf.get_uint(3) as u32;
202 let typ = buf.get_u8();
203 let sid = buf.get_u32_le(); (ts, len, typ, sid)
205 }
206 1 => {
207 let ts = buf.get_uint(3) as u32;
209 let len = buf.get_uint(3) as u32;
210 let typ = buf.get_u8();
211 (ts, len, typ, state.stream_id)
212 }
213 2 => {
214 let ts = buf.get_uint(3) as u32;
216 (
217 ts,
218 state.message_length,
219 state.message_type,
220 state.stream_id,
221 )
222 }
223 3 => {
224 (
226 state.timestamp_delta,
227 state.message_length,
228 state.message_type,
229 state.stream_id,
230 )
231 }
232 _ => unreachable!(),
233 };
234
235 let timestamp = if timestamp_field >= EXTENDED_TIMESTAMP_THRESHOLD
237 || (fmt == 3 && state.has_extended_timestamp)
238 {
239 state.has_extended_timestamp = true;
240 buf.get_u32()
241 } else {
242 state.has_extended_timestamp = false;
243 timestamp_field
244 };
245
246 let absolute_timestamp = if fmt == 0 {
248 timestamp
249 } else if fmt == 3 && !state.partial_message.is_empty() {
250 state.timestamp
252 } else {
253 state.timestamp.wrapping_add(timestamp)
254 };
255
256 state.timestamp_delta = timestamp;
257 state.message_length = message_length;
258 state.message_type = message_type;
259 state.stream_id = stream_id;
260 state.timestamp = absolute_timestamp;
261
262 if message_length > self.max_message_size {
264 return Err(ProtocolError::MessageTooLarge {
265 size: message_length,
266 max: self.max_message_size,
267 }
268 .into());
269 }
270
271 if state.partial_message.is_empty() {
273 state.expected_length = message_length;
274 state.partial_message.reserve(message_length as usize);
275 }
276
277 state.partial_message.put_slice(&buf[..chunk_data_len]);
279 buf.advance(chunk_data_len);
280
281 if state.partial_message.len() as u32 >= state.expected_length {
283 let payload = state.partial_message.split().freeze();
284 state.expected_length = 0;
285
286 Ok(Some(RtmpChunk {
287 csid,
288 timestamp: state.timestamp,
289 message_type: state.message_type,
290 stream_id: state.stream_id,
291 payload,
292 }))
293 } else {
294 Ok(None) }
296 }
297
298 fn parse_basic_header(&self, buf: &[u8]) -> Result<Option<(u8, u32, usize)>> {
300 if buf.is_empty() {
301 return Ok(None);
302 }
303
304 let first = buf[0];
305 let fmt = (first >> 6) & 0x03;
306 let csid_low = first & 0x3F;
307
308 match csid_low {
309 0 => {
310 if buf.len() < 2 {
312 return Ok(None);
313 }
314 let csid = 64 + buf[1] as u32;
315 Ok(Some((fmt, csid, 2)))
316 }
317 1 => {
318 if buf.len() < 3 {
320 return Ok(None);
321 }
322 let csid = 64 + buf[1] as u32 + (buf[2] as u32) * 256;
323 Ok(Some((fmt, csid, 3)))
324 }
325 _ => {
326 Ok(Some((fmt, csid_low as u32, 1)))
328 }
329 }
330 }
331
332 pub fn abort(&mut self, csid: u32) {
334 if let Some(state) = self.streams.get_mut(&csid) {
335 state.partial_message.clear();
336 state.expected_length = 0;
337 }
338 }
339}
340
341impl Default for ChunkDecoder {
342 fn default() -> Self {
343 Self::new()
344 }
345}
346
347pub struct ChunkEncoder {
351 chunk_size: u32,
353 streams: HashMap<u32, ChunkStreamState>,
355}
356
357impl ChunkEncoder {
358 pub fn new() -> Self {
360 Self {
361 chunk_size: DEFAULT_CHUNK_SIZE,
362 streams: HashMap::new(),
363 }
364 }
365
366 pub fn set_chunk_size(&mut self, size: u32) {
368 self.chunk_size = size.min(MAX_CHUNK_SIZE);
369 }
370
371 pub fn chunk_size(&self) -> u32 {
373 self.chunk_size
374 }
375
376 pub fn encode(&mut self, chunk: &RtmpChunk, buf: &mut BytesMut) {
378 let csid = chunk.csid;
379 let chunk_size = self.chunk_size;
380
381 let state = self.streams.entry(csid).or_default();
383
384 let fmt = select_format(chunk, state);
386
387 let needs_extended = chunk.timestamp >= EXTENDED_TIMESTAMP_THRESHOLD;
389 let timestamp_field = if needs_extended {
390 EXTENDED_TIMESTAMP_THRESHOLD
391 } else {
392 chunk.timestamp
393 };
394
395 let timestamp_delta = chunk.timestamp.wrapping_sub(state.timestamp);
396 let delta_field = if needs_extended {
397 EXTENDED_TIMESTAMP_THRESHOLD
398 } else {
399 timestamp_delta
400 };
401
402 let had_extended_timestamp = state.has_extended_timestamp;
403
404 state.timestamp = chunk.timestamp;
406 state.timestamp_delta = timestamp_delta;
407 state.message_length = chunk.payload.len() as u32;
408 state.message_type = chunk.message_type;
409 state.stream_id = chunk.stream_id;
410 state.has_extended_timestamp = needs_extended;
411
412 let mut offset = 0;
414 let payload_len = chunk.payload.len();
415 let mut first_chunk = true;
416
417 while offset < payload_len {
418 let chunk_data_len = (payload_len - offset).min(chunk_size as usize);
419
420 write_basic_header(csid, if first_chunk { fmt } else { 3 }, buf);
422
423 if first_chunk {
425 match fmt {
426 0 => {
427 write_u24(timestamp_field, buf);
429 write_u24(payload_len as u32, buf);
430 buf.put_u8(chunk.message_type);
431 buf.put_u32_le(chunk.stream_id);
432 }
433 1 => {
434 write_u24(delta_field, buf);
436 write_u24(payload_len as u32, buf);
437 buf.put_u8(chunk.message_type);
438 }
439 2 => {
440 write_u24(delta_field, buf);
442 }
443 3 => {
444 }
446 _ => unreachable!(),
447 }
448 }
449
450 if needs_extended && (first_chunk || had_extended_timestamp) {
452 buf.put_u32(chunk.timestamp);
453 }
454
455 buf.put_slice(&chunk.payload[offset..offset + chunk_data_len]);
457 offset += chunk_data_len;
458 first_chunk = false;
459 }
460 }
461}
462
463fn select_format(chunk: &RtmpChunk, state: &ChunkStreamState) -> u8 {
465 if state.message_type == 0 && state.stream_id == 0 {
467 return 0;
468 }
469
470 if chunk.stream_id != state.stream_id {
472 return 0;
473 }
474
475 if chunk.message_type != state.message_type
477 || chunk.payload.len() as u32 != state.message_length
478 {
479 return 1;
480 }
481
482 let delta = chunk.timestamp.wrapping_sub(state.timestamp);
484 if delta == state.timestamp_delta {
485 return 3;
486 }
487
488 2
490}
491
492fn write_basic_header(csid: u32, fmt: u8, buf: &mut BytesMut) {
494 if csid >= 64 + 256 {
495 buf.put_u8((fmt << 6) | 1);
497 let csid_offset = csid - 64;
498 buf.put_u8((csid_offset & 0xFF) as u8);
499 buf.put_u8(((csid_offset >> 8) & 0xFF) as u8);
500 } else if csid >= 64 {
501 buf.put_u8((fmt << 6) | 0);
503 buf.put_u8((csid - 64) as u8);
504 } else {
505 buf.put_u8((fmt << 6) | (csid as u8));
507 }
508}
509
510fn write_u24(value: u32, buf: &mut BytesMut) {
512 buf.put_u8(((value >> 16) & 0xFF) as u8);
513 buf.put_u8(((value >> 8) & 0xFF) as u8);
514 buf.put_u8((value & 0xFF) as u8);
515}
516
517impl Default for ChunkEncoder {
518 fn default() -> Self {
519 Self::new()
520 }
521}
522
523#[cfg(test)]
524mod tests {
525 use super::*;
526
527 #[test]
528 fn test_basic_header_parsing() {
529 let decoder = ChunkDecoder::new();
530
531 let buf = [0x03]; let result = decoder.parse_basic_header(&buf).unwrap().unwrap();
534 assert_eq!(result, (0, 3, 1));
535
536 let buf = [0x00, 0x00]; let result = decoder.parse_basic_header(&buf).unwrap().unwrap();
539 assert_eq!(result, (0, 64, 2));
540
541 let buf = [0x01, 0x00, 0x01]; let result = decoder.parse_basic_header(&buf).unwrap().unwrap();
544 assert_eq!(result, (0, 320, 3));
545 }
546
547 #[test]
548 fn test_encode_decode_roundtrip() {
549 let original = RtmpChunk {
550 csid: CSID_COMMAND,
551 timestamp: 1000,
552 message_type: MSG_COMMAND_AMF0,
553 stream_id: 0,
554 payload: Bytes::from_static(b"test payload data"),
555 };
556
557 let mut encoder = ChunkEncoder::new();
558 let mut decoder = ChunkDecoder::new();
559
560 let mut encoded = BytesMut::new();
561 encoder.encode(&original, &mut encoded);
562
563 let decoded = decoder.decode(&mut encoded).unwrap().unwrap();
564
565 assert_eq!(decoded.csid, original.csid);
566 assert_eq!(decoded.timestamp, original.timestamp);
567 assert_eq!(decoded.message_type, original.message_type);
568 assert_eq!(decoded.stream_id, original.stream_id);
569 assert_eq!(decoded.payload, original.payload);
570 }
571
572 #[test]
573 fn test_large_message_chunking() {
574 let large_payload = vec![0u8; 500]; let original = RtmpChunk {
577 csid: CSID_VIDEO,
578 timestamp: 0,
579 message_type: MSG_VIDEO,
580 stream_id: 1,
581 payload: Bytes::from(large_payload.clone()),
582 };
583
584 let mut encoder = ChunkEncoder::new();
585 let mut decoder = ChunkDecoder::new();
586
587 let mut encoded = BytesMut::new();
588 encoder.encode(&original, &mut encoded);
589
590 assert!(encoded.len() > 500);
592
593 let decoded = loop {
595 if let Some(chunk) = decoder.decode(&mut encoded).unwrap() {
596 break chunk;
597 }
598 };
599 assert_eq!(decoded.payload.len(), 500);
600 }
601}