use core::net::{IpAddr, SocketAddr};
use crate::{Error, InternalErrorKind, netstack};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum Transport {
Tcp,
Udp,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum Family {
Any,
V4,
V6,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) struct Network {
pub transport: Transport,
pub family: Family,
}
pub(crate) fn parse_network(network: &str) -> Result<Network, Error> {
let n = match network {
"tcp" => Network {
transport: Transport::Tcp,
family: Family::Any,
},
"tcp4" => Network {
transport: Transport::Tcp,
family: Family::V4,
},
"tcp6" => Network {
transport: Transport::Tcp,
family: Family::V6,
},
"udp" => Network {
transport: Transport::Udp,
family: Family::Any,
},
"udp4" => Network {
transport: Transport::Udp,
family: Family::V4,
},
"udp6" => Network {
transport: Transport::Udp,
family: Family::V6,
},
_ => return Err(Error::Internal(InternalErrorKind::BadRequest)),
};
Ok(n)
}
pub(crate) fn split_host_port(addr: &str) -> Result<(&str, u16), Error> {
let bad = || Error::Internal(InternalErrorKind::BadRequest);
let (host, port_str) = if let Some(rest) = addr.strip_prefix('[') {
let close = rest.find(']').ok_or_else(bad)?;
let host = &rest[..close];
let after = &rest[close + 1..];
let port_str = after.strip_prefix(':').ok_or_else(bad)?;
(host, port_str)
} else {
let idx = addr.rfind(':').ok_or_else(bad)?;
let host = &addr[..idx];
let port_str = &addr[idx + 1..];
if host.contains(':') {
return Err(bad());
}
(host, port_str)
};
if port_str.is_empty() {
return Err(bad());
}
let port: u16 = port_str.parse().map_err(|_| bad())?;
Ok((host, port))
}
pub struct ConnectedUdpSocket {
sock: netstack::UdpSocket,
peer: SocketAddr,
}
impl ConnectedUdpSocket {
pub(crate) fn new(sock: netstack::UdpSocket, peer: SocketAddr) -> Self {
Self { sock, peer }
}
pub fn peer(&self) -> SocketAddr {
self.peer
}
pub fn local_addr(&self) -> SocketAddr {
self.sock.local_addr()
}
pub async fn send(&self, data: &[u8]) -> Result<(), Error> {
self.sock.send_to(self.peer, data).await.map_err(Into::into)
}
pub async fn recv(&self, buf: &mut [u8]) -> Result<usize, Error> {
loop {
let (from, n) = self.sock.recv_from(buf).await?;
if from == self.peer {
return Ok(n);
}
}
}
}
pub enum DialConn {
Tcp(netstack::TcpStream),
Udp(ConnectedUdpSocket),
}
pub(crate) fn check_family(family: Family, ip: IpAddr) -> Result<(), Error> {
let ok = match family {
Family::Any => true,
Family::V4 => ip.is_ipv4(),
Family::V6 => ip.is_ipv6(),
};
if ok {
Ok(())
} else {
Err(Error::Internal(InternalErrorKind::BadRequest))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_network_accepts_the_tsnet_set() {
assert_eq!(parse_network("tcp").unwrap().transport, Transport::Tcp);
assert_eq!(parse_network("tcp4").unwrap().family, Family::V4);
assert_eq!(parse_network("tcp6").unwrap().family, Family::V6);
assert_eq!(parse_network("udp").unwrap().transport, Transport::Udp);
assert_eq!(parse_network("udp4").unwrap().family, Family::V4);
assert_eq!(parse_network("udp6").unwrap().family, Family::V6);
}
#[test]
fn parse_network_rejects_unsupported() {
for n in ["", "sctp", "ip", "tcp5", "unix", "TCP"] {
assert!(parse_network(n).is_err(), "{n:?} must be rejected");
}
}
#[test]
fn split_host_port_ipv4() {
assert_eq!(split_host_port("1.2.3.4:80").unwrap(), ("1.2.3.4", 80));
}
#[test]
fn split_host_port_ipv6_bracketed() {
assert_eq!(
split_host_port("[2001:db8::1]:443").unwrap(),
("2001:db8::1", 443)
);
}
#[test]
fn split_host_port_name() {
assert_eq!(split_host_port("myhost:22").unwrap(), ("myhost", 22));
assert_eq!(
split_host_port("host.tail.ts.net:8080").unwrap(),
("host.tail.ts.net", 8080)
);
}
#[test]
fn split_host_port_rejects_missing_port() {
assert!(split_host_port("myhost").is_err());
assert!(split_host_port("1.2.3.4").is_err());
assert!(split_host_port("host:").is_err());
}
#[test]
fn split_host_port_rejects_bare_ipv6() {
assert!(split_host_port("2001:db8::1:443").is_err());
}
#[test]
fn split_host_port_rejects_bad_port() {
assert!(split_host_port("host:99999").is_err()); assert!(split_host_port("host:http").is_err()); assert!(split_host_port("host:-1").is_err());
}
#[test]
fn check_family_matches() {
let v4: IpAddr = "1.2.3.4".parse().unwrap();
let v6: IpAddr = "2001:db8::1".parse().unwrap();
assert!(check_family(Family::Any, v4).is_ok());
assert!(check_family(Family::V4, v4).is_ok());
assert!(check_family(Family::V6, v6).is_ok());
assert!(check_family(Family::V4, v6).is_err());
assert!(check_family(Family::V6, v4).is_err());
}
}