use crate::config::Config;
use crate::error;
use crate::error::Error;
use std::fmt::{Display, Formatter};
use std::net::{IpAddr, SocketAddr, ToSocketAddrs, UdpSocket};
use std::time::Instant;
use std::{cmp, fmt};
trait SocketAddrExt {
fn canonical(&self, target_family: &Self) -> Self;
}
impl SocketAddrExt for SocketAddr {
fn canonical(&self, target_family: &Self) -> Self {
if target_family.is_ipv4()
&& let IpAddr::V6(address_v6) = self.ip()
&& let Some(canonical_v4) = address_v6.to_ipv4()
{
return SocketAddr::new(IpAddr::V4(canonical_v4), self.port());
}
if target_family.is_ipv6()
&& let IpAddr::V4(address_v4) = self.ip()
{
let canonical_v6 = address_v4.to_ipv6_mapped();
return SocketAddr::new(IpAddr::V6(canonical_v6), self.port());
}
*self
}
}
#[derive(Debug)]
pub struct Session<'a> {
socket: &'a UdpSocket,
client_address: SocketAddr,
server_address: SocketAddr,
last_uplink: Instant,
last_downlink: Instant,
}
impl<'a> Session<'a> {
pub fn new(client_address: &SocketAddr, config: &Config, socket: &'a UdpSocket) -> Result<Self, Error> {
let mut server_addresses = (config.WGPROXY_SERVER.to_socket_addrs())
.map_err(|e| error!(with: e, "Failed to resolve server address"))?;
let server_address = server_addresses.next().ok_or(error!("Failed to resolve server address"))?;
let server_address = server_address.canonical(&config.WGPROXY_LISTEN);
let client_address = client_address.canonical(&config.WGPROXY_LISTEN);
let last_uplink = Instant::now();
let last_downlink = Instant::now();
Ok(Self { socket, client_address, server_address, last_uplink, last_downlink })
}
pub fn forward(&mut self, packet: &[u8], source: &SocketAddr) -> Result<(), Error> {
if self.client_address.eq(source) {
self.socket.send_to(packet, &self.server_address)?;
self.last_uplink = Instant::now();
Ok(())
} else if self.server_address.eq(source) {
self.socket.send_to(packet, &self.client_address)?;
self.last_downlink = Instant::now();
Ok(())
} else {
Err(error!("Unknown packet from {source}"))
}
}
pub fn atime(&self) -> Instant {
cmp::min(self.last_uplink, self.last_downlink)
}
}
impl Display for Session<'_> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
let socket = self.socket.local_addr().ok();
let last_uplink = self.last_uplink.elapsed();
let last_downlink = self.last_downlink.elapsed();
f.debug_struct("Session")
.field("socket", &socket)
.field("client_address", &self.client_address)
.field("server_address", &self.server_address)
.field("last_uplink", &last_uplink)
.field("last_downlink", &last_downlink)
.finish()
}
}