1use crate::errors::AmfError;
2use crate::traits::{Marshall, MarshallLength, Unmarshall};
3use std::borrow::Borrow;
4use std::fmt::{Debug, Display, Formatter};
5use std::ops::Deref;
6
7#[derive(Debug, Clone, PartialEq, Eq, Hash)]
8pub struct AmfUtf8<const LBW: usize> {
9 inner: String,
10}
11
12impl<const LBW: usize> AmfUtf8<LBW> {
13 pub fn new(inner: String) -> Result<Self, AmfError> {
14 debug_assert!(LBW == 2 || LBW == 4);
15 let len = inner.len();
16 if (LBW == 2 && len > u16::MAX as usize) || (LBW == 4 && len > u32::MAX as usize) {
17 return Err(AmfError::StringTooLong { max: LBW, got: len });
18 }
19 Ok(Self {
20 inner: inner.to_string(),
21 })
22 }
23
24 pub fn new_from_str(inner: &str) -> Result<Self, AmfError> {
25 Self::new(inner.to_string())
26 }
27}
28
29impl<const LBW: usize> Marshall for AmfUtf8<LBW> {
30 fn marshall(&self) -> Result<Vec<u8>, AmfError> {
31 debug_assert!(LBW == 2 || LBW == 4);
32 let mut vec = Vec::with_capacity(self.marshall_length());
33 if LBW == 2 {
34 vec.extend_from_slice((self.inner.len() as u16).to_be_bytes().as_slice())
35 } else if LBW == 4 {
36 vec.extend_from_slice((self.inner.len() as u32).to_be_bytes().as_slice())
37 } else {
38 return Err(AmfError::Custom("Invalid length byte width".to_string()));
39 }
40 vec.extend_from_slice(self.inner.as_bytes());
41 Ok(vec)
42 }
43}
44
45impl<const LBW: usize> MarshallLength for AmfUtf8<LBW> {
46 fn marshall_length(&self) -> usize {
47 debug_assert!(LBW == 2 || LBW == 4);
48 LBW + self.inner.len()
49 }
50}
51
52impl<const LBW: usize> Unmarshall for AmfUtf8<LBW> {
53 fn unmarshall(buf: &[u8]) -> Result<(Self, usize), AmfError> {
54 debug_assert!(LBW == 2 || LBW == 4);
55 let length;
56 if LBW == 2 {
57 if buf.len() < 2 {
58 return Err(AmfError::BufferTooSmall {
59 want: 2,
60 got: buf.len(),
61 });
62 }
63 length = u16::from_be_bytes(buf[0..2].try_into().unwrap()) as usize;
64 } else if LBW == 4 {
65 if buf.len() < 4 {
66 return Err(AmfError::BufferTooSmall {
67 want: 4,
68 got: buf.len(),
69 });
70 }
71 length = u32::from_be_bytes(buf[0..4].try_into().unwrap()) as usize;
72 } else {
73 return Err(AmfError::Custom("Invalid length byte width".to_string()));
74 }
75
76 let start = LBW;
77 let end = start + length;
78 if buf.len() < end {
79 return Err(AmfError::BufferTooSmall {
80 want: end,
81 got: buf.len(),
82 });
83 }
84 let value = std::str::from_utf8(&buf[start..end]).map_err(|e| AmfError::InvalidUtf8(e))?;
85 Ok((
86 Self {
87 inner: value.to_string(),
88 },
89 end,
90 ))
91 }
92}
93
94impl<const LBW: usize> TryFrom<&[u8]> for AmfUtf8<LBW> {
97 type Error = AmfError;
98
99 fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
100 Self::unmarshall(value).map(|(v, _)| v)
101 }
102}
103
104impl<const LBW: usize> TryFrom<Vec<u8>> for AmfUtf8<LBW> {
105 type Error = AmfError;
106
107 fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
108 Self::try_from(value.as_slice())
109 }
110}
111
112impl<const LBW: usize> TryFrom<AmfUtf8<LBW>> for Vec<u8> {
113 type Error = AmfError;
114
115 fn try_from(value: AmfUtf8<LBW>) -> Result<Self, Self::Error> {
116 value.marshall()
117 }
118}
119
120impl<const LBW: usize> TryFrom<String> for AmfUtf8<LBW> {
121 type Error = AmfError;
122
123 fn try_from(value: String) -> Result<Self, Self::Error> {
124 Self::new(value)
125 }
126}
127
128impl<const LBW: usize> TryFrom<AmfUtf8<LBW>> for String {
129 type Error = AmfError;
130
131 fn try_from(value: AmfUtf8<LBW>) -> Result<Self, Self::Error> {
132 Ok(value.inner)
133 }
134}
135
136impl<const LBW: usize> TryFrom<&str> for AmfUtf8<LBW> {
137 type Error = AmfError;
138
139 fn try_from(value: &str) -> Result<Self, Self::Error> {
140 Self::new_from_str(value)
141 }
142}
143
144impl<const LBW: usize> AsRef<str> for AmfUtf8<LBW> {
145 fn as_ref(&self) -> &str {
146 self.inner.as_ref()
147 }
148}
149impl<const LBW: usize> Deref for AmfUtf8<LBW> {
150 type Target = str;
151
152 fn deref(&self) -> &Self::Target {
153 Self::as_ref(self)
154 }
155}
156impl<const LBW: usize> Borrow<str> for AmfUtf8<LBW> {
157 fn borrow(&self) -> &str {
158 Self::as_ref(self)
159 }
160}
161
162impl<const LBW: usize> Display for AmfUtf8<LBW> {
163 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
164 write!(f, "{}", self.inner)
165 }
166}
167
168impl<const LBW: usize> Default for AmfUtf8<LBW> {
169 fn default() -> Self {
170 Self::new_from_str("").unwrap()
171 }
172}
173
174pub type Utf8 = AmfUtf8<2>;
177pub type Utf8Long = AmfUtf8<4>;
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182 use crate::traits::{Marshall, MarshallLength, Unmarshall};
183 use std::hash::{DefaultHasher, Hash, Hasher};
184
185 #[test]
187 fn new_valid_utf8_w2() {
188 let s = "a".repeat(u16::MAX as usize);
189 let amf_str = AmfUtf8::<2>::new_from_str(&s).unwrap();
190 assert_eq!(amf_str.inner, s);
191 }
192
193 #[test]
195 fn new_too_long_utf8_w2() {
196 let s = "a".repeat(u16::MAX as usize + 1);
197 assert!(matches!(
198 AmfUtf8::<2>::new_from_str(&s),
199 Err(AmfError::StringTooLong { max: 2, got: _ })
200 ));
201 }
202
203 #[test]
205 fn new_valid_utf8_w4() {
206 let s = "a".repeat(1000); let amf_str = AmfUtf8::<4>::new_from_str(&s).unwrap();
208 assert_eq!(amf_str.inner, s);
209 }
210
211 #[test]
213 fn try_into_bytes_w2() {
214 let amf_str = AmfUtf8::<2>::new_from_str("hello").unwrap();
215 let bytes = amf_str.marshall().unwrap();
216 assert_eq!(bytes, &[0x00, 0x05, b'h', b'e', b'l', b'l', b'o']);
217 }
218
219 #[test]
221 fn try_into_bytes_w4() {
222 let amf_str = AmfUtf8::<4>::new_from_str("world").unwrap();
223 let bytes = amf_str.marshall().unwrap();
224 assert_eq!(
225 bytes,
226 &[0x00, 0x00, 0x00, 0x05, b'w', b'o', b'r', b'l', b'd']
227 );
228 }
229
230 #[test]
232 fn try_from_bytes_w2() {
233 let data = [0x00, 0x05, b'h', b'e', b'l', b'l', b'o'];
234 let (amf_str, consumed) = AmfUtf8::<2>::unmarshall(&data).unwrap();
235 assert_eq!(amf_str.inner, "hello");
236 assert_eq!(consumed, 7);
237 }
238
239 #[test]
241 fn try_from_bytes_w4() {
242 let data = [0x00, 0x00, 0x00, 0x05, b'w', b'o', b'r', b'l', b'd'];
243 let (amf_str, consumed) = AmfUtf8::<4>::unmarshall(&data).unwrap();
244 assert_eq!(amf_str.inner, "world");
245 assert_eq!(consumed, 9);
246 }
247
248 #[test]
250 fn length_calculation() {
251 let amf_str = AmfUtf8::<2>::new_from_str("abc").unwrap();
252 assert_eq!(amf_str.marshall_length(), 2 + 3); let amf_str = AmfUtf8::<4>::new_from_str("abcde").unwrap();
255 assert_eq!(amf_str.marshall_length(), 4 + 5); }
257
258 #[test]
260 fn try_from_slice() {
261 let data = [0x00, 0x03, b'f', b'o', b'o'];
262 let amf_str: AmfUtf8<2> = data[..].try_into().unwrap();
263 assert_eq!(amf_str.inner, "foo");
264 }
265
266 #[test]
268 fn deref_and_as_ref() {
269 let amf_str = AmfUtf8::<2>::new_from_str("bar").unwrap();
270 assert_eq!(&*amf_str, "bar");
271 assert_eq!(amf_str.as_ref(), "bar");
272 }
273
274 #[test]
276 fn display_format() {
277 let amf_str = AmfUtf8::<2>::new_from_str("test").unwrap();
278 assert_eq!(format!("{}", amf_str), "test");
279 }
280
281 fn calculate_hash<T: Hash>(t: &T) -> u64 {
283 let mut hasher = DefaultHasher::new();
284 t.hash(&mut hasher);
285 hasher.finish()
286 }
287
288 #[test]
289 fn clone_preserves_equality() {
290 let original = AmfUtf8::<2>::new_from_str("hello").unwrap();
291 let cloned = original.clone();
292 assert_eq!(original, cloned);
294 }
295
296 #[test]
297 fn eq_and_neq_behaviour() {
298 let a = AmfUtf8::<4>::new_from_str("rust").unwrap();
299 let b_same = AmfUtf8::<4>::new_from_str("rust").unwrap();
300 let c_diff = AmfUtf8::<4>::new_from_str("Rust").unwrap();
301
302 assert_eq!(a, b_same);
304 assert_ne!(a, c_diff);
306 }
307
308 #[test]
309 fn equal_values_have_same_hash() {
310 let x = AmfUtf8::<2>::new_from_str("hash_me").unwrap();
311 let y = AmfUtf8::<2>::new_from_str("hash_me").unwrap();
312
313 let hx = calculate_hash(&x);
314 let hy = calculate_hash(&y);
315 assert_eq!(hx, hy, "Equal values should produce the same hash");
316 }
317
318 #[test]
319 fn different_values_have_different_hash() {
320 let x = AmfUtf8::<2>::new_from_str("foo").unwrap();
321 let y = AmfUtf8::<2>::new_from_str("bar").unwrap();
322
323 let hx = calculate_hash(&x);
324 let hy = calculate_hash(&y);
325 assert_ne!(hx, hy, "Different values should produce different hashes");
326 }
327
328 #[test]
329 fn clone_preserves_hash() {
330 let original = AmfUtf8::<4>::new_from_str("clone_hash").unwrap();
331 let cloned = original.clone();
332
333 let h1 = calculate_hash(&original);
334 let h2 = calculate_hash(&cloned);
335 assert_eq!(
336 h1, h2,
337 "Cloned instance should have the same hash as original"
338 );
339 }
340}