amf_rs/amf0/
utf8.rs

1use crate::errors::AmfError;
2use crate::traits::{Marshall, MarshallLength, Unmarshall};
3use std::borrow::Borrow;
4use std::fmt::{Debug, Display, Formatter};
5use std::ops::Deref;
6
7#[derive(Debug, Clone, PartialEq, Eq, Hash)]
8pub struct AmfUtf8<const LBW: usize> {
9    inner: String,
10}
11
12impl<const LBW: usize> AmfUtf8<LBW> {
13    pub fn new(inner: String) -> Result<Self, AmfError> {
14        debug_assert!(LBW == 2 || LBW == 4);
15        let len = inner.len();
16        if (LBW == 2 && len > u16::MAX as usize) || (LBW == 4 && len > u32::MAX as usize) {
17            return Err(AmfError::StringTooLong { max: LBW, got: len });
18        }
19        Ok(Self {
20            inner: inner.to_string(),
21        })
22    }
23
24    pub fn new_from_str(inner: &str) -> Result<Self, AmfError> {
25        Self::new(inner.to_string())
26    }
27}
28
29impl<const LBW: usize> Marshall for AmfUtf8<LBW> {
30    fn marshall(&self) -> Result<Vec<u8>, AmfError> {
31        debug_assert!(LBW == 2 || LBW == 4);
32        let mut vec = Vec::with_capacity(self.marshall_length());
33        if LBW == 2 {
34            vec.extend_from_slice((self.inner.len() as u16).to_be_bytes().as_slice())
35        } else if LBW == 4 {
36            vec.extend_from_slice((self.inner.len() as u32).to_be_bytes().as_slice())
37        } else {
38            return Err(AmfError::Custom("Invalid length byte width".to_string()));
39        }
40        vec.extend_from_slice(self.inner.as_bytes());
41        Ok(vec)
42    }
43}
44
45impl<const LBW: usize> MarshallLength for AmfUtf8<LBW> {
46    fn marshall_length(&self) -> usize {
47        debug_assert!(LBW == 2 || LBW == 4);
48        LBW + self.inner.len()
49    }
50}
51
52impl<const LBW: usize> Unmarshall for AmfUtf8<LBW> {
53    fn unmarshall(buf: &[u8]) -> Result<(Self, usize), AmfError> {
54        debug_assert!(LBW == 2 || LBW == 4);
55        let length;
56        if LBW == 2 {
57            if buf.len() < 2 {
58                return Err(AmfError::BufferTooSmall {
59                    want: 2,
60                    got: buf.len(),
61                });
62            }
63            length = u16::from_be_bytes(buf[0..2].try_into().unwrap()) as usize;
64        } else if LBW == 4 {
65            if buf.len() < 4 {
66                return Err(AmfError::BufferTooSmall {
67                    want: 4,
68                    got: buf.len(),
69                });
70            }
71            length = u32::from_be_bytes(buf[0..4].try_into().unwrap()) as usize;
72        } else {
73            return Err(AmfError::Custom("Invalid length byte width".to_string()));
74        }
75
76        let start = LBW;
77        let end = start + length;
78        if buf.len() < end {
79            return Err(AmfError::BufferTooSmall {
80                want: end,
81                got: buf.len(),
82            });
83        }
84        let value = std::str::from_utf8(&buf[start..end]).map_err(|e| AmfError::InvalidUtf8(e))?;
85        Ok((
86            Self {
87                inner: value.to_string(),
88            },
89            end,
90        ))
91    }
92}
93
94// 实现 rust 惯用语("idiom") 方便用户使用
95
96impl<const LBW: usize> TryFrom<&[u8]> for AmfUtf8<LBW> {
97    type Error = AmfError;
98
99    fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
100        Self::unmarshall(value).map(|(v, _)| v)
101    }
102}
103
104impl<const LBW: usize> TryFrom<Vec<u8>> for AmfUtf8<LBW> {
105    type Error = AmfError;
106
107    fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
108        Self::try_from(value.as_slice())
109    }
110}
111
112impl<const LBW: usize> TryFrom<AmfUtf8<LBW>> for Vec<u8> {
113    type Error = AmfError;
114
115    fn try_from(value: AmfUtf8<LBW>) -> Result<Self, Self::Error> {
116        value.marshall()
117    }
118}
119
120impl<const LBW: usize> TryFrom<String> for AmfUtf8<LBW> {
121    type Error = AmfError;
122
123    fn try_from(value: String) -> Result<Self, Self::Error> {
124        Self::new(value)
125    }
126}
127
128impl<const LBW: usize> TryFrom<AmfUtf8<LBW>> for String {
129    type Error = AmfError;
130
131    fn try_from(value: AmfUtf8<LBW>) -> Result<Self, Self::Error> {
132        Ok(value.inner)
133    }
134}
135
136impl<const LBW: usize> TryFrom<&str> for AmfUtf8<LBW> {
137    type Error = AmfError;
138
139    fn try_from(value: &str) -> Result<Self, Self::Error> {
140        Self::new_from_str(value)
141    }
142}
143
144impl<const LBW: usize> AsRef<str> for AmfUtf8<LBW> {
145    fn as_ref(&self) -> &str {
146        self.inner.as_ref()
147    }
148}
149impl<const LBW: usize> Deref for AmfUtf8<LBW> {
150    type Target = str;
151
152    fn deref(&self) -> &Self::Target {
153        Self::as_ref(self)
154    }
155}
156impl<const LBW: usize> Borrow<str> for AmfUtf8<LBW> {
157    fn borrow(&self) -> &str {
158        Self::as_ref(self)
159    }
160}
161
162impl<const LBW: usize> Display for AmfUtf8<LBW> {
163    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
164        write!(f, "{}", self.inner)
165    }
166}
167
168impl<const LBW: usize> Default for AmfUtf8<LBW> {
169    fn default() -> Self {
170        Self::new_from_str("").unwrap()
171    }
172}
173
174// 类型别名
175
176pub type Utf8 = AmfUtf8<2>;
177pub type Utf8Long = AmfUtf8<4>;
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182    use crate::traits::{Marshall, MarshallLength, Unmarshall};
183    use std::hash::{DefaultHasher, Hash, Hasher};
184
185    // 测试有效字符串创建(LBW=2)
186    #[test]
187    fn new_valid_utf8_w2() {
188        let s = "a".repeat(u16::MAX as usize);
189        let amf_str = AmfUtf8::<2>::new_from_str(&s).unwrap();
190        assert_eq!(amf_str.inner, s);
191    }
192
193    // 测试过长字符串创建(LBW=2)
194    #[test]
195    fn new_too_long_utf8_w2() {
196        let s = "a".repeat(u16::MAX as usize + 1);
197        assert!(matches!(
198            AmfUtf8::<2>::new_from_str(&s),
199            Err(AmfError::StringTooLong { max: 2, got: _ })
200        ));
201    }
202
203    // 测试有效字符串创建(LBW=4)
204    #[test]
205    fn new_valid_utf8_w4() {
206        let s = "a".repeat(1000); // 在u32范围内
207        let amf_str = AmfUtf8::<4>::new_from_str(&s).unwrap();
208        assert_eq!(amf_str.inner, s);
209    }
210
211    // 测试序列化(LBW=2)
212    #[test]
213    fn try_into_bytes_w2() {
214        let amf_str = AmfUtf8::<2>::new_from_str("hello").unwrap();
215        let bytes = amf_str.marshall().unwrap();
216        assert_eq!(bytes, &[0x00, 0x05, b'h', b'e', b'l', b'l', b'o']);
217    }
218
219    // 测试序列化(LBW=4)
220    #[test]
221    fn try_into_bytes_w4() {
222        let amf_str = AmfUtf8::<4>::new_from_str("world").unwrap();
223        let bytes = amf_str.marshall().unwrap();
224        assert_eq!(
225            bytes,
226            &[0x00, 0x00, 0x00, 0x05, b'w', b'o', b'r', b'l', b'd']
227        );
228    }
229
230    // 测试反序列化(LBW=2)
231    #[test]
232    fn try_from_bytes_w2() {
233        let data = [0x00, 0x05, b'h', b'e', b'l', b'l', b'o'];
234        let (amf_str, consumed) = AmfUtf8::<2>::unmarshall(&data).unwrap();
235        assert_eq!(amf_str.inner, "hello");
236        assert_eq!(consumed, 7);
237    }
238
239    // 测试反序列化(LBW=4)
240    #[test]
241    fn try_from_bytes_w4() {
242        let data = [0x00, 0x00, 0x00, 0x05, b'w', b'o', b'r', b'l', b'd'];
243        let (amf_str, consumed) = AmfUtf8::<4>::unmarshall(&data).unwrap();
244        assert_eq!(amf_str.inner, "world");
245        assert_eq!(consumed, 9);
246    }
247
248    // 测试长度计算
249    #[test]
250    fn length_calculation() {
251        let amf_str = AmfUtf8::<2>::new_from_str("abc").unwrap();
252        assert_eq!(amf_str.marshall_length(), 2 + 3); // 2字节长度头 + 3字节内容
253
254        let amf_str = AmfUtf8::<4>::new_from_str("abcde").unwrap();
255        assert_eq!(amf_str.marshall_length(), 4 + 5); // 4字节长度头 + 5字节内容
256    }
257
258    // 测试TryFrom转换
259    #[test]
260    fn try_from_slice() {
261        let data = [0x00, 0x03, b'f', b'o', b'o'];
262        let amf_str: AmfUtf8<2> = data[..].try_into().unwrap();
263        assert_eq!(amf_str.inner, "foo");
264    }
265
266    // 测试Deref和AsRef
267    #[test]
268    fn deref_and_as_ref() {
269        let amf_str = AmfUtf8::<2>::new_from_str("bar").unwrap();
270        assert_eq!(&*amf_str, "bar");
271        assert_eq!(amf_str.as_ref(), "bar");
272    }
273
274    // 测试Display
275    #[test]
276    fn display_format() {
277        let amf_str = AmfUtf8::<2>::new_from_str("test").unwrap();
278        assert_eq!(format!("{}", amf_str), "test");
279    }
280
281    /// Helper to compute the hash of a value
282    fn calculate_hash<T: Hash>(t: &T) -> u64 {
283        let mut hasher = DefaultHasher::new();
284        t.hash(&mut hasher);
285        hasher.finish()
286    }
287
288    #[test]
289    fn clone_preserves_equality() {
290        let original = AmfUtf8::<2>::new_from_str("hello").unwrap();
291        let cloned = original.clone();
292        // After cloning, they should be equal
293        assert_eq!(original, cloned);
294    }
295
296    #[test]
297    fn eq_and_neq_behaviour() {
298        let a = AmfUtf8::<4>::new_from_str("rust").unwrap();
299        let b_same = AmfUtf8::<4>::new_from_str("rust").unwrap();
300        let c_diff = AmfUtf8::<4>::new_from_str("Rust").unwrap();
301
302        // Same content should be equal
303        assert_eq!(a, b_same);
304        // Different case should not be equal
305        assert_ne!(a, c_diff);
306    }
307
308    #[test]
309    fn equal_values_have_same_hash() {
310        let x = AmfUtf8::<2>::new_from_str("hash_me").unwrap();
311        let y = AmfUtf8::<2>::new_from_str("hash_me").unwrap();
312
313        let hx = calculate_hash(&x);
314        let hy = calculate_hash(&y);
315        assert_eq!(hx, hy, "Equal values should produce the same hash");
316    }
317
318    #[test]
319    fn different_values_have_different_hash() {
320        let x = AmfUtf8::<2>::new_from_str("foo").unwrap();
321        let y = AmfUtf8::<2>::new_from_str("bar").unwrap();
322
323        let hx = calculate_hash(&x);
324        let hy = calculate_hash(&y);
325        assert_ne!(hx, hy, "Different values should produce different hashes");
326    }
327
328    #[test]
329    fn clone_preserves_hash() {
330        let original = AmfUtf8::<4>::new_from_str("clone_hash").unwrap();
331        let cloned = original.clone();
332
333        let h1 = calculate_hash(&original);
334        let h2 = calculate_hash(&cloned);
335        assert_eq!(
336            h1, h2,
337            "Cloned instance should have the same hash as original"
338        );
339    }
340}