oxicode/features/
impl_std.rs

1//! Encode/Decode implementations for std-dependent types
2
3use crate::{
4    de::{read::Reader, Decode, Decoder},
5    enc::{write::Writer, Encode, Encoder},
6    error::Error,
7};
8use std::{
9    collections::{HashMap, HashSet},
10    ffi::{CStr, CString},
11    hash::{BuildHasher, Hash},
12    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
13    path::{Path, PathBuf},
14    sync::{Mutex, RwLock},
15    time::{Duration, SystemTime, UNIX_EPOCH},
16};
17
18// ===== HashMap<K, V> =====
19
20impl<K, V, S> Encode for HashMap<K, V, S>
21where
22    K: Encode,
23    V: Encode,
24{
25    fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), Error> {
26        (self.len() as u64).encode(encoder)?;
27        for (key, value) in self.iter() {
28            key.encode(encoder)?;
29            value.encode(encoder)?;
30        }
31        Ok(())
32    }
33}
34
35impl<K, V, S> Decode for HashMap<K, V, S>
36where
37    K: Decode + Eq + Hash,
38    V: Decode,
39    S: BuildHasher + Default,
40{
41    fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, Error> {
42        let len = u64::decode(decoder)? as usize;
43
44        let mut map = HashMap::with_capacity_and_hasher(len, S::default());
45        for _ in 0..len {
46            let key = K::decode(decoder)?;
47            let value = V::decode(decoder)?;
48            map.insert(key, value);
49        }
50        Ok(map)
51    }
52}
53
54// ===== HashSet<T> =====
55
56impl<T, S> Encode for HashSet<T, S>
57where
58    T: Encode,
59{
60    fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), Error> {
61        (self.len() as u64).encode(encoder)?;
62        for item in self.iter() {
63            item.encode(encoder)?;
64        }
65        Ok(())
66    }
67}
68
69impl<T, S> Decode for HashSet<T, S>
70where
71    T: Decode + Eq + Hash,
72    S: BuildHasher + Default,
73{
74    fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, Error> {
75        let len = u64::decode(decoder)? as usize;
76
77        let mut set = HashSet::with_capacity_and_hasher(len, S::default());
78        for _ in 0..len {
79            set.insert(T::decode(decoder)?);
80        }
81        Ok(set)
82    }
83}
84
85// ===== Mutex<T> =====
86
87impl<T: Encode> Encode for Mutex<T> {
88    fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), Error> {
89        let guard = self.lock().map_err(|_| Error::Custom {
90            message: "Mutex poisoned",
91        })?;
92        (*guard).encode(encoder)
93    }
94}
95
96impl<T: Decode> Decode for Mutex<T> {
97    fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, Error> {
98        Ok(Mutex::new(T::decode(decoder)?))
99    }
100}
101
102// ===== RwLock<T> =====
103
104impl<T: Encode> Encode for RwLock<T> {
105    fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), Error> {
106        let guard = self.read().map_err(|_| Error::Custom {
107            message: "RwLock poisoned",
108        })?;
109        (*guard).encode(encoder)
110    }
111}
112
113impl<T: Decode> Decode for RwLock<T> {
114    fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, Error> {
115        Ok(RwLock::new(T::decode(decoder)?))
116    }
117}
118
119// ===== Duration =====
120
121impl Encode for Duration {
122    fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), Error> {
123        self.as_secs().encode(encoder)?;
124        self.subsec_nanos().encode(encoder)
125    }
126}
127
128impl Decode for Duration {
129    fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, Error> {
130        let secs = u64::decode(decoder)?;
131        let nanos = u32::decode(decoder)?;
132
133        // Validate nanos < 1_000_000_000
134        if nanos >= 1_000_000_000 {
135            return Err(Error::InvalidDuration { secs, nanos });
136        }
137
138        Ok(Duration::new(secs, nanos))
139    }
140}
141
142// ===== SystemTime =====
143
144impl Encode for SystemTime {
145    fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), Error> {
146        match self.duration_since(UNIX_EPOCH) {
147            Ok(duration) => {
148                0u8.encode(encoder)?;
149                duration.encode(encoder)
150            }
151            Err(e) => {
152                1u8.encode(encoder)?;
153                e.duration().encode(encoder)
154            }
155        }
156    }
157}
158
159impl Decode for SystemTime {
160    fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, Error> {
161        let variant = u8::decode(decoder)?;
162        let duration = Duration::decode(decoder)?;
163
164        match variant {
165            0 => Ok(UNIX_EPOCH + duration),
166            1 => Ok(UNIX_EPOCH - duration),
167            _ => Err(Error::InvalidData {
168                message: "Invalid SystemTime variant",
169            }),
170        }
171    }
172}
173
174// ===== Path & PathBuf =====
175
176impl Encode for Path {
177    fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), Error> {
178        let os_str = self.as_os_str();
179        #[cfg(unix)]
180        {
181            use std::os::unix::ffi::OsStrExt;
182            let bytes = os_str.as_bytes();
183            (bytes.len() as u64).encode(encoder)?;
184            encoder.writer().write(bytes)
185        }
186        #[cfg(windows)]
187        {
188            use std::os::windows::ffi::OsStrExt;
189            let wide: Vec<u16> = os_str.encode_wide().collect();
190            (wide.len() as u64).encode(encoder)?;
191            for code_unit in wide {
192                code_unit.encode(encoder)?;
193            }
194            Ok(())
195        }
196        #[cfg(not(any(unix, windows)))]
197        {
198            // Fallback: convert to string lossy
199            let string = os_str.to_string_lossy();
200            string.encode(encoder)
201        }
202    }
203}
204
205impl Encode for PathBuf {
206    fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), Error> {
207        self.as_path().encode(encoder)
208    }
209}
210
211impl Decode for PathBuf {
212    fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, Error> {
213        #[cfg(unix)]
214        {
215            use std::ffi::OsStr;
216            use std::os::unix::ffi::OsStrExt;
217
218            let len = u64::decode(decoder)? as usize;
219            decoder.claim_bytes_read(len)?;
220
221            let mut bytes = alloc::vec![0u8; len];
222            decoder.reader().read(&mut bytes)?;
223
224            Ok(PathBuf::from(OsStr::from_bytes(&bytes)))
225        }
226        #[cfg(windows)]
227        {
228            use std::ffi::OsString;
229            use std::os::windows::ffi::OsStringExt;
230
231            let len = u64::decode(decoder)? as usize;
232            let mut wide = alloc::vec![0u16; len];
233            for code_unit in &mut wide {
234                *code_unit = u16::decode(decoder)?;
235            }
236
237            Ok(PathBuf::from(OsString::from_wide(&wide)))
238        }
239        #[cfg(not(any(unix, windows)))]
240        {
241            let string = String::decode(decoder)?;
242            Ok(PathBuf::from(string))
243        }
244    }
245}
246
247// ===== IpAddr =====
248
249impl Encode for IpAddr {
250    fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), Error> {
251        match self {
252            IpAddr::V4(addr) => {
253                0u8.encode(encoder)?;
254                addr.encode(encoder)
255            }
256            IpAddr::V6(addr) => {
257                1u8.encode(encoder)?;
258                addr.encode(encoder)
259            }
260        }
261    }
262}
263
264impl Decode for IpAddr {
265    fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, Error> {
266        let variant = u8::decode(decoder)?;
267        match variant {
268            0 => Ok(IpAddr::V4(Ipv4Addr::decode(decoder)?)),
269            1 => Ok(IpAddr::V6(Ipv6Addr::decode(decoder)?)),
270            _ => Err(Error::InvalidData {
271                message: "Invalid IpAddr variant",
272            }),
273        }
274    }
275}
276
277// ===== Ipv4Addr =====
278
279impl Encode for Ipv4Addr {
280    fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), Error> {
281        encoder.writer().write(&self.octets())
282    }
283}
284
285impl Decode for Ipv4Addr {
286    fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, Error> {
287        let mut octets = [0u8; 4];
288        decoder.reader().read(&mut octets)?;
289        Ok(Ipv4Addr::from(octets))
290    }
291}
292
293// ===== Ipv6Addr =====
294
295impl Encode for Ipv6Addr {
296    fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), Error> {
297        encoder.writer().write(&self.octets())
298    }
299}
300
301impl Decode for Ipv6Addr {
302    fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, Error> {
303        let mut octets = [0u8; 16];
304        decoder.reader().read(&mut octets)?;
305        Ok(Ipv6Addr::from(octets))
306    }
307}
308
309// ===== SocketAddr =====
310
311impl Encode for SocketAddr {
312    fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), Error> {
313        match self {
314            SocketAddr::V4(addr) => {
315                0u8.encode(encoder)?;
316                addr.encode(encoder)
317            }
318            SocketAddr::V6(addr) => {
319                1u8.encode(encoder)?;
320                addr.encode(encoder)
321            }
322        }
323    }
324}
325
326impl Decode for SocketAddr {
327    fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, Error> {
328        let variant = u8::decode(decoder)?;
329        match variant {
330            0 => Ok(SocketAddr::V4(SocketAddrV4::decode(decoder)?)),
331            1 => Ok(SocketAddr::V6(SocketAddrV6::decode(decoder)?)),
332            _ => Err(Error::InvalidData {
333                message: "Invalid SocketAddr variant",
334            }),
335        }
336    }
337}
338
339// ===== SocketAddrV4 =====
340
341impl Encode for SocketAddrV4 {
342    fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), Error> {
343        self.ip().encode(encoder)?;
344        self.port().encode(encoder)
345    }
346}
347
348impl Decode for SocketAddrV4 {
349    fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, Error> {
350        let ip = Ipv4Addr::decode(decoder)?;
351        let port = u16::decode(decoder)?;
352        Ok(SocketAddrV4::new(ip, port))
353    }
354}
355
356// ===== SocketAddrV6 =====
357
358impl Encode for SocketAddrV6 {
359    fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), Error> {
360        self.ip().encode(encoder)?;
361        self.port().encode(encoder)?;
362        self.flowinfo().encode(encoder)?;
363        self.scope_id().encode(encoder)
364    }
365}
366
367impl Decode for SocketAddrV6 {
368    fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, Error> {
369        let ip = Ipv6Addr::decode(decoder)?;
370        let port = u16::decode(decoder)?;
371        let flowinfo = u32::decode(decoder)?;
372        let scope_id = u32::decode(decoder)?;
373        Ok(SocketAddrV6::new(ip, port, flowinfo, scope_id))
374    }
375}
376
377// ===== CString =====
378
379impl Encode for CString {
380    fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), Error> {
381        // Encode as bytes without the null terminator
382        let bytes = self.as_bytes();
383        (bytes.len() as u64).encode(encoder)?;
384        encoder.writer().write(bytes)
385    }
386}
387
388impl Decode for CString {
389    fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, Error> {
390        let len = u64::decode(decoder)? as usize;
391        decoder.claim_bytes_read(len)?;
392
393        let mut bytes = alloc::vec![0u8; len];
394        decoder.reader().read(&mut bytes)?;
395
396        // Verify no null bytes in the middle
397        if bytes.contains(&0) {
398            return Err(Error::Custom {
399                message: "CString contains null byte",
400            });
401        }
402
403        CString::new(bytes).map_err(|_| Error::Custom {
404            message: "CString contains null byte",
405        })
406    }
407}
408
409// ===== CStr =====
410
411impl Encode for CStr {
412    fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), Error> {
413        let bytes = self.to_bytes();
414        (bytes.len() as u64).encode(encoder)?;
415        encoder.writer().write(bytes)
416    }
417}