use parking_lot::RwLock;
use socket2::{Domain, Protocol, Type};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, Shutdown, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::str::FromStr;
use crate::device::{AllowedIps, Error};
use crate::noise::{Tunn, TunnResult};
#[derive(Default, Debug)]
pub struct Endpoint {
pub addr: Option<SocketAddr>,
pub conn: Option<socket2::Socket>,
}
pub struct Peer {
pub(crate) tunnel: Tunn,
index: u32,
endpoint: RwLock<Endpoint>,
allowed_ips: AllowedIps<()>,
preshared_key: Option<[u8; 32]>,
}
#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Debug)]
pub struct AllowedIP {
pub addr: IpAddr,
pub cidr: u8,
}
impl FromStr for AllowedIP {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let ip: Vec<&str> = s.split('/').collect();
if ip.len() != 2 {
return Err("Invalid IP format".to_owned());
}
let (addr, cidr) = (ip[0].parse::<IpAddr>(), ip[1].parse::<u8>());
match (addr, cidr) {
(Ok(addr @ IpAddr::V4(_)), Ok(cidr)) if cidr <= 32 => Ok(AllowedIP { addr, cidr }),
(Ok(addr @ IpAddr::V6(_)), Ok(cidr)) if cidr <= 128 => Ok(AllowedIP { addr, cidr }),
_ => Err("Invalid IP format".to_owned()),
}
}
}
impl Peer {
pub fn new(
tunnel: Tunn,
index: u32,
endpoint: Option<SocketAddr>,
allowed_ips: &[AllowedIP],
preshared_key: Option<[u8; 32]>,
) -> Peer {
Peer {
tunnel,
index,
endpoint: RwLock::new(Endpoint {
addr: endpoint,
conn: None,
}),
allowed_ips: allowed_ips.iter().map(|ip| (ip, ())).collect(),
preshared_key,
}
}
pub fn update_timers<'a>(&mut self, dst: &'a mut [u8]) -> TunnResult<'a> {
self.tunnel.update_timers(dst)
}
pub fn endpoint(&self) -> parking_lot::RwLockReadGuard<'_, Endpoint> {
self.endpoint.read()
}
pub(crate) fn endpoint_mut(&self) -> parking_lot::RwLockWriteGuard<'_, Endpoint> {
self.endpoint.write()
}
pub fn shutdown_endpoint(&self) {
if let Some(conn) = self.endpoint.write().conn.take() {
tracing::info!("Disconnecting from endpoint");
conn.shutdown(Shutdown::Both).unwrap();
}
}
pub fn set_endpoint(&self, addr: SocketAddr) {
let mut endpoint = self.endpoint.write();
if endpoint.addr != Some(addr) {
if let Some(conn) = endpoint.conn.take() {
conn.shutdown(Shutdown::Both).unwrap();
}
endpoint.addr = Some(addr);
}
}
pub fn connect_endpoint(
&self,
port: u16,
fwmark: Option<u32>,
) -> Result<socket2::Socket, Error> {
let mut endpoint = self.endpoint.write();
if endpoint.conn.is_some() {
return Err(Error::Connect("Connected".to_owned()));
}
let addr = endpoint
.addr
.expect("Attempt to connect to undefined endpoint");
let udp_conn =
socket2::Socket::new(Domain::for_address(addr), Type::DGRAM, Some(Protocol::UDP))?;
udp_conn.set_reuse_address(true)?;
let bind_addr = if addr.is_ipv4() {
SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, port).into()
} else {
SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, port, 0, 0).into()
};
udp_conn.bind(&bind_addr)?;
udp_conn.connect(&addr.into())?;
udp_conn.set_nonblocking(true)?;
#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
if let Some(fwmark) = fwmark {
udp_conn.set_mark(fwmark)?;
}
tracing::info!(
message="Connected endpoint",
port=port,
endpoint=?endpoint.addr.unwrap()
);
endpoint.conn = Some(udp_conn.try_clone().unwrap());
Ok(udp_conn)
}
pub fn is_allowed_ip<I: Into<IpAddr>>(&self, addr: I) -> bool {
self.allowed_ips.find(addr.into()).is_some()
}
pub fn allowed_ips(&self) -> impl Iterator<Item = (IpAddr, u8)> + '_ {
self.allowed_ips.iter().map(|(_, ip, cidr)| (ip, cidr))
}
pub fn time_since_last_handshake(&self) -> Option<std::time::Duration> {
self.tunnel.time_since_last_handshake()
}
pub fn persistent_keepalive(&self) -> Option<u16> {
self.tunnel.persistent_keepalive()
}
pub fn preshared_key(&self) -> Option<&[u8; 32]> {
self.preshared_key.as_ref()
}
pub fn index(&self) -> u32 {
self.index
}
}