amf_rs/amf0/
number.rs

1use crate::amf0::type_marker::TypeMarker;
2use crate::errors::AmfError;
3use crate::traits::{Marshall, MarshallLength, Unmarshall};
4use std::fmt::{Display, Formatter};
5use std::ops::{Add, Deref};
6
7// An AMF 0 Number type is used to encode an ActionScript Number.
8// The data following a Number type marker is always an 8 byte IEEE-754 double precision floating point value in network byte order (sign bit in low memory).
9#[derive(Debug, Clone, PartialEq)]
10pub struct NumberType {
11    type_marker: TypeMarker,
12    value: f64,
13}
14
15impl NumberType {
16    pub fn new(value: f64) -> Self {
17        Self {
18            type_marker: TypeMarker::Number,
19            value,
20        }
21    }
22}
23
24impl Marshall for NumberType {
25    fn marshall(&self) -> Result<Vec<u8>, AmfError> {
26        debug_assert!(self.type_marker == TypeMarker::Number);
27        let mut buf = [0u8; 9];
28        buf[0] = self.type_marker as u8;
29        buf[1..9].copy_from_slice(&self.value.to_be_bytes());
30        Ok(buf.to_vec())
31    }
32}
33
34impl MarshallLength for NumberType {
35    fn marshall_length(&self) -> usize {
36        1 + 8 // 1 byte for type marker + 8 bytes for value
37    }
38}
39
40impl Unmarshall for NumberType {
41    fn unmarshall(buf: &[u8]) -> Result<(Self, usize), AmfError> {
42        if buf.len() < 9 {
43            return Err(AmfError::BufferTooSmall {
44                want: 9,
45                got: buf.len(),
46            });
47        }
48        let type_marker = TypeMarker::try_from(buf[0])?;
49        if type_marker != TypeMarker::Number {
50            return Err(AmfError::TypeMarkerValueMismatch {
51                want: TypeMarker::Number as u8,
52                got: buf[0],
53            });
54        }
55        let value = f64::from_be_bytes(buf[1..9].try_into().unwrap()); // 前边已经校验了 buf 的长度,这里直接用 .unwrap() 是安全的
56        Ok((Self { type_marker, value }, 9))
57    }
58}
59
60// 实现 rust 惯用语("idiom") 方便用户使用
61
62impl TryFrom<&[u8]> for NumberType {
63    type Error = AmfError;
64
65    fn try_from(buf: &[u8]) -> Result<Self, Self::Error> {
66        Self::unmarshall(buf).map(|(n, _)| n)
67    }
68}
69
70impl TryFrom<Vec<u8>> for NumberType {
71    type Error = AmfError;
72
73    fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
74        Self::try_from(value.as_slice())
75    }
76}
77
78impl TryFrom<NumberType> for Vec<u8> {
79    type Error = AmfError;
80
81    fn try_from(value: NumberType) -> Result<Self, Self::Error> {
82        value.marshall()
83    }
84}
85
86impl From<f64> for NumberType {
87    fn from(value: f64) -> Self {
88        Self::new(value)
89    }
90}
91
92impl From<NumberType> for f64 {
93    fn from(value: NumberType) -> Self {
94        value.value
95    }
96}
97
98impl AsRef<f64> for NumberType {
99    fn as_ref(&self) -> &f64 {
100        &self.value
101    }
102}
103
104impl Deref for NumberType {
105    type Target = f64;
106
107    fn deref(&self) -> &Self::Target {
108        self.as_ref()
109    }
110}
111
112impl Display for NumberType {
113    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
114        write!(f, "{}", self.value)
115    }
116}
117
118impl Default for NumberType {
119    fn default() -> Self {
120        Self::new(0.0)
121    }
122}
123
124impl Add for NumberType {
125    type Output = NumberType;
126
127    fn add(self, rhs: Self) -> Self::Output {
128        Self::new(self.value + rhs.value)
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135    use crate::amf0::type_marker::TypeMarker;
136    use std::f64::{EPSILON, INFINITY, NAN, NEG_INFINITY};
137
138    #[test]
139    fn test_new() {
140        let num = NumberType::new(3.14);
141        assert_eq!(num.type_marker, TypeMarker::Number);
142        assert!((num.value - 3.14).abs() < EPSILON);
143    }
144
145    #[test]
146    fn test_default() {
147        let num = NumberType::default();
148        assert_eq!(num.type_marker, TypeMarker::Number);
149        assert!((num.value - 0.0).abs() < EPSILON);
150    }
151
152    #[test]
153    fn test_from_f64() {
154        let num: NumberType = 3.14.into();
155        assert_eq!(num.type_marker, TypeMarker::Number);
156        assert!((num.value - 3.14).abs() < EPSILON);
157    }
158
159    #[test]
160    fn test_clone_eq() {
161        let original = NumberType::new(2.718);
162        let cloned = original.clone();
163        // Ensure clone produces an equal value
164        assert_eq!(cloned, original);
165        // Ensure they are distinct instances
166        assert!(!std::ptr::eq(&original, &cloned));
167    }
168
169    #[test]
170    fn test_partial_eq() {
171        let a = NumberType::new(1.0);
172        let b = NumberType::new(1.0);
173        let c = NumberType::new(2.0);
174        assert_eq!(a, b);
175        assert_ne!(a, c);
176    }
177
178    #[test]
179    fn test_marshall() {
180        let num = NumberType::new(3.14);
181        let data = num.marshall().unwrap();
182
183        let expected_marker = TypeMarker::Number as u8;
184        let expected_value = 3.14f64.to_be_bytes();
185
186        assert_eq!(data[0], expected_marker);
187        assert_eq!(&data[1..9], expected_value);
188    }
189
190    #[test]
191    fn test_marshall_special_values() {
192        // 测试特殊浮点值
193        let tests = vec![
194            (0.0, 0.0),
195            (-0.0, -0.0),
196            (INFINITY, INFINITY),
197            (NEG_INFINITY, NEG_INFINITY),
198            (f64::MIN, f64::MIN),
199            (f64::MAX, f64::MAX),
200        ];
201
202        for (input, expected) in tests {
203            let num = NumberType::new(input);
204            let data = num.marshall().unwrap();
205
206            let mut buf = [0u8; 9];
207            buf[0] = TypeMarker::Number as u8;
208            buf[1..9].copy_from_slice(&expected.to_be_bytes());
209
210            assert_eq!(data, buf.to_vec());
211        }
212    }
213
214    #[test]
215    fn test_marshall_length() {
216        let num = NumberType::new(3.14);
217        assert_eq!(num.marshall_length(), 9);
218    }
219
220    #[test]
221    fn test_unmarshall() {
222        let mut data = [0u8; 9];
223        data[0] = TypeMarker::Number as u8;
224        data[1..9].copy_from_slice(&3.14f64.to_be_bytes());
225
226        let (num, bytes_read) = NumberType::unmarshall(&data).unwrap();
227
228        assert_eq!(bytes_read, 9);
229        assert_eq!(num.type_marker, TypeMarker::Number);
230        assert!((num.value - 3.14).abs() < EPSILON);
231    }
232
233    #[test]
234    fn test_unmarshall_special_values() {
235        let tests = vec![
236            (0.0, 0.0),
237            (-0.0, -0.0),
238            (INFINITY, INFINITY),
239            (NEG_INFINITY, NEG_INFINITY),
240            (f64::MIN, f64::MIN),
241            (f64::MAX, f64::MAX),
242        ];
243
244        for (input, expected) in tests {
245            let mut data = [0u8; 9];
246            data[0] = TypeMarker::Number as u8;
247            data[1..9].copy_from_slice(&input.to_be_bytes());
248
249            let (num, _) = NumberType::unmarshall(&data).unwrap();
250            if expected.is_nan() {
251                assert!(num.value.is_nan());
252            } else {
253                assert_eq!(num.value.to_bits(), expected.to_bits());
254            }
255        }
256    }
257
258    #[test]
259    fn test_unmarshall_nan() {
260        let mut data = [0u8; 9];
261        data[0] = TypeMarker::Number as u8;
262        data[1..9].copy_from_slice(&NAN.to_be_bytes());
263
264        let (num, _) = NumberType::unmarshall(&data).unwrap();
265        assert!(num.value.is_nan());
266    }
267
268    #[test]
269    fn test_unmarshall_buffer_too_small() {
270        let data = [0u8; 8];
271        let result = NumberType::unmarshall(&data);
272        assert!(matches!(
273            result,
274            Err(AmfError::BufferTooSmall { want: 9, got: 8 })
275        ));
276    }
277
278    #[test]
279    fn test_unmarshall_invalid_marker() {
280        let mut data = [0u8; 9];
281        data[0] = TypeMarker::Null as u8; // 错误的类型标记
282        data[1..9].copy_from_slice(&3.14f64.to_be_bytes());
283
284        let result = NumberType::unmarshall(&data);
285        assert!(matches!(
286            result,
287            Err(AmfError::TypeMarkerValueMismatch {
288                want: 0x00,
289                got: 0x05
290            })
291        ));
292    }
293
294    #[test]
295    fn test_try_from_slice() {
296        let mut data = [0u8; 9];
297        data[0] = TypeMarker::Number as u8;
298        data[1..9].copy_from_slice(&3.14f64.to_be_bytes());
299
300        let num = NumberType::try_from(&data[..]).unwrap();
301        assert!((num.value - 3.14).abs() < EPSILON);
302    }
303
304    #[test]
305    fn test_deref() {
306        let num = NumberType::new(3.14);
307        assert!((*num - 3.14).abs() < EPSILON);
308    }
309
310    #[test]
311    fn test_as_ref() {
312        let num = NumberType::new(3.14);
313        let value_ref: &f64 = num.as_ref();
314        assert!((*value_ref - 3.14).abs() < EPSILON);
315    }
316
317    #[test]
318    fn test_display() {
319        let num = NumberType::new(3.14);
320        assert_eq!(format!("{}", num), "3.14");
321
322        let num = NumberType::new(-42.0);
323        assert_eq!(format!("{}", num), "-42");
324
325        let num = NumberType::new(INFINITY);
326        assert_eq!(format!("{}", num), "inf");
327
328        let num = NumberType::new(NEG_INFINITY);
329        assert_eq!(format!("{}", num), "-inf");
330
331        let num = NumberType::new(NAN);
332        assert_eq!(format!("{}", num), "NaN");
333    }
334}