1use bytes::{Buf, BufMut};
4use commonware_codec::{EncodeSize, Error as CodecError, FixedSize, Read, ReadExt, Write};
5use commonware_runtime::{Error as RuntimeError, Resolver};
6use commonware_utils::{Hostname, IpAddrExt};
7use std::net::{IpAddr, SocketAddr};
8
9const INGRESS_SOCKET_PREFIX: u8 = 0;
10const INGRESS_DNS_PREFIX: u8 = 1;
11
12const ADDRESS_SYMMETRIC_PREFIX: u8 = 0;
13const ADDRESS_ASYMMETRIC_PREFIX: u8 = 1;
14
15#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
17pub enum Ingress {
18 Socket(SocketAddr),
20 Dns {
22 host: Hostname,
24 port: u16,
26 },
27}
28
29impl Ingress {
30 pub const fn port(&self) -> u16 {
32 match self {
33 Self::Socket(addr) => addr.port(),
34 Self::Dns { port, .. } => *port,
35 }
36 }
37
38 pub const fn ip(&self) -> Option<IpAddr> {
40 match self {
41 Self::Socket(addr) => Some(addr.ip()),
42 Self::Dns { .. } => None,
43 }
44 }
45
46 pub fn is_valid(&self, allow_private_ips: bool, allow_dns: bool) -> bool {
54 match self {
55 Self::Socket(addr) => allow_private_ips || IpAddrExt::is_global(&addr.ip()),
56 Self::Dns { .. } => allow_dns,
57 }
58 }
59
60 pub async fn resolve(
65 &self,
66 resolver: &impl Resolver,
67 ) -> Result<impl Iterator<Item = SocketAddr>, RuntimeError> {
68 match self {
69 Self::Socket(addr) => Ok(vec![*addr].into_iter()),
70 Self::Dns { host, port } => {
71 let ips = resolver.resolve(host.as_str()).await?;
72 if ips.is_empty() {
73 return Err(RuntimeError::ResolveFailed(host.to_string()));
74 }
75 Ok(ips
76 .into_iter()
77 .map(move |ip| SocketAddr::new(ip, *port))
78 .collect::<Vec<_>>()
79 .into_iter())
80 }
81 }
82 }
83
84 pub async fn resolve_filtered(
86 &self,
87 resolver: &impl Resolver,
88 allow_private_ips: bool,
89 ) -> Option<impl Iterator<Item = SocketAddr>> {
90 let addrs = self.resolve(resolver).await.ok()?;
91 Some(addrs.filter(move |addr| allow_private_ips || IpAddrExt::is_global(&addr.ip())))
92 }
93}
94
95impl Write for Ingress {
96 fn write(&self, buf: &mut impl BufMut) {
97 match self {
98 Self::Socket(addr) => {
99 INGRESS_SOCKET_PREFIX.write(buf);
100 addr.write(buf);
101 }
102 Self::Dns { host, port } => {
103 INGRESS_DNS_PREFIX.write(buf);
104 host.write(buf);
105 port.write(buf);
106 }
107 }
108 }
109}
110
111impl EncodeSize for Ingress {
112 fn encode_size(&self) -> usize {
113 u8::SIZE
114 + match self {
115 Self::Socket(addr) => addr.encode_size(),
116 Self::Dns { host, port } => host.encode_size() + port.encode_size(),
117 }
118 }
119}
120
121impl Read for Ingress {
122 type Cfg = ();
123
124 fn read_cfg(buf: &mut impl Buf, _cfg: &Self::Cfg) -> Result<Self, CodecError> {
125 let prefix = u8::read(buf)?;
126 match prefix {
127 INGRESS_SOCKET_PREFIX => {
128 let addr = SocketAddr::read(buf)?;
129 Ok(Self::Socket(addr))
130 }
131 INGRESS_DNS_PREFIX => {
132 let host = Hostname::read(buf)?;
133 let port = u16::read(buf)?;
134 Ok(Self::Dns { host, port })
135 }
136 other => Err(CodecError::InvalidEnum(other)),
137 }
138 }
139}
140
141impl From<SocketAddr> for Ingress {
142 fn from(addr: SocketAddr) -> Self {
143 Self::Socket(addr)
144 }
145}
146
147#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
149pub enum Address {
150 Symmetric(SocketAddr),
152 Asymmetric {
154 ingress: Ingress,
156 egress: SocketAddr,
158 },
159}
160
161impl Address {
162 pub fn ingress(&self) -> Ingress {
164 match self {
165 Self::Symmetric(addr) => Ingress::Socket(*addr),
166 Self::Asymmetric { ingress, .. } => ingress.clone(),
167 }
168 }
169
170 pub const fn egress_ip(&self) -> IpAddr {
172 match self {
173 Self::Symmetric(addr) => addr.ip(),
174 Self::Asymmetric { egress, .. } => egress.ip(),
175 }
176 }
177
178 pub const fn egress(&self) -> SocketAddr {
180 match self {
181 Self::Symmetric(addr) => *addr,
182 Self::Asymmetric { egress, .. } => *egress,
183 }
184 }
185}
186
187impl Write for Address {
188 fn write(&self, buf: &mut impl BufMut) {
189 match self {
190 Self::Symmetric(addr) => {
191 ADDRESS_SYMMETRIC_PREFIX.write(buf);
192 addr.write(buf);
193 }
194 Self::Asymmetric { ingress, egress } => {
195 ADDRESS_ASYMMETRIC_PREFIX.write(buf);
196 ingress.write(buf);
197 egress.write(buf);
198 }
199 }
200 }
201}
202
203impl EncodeSize for Address {
204 fn encode_size(&self) -> usize {
205 u8::SIZE
206 + match self {
207 Self::Symmetric(addr) => addr.encode_size(),
208 Self::Asymmetric { ingress, egress } => {
209 ingress.encode_size() + egress.encode_size()
210 }
211 }
212 }
213}
214
215impl Read for Address {
216 type Cfg = ();
217
218 fn read_cfg(buf: &mut impl Buf, _cfg: &Self::Cfg) -> Result<Self, CodecError> {
219 let prefix = u8::read(buf)?;
220 match prefix {
221 ADDRESS_SYMMETRIC_PREFIX => {
222 let addr = SocketAddr::read(buf)?;
223 Ok(Self::Symmetric(addr))
224 }
225 ADDRESS_ASYMMETRIC_PREFIX => {
226 let ingress = Ingress::read(buf)?;
227 let egress = SocketAddr::read(buf)?;
228 Ok(Self::Asymmetric { ingress, egress })
229 }
230 other => Err(CodecError::InvalidEnum(other)),
231 }
232 }
233}
234
235impl From<SocketAddr> for Address {
236 fn from(addr: SocketAddr) -> Self {
237 Self::Symmetric(addr)
238 }
239}
240
241#[cfg(feature = "arbitrary")]
242impl arbitrary::Arbitrary<'_> for Ingress {
243 fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
244 if u.ratio(1, 2)? {
245 Ok(Self::Socket(u.arbitrary()?))
246 } else {
247 let host: Hostname = u.arbitrary()?;
248 let port = u.arbitrary()?;
249 Ok(Self::Dns { host, port })
250 }
251 }
252}
253
254#[cfg(feature = "arbitrary")]
255impl arbitrary::Arbitrary<'_> for Address {
256 fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
257 if u.ratio(1, 2)? {
258 Ok(Self::Symmetric(u.arbitrary()?))
259 } else {
260 Ok(Self::Asymmetric {
261 ingress: u.arbitrary()?,
262 egress: u.arbitrary()?,
263 })
264 }
265 }
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271 use commonware_codec::{DecodeExt, Encode};
272 use commonware_utils::hostname;
273 use std::net::{Ipv4Addr, Ipv6Addr};
274
275 #[test]
276 fn test_ingress_socket_roundtrip() {
277 let addrs = [
278 SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080),
279 SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 443),
280 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 65535),
281 ];
282
283 for addr in addrs {
284 let ingress = Ingress::Socket(addr);
285 let encoded = ingress.encode();
286 let decoded = Ingress::decode(encoded).unwrap();
287 assert_eq!(ingress, decoded);
288 }
289 }
290
291 #[test]
292 fn test_ingress_dns_roundtrip() {
293 let cases = [
294 ("localhost", 8080),
295 ("example.com", 443),
296 ("a.b.c.d.e.f.g", 1234),
297 ];
298
299 for (host, port) in cases {
300 let ingress = Ingress::Dns {
301 host: hostname!(host),
302 port,
303 };
304 let encoded = ingress.encode();
305 let decoded = Ingress::decode(encoded).unwrap();
306 assert_eq!(ingress, decoded);
307 }
308 }
309
310 #[test]
311 fn test_ingress_dns_max_len_exceeded() {
312 let mut buf = Vec::new();
315 INGRESS_DNS_PREFIX.write(&mut buf);
316 let long_hostname = "a".repeat(300);
317 long_hostname.len().write(&mut buf);
318 buf.extend(long_hostname.as_bytes());
319 8080u16.write(&mut buf);
320
321 let result = Ingress::decode(bytes::Bytes::from(buf));
322 assert!(result.is_err());
323 }
324
325 #[test]
326 fn test_ingress_port() {
327 let socket = Ingress::Socket(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080));
328 assert_eq!(socket.port(), 8080);
329
330 let dns = Ingress::Dns {
331 host: hostname!("example.com"),
332 port: 443,
333 };
334 assert_eq!(dns.port(), 443);
335 }
336
337 #[test]
338 fn test_ingress_ip() {
339 let socket = Ingress::Socket(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080));
340 assert_eq!(socket.ip(), Some(IpAddr::V4(Ipv4Addr::LOCALHOST)));
341
342 let dns = Ingress::Dns {
343 host: hostname!("example.com"),
344 port: 443,
345 };
346 assert_eq!(dns.ip(), None);
347 }
348
349 #[test]
350 fn test_address_symmetric_roundtrip() {
351 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 8080);
352 let address = Address::Symmetric(addr);
353 let encoded = address.encode();
354 let decoded = Address::decode(encoded).unwrap();
355 assert_eq!(address, decoded);
356 }
357
358 #[test]
359 fn test_address_asymmetric_socket_roundtrip() {
360 let ingress_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 8080);
361 let egress_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 9090);
362 let address = Address::Asymmetric {
363 ingress: Ingress::Socket(ingress_addr),
364 egress: egress_addr,
365 };
366 let encoded = address.encode();
367 let decoded = Address::decode(encoded).unwrap();
368 assert_eq!(address, decoded);
369 }
370
371 #[test]
372 fn test_address_asymmetric_dns_roundtrip() {
373 let egress_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 9090);
374 let address = Address::Asymmetric {
375 ingress: Ingress::Dns {
376 host: hostname!("node.example.com"),
377 port: 8080,
378 },
379 egress: egress_addr,
380 };
381 let encoded = address.encode();
382 let decoded = Address::decode(encoded).unwrap();
383 assert_eq!(address, decoded);
384 }
385
386 #[test]
387 fn test_address_helpers() {
388 let socket_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 8080);
389 let egress_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 9090);
390
391 let symmetric = Address::Symmetric(socket_addr);
392 assert_eq!(symmetric.ingress(), Ingress::Socket(socket_addr));
393 assert_eq!(
394 symmetric.egress_ip(),
395 IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))
396 );
397 assert_eq!(symmetric.egress(), socket_addr);
398
399 let asymmetric = Address::Asymmetric {
400 ingress: Ingress::Dns {
401 host: hostname!("example.com"),
402 port: 8080,
403 },
404 egress: egress_addr,
405 };
406 assert_eq!(
407 asymmetric.ingress(),
408 Ingress::Dns {
409 host: hostname!("example.com"),
410 port: 8080
411 }
412 );
413 assert_eq!(
414 asymmetric.egress_ip(),
415 IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1))
416 );
417 assert_eq!(asymmetric.egress(), egress_addr);
418 }
419
420 #[test]
421 fn test_from_socket_addr() {
422 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080);
423
424 let ingress: Ingress = addr.into();
425 assert_eq!(ingress, Ingress::Socket(addr));
426
427 let address: Address = addr.into();
428 assert_eq!(address, Address::Symmetric(addr));
429 }
430
431 #[test]
432 fn test_ingress_is_allowed() {
433 let public_socket =
434 Ingress::Socket(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 8080));
435 let private_socket = Ingress::Socket(SocketAddr::new(
436 IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
437 8080,
438 ));
439 let dns = Ingress::Dns {
440 host: hostname!("example.com"),
441 port: 8080,
442 };
443
444 assert!(public_socket.is_valid(false, false));
446 assert!(public_socket.is_valid(false, true));
447 assert!(public_socket.is_valid(true, false));
448 assert!(public_socket.is_valid(true, true));
449
450 assert!(!private_socket.is_valid(false, false));
452 assert!(!private_socket.is_valid(false, true));
453 assert!(private_socket.is_valid(true, false));
454 assert!(private_socket.is_valid(true, true));
455
456 assert!(!dns.is_valid(false, false));
458 assert!(dns.is_valid(false, true));
459 assert!(!dns.is_valid(true, false));
460 assert!(dns.is_valid(true, true));
461 }
462
463 #[cfg(feature = "arbitrary")]
464 mod conformance {
465 use super::*;
466 use commonware_codec::conformance::CodecConformance;
467
468 commonware_conformance::conformance_tests! {
469 CodecConformance<Ingress>,
470 CodecConformance<Address>,
471 }
472 }
473}