1use commonware_codec::{EncodeSize, Error as CodecError, FixedSize, Read, ReadExt, Write};
4use commonware_runtime::{Buf, BufMut, Error as RuntimeError, Resolver};
5use commonware_utils::{Hostname, IpAddrExt};
6use std::net::{IpAddr, SocketAddr};
7
8const INGRESS_SOCKET_PREFIX: u8 = 0;
9const INGRESS_DNS_PREFIX: u8 = 1;
10
11const ADDRESS_SYMMETRIC_PREFIX: u8 = 0;
12const ADDRESS_ASYMMETRIC_PREFIX: u8 = 1;
13
14#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
16pub enum Ingress {
17 Socket(SocketAddr),
19 Dns {
21 host: Hostname,
23 port: u16,
25 },
26}
27
28impl Ingress {
29 pub const fn port(&self) -> u16 {
31 match self {
32 Self::Socket(addr) => addr.port(),
33 Self::Dns { port, .. } => *port,
34 }
35 }
36
37 pub const fn ip(&self) -> Option<IpAddr> {
39 match self {
40 Self::Socket(addr) => Some(addr.ip()),
41 Self::Dns { .. } => None,
42 }
43 }
44
45 pub fn is_valid(&self, allow_private_ips: bool, allow_dns: bool) -> bool {
53 match self {
54 Self::Socket(addr) => allow_private_ips || IpAddrExt::is_global(&addr.ip()),
55 Self::Dns { .. } => allow_dns,
56 }
57 }
58
59 pub async fn resolve(
64 &self,
65 resolver: &impl Resolver,
66 ) -> Result<impl Iterator<Item = SocketAddr>, RuntimeError> {
67 match self {
68 Self::Socket(addr) => Ok(vec![*addr].into_iter()),
69 Self::Dns { host, port } => {
70 let ips = resolver.resolve(host.as_str()).await?;
71 if ips.is_empty() {
72 return Err(RuntimeError::ResolveFailed(host.to_string()));
73 }
74 Ok(ips
75 .into_iter()
76 .map(move |ip| SocketAddr::new(ip, *port))
77 .collect::<Vec<_>>()
78 .into_iter())
79 }
80 }
81 }
82
83 pub async fn resolve_filtered(
85 &self,
86 resolver: &impl Resolver,
87 allow_private_ips: bool,
88 ) -> Option<impl Iterator<Item = SocketAddr>> {
89 let addrs = self.resolve(resolver).await.ok()?;
90 Some(addrs.filter(move |addr| allow_private_ips || IpAddrExt::is_global(&addr.ip())))
91 }
92}
93
94impl Write for Ingress {
95 fn write(&self, buf: &mut impl BufMut) {
96 match self {
97 Self::Socket(addr) => {
98 INGRESS_SOCKET_PREFIX.write(buf);
99 addr.write(buf);
100 }
101 Self::Dns { host, port } => {
102 INGRESS_DNS_PREFIX.write(buf);
103 host.write(buf);
104 port.write(buf);
105 }
106 }
107 }
108}
109
110impl EncodeSize for Ingress {
111 fn encode_size(&self) -> usize {
112 u8::SIZE
113 + match self {
114 Self::Socket(addr) => addr.encode_size(),
115 Self::Dns { host, port } => host.encode_size() + port.encode_size(),
116 }
117 }
118}
119
120impl Read for Ingress {
121 type Cfg = ();
122
123 fn read_cfg(buf: &mut impl Buf, _cfg: &Self::Cfg) -> Result<Self, CodecError> {
124 let prefix = u8::read(buf)?;
125 match prefix {
126 INGRESS_SOCKET_PREFIX => {
127 let addr = SocketAddr::read(buf)?;
128 Ok(Self::Socket(addr))
129 }
130 INGRESS_DNS_PREFIX => {
131 let host = Hostname::read(buf)?;
132 let port = u16::read(buf)?;
133 Ok(Self::Dns { host, port })
134 }
135 other => Err(CodecError::InvalidEnum(other)),
136 }
137 }
138}
139
140impl From<SocketAddr> for Ingress {
141 fn from(addr: SocketAddr) -> Self {
142 Self::Socket(addr)
143 }
144}
145
146#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
148pub enum Address {
149 Symmetric(SocketAddr),
151 Asymmetric {
153 ingress: Ingress,
155 egress: SocketAddr,
157 },
158}
159
160impl Address {
161 pub fn ingress(&self) -> Ingress {
163 match self {
164 Self::Symmetric(addr) => Ingress::Socket(*addr),
165 Self::Asymmetric { ingress, .. } => ingress.clone(),
166 }
167 }
168
169 pub const fn egress_ip(&self) -> IpAddr {
171 match self {
172 Self::Symmetric(addr) => addr.ip(),
173 Self::Asymmetric { egress, .. } => egress.ip(),
174 }
175 }
176
177 pub const fn egress(&self) -> SocketAddr {
179 match self {
180 Self::Symmetric(addr) => *addr,
181 Self::Asymmetric { egress, .. } => *egress,
182 }
183 }
184}
185
186impl Write for Address {
187 fn write(&self, buf: &mut impl BufMut) {
188 match self {
189 Self::Symmetric(addr) => {
190 ADDRESS_SYMMETRIC_PREFIX.write(buf);
191 addr.write(buf);
192 }
193 Self::Asymmetric { ingress, egress } => {
194 ADDRESS_ASYMMETRIC_PREFIX.write(buf);
195 ingress.write(buf);
196 egress.write(buf);
197 }
198 }
199 }
200}
201
202impl EncodeSize for Address {
203 fn encode_size(&self) -> usize {
204 u8::SIZE
205 + match self {
206 Self::Symmetric(addr) => addr.encode_size(),
207 Self::Asymmetric { ingress, egress } => {
208 ingress.encode_size() + egress.encode_size()
209 }
210 }
211 }
212}
213
214impl Read for Address {
215 type Cfg = ();
216
217 fn read_cfg(buf: &mut impl Buf, _cfg: &Self::Cfg) -> Result<Self, CodecError> {
218 let prefix = u8::read(buf)?;
219 match prefix {
220 ADDRESS_SYMMETRIC_PREFIX => {
221 let addr = SocketAddr::read(buf)?;
222 Ok(Self::Symmetric(addr))
223 }
224 ADDRESS_ASYMMETRIC_PREFIX => {
225 let ingress = Ingress::read(buf)?;
226 let egress = SocketAddr::read(buf)?;
227 Ok(Self::Asymmetric { ingress, egress })
228 }
229 other => Err(CodecError::InvalidEnum(other)),
230 }
231 }
232}
233
234impl From<SocketAddr> for Address {
235 fn from(addr: SocketAddr) -> Self {
236 Self::Symmetric(addr)
237 }
238}
239
240#[cfg(feature = "arbitrary")]
241impl arbitrary::Arbitrary<'_> for Ingress {
242 fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
243 if u.ratio(1, 2)? {
244 Ok(Self::Socket(u.arbitrary()?))
245 } else {
246 let host: Hostname = u.arbitrary()?;
247 let port = u.arbitrary()?;
248 Ok(Self::Dns { host, port })
249 }
250 }
251}
252
253#[cfg(feature = "arbitrary")]
254impl arbitrary::Arbitrary<'_> for Address {
255 fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
256 if u.ratio(1, 2)? {
257 Ok(Self::Symmetric(u.arbitrary()?))
258 } else {
259 Ok(Self::Asymmetric {
260 ingress: u.arbitrary()?,
261 egress: u.arbitrary()?,
262 })
263 }
264 }
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270 use commonware_codec::{DecodeExt, Encode};
271 use commonware_runtime::IoBuf;
272 use commonware_utils::hostname;
273 use std::net::{Ipv4Addr, Ipv6Addr};
274
275 #[test]
276 fn test_ingress_socket_roundtrip() {
277 let addrs = [
278 SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080),
279 SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 443),
280 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 65535),
281 ];
282
283 for addr in addrs {
284 let ingress = Ingress::Socket(addr);
285 let encoded = ingress.encode();
286 let decoded = Ingress::decode(encoded).unwrap();
287 assert_eq!(ingress, decoded);
288 }
289 }
290
291 #[test]
292 fn test_ingress_dns_roundtrip() {
293 let cases = [
294 ("localhost", 8080),
295 ("example.com", 443),
296 ("a.b.c.d.e.f.g", 1234),
297 ];
298
299 for (host, port) in cases {
300 let ingress = Ingress::Dns {
301 host: hostname!(host),
302 port,
303 };
304 let encoded = ingress.encode();
305 let decoded = Ingress::decode(encoded).unwrap();
306 assert_eq!(ingress, decoded);
307 }
308 }
309
310 #[test]
311 fn test_ingress_dns_max_len_exceeded() {
312 let mut buf = Vec::new();
315 INGRESS_DNS_PREFIX.write(&mut buf);
316 let long_hostname = "a".repeat(300);
317 long_hostname.len().write(&mut buf);
318 buf.extend(long_hostname.as_bytes());
319 8080u16.write(&mut buf);
320
321 let result = Ingress::decode(IoBuf::from(buf));
322 assert!(result.is_err());
323 }
324
325 #[test]
326 fn test_ingress_port() {
327 let socket = Ingress::Socket(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080));
328 assert_eq!(socket.port(), 8080);
329
330 let dns = Ingress::Dns {
331 host: hostname!("example.com"),
332 port: 443,
333 };
334 assert_eq!(dns.port(), 443);
335 }
336
337 #[test]
338 fn test_ingress_ip() {
339 let socket = Ingress::Socket(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080));
340 assert_eq!(socket.ip(), Some(IpAddr::V4(Ipv4Addr::LOCALHOST)));
341
342 let dns = Ingress::Dns {
343 host: hostname!("example.com"),
344 port: 443,
345 };
346 assert_eq!(dns.ip(), None);
347 }
348
349 #[test]
350 fn test_address_symmetric_roundtrip() {
351 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 8080);
352 let address = Address::Symmetric(addr);
353 let encoded = address.encode();
354 let decoded = Address::decode(encoded).unwrap();
355 assert_eq!(address, decoded);
356 }
357
358 #[test]
359 fn test_address_asymmetric_socket_roundtrip() {
360 let ingress_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 8080);
361 let egress_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 9090);
362 let address = Address::Asymmetric {
363 ingress: Ingress::Socket(ingress_addr),
364 egress: egress_addr,
365 };
366 let encoded = address.encode();
367 let decoded = Address::decode(encoded).unwrap();
368 assert_eq!(address, decoded);
369 }
370
371 #[test]
372 fn test_address_asymmetric_dns_roundtrip() {
373 let egress_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 9090);
374 let address = Address::Asymmetric {
375 ingress: Ingress::Dns {
376 host: hostname!("node.example.com"),
377 port: 8080,
378 },
379 egress: egress_addr,
380 };
381 let encoded = address.encode();
382 let decoded = Address::decode(encoded).unwrap();
383 assert_eq!(address, decoded);
384 }
385
386 #[test]
387 fn test_address_helpers() {
388 let socket_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 8080);
389 let egress_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 9090);
390
391 let symmetric = Address::Symmetric(socket_addr);
392 assert_eq!(symmetric.ingress(), Ingress::Socket(socket_addr));
393 assert_eq!(
394 symmetric.egress_ip(),
395 IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))
396 );
397 assert_eq!(symmetric.egress(), socket_addr);
398
399 let asymmetric = Address::Asymmetric {
400 ingress: Ingress::Dns {
401 host: hostname!("example.com"),
402 port: 8080,
403 },
404 egress: egress_addr,
405 };
406 assert_eq!(
407 asymmetric.ingress(),
408 Ingress::Dns {
409 host: hostname!("example.com"),
410 port: 8080
411 }
412 );
413 assert_eq!(
414 asymmetric.egress_ip(),
415 IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1))
416 );
417 assert_eq!(asymmetric.egress(), egress_addr);
418 }
419
420 #[test]
421 fn test_from_socket_addr() {
422 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080);
423
424 let ingress: Ingress = addr.into();
425 assert_eq!(ingress, Ingress::Socket(addr));
426
427 let address: Address = addr.into();
428 assert_eq!(address, Address::Symmetric(addr));
429 }
430
431 #[test]
432 fn test_ingress_is_allowed() {
433 let public_socket =
434 Ingress::Socket(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 8080));
435 let private_socket = Ingress::Socket(SocketAddr::new(
436 IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
437 8080,
438 ));
439 let dns = Ingress::Dns {
440 host: hostname!("example.com"),
441 port: 8080,
442 };
443
444 assert!(public_socket.is_valid(false, false));
446 assert!(public_socket.is_valid(false, true));
447 assert!(public_socket.is_valid(true, false));
448 assert!(public_socket.is_valid(true, true));
449
450 assert!(!private_socket.is_valid(false, false));
452 assert!(!private_socket.is_valid(false, true));
453 assert!(private_socket.is_valid(true, false));
454 assert!(private_socket.is_valid(true, true));
455
456 assert!(!dns.is_valid(false, false));
458 assert!(dns.is_valid(false, true));
459 assert!(!dns.is_valid(true, false));
460 assert!(dns.is_valid(true, true));
461 }
462
463 #[cfg(feature = "arbitrary")]
464 mod conformance {
465 use super::*;
466 use commonware_codec::conformance::CodecConformance;
467
468 commonware_conformance::conformance_tests! {
469 CodecConformance<Ingress>,
470 CodecConformance<Address>,
471 }
472 }
473}