1extern crate alloc;
19use alloc::string::String;
20use alloc::vec::Vec;
21
22use crate::buffer::{BufferReader, BufferWriter};
23use crate::encode::{CdrDecode, CdrEncode};
24use crate::error::{DecodeError, EncodeError};
25
26impl CdrEncode for str {
31 fn encode(&self, writer: &mut BufferWriter) -> Result<(), EncodeError> {
32 let bytes = self.as_bytes();
33 let len_with_nul = bytes
35 .len()
36 .checked_add(1)
37 .and_then(|n| u32::try_from(n).ok())
38 .ok_or(EncodeError::ValueOutOfRange {
39 message: "string length exceeds u32::MAX",
40 })?;
41 writer.write_u32(len_with_nul)?;
42 writer.write_bytes(bytes)?;
43 writer.write_u8(0)?;
44 Ok(())
45 }
46}
47
48impl CdrEncode for String {
49 fn encode(&self, writer: &mut BufferWriter) -> Result<(), EncodeError> {
50 self.as_str().encode(writer)
51 }
52}
53
54impl CdrDecode for String {
55 fn decode(reader: &mut BufferReader<'_>) -> Result<Self, DecodeError> {
56 let len_with_nul = reader.read_u32()? as usize;
57 if len_with_nul == 0 {
58 return Err(DecodeError::LengthExceeded {
59 announced: 0,
60 remaining: reader.remaining(),
61 offset: reader.position(),
62 });
63 }
64 if len_with_nul > reader.remaining() {
65 return Err(DecodeError::LengthExceeded {
66 announced: len_with_nul,
67 remaining: reader.remaining(),
68 offset: reader.position(),
69 });
70 }
71 let payload_len = len_with_nul - 1;
73 let offset = reader.position();
74 let bytes = reader.read_bytes(payload_len)?;
75 let s = core::str::from_utf8(bytes).map_err(|_| DecodeError::InvalidUtf8 { offset })?;
76 let owned = String::from(s);
77 let nul = reader.read_u8()?;
79 if nul != 0 {
80 return Err(DecodeError::InvalidUtf8 { offset });
81 }
82 Ok(owned)
83 }
84}
85
86impl<T: CdrEncode> CdrEncode for Vec<T> {
91 fn encode(&self, writer: &mut BufferWriter) -> Result<(), EncodeError> {
92 let len = u32::try_from(self.len()).map_err(|_| EncodeError::ValueOutOfRange {
93 message: "sequence length exceeds u32::MAX",
94 })?;
95 writer.write_u32(len)?;
96 for item in self {
97 item.encode(writer)?;
98 }
99 Ok(())
100 }
101}
102
103impl<T: CdrDecode> CdrDecode for Vec<T> {
104 fn decode(reader: &mut BufferReader<'_>) -> Result<Self, DecodeError> {
105 let len = reader.read_u32()? as usize;
106 if len > reader.remaining() {
109 return Err(DecodeError::LengthExceeded {
110 announced: len,
111 remaining: reader.remaining(),
112 offset: reader.position(),
113 });
114 }
115 let mut out = Vec::with_capacity(len);
116 for _ in 0..len {
117 out.push(T::decode(reader)?);
118 }
119 Ok(out)
120 }
121}
122
123impl<T: CdrEncode, const N: usize> CdrEncode for [T; N] {
128 fn encode(&self, writer: &mut BufferWriter) -> Result<(), EncodeError> {
129 for item in self {
130 item.encode(writer)?;
131 }
132 Ok(())
133 }
134}
135
136impl<T: CdrDecode + Default + Copy, const N: usize> CdrDecode for [T; N] {
137 fn decode(reader: &mut BufferReader<'_>) -> Result<Self, DecodeError> {
138 let mut out = [T::default(); N];
139 for slot in &mut out {
140 *slot = T::decode(reader)?;
141 }
142 Ok(out)
143 }
144}
145
146impl<T: CdrEncode> CdrEncode for Option<T> {
151 fn encode(&self, writer: &mut BufferWriter) -> Result<(), EncodeError> {
152 match self {
153 None => writer.write_u8(0),
154 Some(value) => {
155 writer.write_u8(1)?;
156 value.encode(writer)
157 }
158 }
159 }
160}
161
162impl<T: CdrDecode> CdrDecode for Option<T> {
163 fn decode(reader: &mut BufferReader<'_>) -> Result<Self, DecodeError> {
164 let offset = reader.position();
165 let flag = reader.read_u8()?;
166 match flag {
167 0 => Ok(None),
168 1 => Ok(Some(T::decode(reader)?)),
169 other => Err(DecodeError::InvalidBool {
172 value: other,
173 offset,
174 }),
175 }
176 }
177}
178
179use alloc::collections::BTreeMap;
188
189impl<K, V> CdrEncode for BTreeMap<K, V>
190where
191 K: CdrEncode + Ord,
192 V: CdrEncode,
193{
194 fn encode(&self, w: &mut BufferWriter) -> Result<(), EncodeError> {
195 let len = u32::try_from(self.len()).map_err(|_| EncodeError::ValueOutOfRange {
196 message: "map: entry-count > u32::MAX",
197 })?;
198 w.write_u32(len)?;
199 for (k, v) in self {
200 k.encode(w)?;
201 v.encode(w)?;
202 }
203 Ok(())
204 }
205}
206
207impl<K, V> CdrDecode for BTreeMap<K, V>
208where
209 K: CdrDecode + Ord,
210 V: CdrDecode,
211{
212 fn decode(r: &mut BufferReader<'_>) -> Result<Self, DecodeError> {
213 let len = r.read_u32()? as usize;
214 let mut map = BTreeMap::new();
215 for _ in 0..len {
216 let k = K::decode(r)?;
217 let v = V::decode(r)?;
218 map.insert(k, v);
219 }
220 Ok(map)
221 }
222}
223
224#[cfg(test)]
225mod tests {
226 #![allow(clippy::expect_used, clippy::panic, clippy::unwrap_used)]
227 use super::*;
228 use crate::Endianness;
229 use alloc::string::ToString;
230 use alloc::vec;
231
232 fn rt_le<T>(value: T)
233 where
234 T: CdrEncode + CdrDecode + PartialEq + core::fmt::Debug,
235 {
236 let mut w = BufferWriter::new(Endianness::Little);
237 value.encode(&mut w).expect("encode");
238 let bytes = w.into_bytes();
239 let mut r = BufferReader::new(&bytes, Endianness::Little);
240 let decoded = T::decode(&mut r).expect("decode");
241 assert_eq!(decoded, value);
242 assert_eq!(r.remaining(), 0);
243 }
244
245 #[test]
248 fn string_roundtrip_ascii() {
249 rt_le(String::from("hello"));
250 }
251
252 #[test]
253 fn string_roundtrip_unicode() {
254 rt_le(String::from("Hällo, 🌍 Welt"));
255 }
256
257 #[test]
258 fn string_roundtrip_empty() {
259 rt_le(String::new());
260 }
261
262 #[test]
263 fn string_wire_format_includes_null_terminator() {
264 let mut w = BufferWriter::new(Endianness::Little);
265 "ab".encode(&mut w).unwrap();
266 let bytes = w.into_bytes();
267 assert_eq!(bytes, vec![3, 0, 0, 0, b'a', b'b', 0]);
269 }
270
271 #[test]
272 fn string_decode_rejects_zero_length() {
273 let bytes = [0u8, 0, 0, 0]; let mut r = BufferReader::new(&bytes, Endianness::Little);
275 let res = String::decode(&mut r);
276 assert!(matches!(res, Err(DecodeError::LengthExceeded { .. })));
277 }
278
279 #[test]
280 fn string_decode_rejects_announced_overrun() {
281 let bytes = [100u8, 0, 0, 0, b'x'];
282 let mut r = BufferReader::new(&bytes, Endianness::Little);
283 let res = String::decode(&mut r);
284 assert!(matches!(res, Err(DecodeError::LengthExceeded { .. })));
285 }
286
287 #[test]
288 fn string_decode_rejects_missing_null_terminator() {
289 let bytes = [3u8, 0, 0, 0, b'a', b'b', b'x'];
291 let mut r = BufferReader::new(&bytes, Endianness::Little);
292 let res = String::decode(&mut r);
293 assert!(matches!(res, Err(DecodeError::InvalidUtf8 { .. })));
294 }
295
296 #[test]
299 fn sequence_u8_roundtrip() {
300 rt_le::<Vec<u8>>(vec![1, 2, 3, 4, 5]);
301 }
302
303 #[test]
304 fn sequence_u32_roundtrip() {
305 rt_le::<Vec<u32>>(vec![0xDEAD, 0xBEEF, 0x1234]);
306 }
307
308 #[test]
309 fn sequence_empty_roundtrip() {
310 rt_le::<Vec<u32>>(vec![]);
311 }
312
313 #[test]
314 fn sequence_string_roundtrip() {
315 rt_le::<Vec<String>>(vec!["alpha".to_string(), "beta".to_string()]);
316 }
317
318 #[test]
319 fn sequence_decode_rejects_overrun_length() {
320 let bytes = [0xE7u8, 0x03, 0, 0, b'x']; let mut r = BufferReader::new(&bytes, Endianness::Little);
323 let res = Vec::<u8>::decode(&mut r);
324 assert!(matches!(res, Err(DecodeError::LengthExceeded { .. })));
325 }
326
327 #[test]
328 fn sequence_alignment_4_byte_prefix() {
329 let mut w = BufferWriter::new(Endianness::Little);
331 1u8.encode(&mut w).unwrap();
332 vec![10u8, 20, 30].encode(&mut w).unwrap();
333 let bytes = w.into_bytes();
334 assert_eq!(bytes[0], 1); assert_eq!(&bytes[1..4], &[0, 0, 0]); assert_eq!(&bytes[4..8], &[3, 0, 0, 0]); assert_eq!(&bytes[8..11], &[10, 20, 30]); }
339
340 #[test]
343 fn array_u8_roundtrip() {
344 rt_le::<[u8; 4]>([1, 2, 3, 4]);
345 }
346
347 #[test]
348 fn array_u32_roundtrip() {
349 rt_le::<[u32; 3]>([100, 200, 300]);
350 }
351
352 #[test]
353 fn array_no_length_prefix() {
354 let mut w = BufferWriter::new(Endianness::Little);
355 [1u8, 2, 3].encode(&mut w).unwrap();
356 assert_eq!(w.into_bytes(), vec![1, 2, 3]);
358 }
359
360 #[test]
361 fn array_zero_size() {
362 let arr: [u32; 0] = [];
363 let mut w = BufferWriter::new(Endianness::Little);
364 arr.encode(&mut w).unwrap();
365 assert!(w.into_bytes().is_empty());
366 }
367
368 #[test]
371 fn optional_none_roundtrip() {
372 rt_le::<Option<u32>>(None);
373 }
374
375 #[test]
376 fn optional_some_roundtrip() {
377 rt_le::<Option<u32>>(Some(42));
378 }
379
380 #[test]
381 fn optional_some_string_roundtrip() {
382 rt_le::<Option<String>>(Some("hi".to_string()));
383 }
384
385 #[test]
386 fn optional_wire_format_none_is_zero_byte() {
387 let mut w = BufferWriter::new(Endianness::Little);
388 Option::<u32>::None.encode(&mut w).unwrap();
389 assert_eq!(w.into_bytes(), vec![0]);
390 }
391
392 #[test]
393 fn optional_wire_format_some_is_one_then_value() {
394 let mut w = BufferWriter::new(Endianness::Little);
395 Some(0xABCDu32).encode(&mut w).unwrap();
396 let bytes = w.into_bytes();
397 assert_eq!(bytes[0], 1); assert_eq!(&bytes[1..4], &[0, 0, 0]);
400 assert_eq!(&bytes[4..8], &[0xCD, 0xAB, 0, 0]);
401 }
402
403 #[test]
404 fn optional_decode_rejects_invalid_flag() {
405 let bytes = [0xFFu8];
406 let mut r = BufferReader::new(&bytes, Endianness::Little);
407 let res = Option::<u32>::decode(&mut r);
408 assert!(matches!(res, Err(DecodeError::InvalidBool { .. })));
409 }
410
411 #[test]
414 fn nested_optional_sequence_string() {
415 let value: Option<Vec<String>> = Some(vec!["a".to_string(), "bb".to_string()]);
416 rt_le(value);
417 }
418
419 #[test]
420 fn nested_array_of_optionals() {
421 let value: [Option<u32>; 3] = [Some(1), None, Some(3)];
422 rt_le(value);
423 }
424}