commonware_codec/types/
net.rs

1//! Codec implementations for network-related types
2
3use crate::{EncodeSize, Error, FixedSize, Read, ReadExt, Write};
4use bytes::{Buf, BufMut};
5use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
6
7impl Write for Ipv4Addr {
8    #[inline]
9    fn write(&self, buf: &mut impl BufMut) {
10        self.to_bits().to_be().write(buf);
11    }
12}
13
14impl Read for Ipv4Addr {
15    type Cfg = ();
16
17    #[inline]
18    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
19        let bits = <u32>::read(buf)?;
20        Ok(Ipv4Addr::from_bits(u32::from_be(bits)))
21    }
22}
23
24impl FixedSize for Ipv4Addr {
25    const SIZE: usize = u32::SIZE;
26}
27
28impl Write for Ipv6Addr {
29    #[inline]
30    fn write(&self, buf: &mut impl BufMut) {
31        self.to_bits().to_be().write(buf);
32    }
33}
34
35impl Read for Ipv6Addr {
36    type Cfg = ();
37
38    #[inline]
39    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
40        let bits = <u128>::read(buf)?;
41        Ok(Ipv6Addr::from_bits(u128::from_be(bits)))
42    }
43}
44
45impl FixedSize for Ipv6Addr {
46    const SIZE: usize = u128::SIZE;
47}
48
49impl Write for SocketAddrV4 {
50    #[inline]
51    fn write(&self, buf: &mut impl BufMut) {
52        self.ip().write(buf);
53        self.port().write(buf);
54    }
55}
56
57impl Read for SocketAddrV4 {
58    type Cfg = ();
59
60    #[inline]
61    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
62        let ip = Ipv4Addr::read(buf)?;
63        let port = u16::read(buf)?;
64        Ok(Self::new(ip, port))
65    }
66}
67
68impl FixedSize for SocketAddrV4 {
69    const SIZE: usize = Ipv4Addr::SIZE + u16::SIZE;
70}
71
72impl Write for SocketAddrV6 {
73    #[inline]
74    fn write(&self, buf: &mut impl BufMut) {
75        self.ip().write(buf);
76        self.port().write(buf);
77    }
78}
79
80impl Read for SocketAddrV6 {
81    type Cfg = ();
82
83    #[inline]
84    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
85        let address = Ipv6Addr::read(buf)?;
86        let port = u16::read(buf)?;
87        Ok(SocketAddrV6::new(address, port, 0, 0))
88    }
89}
90
91impl FixedSize for SocketAddrV6 {
92    const SIZE: usize = Ipv6Addr::SIZE + u16::SIZE;
93}
94
95// SocketAddr implementation
96impl Write for SocketAddr {
97    #[inline]
98    fn write(&self, buf: &mut impl BufMut) {
99        match self {
100            SocketAddr::V4(v4) => {
101                4u8.write(buf);
102                v4.write(buf);
103            }
104            SocketAddr::V6(v6) => {
105                6u8.write(buf);
106                v6.write(buf);
107            }
108        }
109    }
110}
111
112impl EncodeSize for SocketAddr {
113    #[inline]
114    fn encode_size(&self) -> usize {
115        (match self {
116            SocketAddr::V4(_) => SocketAddrV4::SIZE,
117            SocketAddr::V6(_) => SocketAddrV6::SIZE,
118        }) + u8::SIZE
119    }
120}
121
122impl Read for SocketAddr {
123    type Cfg = ();
124
125    #[inline]
126    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
127        let version = u8::read(buf)?;
128        match version {
129            4 => Ok(SocketAddr::V4(SocketAddrV4::read(buf)?)),
130            6 => Ok(SocketAddr::V6(SocketAddrV6::read(buf)?)),
131            _ => Err(Error::Invalid("SocketAddr", "Invalid version")),
132        }
133    }
134}
135
136#[cfg(test)]
137mod test {
138    use super::*;
139    use crate::{DecodeExt, Encode};
140    use bytes::Bytes;
141
142    #[test]
143    fn test_ipv4_addr() {
144        // Test various IPv4 addresses
145        let ips = [
146            Ipv4Addr::UNSPECIFIED,
147            Ipv4Addr::LOCALHOST,
148            Ipv4Addr::new(192, 168, 1, 1),
149            Ipv4Addr::new(255, 255, 255, 255),
150        ];
151
152        for ip in ips.iter() {
153            let encoded = ip.encode();
154            assert_eq!(encoded.len(), 4);
155            let decoded = Ipv4Addr::decode(encoded).unwrap();
156            assert_eq!(*ip, decoded);
157        }
158
159        // Test insufficient data
160        let insufficient = vec![0, 0, 0]; // 3 bytes instead of 4
161        assert!(Ipv4Addr::decode(Bytes::from(insufficient)).is_err());
162    }
163
164    #[test]
165    fn test_ipv6_addr() {
166        // Test various IPv6 addresses
167        let ips = [
168            Ipv6Addr::UNSPECIFIED,
169            Ipv6Addr::LOCALHOST,
170            Ipv6Addr::new(0x2001, 0x0db8, 0, 0, 0, 0, 0, 1),
171            Ipv6Addr::new(
172                0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff,
173            ),
174        ];
175
176        for ip in ips.iter() {
177            let encoded = ip.encode();
178            assert_eq!(encoded.len(), 16);
179            let decoded = Ipv6Addr::decode(encoded).unwrap();
180            assert_eq!(*ip, decoded);
181        }
182
183        // Test insufficient data
184        let insufficient = Bytes::from(vec![0u8; 15]); // 15 bytes instead of 16
185        assert!(Ipv6Addr::decode(insufficient).is_err());
186    }
187
188    #[test]
189    fn test_socket_addr_v4() {
190        // Test various SocketAddrV4 instances
191        let addrs = [
192            SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0),
193            SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8080),
194            SocketAddrV4::new(Ipv4Addr::new(192, 168, 1, 1), 65535),
195        ];
196
197        for addr in addrs.iter() {
198            let encoded = addr.encode();
199            assert_eq!(encoded.len(), 6);
200            let decoded = SocketAddrV4::decode(encoded).unwrap();
201            assert_eq!(*addr, decoded);
202        }
203
204        // Test insufficient data
205        let insufficient = Bytes::from(vec![0u8; 5]); // 5 bytes instead of 6
206        assert!(SocketAddrV4::decode(insufficient).is_err());
207    }
208
209    #[test]
210    fn test_socket_addr_v6() {
211        // Test various SocketAddrV6 instances
212        let addrs = [
213            SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0),
214            SocketAddrV6::new(Ipv6Addr::LOCALHOST, 8080, 0, 0),
215            SocketAddrV6::new(Ipv6Addr::new(0x2001, 0x0db8, 0, 0, 0, 0, 0, 1), 65535, 0, 0),
216        ];
217
218        for addr in addrs.iter() {
219            let encoded = addr.encode();
220            assert_eq!(encoded.len(), 18);
221            let decoded = SocketAddrV6::decode(encoded).unwrap();
222            assert_eq!(*addr, decoded);
223        }
224
225        // Test insufficient data
226        let insufficient = Bytes::from(vec![0u8; 17]); // 17 bytes instead of 18
227        assert!(SocketAddrV6::decode(insufficient).is_err());
228    }
229
230    #[test]
231    fn test_socket_addr() {
232        // Test SocketAddr::V4
233        let addr_v4 = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(196, 168, 0, 1), 8080));
234        let encoded_v4 = addr_v4.encode();
235        assert_eq!(encoded_v4.len(), 7);
236        assert_eq!(addr_v4.encode_size(), 7);
237        let decoded_v4 = SocketAddr::decode(encoded_v4).unwrap();
238        assert_eq!(addr_v4, decoded_v4);
239
240        // Test SocketAddr::V6
241        let addr_v6 = SocketAddr::V6(SocketAddrV6::new(
242            Ipv6Addr::new(0x2001, 0x0db8, 0xffff, 0x1234, 0x5678, 0x9abc, 0xdeff, 1),
243            8080,
244            0,
245            0,
246        ));
247        let encoded_v6 = addr_v6.encode();
248        assert_eq!(encoded_v6.len(), 19);
249        assert_eq!(addr_v6.encode_size(), 19);
250        let decoded_v6 = SocketAddr::decode(encoded_v6).unwrap();
251        assert_eq!(addr_v6, decoded_v6);
252
253        // Test invalid version
254        let invalid_version = [5]; // Neither 4 nor 6
255        assert!(matches!(
256            SocketAddr::decode(&invalid_version[..]),
257            Err(Error::Invalid(_, _))
258        ));
259
260        // Test insufficient data for V4
261        let mut insufficient_v4 = vec![4]; // Version byte
262        insufficient_v4.extend_from_slice(&[127, 0, 0, 1, 0x1f]); // IP + 1 byte of port (5 bytes total)
263        assert!(SocketAddr::decode(&insufficient_v4[..]).is_err());
264
265        // Test insufficient data for V6
266        let mut insufficient_v6 = vec![6]; // Version byte
267        insufficient_v6.extend_from_slice(&[0; 17]); // 17 bytes instead of 18
268        assert!(SocketAddr::decode(&insufficient_v6[..]).is_err());
269    }
270}