use bytes::BytesMut;
use duplicate::duplicate_item;
use std::{
io,
net::{Ipv4Addr, Ipv6Addr, SocketAddr},
sync::{Arc, atomic::AtomicU16},
};
use tokio::sync::{Mutex, OwnedMutexGuard, mpsc};
use zerocopy::{FromBytes, IntoBytes};
use crate::{
packet::{
IpNextProtocol, Ipv4, Ipv4Header, Ipv6, Ipv6Header, Packet, PacketBufPool, PseudoHeaderV4,
PseudoHeaderV6, Udp, UdpHeader, checksum, checksum_udp,
},
tun::{
MtuWatcher,
channel::{Ipv4Fragments, TunChannelRx, TunChannelTx},
},
udp::{UdpRecv, UdpSend, UdpTransportFactory},
};
use super::UdpTransportFactoryParams;
#[derive(Clone)]
pub struct UdpChannelTx {
source_ip_v4: Ipv4Addr,
source_ip_v6: Ipv6Addr,
source_port: u16,
connection_id: u32,
udp_tx_v4: mpsc::Sender<Packet<Ipv4<Udp>>>,
udp_tx_v6: mpsc::Sender<Packet<Ipv6<Udp>>>,
}
type Ipv4UdpReceiver = mpsc::Receiver<Packet<Ipv4<Udp>>>;
type Ipv6UdpReceiver = mpsc::Receiver<Packet<Ipv6<Udp>>>;
pub struct UdpChannelV4Rx {
udp_rx_v4: OwnedMutexGuard<Ipv4UdpReceiver>,
}
pub struct UdpChannelV6Rx {
udp_rx_v6: OwnedMutexGuard<Ipv6UdpReceiver>,
}
pub struct UdpChannelFactory {
source_ip_v4: Ipv4Addr,
source_ip_v6: Ipv6Addr,
udp_tx_v4: mpsc::Sender<Packet<Ipv4<Udp>>>,
udp_tx_v6: mpsc::Sender<Packet<Ipv6<Udp>>>,
udp_rx_v4: Arc<Mutex<Ipv4UdpReceiver>>,
udp_rx_v6: Arc<Mutex<Ipv6UdpReceiver>>,
}
pub fn new_udp_tun_channel(
capacity: usize,
source_ip_v4: Ipv4Addr,
source_ip_v6: Ipv6Addr,
tun_link_mtu: MtuWatcher,
) -> (TunChannelTx, TunChannelRx, UdpChannelFactory) {
let (udp_v4, tun_v4) = UdpChannelV4::new_pair(capacity);
let (udp_v6, tun_v6) = UdpChannelV6::new_pair(capacity);
let tun_tx = TunChannelTx {
tun_tx_v4: tun_v4.tx,
tun_tx_v6: tun_v6.tx,
fragments_v4: Ipv4Fragments::default(),
};
let tun_rx = TunChannelRx {
tun_rx_v4: tun_v4.rx,
tun_rx_v6: tun_v6.rx,
mtu: tun_link_mtu,
};
let udp = UdpChannelFactory::new(source_ip_v4, udp_v4, source_ip_v6, udp_v6);
(tun_tx, tun_rx, udp)
}
pub(crate) struct UdpChannelV4 {
pub tx: mpsc::Sender<Packet<Ipv4<Udp>>>,
pub rx: mpsc::Receiver<Packet<Ipv4<Udp>>>,
}
pub(crate) struct UdpChannelV6 {
pub tx: mpsc::Sender<Packet<Ipv6<Udp>>>,
pub rx: mpsc::Receiver<Packet<Ipv6<Udp>>>,
}
#[duplicate_item(
UdpChannel;
[UdpChannelV4];
[UdpChannelV6];
)]
impl UdpChannel {
pub(crate) fn new_pair(capacity: usize) -> (Self, Self) {
let (a_tx, b_rx) = mpsc::channel(capacity);
let (b_tx, a_rx) = mpsc::channel(capacity);
(Self { tx: a_tx, rx: a_rx }, Self { tx: b_tx, rx: b_rx })
}
}
impl UdpChannelFactory {
pub(crate) fn new(
v4_src: Ipv4Addr,
v4: UdpChannelV4,
v6_src: Ipv6Addr,
v6: UdpChannelV6,
) -> Self {
Self {
source_ip_v4: v4_src,
source_ip_v6: v6_src,
udp_tx_v4: v4.tx,
udp_tx_v6: v6.tx,
udp_rx_v4: Arc::new(Mutex::new(v4.rx)),
udp_rx_v6: Arc::new(Mutex::new(v6.rx)),
}
}
}
impl UdpTransportFactory for UdpChannelFactory {
type SendV4 = UdpChannelTx;
type SendV6 = UdpChannelTx;
type RecvV4 = UdpChannelV4Rx;
type RecvV6 = UdpChannelV6Rx;
async fn bind(
&mut self,
params: &UdpTransportFactoryParams,
) -> io::Result<((Self::SendV4, Self::RecvV4), (Self::SendV6, Self::RecvV6))> {
let connection_id = rand::random_range(1..=u32::MAX);
let source_port = match params.port {
0 => rand::random_range(1..=u16::MAX),
p => p,
};
let channel_tx = UdpChannelTx {
source_ip_v4: self.source_ip_v4,
source_ip_v6: self.source_ip_v6,
source_port,
connection_id,
udp_tx_v4: self.udp_tx_v4.clone(),
udp_tx_v6: self.udp_tx_v6.clone(),
};
let channel_rx_v4 = UdpChannelV4Rx {
udp_rx_v4: self.udp_rx_v4.clone().lock_owned().await,
};
let channel_rx_v6 = UdpChannelV6Rx {
udp_rx_v6: self.udp_rx_v6.clone().lock_owned().await,
};
Ok((
(channel_tx.clone(), channel_rx_v4),
(channel_tx, channel_rx_v6),
))
}
}
impl UdpSend for UdpChannelTx {
type SendManyBuf = ();
async fn send_to(&self, udp_payload: Packet, destination: SocketAddr) -> io::Result<()> {
match destination {
SocketAddr::V4(dest) => {
let ipv4 = create_ipv4_payload(
self.source_ip_v4,
self.source_port,
*dest.ip(),
dest.port(),
&udp_payload,
);
self.udp_tx_v4.send(ipv4).await.expect("receiver exists");
}
SocketAddr::V6(dest) => {
let ipv6 = create_ipv6_payload(
&self.source_ip_v6,
self.source_port,
dest.ip(),
dest.port(),
&udp_payload,
self.connection_id,
);
self.udp_tx_v6.send(ipv6).await.expect("receiver exists");
}
};
Ok(())
}
}
impl UdpRecv for UdpChannelV4Rx {
type RecvManyBuf = ();
async fn recv_from(&mut self, _pool: &mut PacketBufPool) -> io::Result<(Packet, SocketAddr)> {
let ipv4 = self.udp_rx_v4.recv().await.expect("sender exists");
let source_addr = ipv4.header.source();
let udp = ipv4.into_payload();
let source_port = udp.header.source_port.get();
let inner_packet = udp.into_payload();
let socket_addr = SocketAddr::from((source_addr, source_port));
Ok((inner_packet, socket_addr))
}
}
impl UdpRecv for UdpChannelV6Rx {
type RecvManyBuf = ();
async fn recv_from(&mut self, _pool: &mut PacketBufPool) -> io::Result<(Packet, SocketAddr)> {
let ipv6 = self.udp_rx_v6.recv().await.expect("sender exists");
let source_addr = ipv6.header.source();
let udp = ipv6.into_payload();
let source_port = udp.header.source_port.get();
let inner_packet = udp.into_payload();
let socket_addr = SocketAddr::from((source_addr, source_port));
Ok((inner_packet, socket_addr))
}
}
fn create_ipv4_payload(
source_ip: Ipv4Addr,
source_port: u16,
destination_ip: Ipv4Addr,
destination_port: u16,
udp_payload: &[u8],
) -> Packet<Ipv4<Udp>> {
static NEXT_ID: AtomicU16 = AtomicU16::new(1);
let identification = NEXT_ID.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
create_ipv4_payload_inner(
source_ip,
source_port,
destination_ip,
destination_port,
udp_payload,
identification,
)
}
fn create_ipv4_payload_inner(
source_ip: Ipv4Addr,
source_port: u16,
destination_ip: Ipv4Addr,
destination_port: u16,
udp_payload: &[u8],
identification: u16,
) -> Packet<Ipv4<Udp>> {
let udp_len: u16 = (UdpHeader::LEN + udp_payload.len()).try_into().unwrap();
let total_len = u16::try_from(Ipv4Header::LEN).unwrap() + udp_len;
let mut packet = BytesMut::zeroed(usize::from(total_len));
let ipv4 = Ipv4::<Udp>::mut_from_bytes(&mut packet).expect("bad IP packet buffer");
ipv4.header =
Ipv4Header::new_for_length(source_ip, destination_ip, IpNextProtocol::Udp, udp_len);
ipv4.header.identification = identification.into();
ipv4.header.header_checksum = checksum(&[ipv4.header.as_bytes()]).into();
let udp = &mut ipv4.payload;
udp.header.source_port = source_port.into();
udp.header.destination_port = destination_port.into();
udp.header.length = udp_len.into();
udp.payload.copy_from_slice(udp_payload);
let csum = checksum_udp(
PseudoHeaderV4::from_udp(
source_ip.octets().into(),
destination_ip.octets().into(),
udp,
),
udp.as_bytes(),
);
udp.header.checksum = csum.into();
Packet::from_bytes(packet)
.try_into_ip()
.and_then(|p| p.try_into_ipvx())
.expect("packet is valid")
.expect_left("packet is ipv4")
.try_into_udp()
.expect("packet is udp")
}
fn create_ipv6_payload(
source_ip: &Ipv6Addr,
source_port: u16,
destination_ip: &Ipv6Addr,
destination_port: u16,
udp_payload: &[u8],
connection_id: u32,
) -> Packet<Ipv6<Udp>> {
let udp_len: u16 = (UdpHeader::LEN + udp_payload.len()).try_into().unwrap();
let total_len = u16::try_from(Ipv6Header::LEN).unwrap() + udp_len;
let mut packet = BytesMut::zeroed(usize::from(total_len));
let ipv6 = Ipv6::<Udp>::mut_from_bytes(&mut packet).expect("bad IP packet buffer");
ipv6.header.set_version(6);
ipv6.header.set_flow_label(connection_id);
ipv6.header.payload_length = udp_len.into();
ipv6.header.next_header = IpNextProtocol::Udp;
ipv6.header.source_address = source_ip.to_bits().into();
ipv6.header.destination_address = destination_ip.to_bits().into();
ipv6.header.hop_limit = 64;
let udp = &mut ipv6.payload;
udp.header.source_port = source_port.into();
udp.header.destination_port = destination_port.into();
udp.header.length = udp_len.into();
udp.payload.copy_from_slice(udp_payload);
let csum = checksum_udp(
PseudoHeaderV6::from_udp(
source_ip.octets().into(),
destination_ip.octets().into(),
udp,
),
udp.as_bytes(),
);
udp.header.checksum = csum.into();
Packet::from_bytes(packet)
.try_into_ip()
.and_then(|p| p.try_into_ipvx())
.expect("packet is valid")
.expect_right("packet is ipv6")
.try_into_udp()
.expect("packet is udp")
}
#[cfg(test)]
mod tests {
use super::*;
use std::{
net::{Ipv4Addr, Ipv6Addr},
str::FromStr,
};
#[test]
fn test_create_ipv4_payload_checksum() {
let src_ip = Ipv4Addr::new(10, 0, 0, 1);
let dst_ip = Ipv4Addr::new(192, 168, 1, 1);
let src_port = 12345u16;
let dst_port = 51820u16;
let payload = b"hello wireguard";
let packet = create_ipv4_payload(src_ip, src_port, dst_ip, dst_port, payload);
assert_eq!(
packet.payload.header.checksum.get(),
0xDEC6,
"UDP checksum invalid"
);
}
#[test]
fn test_create_ipv6_payload_checksum() {
let src_ip = Ipv6Addr::from_str("fc00::1").unwrap();
let dst_ip = Ipv6Addr::from_str("2606:4700:4700::1111").unwrap();
let src_port = 12345u16;
let dst_port = 51820u16;
let connection_id = 0xABCDE;
let payload = b"hello wireguard ipv6";
let packet =
create_ipv6_payload(&src_ip, src_port, &dst_ip, dst_port, payload, connection_id);
assert_eq!(
packet.payload.header.checksum.get(),
0x0987,
"UDP checksum invalid"
);
}
#[test]
fn test_create_ipv4_payload_checksum_zero() {
let src_ip = Ipv4Addr::new(10, 0, 0, 1);
let dst_ip = Ipv4Addr::new(192, 168, 1, 1);
let src_port = 12345u16;
let dst_port = 51820u16;
let mut payload = *b"\0\0";
let packet = create_ipv4_payload(src_ip, src_port, dst_ip, dst_port, &payload);
payload = *packet.payload.header.checksum.as_ref();
let packet = create_ipv4_payload(src_ip, src_port, dst_ip, dst_port, &payload);
assert_eq!(
packet.payload.header.checksum.get(),
0xffff,
"UDP checksum invalid"
);
}
}