rtmp_rs/amf/
amf0.rs

1//! AMF0 encoder and decoder
2//!
3//! AMF0 is the original Action Message Format used in Flash/RTMP.
4//! Reference: AMF0 File Format Specification (amf0-file-format-specification.pdf)
5//!
6//! Type Markers:
7//! ```text
8//! 0x00 - Number (IEEE 754 double)
9//! 0x01 - Boolean
10//! 0x02 - String (UTF-8, 16-bit length prefix)
11//! 0x03 - Object (key-value pairs until 0x000009)
12//! 0x04 - MovieClip (reserved, not supported)
13//! 0x05 - Null
14//! 0x06 - Undefined
15//! 0x07 - Reference (16-bit index)
16//! 0x08 - ECMA Array (associative array)
17//! 0x09 - Object End (0x000009 sequence)
18//! 0x0A - Strict Array (dense array)
19//! 0x0B - Date (double + timezone)
20//! 0x0C - Long String (UTF-8, 32-bit length prefix)
21//! 0x0D - Unsupported
22//! 0x0E - RecordSet (reserved, not supported)
23//! 0x0F - XML Document
24//! 0x10 - Typed Object (class name + properties)
25//! 0x11 - AVM+ (switch to AMF3)
26//! ```
27
28use bytes::{Buf, BufMut, Bytes, BytesMut};
29use std::collections::HashMap;
30
31use super::value::AmfValue;
32use crate::error::AmfError;
33
34// AMF0 type markers
35const MARKER_NUMBER: u8 = 0x00;
36const MARKER_BOOLEAN: u8 = 0x01;
37const MARKER_STRING: u8 = 0x02;
38const MARKER_OBJECT: u8 = 0x03;
39const MARKER_NULL: u8 = 0x05;
40const MARKER_UNDEFINED: u8 = 0x06;
41const MARKER_REFERENCE: u8 = 0x07;
42const MARKER_ECMA_ARRAY: u8 = 0x08;
43const MARKER_OBJECT_END: u8 = 0x09;
44const MARKER_STRICT_ARRAY: u8 = 0x0A;
45const MARKER_DATE: u8 = 0x0B;
46const MARKER_LONG_STRING: u8 = 0x0C;
47const MARKER_UNSUPPORTED: u8 = 0x0D;
48const MARKER_XML_DOCUMENT: u8 = 0x0F;
49const MARKER_TYPED_OBJECT: u8 = 0x10;
50const MARKER_AVMPLUS: u8 = 0x11;
51
52/// Maximum nesting depth for objects/arrays (prevent stack overflow)
53const MAX_NESTING_DEPTH: usize = 64;
54
55/// AMF0 decoder with lenient parsing mode
56pub struct Amf0Decoder {
57    /// Reference table for object references
58    references: Vec<AmfValue>,
59    /// Enable lenient parsing for encoder quirks
60    lenient: bool,
61    /// Current nesting depth
62    depth: usize,
63}
64
65impl Amf0Decoder {
66    /// Create a new decoder with default settings
67    pub fn new() -> Self {
68        Self {
69            references: Vec::new(),
70            lenient: true, // Default to lenient for OBS/encoder compatibility
71            depth: 0,
72        }
73    }
74
75    /// Create decoder with explicit lenient mode setting
76    pub fn with_lenient(lenient: bool) -> Self {
77        Self {
78            references: Vec::new(),
79            lenient,
80            depth: 0,
81        }
82    }
83
84    /// Reset decoder state (call between messages)
85    pub fn reset(&mut self) {
86        self.references.clear();
87        self.depth = 0;
88    }
89
90    /// Decode a single AMF0 value from the buffer
91    pub fn decode(&mut self, buf: &mut Bytes) -> Result<AmfValue, AmfError> {
92        if buf.is_empty() {
93            return Err(AmfError::UnexpectedEof);
94        }
95
96        self.depth += 1;
97        if self.depth > MAX_NESTING_DEPTH {
98            return Err(AmfError::NestingTooDeep);
99        }
100
101        let marker = buf.get_u8();
102        let result = self.decode_value(marker, buf);
103        self.depth -= 1;
104        result
105    }
106
107    /// Decode all values from buffer until exhausted
108    pub fn decode_all(&mut self, buf: &mut Bytes) -> Result<Vec<AmfValue>, AmfError> {
109        let mut values = Vec::new();
110        while buf.has_remaining() {
111            values.push(self.decode(buf)?);
112        }
113        Ok(values)
114    }
115
116    fn decode_value(&mut self, marker: u8, buf: &mut Bytes) -> Result<AmfValue, AmfError> {
117        match marker {
118            MARKER_NUMBER => self.decode_number(buf),
119            MARKER_BOOLEAN => self.decode_boolean(buf),
120            MARKER_STRING => self.decode_string(buf),
121            MARKER_OBJECT => self.decode_object(buf),
122            MARKER_NULL => Ok(AmfValue::Null),
123            MARKER_UNDEFINED => Ok(AmfValue::Undefined),
124            MARKER_REFERENCE => self.decode_reference(buf),
125            MARKER_ECMA_ARRAY => self.decode_ecma_array(buf),
126            MARKER_STRICT_ARRAY => self.decode_strict_array(buf),
127            MARKER_DATE => self.decode_date(buf),
128            MARKER_LONG_STRING => self.decode_long_string(buf),
129            MARKER_UNSUPPORTED => Ok(AmfValue::Undefined),
130            MARKER_XML_DOCUMENT => self.decode_xml(buf),
131            MARKER_TYPED_OBJECT => self.decode_typed_object(buf),
132            MARKER_AVMPLUS => {
133                // AMF3 value embedded in AMF0 stream
134                // For now, skip and return null (full AMF3 support in amf3.rs)
135                Ok(AmfValue::Null)
136            }
137            _ => {
138                if self.lenient {
139                    // Skip unknown marker in lenient mode
140                    Ok(AmfValue::Undefined)
141                } else {
142                    Err(AmfError::UnknownMarker(marker))
143                }
144            }
145        }
146    }
147
148    fn decode_number(&mut self, buf: &mut Bytes) -> Result<AmfValue, AmfError> {
149        if buf.remaining() < 8 {
150            return Err(AmfError::UnexpectedEof);
151        }
152        Ok(AmfValue::Number(buf.get_f64()))
153    }
154
155    fn decode_boolean(&mut self, buf: &mut Bytes) -> Result<AmfValue, AmfError> {
156        if buf.is_empty() {
157            return Err(AmfError::UnexpectedEof);
158        }
159        Ok(AmfValue::Boolean(buf.get_u8() != 0))
160    }
161
162    fn decode_string(&mut self, buf: &mut Bytes) -> Result<AmfValue, AmfError> {
163        let s = self.read_utf8(buf)?;
164        Ok(AmfValue::String(s))
165    }
166
167    fn decode_long_string(&mut self, buf: &mut Bytes) -> Result<AmfValue, AmfError> {
168        let s = self.read_utf8_long(buf)?;
169        Ok(AmfValue::String(s))
170    }
171
172    fn decode_object(&mut self, buf: &mut Bytes) -> Result<AmfValue, AmfError> {
173        let mut properties = HashMap::new();
174
175        // Track this object for potential references
176        let obj_index = self.references.len();
177        self.references.push(AmfValue::Null); // Placeholder
178
179        loop {
180            let key = self.read_utf8(buf)?;
181
182            // Check for object end marker
183            if key.is_empty() {
184                if buf.is_empty() {
185                    if self.lenient {
186                        // OBS sometimes omits the object end marker
187                        break;
188                    }
189                    return Err(AmfError::UnexpectedEof);
190                }
191                let end_marker = buf.get_u8();
192                if end_marker == MARKER_OBJECT_END {
193                    break;
194                } else if self.lenient {
195                    // Some encoders omit the end marker, treat as end
196                    // Put the byte back conceptually by continuing
197                    break;
198                } else {
199                    return Err(AmfError::InvalidObjectEnd);
200                }
201            }
202
203            let value = self.decode(buf)?;
204            properties.insert(key, value);
205        }
206
207        let obj = AmfValue::Object(properties);
208        self.references[obj_index] = obj.clone();
209        Ok(obj)
210    }
211
212    fn decode_ecma_array(&mut self, buf: &mut Bytes) -> Result<AmfValue, AmfError> {
213        if buf.remaining() < 4 {
214            return Err(AmfError::UnexpectedEof);
215        }
216
217        // ECMA array count hint (not always accurate)
218        let _count = buf.get_u32();
219
220        // Track for references
221        let arr_index = self.references.len();
222        self.references.push(AmfValue::Null);
223
224        let mut properties = HashMap::new();
225
226        loop {
227            let key = self.read_utf8(buf)?;
228
229            if key.is_empty() {
230                if buf.is_empty() {
231                    if self.lenient {
232                        break;
233                    }
234                    return Err(AmfError::UnexpectedEof);
235                }
236                let end_marker = buf.get_u8();
237                if end_marker == MARKER_OBJECT_END {
238                    break;
239                } else if self.lenient {
240                    break;
241                } else {
242                    return Err(AmfError::InvalidObjectEnd);
243                }
244            }
245
246            let value = self.decode(buf)?;
247            properties.insert(key, value);
248        }
249
250        let arr = AmfValue::EcmaArray(properties);
251        self.references[arr_index] = arr.clone();
252        Ok(arr)
253    }
254
255    fn decode_strict_array(&mut self, buf: &mut Bytes) -> Result<AmfValue, AmfError> {
256        if buf.remaining() < 4 {
257            return Err(AmfError::UnexpectedEof);
258        }
259
260        let count = buf.get_u32() as usize;
261
262        // Track for references
263        let arr_index = self.references.len();
264        self.references.push(AmfValue::Null);
265
266        let mut elements = Vec::with_capacity(count.min(1024)); // Cap initial allocation
267        for _ in 0..count {
268            elements.push(self.decode(buf)?);
269        }
270
271        let arr = AmfValue::Array(elements);
272        self.references[arr_index] = arr.clone();
273        Ok(arr)
274    }
275
276    fn decode_date(&mut self, buf: &mut Bytes) -> Result<AmfValue, AmfError> {
277        if buf.remaining() < 10 {
278            return Err(AmfError::UnexpectedEof);
279        }
280
281        let timestamp = buf.get_f64();
282        let _timezone = buf.get_i16(); // Timezone offset (deprecated, usually 0)
283
284        Ok(AmfValue::Date(timestamp))
285    }
286
287    fn decode_reference(&mut self, buf: &mut Bytes) -> Result<AmfValue, AmfError> {
288        if buf.remaining() < 2 {
289            return Err(AmfError::UnexpectedEof);
290        }
291
292        let index = buf.get_u16() as usize;
293        if index >= self.references.len() {
294            return Err(AmfError::InvalidReference(index as u16));
295        }
296
297        Ok(self.references[index].clone())
298    }
299
300    fn decode_xml(&mut self, buf: &mut Bytes) -> Result<AmfValue, AmfError> {
301        let s = self.read_utf8_long(buf)?;
302        Ok(AmfValue::Xml(s))
303    }
304
305    fn decode_typed_object(&mut self, buf: &mut Bytes) -> Result<AmfValue, AmfError> {
306        let class_name = self.read_utf8(buf)?;
307
308        // Track for references
309        let obj_index = self.references.len();
310        self.references.push(AmfValue::Null);
311
312        let mut properties = HashMap::new();
313
314        loop {
315            let key = self.read_utf8(buf)?;
316
317            if key.is_empty() {
318                if buf.is_empty() {
319                    if self.lenient {
320                        break;
321                    }
322                    return Err(AmfError::UnexpectedEof);
323                }
324                let end_marker = buf.get_u8();
325                if end_marker == MARKER_OBJECT_END {
326                    break;
327                } else if self.lenient {
328                    break;
329                } else {
330                    return Err(AmfError::InvalidObjectEnd);
331                }
332            }
333
334            let value = self.decode(buf)?;
335            properties.insert(key, value);
336        }
337
338        let obj = AmfValue::TypedObject {
339            class_name,
340            properties,
341        };
342        self.references[obj_index] = obj.clone();
343        Ok(obj)
344    }
345
346    /// Read UTF-8 string with 16-bit length prefix
347    fn read_utf8(&mut self, buf: &mut Bytes) -> Result<String, AmfError> {
348        if buf.remaining() < 2 {
349            return Err(AmfError::UnexpectedEof);
350        }
351
352        let len = buf.get_u16() as usize;
353        if buf.remaining() < len {
354            return Err(AmfError::UnexpectedEof);
355        }
356
357        let bytes = buf.copy_to_bytes(len);
358        String::from_utf8(bytes.to_vec()).map_err(|_| AmfError::InvalidUtf8)
359    }
360
361    /// Read UTF-8 string with 32-bit length prefix
362    fn read_utf8_long(&mut self, buf: &mut Bytes) -> Result<String, AmfError> {
363        if buf.remaining() < 4 {
364            return Err(AmfError::UnexpectedEof);
365        }
366
367        let len = buf.get_u32() as usize;
368        if buf.remaining() < len {
369            return Err(AmfError::UnexpectedEof);
370        }
371
372        let bytes = buf.copy_to_bytes(len);
373        String::from_utf8(bytes.to_vec()).map_err(|_| AmfError::InvalidUtf8)
374    }
375}
376
377impl Default for Amf0Decoder {
378    fn default() -> Self {
379        Self::new()
380    }
381}
382
383/// AMF0 encoder
384pub struct Amf0Encoder {
385    buf: BytesMut,
386}
387
388impl Amf0Encoder {
389    /// Create a new encoder
390    pub fn new() -> Self {
391        Self {
392            buf: BytesMut::with_capacity(256),
393        }
394    }
395
396    /// Create encoder with specific capacity
397    pub fn with_capacity(capacity: usize) -> Self {
398        Self {
399            buf: BytesMut::with_capacity(capacity),
400        }
401    }
402
403    /// Get the encoded bytes and reset encoder
404    pub fn finish(&mut self) -> Bytes {
405        self.buf.split().freeze()
406    }
407
408    /// Get current encoded length
409    pub fn len(&self) -> usize {
410        self.buf.len()
411    }
412
413    /// Check if encoder is empty
414    pub fn is_empty(&self) -> bool {
415        self.buf.is_empty()
416    }
417
418    /// Encode a single AMF0 value
419    pub fn encode(&mut self, value: &AmfValue) {
420        match value {
421            AmfValue::Null => {
422                self.buf.put_u8(MARKER_NULL);
423            }
424            AmfValue::Undefined => {
425                self.buf.put_u8(MARKER_UNDEFINED);
426            }
427            AmfValue::Boolean(b) => {
428                self.buf.put_u8(MARKER_BOOLEAN);
429                self.buf.put_u8(if *b { 1 } else { 0 });
430            }
431            AmfValue::Number(n) => {
432                self.buf.put_u8(MARKER_NUMBER);
433                self.buf.put_f64(*n);
434            }
435            AmfValue::Integer(i) => {
436                // AMF0 doesn't have integer type, encode as number
437                self.buf.put_u8(MARKER_NUMBER);
438                self.buf.put_f64(*i as f64);
439            }
440            AmfValue::String(s) => {
441                if s.len() > 0xFFFF {
442                    // Long string
443                    self.buf.put_u8(MARKER_LONG_STRING);
444                    self.buf.put_u32(s.len() as u32);
445                } else {
446                    self.buf.put_u8(MARKER_STRING);
447                    self.buf.put_u16(s.len() as u16);
448                }
449                self.buf.put_slice(s.as_bytes());
450            }
451            AmfValue::Object(props) => {
452                self.buf.put_u8(MARKER_OBJECT);
453                for (key, val) in props {
454                    self.write_utf8(key);
455                    self.encode(val);
456                }
457                // Object end marker
458                self.buf.put_u16(0); // Empty key
459                self.buf.put_u8(MARKER_OBJECT_END);
460            }
461            AmfValue::EcmaArray(props) => {
462                self.buf.put_u8(MARKER_ECMA_ARRAY);
463                self.buf.put_u32(props.len() as u32);
464                for (key, val) in props {
465                    self.write_utf8(key);
466                    self.encode(val);
467                }
468                self.buf.put_u16(0);
469                self.buf.put_u8(MARKER_OBJECT_END);
470            }
471            AmfValue::Array(elements) => {
472                self.buf.put_u8(MARKER_STRICT_ARRAY);
473                self.buf.put_u32(elements.len() as u32);
474                for elem in elements {
475                    self.encode(elem);
476                }
477            }
478            AmfValue::Date(timestamp) => {
479                self.buf.put_u8(MARKER_DATE);
480                self.buf.put_f64(*timestamp);
481                self.buf.put_i16(0); // Timezone (deprecated)
482            }
483            AmfValue::Xml(s) => {
484                self.buf.put_u8(MARKER_XML_DOCUMENT);
485                self.buf.put_u32(s.len() as u32);
486                self.buf.put_slice(s.as_bytes());
487            }
488            AmfValue::TypedObject {
489                class_name,
490                properties,
491            } => {
492                self.buf.put_u8(MARKER_TYPED_OBJECT);
493                self.write_utf8(class_name);
494                for (key, val) in properties {
495                    self.write_utf8(key);
496                    self.encode(val);
497                }
498                self.buf.put_u16(0);
499                self.buf.put_u8(MARKER_OBJECT_END);
500            }
501            AmfValue::ByteArray(_) => {
502                // ByteArray is AMF3-only, encode as null in AMF0
503                self.buf.put_u8(MARKER_NULL);
504            }
505        }
506    }
507
508    /// Encode multiple values
509    pub fn encode_all(&mut self, values: &[AmfValue]) {
510        for value in values {
511            self.encode(value);
512        }
513    }
514
515    /// Write UTF-8 string with 16-bit length prefix (no type marker)
516    fn write_utf8(&mut self, s: &str) {
517        let len = s.len().min(0xFFFF);
518        self.buf.put_u16(len as u16);
519        self.buf.put_slice(&s.as_bytes()[..len]);
520    }
521}
522
523impl Default for Amf0Encoder {
524    fn default() -> Self {
525        Self::new()
526    }
527}
528
529/// Convenience function to encode a single value
530pub fn encode(value: &AmfValue) -> Bytes {
531    let mut encoder = Amf0Encoder::new();
532    encoder.encode(value);
533    encoder.finish()
534}
535
536/// Convenience function to encode multiple values
537pub fn encode_all(values: &[AmfValue]) -> Bytes {
538    let mut encoder = Amf0Encoder::new();
539    encoder.encode_all(values);
540    encoder.finish()
541}
542
543/// Convenience function to decode a single value
544pub fn decode(data: &[u8]) -> Result<AmfValue, AmfError> {
545    let mut decoder = Amf0Decoder::new();
546    let mut buf = Bytes::copy_from_slice(data);
547    decoder.decode(&mut buf)
548}
549
550/// Convenience function to decode all values
551pub fn decode_all(data: &[u8]) -> Result<Vec<AmfValue>, AmfError> {
552    let mut decoder = Amf0Decoder::new();
553    let mut buf = Bytes::copy_from_slice(data);
554    decoder.decode_all(&mut buf)
555}
556
557#[cfg(test)]
558mod tests {
559    use super::*;
560
561    #[test]
562    fn test_number_roundtrip() {
563        let value = AmfValue::Number(42.5);
564        let encoded = encode(&value);
565        let decoded = decode(&encoded).unwrap();
566        assert_eq!(decoded, value);
567    }
568
569    #[test]
570    fn test_string_roundtrip() {
571        let value = AmfValue::String("hello world".into());
572        let encoded = encode(&value);
573        let decoded = decode(&encoded).unwrap();
574        assert_eq!(decoded, value);
575    }
576
577    #[test]
578    fn test_boolean_roundtrip() {
579        let value = AmfValue::Boolean(true);
580        let encoded = encode(&value);
581        let decoded = decode(&encoded).unwrap();
582        assert_eq!(decoded, value);
583    }
584
585    #[test]
586    fn test_null_roundtrip() {
587        let value = AmfValue::Null;
588        let encoded = encode(&value);
589        let decoded = decode(&encoded).unwrap();
590        assert_eq!(decoded, value);
591    }
592
593    #[test]
594    fn test_object_roundtrip() {
595        let mut props = HashMap::new();
596        props.insert("name".to_string(), AmfValue::String("test".into()));
597        props.insert("value".to_string(), AmfValue::Number(123.0));
598        let value = AmfValue::Object(props);
599
600        let encoded = encode(&value);
601        let decoded = decode(&encoded).unwrap();
602
603        // Compare as objects (order may differ)
604        if let (AmfValue::Object(orig), AmfValue::Object(dec)) = (&value, &decoded) {
605            assert_eq!(orig.len(), dec.len());
606            for (k, v) in orig {
607                assert_eq!(dec.get(k), Some(v));
608            }
609        } else {
610            panic!("Expected objects");
611        }
612    }
613
614    #[test]
615    fn test_array_roundtrip() {
616        let value = AmfValue::Array(vec![
617            AmfValue::Number(1.0),
618            AmfValue::String("two".into()),
619            AmfValue::Boolean(true),
620        ]);
621        let encoded = encode(&value);
622        let decoded = decode(&encoded).unwrap();
623        assert_eq!(decoded, value);
624    }
625
626    #[test]
627    fn test_multiple_values() {
628        let values = vec![
629            AmfValue::String("connect".into()),
630            AmfValue::Number(1.0),
631            AmfValue::Null,
632        ];
633
634        let encoded = encode_all(&values);
635        let decoded = decode_all(&encoded).unwrap();
636        assert_eq!(decoded, values);
637    }
638
639    #[test]
640    fn test_long_string() {
641        let long_str = "x".repeat(70000);
642        let value = AmfValue::String(long_str.clone());
643        let encoded = encode(&value);
644        let decoded = decode(&encoded).unwrap();
645        assert_eq!(decoded, AmfValue::String(long_str));
646    }
647}