use std::{
pin::Pin,
task::{Context, Poll},
};
use crate::{
multiaddr::{Multiaddr, Protocol},
transport::{DialOpts, ListenerId, TransportError, TransportEvent},
};
#[derive(Debug, Clone, Default)]
pub struct Transport<T> {
inner: T,
}
mod ipv4_global {
use std::net::Ipv4Addr;
#[must_use]
#[inline]
const fn is_reserved(a: Ipv4Addr) -> bool {
a.octets()[0] & 240 == 240 && !a.is_broadcast()
}
#[must_use]
#[inline]
const fn is_benchmarking(a: Ipv4Addr) -> bool {
a.octets()[0] == 198 && (a.octets()[1] & 0xfe) == 18
}
#[must_use]
#[inline]
const fn is_shared(a: Ipv4Addr) -> bool {
a.octets()[0] == 100 && (a.octets()[1] & 0b1100_0000 == 0b0100_0000)
}
#[must_use]
#[inline]
const fn is_private(a: Ipv4Addr) -> bool {
match a.octets() {
[10, ..] => true,
[172, b, ..] if b >= 16 && b <= 31 => true,
[192, 168, ..] => true,
_ => false,
}
}
#[must_use]
#[inline]
pub(crate) const fn is_global(a: Ipv4Addr) -> bool {
!(a.octets()[0] == 0 || is_private(a)
|| is_shared(a)
|| a.is_loopback()
|| a.is_link_local()
||(a.octets()[0] == 192 && a.octets()[1] == 0 && a.octets()[2] == 0)
|| a.is_documentation()
|| is_benchmarking(a)
|| is_reserved(a)
|| a.is_broadcast())
}
}
mod ipv6_global {
use std::net::Ipv6Addr;
#[must_use]
#[inline]
const fn is_unicast_link_local(a: Ipv6Addr) -> bool {
(a.segments()[0] & 0xffc0) == 0xfe80
}
#[must_use]
#[inline]
const fn is_unique_local(a: Ipv6Addr) -> bool {
(a.segments()[0] & 0xfe00) == 0xfc00
}
#[must_use]
#[inline]
const fn is_documentation(a: Ipv6Addr) -> bool {
(a.segments()[0] == 0x2001) && (a.segments()[1] == 0xdb8)
}
#[must_use]
#[inline]
pub(crate) const fn is_global(a: Ipv6Addr) -> bool {
!(a.is_unspecified()
|| a.is_loopback()
|| matches!(a.segments(), [0, 0, 0, 0, 0, 0xffff, _, _])
|| matches!(a.segments(), [0x64, 0xff9b, 1, _, _, _, _, _])
|| matches!(a.segments(), [0x100, 0, 0, 0, _, _, _, _])
|| (matches!(a.segments(), [0x2001, b, _, _, _, _, _, _] if b < 0x200)
&& !(
u128::from_be_bytes(a.octets()) == 0x2001_0001_0000_0000_0000_0000_0000_0001
|| u128::from_be_bytes(a.octets()) == 0x2001_0001_0000_0000_0000_0000_0000_0002
|| matches!(a.segments(), [0x2001, 3, _, _, _, _, _, _])
|| matches!(a.segments(), [0x2001, 4, 0x112, _, _, _, _, _])
|| matches!(a.segments(), [0x2001, b, _, _, _, _, _, _] if b >= 0x20 && b <= 0x2F)
))
|| is_documentation(a)
|| is_unique_local(a)
|| is_unicast_link_local(a))
}
}
impl<T> Transport<T> {
pub fn new(transport: T) -> Self {
Transport { inner: transport }
}
}
impl<T: crate::Transport + Unpin> crate::Transport for Transport<T> {
type Output = <T as crate::Transport>::Output;
type Error = <T as crate::Transport>::Error;
type ListenerUpgrade = <T as crate::Transport>::ListenerUpgrade;
type Dial = <T as crate::Transport>::Dial;
fn listen_on(
&mut self,
id: ListenerId,
addr: Multiaddr,
) -> Result<(), TransportError<Self::Error>> {
self.inner.listen_on(id, addr)
}
fn remove_listener(&mut self, id: ListenerId) -> bool {
self.inner.remove_listener(id)
}
fn dial(
&mut self,
addr: Multiaddr,
opts: DialOpts,
) -> Result<Self::Dial, TransportError<Self::Error>> {
match addr.iter().next() {
Some(Protocol::Ip4(a)) => {
if !ipv4_global::is_global(a) {
tracing::debug!(ip=%a, "Not dialing non global IP address");
return Err(TransportError::MultiaddrNotSupported(addr));
}
self.inner.dial(addr, opts)
}
Some(Protocol::Ip6(a)) => {
if !ipv6_global::is_global(a) {
tracing::debug!(ip=%a, "Not dialing non global IP address");
return Err(TransportError::MultiaddrNotSupported(addr));
}
self.inner.dial(addr, opts)
}
_ => {
tracing::debug!(address=%addr, "Not dialing unsupported Multiaddress");
Err(TransportError::MultiaddrNotSupported(addr))
}
}
}
fn poll(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<TransportEvent<Self::ListenerUpgrade, Self::Error>> {
Pin::new(&mut self.inner).poll(cx)
}
}