commonware_codec/types/
net.rs

1//! Codec implementations for network-related types
2
3use crate::{EncodeSize, Error, FixedSize, Read, ReadExt, Write};
4use bytes::{Buf, BufMut};
5use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
6
7impl Write for Ipv4Addr {
8    #[inline]
9    fn write(&self, buf: &mut impl BufMut) {
10        self.to_bits().write(buf);
11    }
12}
13
14impl Read for Ipv4Addr {
15    #[inline]
16    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
17        let bits = <u32>::read(buf)?;
18        Ok(Ipv4Addr::from_bits(bits))
19    }
20}
21
22impl FixedSize for Ipv4Addr {
23    const SIZE: usize = u32::SIZE;
24}
25
26impl Write for Ipv6Addr {
27    #[inline]
28    fn write(&self, buf: &mut impl BufMut) {
29        self.to_bits().write(buf);
30    }
31}
32
33impl Read for Ipv6Addr {
34    #[inline]
35    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
36        let bits = <u128>::read(buf)?;
37        Ok(Ipv6Addr::from_bits(bits))
38    }
39}
40
41impl FixedSize for Ipv6Addr {
42    const SIZE: usize = u128::SIZE;
43}
44
45impl Write for SocketAddrV4 {
46    #[inline]
47    fn write(&self, buf: &mut impl BufMut) {
48        self.ip().write(buf);
49        self.port().write(buf);
50    }
51}
52
53impl Read for SocketAddrV4 {
54    #[inline]
55    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
56        let ip = Ipv4Addr::read(buf)?;
57        let port = u16::read(buf)?;
58        Ok(Self::new(ip, port))
59    }
60}
61
62impl FixedSize for SocketAddrV4 {
63    const SIZE: usize = Ipv4Addr::SIZE + u16::SIZE;
64}
65
66impl Write for SocketAddrV6 {
67    #[inline]
68    fn write(&self, buf: &mut impl BufMut) {
69        self.ip().write(buf);
70        self.port().write(buf);
71    }
72}
73
74impl Read for SocketAddrV6 {
75    #[inline]
76    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
77        let address = Ipv6Addr::read(buf)?;
78        let port = u16::read(buf)?;
79        Ok(SocketAddrV6::new(address, port, 0, 0))
80    }
81}
82
83impl FixedSize for SocketAddrV6 {
84    const SIZE: usize = Ipv6Addr::SIZE + u16::SIZE;
85}
86
87// SocketAddr implementation
88impl Write for SocketAddr {
89    #[inline]
90    fn write(&self, buf: &mut impl BufMut) {
91        match self {
92            SocketAddr::V4(v4) => {
93                u8::write(&4, buf);
94                v4.write(buf);
95            }
96            SocketAddr::V6(v6) => {
97                u8::write(&6, buf);
98                v6.write(buf);
99            }
100        }
101    }
102}
103
104impl EncodeSize for SocketAddr {
105    #[inline]
106    fn encode_size(&self) -> usize {
107        (match self {
108            SocketAddr::V4(_) => SocketAddrV4::SIZE,
109            SocketAddr::V6(_) => SocketAddrV6::SIZE,
110        }) + u8::SIZE
111    }
112}
113
114impl Read for SocketAddr {
115    #[inline]
116    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
117        let version = u8::read(buf)?;
118        match version {
119            4 => Ok(SocketAddr::V4(SocketAddrV4::read(buf)?)),
120            6 => Ok(SocketAddr::V6(SocketAddrV6::read(buf)?)),
121            _ => Err(Error::Invalid("SocketAddr", "Invalid version")),
122        }
123    }
124}
125
126#[cfg(test)]
127mod test {
128    use super::*;
129    use crate::{DecodeExt, Encode};
130    use bytes::Bytes;
131
132    #[test]
133    fn test_ipv4_addr() {
134        // Test various IPv4 addresses
135        let ips = [
136            Ipv4Addr::UNSPECIFIED,
137            Ipv4Addr::LOCALHOST,
138            Ipv4Addr::new(192, 168, 1, 1),
139            Ipv4Addr::new(255, 255, 255, 255),
140        ];
141
142        for ip in ips.iter() {
143            let encoded = ip.encode();
144            assert_eq!(encoded.len(), 4);
145            let decoded = Ipv4Addr::decode(encoded).unwrap();
146            assert_eq!(*ip, decoded);
147        }
148
149        // Test insufficient data
150        let insufficient = vec![0, 0, 0]; // 3 bytes instead of 4
151        assert!(Ipv4Addr::decode(Bytes::from(insufficient)).is_err());
152    }
153
154    #[test]
155    fn test_ipv6_addr() {
156        // Test various IPv6 addresses
157        let ips = [
158            Ipv6Addr::UNSPECIFIED,
159            Ipv6Addr::LOCALHOST,
160            Ipv6Addr::new(0x2001, 0x0db8, 0, 0, 0, 0, 0, 1),
161            Ipv6Addr::new(
162                0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff,
163            ),
164        ];
165
166        for ip in ips.iter() {
167            let encoded = ip.encode();
168            assert_eq!(encoded.len(), 16);
169            let decoded = Ipv6Addr::decode(encoded).unwrap();
170            assert_eq!(*ip, decoded);
171        }
172
173        // Test insufficient data
174        let insufficient = Bytes::from(vec![0u8; 15]); // 15 bytes instead of 16
175        assert!(Ipv6Addr::decode(insufficient).is_err());
176    }
177
178    #[test]
179    fn test_socket_addr_v4() {
180        // Test various SocketAddrV4 instances
181        let addrs = [
182            SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0),
183            SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8080),
184            SocketAddrV4::new(Ipv4Addr::new(192, 168, 1, 1), 65535),
185        ];
186
187        for addr in addrs.iter() {
188            let encoded = addr.encode();
189            assert_eq!(encoded.len(), 6);
190            let decoded = SocketAddrV4::decode(encoded).unwrap();
191            assert_eq!(*addr, decoded);
192        }
193
194        // Test insufficient data
195        let insufficient = Bytes::from(vec![0u8; 5]); // 5 bytes instead of 6
196        assert!(SocketAddrV4::decode(insufficient).is_err());
197    }
198
199    #[test]
200    fn test_socket_addr_v6() {
201        // Test various SocketAddrV6 instances
202        let addrs = [
203            SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0),
204            SocketAddrV6::new(Ipv6Addr::LOCALHOST, 8080, 0, 0),
205            SocketAddrV6::new(Ipv6Addr::new(0x2001, 0x0db8, 0, 0, 0, 0, 0, 1), 65535, 0, 0),
206        ];
207
208        for addr in addrs.iter() {
209            let encoded = addr.encode();
210            assert_eq!(encoded.len(), 18);
211            let decoded = SocketAddrV6::decode(encoded).unwrap();
212            assert_eq!(*addr, decoded);
213        }
214
215        // Test insufficient data
216        let insufficient = Bytes::from(vec![0u8; 17]); // 17 bytes instead of 18
217        assert!(SocketAddrV6::decode(insufficient).is_err());
218    }
219
220    #[test]
221    fn test_socket_addr() {
222        // Test SocketAddr::V4
223        let addr_v4 = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(196, 168, 0, 1), 8080));
224        let encoded_v4 = addr_v4.encode();
225        assert_eq!(encoded_v4.len(), 7);
226        assert_eq!(addr_v4.encode_size(), 7);
227        let decoded_v4 = SocketAddr::decode(encoded_v4).unwrap();
228        assert_eq!(addr_v4, decoded_v4);
229
230        // Test SocketAddr::V6
231        let addr_v6 = SocketAddr::V6(SocketAddrV6::new(
232            Ipv6Addr::new(0x2001, 0x0db8, 0xffff, 0x1234, 0x5678, 0x9abc, 0xdeff, 1),
233            8080,
234            0,
235            0,
236        ));
237        let encoded_v6 = addr_v6.encode();
238        assert_eq!(encoded_v6.len(), 19);
239        assert_eq!(addr_v6.encode_size(), 19);
240        let decoded_v6 = SocketAddr::decode(encoded_v6).unwrap();
241        assert_eq!(addr_v6, decoded_v6);
242
243        // Test invalid version
244        let invalid_version = [5]; // Neither 4 nor 6
245        assert!(matches!(
246            SocketAddr::decode(&invalid_version[..]),
247            Err(Error::Invalid(_, _))
248        ));
249
250        // Test insufficient data for V4
251        let mut insufficient_v4 = vec![4]; // Version byte
252        insufficient_v4.extend_from_slice(&[127, 0, 0, 1, 0x1f]); // IP + 1 byte of port (5 bytes total)
253        assert!(SocketAddr::decode(&insufficient_v4[..]).is_err());
254
255        // Test insufficient data for V6
256        let mut insufficient_v6 = vec![6]; // Version byte
257        insufficient_v6.extend_from_slice(&[0; 17]); // 17 bytes instead of 18
258        assert!(SocketAddr::decode(&insufficient_v6[..]).is_err());
259    }
260}