amf_rs/amf0/
string.rs

1use crate::amf0::type_marker::TypeMarker;
2use crate::amf0::utf8::AmfUtf8;
3use crate::errors::AmfError;
4use crate::traits::{Marshall, MarshallLength, Unmarshall};
5use std::borrow::Borrow;
6use std::fmt::{Display, Formatter};
7use std::ops::Deref;
8
9#[derive(Debug, Clone, PartialEq, Eq, Hash)]
10pub struct AmfUtf8ValuedType<const LBW: usize, const TM: u8> {
11    inner: AmfUtf8<LBW>,
12}
13
14impl<const LBW: usize, const TM: u8> AmfUtf8ValuedType<LBW, TM> {
15    pub fn new(inner: AmfUtf8<LBW>) -> Self {
16        Self { inner }
17    }
18
19    pub fn new_from_string(value: String) -> Result<Self, AmfError> {
20        let inner = AmfUtf8::<LBW>::new(value)?;
21        Ok(Self::new(inner))
22    }
23
24    pub fn new_from_str(value: &str) -> Result<Self, AmfError> {
25        Self::new_from_string(value.to_string())
26    }
27}
28
29impl<const LBW: usize, const TM: u8> Marshall for AmfUtf8ValuedType<LBW, TM> {
30    fn marshall(&self) -> Result<Vec<u8>, AmfError> {
31        let mut vec = Vec::with_capacity(self.marshall_length());
32        vec.push(TM);
33        let inner_vec = self.inner.marshall()?;
34        vec.extend_from_slice(inner_vec.as_slice());
35        Ok(vec)
36    }
37}
38
39impl<const LBW: usize, const TM: u8> MarshallLength for AmfUtf8ValuedType<LBW, TM> {
40    fn marshall_length(&self) -> usize {
41        1 + self.inner.marshall_length()
42    }
43}
44
45impl<const LBW: usize, const TM: u8> Unmarshall for AmfUtf8ValuedType<LBW, TM> {
46    fn unmarshall(buf: &[u8]) -> Result<(Self, usize), AmfError> {
47        let required_size = 1 + LBW;
48        if buf.len() < required_size {
49            return Err(AmfError::BufferTooSmall {
50                want: required_size,
51                got: buf.len(),
52            });
53        }
54
55        if buf[0] != TM {
56            return Err(AmfError::TypeMarkerValueMismatch {
57                want: TM,
58                got: buf[0],
59            });
60        }
61        let inner = AmfUtf8::unmarshall(&buf[1..])?;
62        Ok((Self::new(inner.0), 1 + inner.1))
63    }
64}
65
66// 实现 rust 惯用语("idiom") 方便用户使用
67
68impl<const LBW: usize, const TM: u8> TryFrom<&[u8]> for AmfUtf8ValuedType<LBW, TM> {
69    type Error = AmfError;
70
71    fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
72        Self::unmarshall(value).map(|(inner, _)| inner)
73    }
74}
75
76impl<const LBW: usize, const TM: u8> TryFrom<Vec<u8>> for AmfUtf8ValuedType<LBW, TM> {
77    type Error = AmfError;
78
79    fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
80        Self::try_from(value.as_slice())
81    }
82}
83
84impl<const LBW: usize, const TM: u8> TryFrom<AmfUtf8ValuedType<LBW, TM>> for Vec<u8> {
85    type Error = AmfError;
86
87    fn try_from(value: AmfUtf8ValuedType<LBW, TM>) -> Result<Self, Self::Error> {
88        value.marshall()
89    }
90}
91
92impl<const LBW: usize, const TM: u8> TryFrom<String> for AmfUtf8ValuedType<LBW, TM> {
93    type Error = AmfError;
94
95    fn try_from(value: String) -> Result<Self, Self::Error> {
96        Self::new_from_string(value)
97    }
98}
99
100impl<const LBW: usize, const TM: u8> TryFrom<AmfUtf8ValuedType<LBW, TM>> for String {
101    type Error = AmfError;
102
103    fn try_from(value: AmfUtf8ValuedType<LBW, TM>) -> Result<Self, Self::Error> {
104        value.inner.try_into()
105    }
106}
107
108impl<const LBW: usize, const TM: u8> TryFrom<&str> for AmfUtf8ValuedType<LBW, TM> {
109    type Error = AmfError;
110
111    fn try_from(value: &str) -> Result<Self, Self::Error> {
112        Self::new_from_str(value)
113    }
114}
115
116impl<const LBW: usize, const TM: u8> From<AmfUtf8<LBW>> for AmfUtf8ValuedType<LBW, TM> {
117    fn from(value: AmfUtf8<LBW>) -> Self {
118        Self::new(value)
119    }
120}
121
122impl<const LBW: usize, const TM: u8> AsRef<AmfUtf8<LBW>> for AmfUtf8ValuedType<LBW, TM> {
123    fn as_ref(&self) -> &AmfUtf8<LBW> {
124        &self.inner
125    }
126}
127
128impl<const LBW: usize, const TM: u8> Deref for AmfUtf8ValuedType<LBW, TM> {
129    type Target = AmfUtf8<LBW>;
130
131    fn deref(&self) -> &Self::Target {
132        self.as_ref()
133    }
134}
135
136impl<const LBW: usize, const TM: u8> Borrow<AmfUtf8<LBW>> for AmfUtf8ValuedType<LBW, TM> {
137    fn borrow(&self) -> &AmfUtf8<LBW> {
138        self.as_ref()
139    }
140}
141
142impl<const LBW: usize, const TM: u8> Display for AmfUtf8ValuedType<LBW, TM> {
143    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
144        write!(f, "\"{}\"", self.inner)
145    }
146}
147
148impl<const LBW: usize, const TM: u8> Default for AmfUtf8ValuedType<LBW, TM> {
149    fn default() -> Self {
150        Self::new(AmfUtf8::<LBW>::default())
151    }
152}
153
154// 类型别名
155
156//	All strings in AMF are encoded using UTF-8; however, the byte-length header format
157//	may vary. The AMF 0 String type uses the standard byte-length header (i.e. U16). For
158//	long Strings that require more than 65535 bytes to encode in UTF-8, the AMF 0 Long
159//	String type should be used.
160pub type StringType = AmfUtf8ValuedType<2, { TypeMarker::String as u8 }>;
161
162//	A long string is used in AMF 0 to encode strings that would occupy more than 65535
163//	bytes when UTF-8 encoded. The byte-length header of the UTF-8 encoded string is a 32-
164//	bit integer instead of the regular 16-bit integer.
165pub type LongStringType = AmfUtf8ValuedType<4, { TypeMarker::LongString as u8 }>;
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170    use crate::amf0::utf8::AmfUtf8;
171    use std::hash::{DefaultHasher, Hash, Hasher};
172
173    // 测试 AmfUtf8ValuedType 的通用功能
174    #[test]
175    fn test_new() {
176        let utf8 = AmfUtf8::<2>::new_from_str("test").unwrap();
177        let valued = AmfUtf8ValuedType::<2, 0x02>::new(utf8.clone());
178        assert_eq!(valued.inner, utf8);
179    }
180
181    #[test]
182    fn test_default() {
183        let valued = AmfUtf8ValuedType::<2, 0x02>::default();
184        assert_eq!(valued.inner, AmfUtf8::<2>::default());
185    }
186
187    #[test]
188    fn test_try_from() {
189        let utf8 = AmfUtf8::<2>::new_from_str("test").unwrap();
190        let valued: AmfUtf8ValuedType<2, 0x02> = utf8.clone().try_into().unwrap();
191        assert_eq!(valued.inner, utf8);
192    }
193
194    #[test]
195    fn test_as_ref() {
196        let utf8 = AmfUtf8::<2>::new_from_str("test").unwrap();
197        let valued = AmfUtf8ValuedType::<2, 0x02>::new(utf8.clone());
198        assert_eq!(valued.as_ref(), &utf8);
199    }
200
201    #[test]
202    fn test_deref() {
203        let utf8 = AmfUtf8::<2>::new_from_str("test").unwrap();
204        let valued = AmfUtf8ValuedType::<2, 0x02>::new(utf8.clone());
205        assert_eq!(&*valued, &utf8);
206    }
207
208    #[test]
209    fn test_display() {
210        let valued = AmfUtf8ValuedType::<2, 0x02>::new(AmfUtf8::<2>::new_from_str("test").unwrap());
211        assert_eq!(format!("{}", valued), "\"test\"");
212    }
213
214    // 测试 StringType 具体实现
215    #[test]
216    fn test_string_type_marshall() {
217        let s = StringType::new(AmfUtf8::<2>::new_from_str("hello").unwrap());
218        let data = s.marshall().unwrap();
219        assert_eq!(data[0], TypeMarker::String as u8);
220        assert_eq!(&data[1..], [0x00, 0x05, b'h', b'e', b'l', b'l', b'o']);
221    }
222
223    #[test]
224    fn test_string_type_marshall_length() {
225        let s = StringType::new(AmfUtf8::<2>::new_from_str("hello").unwrap());
226        assert_eq!(s.marshall_length(), 8); // 1 marker + 2 length + 5 chars
227    }
228
229    #[test]
230    fn test_string_type_unmarshall() {
231        let data = [
232            TypeMarker::String as u8,
233            0x00,
234            0x05, // length 5
235            b'h',
236            b'e',
237            b'l',
238            b'l',
239            b'o',
240        ];
241        let (s, bytes_read) = StringType::unmarshall(&data).unwrap();
242        assert_eq!(bytes_read, 8);
243        assert_eq!(s.as_ref().as_ref(), "hello");
244    }
245
246    #[test]
247    fn test_string_type_unmarshall_invalid_marker() {
248        let data = [
249            TypeMarker::Number as u8, // wrong marker
250            0x00,
251            0x05,
252            b'h',
253            b'e',
254            b'l',
255            b'l',
256            b'o',
257        ];
258        let result = StringType::unmarshall(&data);
259        assert!(matches!(
260            result,
261            Err(AmfError::TypeMarkerValueMismatch {
262                want: 0x02,
263                got: 0x00
264            })
265        ));
266    }
267
268    #[test]
269    fn test_string_type_unmarshall_buffer_too_small() {
270        let data = [TypeMarker::String as u8, 0x00]; // incomplete
271        let result = StringType::unmarshall(&data);
272        assert!(matches!(
273            result,
274            Err(AmfError::BufferTooSmall {
275                want: 3, // marker + 2-byte length
276                got: 2
277            })
278        ));
279    }
280
281    // 测试 LongStringType 具体实现
282    #[test]
283    fn test_long_string_type_marshall() {
284        let s = LongStringType::new(AmfUtf8::<4>::new_from_str("hello").unwrap());
285        let data = s.marshall().unwrap();
286        assert_eq!(data[0], TypeMarker::LongString as u8);
287        assert_eq!(
288            &data[1..],
289            [0x00, 0x00, 0x00, 0x05, b'h', b'e', b'l', b'l', b'o']
290        );
291    }
292
293    #[test]
294    fn test_long_string_type_marshall_length() {
295        let s = LongStringType::new(AmfUtf8::<4>::new_from_str("hello").unwrap());
296        assert_eq!(s.marshall_length(), 10); // 1 marker + 4 length + 5 chars
297    }
298
299    #[test]
300    fn test_long_string_type_unmarshall() {
301        let data = [
302            TypeMarker::LongString as u8,
303            0x00,
304            0x00,
305            0x00,
306            0x05, // length 5
307            b'h',
308            b'e',
309            b'l',
310            b'l',
311            b'o',
312        ];
313        let (s, bytes_read) = LongStringType::unmarshall(&data).unwrap();
314        assert_eq!(bytes_read, 10);
315        assert_eq!(s.as_ref().as_ref(), "hello");
316    }
317
318    #[test]
319    fn test_long_string_type_unmarshall_large_string() {
320        let long_str = "a".repeat(70_000);
321        let mut data = vec![TypeMarker::LongString as u8];
322        let len_bytes = (long_str.len() as u32).to_be_bytes();
323        data.extend_from_slice(&len_bytes);
324        data.extend_from_slice(long_str.as_bytes());
325
326        let (s, bytes_read) = LongStringType::unmarshall(&data).unwrap();
327        assert_eq!(bytes_read, 1 + 4 + long_str.len());
328        assert_eq!(s.as_ref().as_ref(), long_str);
329    }
330
331    #[test]
332    fn test_long_string_type_unmarshall_invalid_marker() {
333        let data = [
334            TypeMarker::String as u8, // wrong marker
335            0x00,
336            0x00,
337            0x00,
338            0x05,
339            b'h',
340            b'e',
341            b'l',
342            b'l',
343            b'o',
344        ];
345        let result = LongStringType::unmarshall(&data);
346        assert!(matches!(
347            result,
348            Err(AmfError::TypeMarkerValueMismatch {
349                want: 0x0C,
350                got: 0x02
351            })
352        ));
353    }
354
355    #[test]
356    fn test_long_string_type_unmarshall_buffer_too_small() {
357        let data = [TypeMarker::LongString as u8, 0x00, 0x00, 0x00]; // incomplete
358        let result = LongStringType::unmarshall(&data);
359        assert!(matches!(
360            result,
361            Err(AmfError::BufferTooSmall {
362                want: 5, // marker + 4-byte length
363                got: 4
364            })
365        ));
366    }
367
368    // 测试类型别名
369    #[test]
370    fn test_string_type_alias() {
371        let s: StringType = AmfUtf8::<2>::new_from_str("test")
372            .unwrap()
373            .try_into()
374            .unwrap();
375        assert_eq!(s.as_ref().as_ref(), "test");
376    }
377
378    #[test]
379    fn test_long_string_type_alias() {
380        let s: LongStringType = AmfUtf8::<4>::new_from_str("test")
381            .unwrap()
382            .try_into()
383            .unwrap();
384        assert_eq!(s.as_ref().as_ref(), "test");
385    }
386
387    /// Helper to compute the hash of any `T: Hash`
388    fn hash_of<T: Hash>(t: &T) -> u64 {
389        let mut hasher = DefaultHasher::new();
390        t.hash(&mut hasher);
391        hasher.finish()
392    }
393
394    #[test]
395    fn string_type_clone_and_eq() {
396        // create an original value
397        let inner = AmfUtf8::<2>::new_from_str("hello").unwrap();
398        let orig: StringType = StringType::new(inner);
399
400        // Clone should produce an equal value
401        let cloned = orig.clone();
402        assert_eq!(orig, cloned, "Clone must preserve value (PartialEq/ Eq)");
403
404        // Hash of orig and cloned should be the same
405        let h1 = hash_of(&orig);
406        let h2 = hash_of(&cloned);
407        assert_eq!(h1, h2, "Hash must be consistent for equal values");
408    }
409
410    #[test]
411    fn string_type_hash_differs_on_content_change() {
412        let a = StringType::new(AmfUtf8::<2>::new_from_str("foo").unwrap());
413        let b = StringType::new(AmfUtf8::<2>::new_from_str("bar").unwrap());
414        // different strings must produce different hashes (very likely)
415        assert_ne!(
416            hash_of(&a),
417            hash_of(&b),
418            "Different values should hash differently"
419        );
420    }
421
422    #[test]
423    fn long_string_type_clone_and_eq() {
424        let inner = AmfUtf8::<4>::new_from_str("a very long string").unwrap();
425        let orig: LongStringType = LongStringType::new(inner);
426
427        // Clone ↔ Eq
428        let cloned = orig.clone();
429        assert_eq!(orig, cloned);
430
431        // Hash consistency
432        assert_eq!(hash_of(&orig), hash_of(&cloned));
433    }
434
435    #[test]
436    fn long_string_type_hash_differs_on_content_change() {
437        let a = LongStringType::new(AmfUtf8::<4>::new_from_str("one").unwrap());
438        let b = LongStringType::new(AmfUtf8::<4>::new_from_str("two").unwrap());
439        assert_ne!(hash_of(&a), hash_of(&b));
440    }
441    #[test]
442    fn test_string_type_clone_partial_eq() {
443        let s1: StringType = StringType::default();
444        let s2 = s1.clone();
445        assert_eq!(s1, s2);
446    }
447
448    #[test]
449    fn test_long_string_type_clone_partial_eq() {
450        let ls1: LongStringType = LongStringType::default();
451        let ls2 = ls1.clone();
452        assert_eq!(ls1, ls2);
453    }
454}