#![allow(unsafe_code)]
use super::buffer_pool::BufferPool;
use super::pacer::Pacer;
use crate::crypto::aes_session::AesSession;
use crate::transport::bandwidth_estimator;
use crate::transport::handshake::{ClientHello, HandshakeResponse, HandshakeServer};
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::io::{self, Result as IoResult};
use tokio::net::UdpSocket;
pub struct UdpTransport {
socket: Arc<UdpSocket>,
peer_addr: SocketAddr,
session: Arc<AesSession>,
buffer_pool: Arc<BufferPool>,
}
impl UdpTransport {
pub async fn bind(local_addr: &str) -> IoResult<Self> {
let socket = UdpSocket::bind(local_addr).await?;
socket.set_broadcast(false)?;
let peer_addr = "0.0.0.0:0"
.parse()
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
let session = AesSession::from_shared_secret(&[0u8; 32]).map_err(io::Error::other)?;
Ok(Self {
socket: Arc::new(socket),
peer_addr,
session: Arc::new(session),
buffer_pool: Arc::new(BufferPool::new(65536, 16, 256)),
})
}
pub async fn connect(&mut self, peer_addr: SocketAddr, session: AesSession) {
self.peer_addr = peer_addr;
self.session = Arc::new(session);
}
#[inline]
pub async fn send(&self, data: &[u8]) -> IoResult<usize> {
let encrypted = self.session.encrypt(&[], data).map_err(io::Error::other)?;
self.socket.send_to(&encrypted, self.peer_addr).await
}
#[inline]
pub async fn send_zero_copy(&self, data: &[u8]) -> IoResult<usize> {
let mut buf = Vec::with_capacity(data.len() + 16);
buf.extend_from_slice(data);
self.session
.encrypt_in_place(&[], &mut buf)
.map_err(io::Error::other)?;
self.socket.send_to(&buf, self.peer_addr).await
}
#[inline]
pub async fn recv(&self) -> IoResult<(Vec<u8>, SocketAddr)> {
let mut buf = self.buffer_pool.acquire();
buf.resize(65536, 0);
let (len, addr) = self.socket.recv_from(&mut buf).await?;
let decrypted = self
.session
.decrypt(&[], &buf[..len])
.map_err(io::Error::other)?;
Ok((decrypted, addr))
}
#[inline]
pub async fn send_batch(&self, packets: &[&[u8]]) -> IoResult<usize> {
let mut total = 0;
for packet in packets {
total += self.send(packet).await?;
}
Ok(total)
}
pub fn socket(&self) -> &Arc<UdpSocket> {
&self.socket
}
pub fn set_pacing_rate(&self, rate_bps: u64) -> IoResult<()> {
#[cfg(not(target_os = "linux"))]
let _ = rate_bps;
#[cfg(target_os = "linux")]
{
use std::os::unix::io::AsRawFd;
let rate_u32 = rate_bps.min(u32::MAX as u64) as u32;
let fd = self.socket.as_ref().as_raw_fd();
let ret = unsafe {
libc::setsockopt(
fd,
libc::SOL_SOCKET,
47, &rate_u32 as *const u32 as *const libc::c_void,
std::mem::size_of::<u32>() as libc::socklen_t,
)
};
if ret != 0 {
return Err(io::Error::last_os_error());
}
}
Ok(())
}
pub fn buffer_stats(&self) -> super::buffer_pool::PoolStats {
self.buffer_pool.stats()
}
}
pub struct UdpHandshakeListener {
socket: Arc<UdpSocket>,
buffer_pool: Arc<BufferPool>,
}
impl UdpHandshakeListener {
pub async fn bind(local_addr: &str) -> IoResult<Self> {
let socket = UdpSocket::bind(local_addr).await?;
socket.set_broadcast(false)?;
Ok(Self {
socket: Arc::new(socket),
buffer_pool: Arc::new(BufferPool::new(65536, 16, 256)),
})
}
pub async fn accept_handshake(&self, server: &HandshakeServer, difficulty: u8) -> IoResult<()> {
let mut buf = self.buffer_pool.acquire();
buf.resize(65536, 0);
loop {
let (len, addr) = self.socket.recv_from(&mut buf).await?;
if len < 1200 {
continue;
}
let client_hello = match borsh::from_slice::<ClientHello>(&buf[..len]) {
Ok(ch) => ch,
Err(_) => {
continue;
}
};
match server.process_client_hello(&client_hello, difficulty, addr.ip()) {
HandshakeResponse::Retry(retry_req) => {
if let Ok(encoded) = borsh::to_vec(&retry_req) {
let _ = self.socket.send_to(&encoded, addr).await;
}
}
HandshakeResponse::Success(server_hello, _session, _early_data) => {
if let Ok(encoded) = borsh::to_vec(&server_hello) {
let _ = self.socket.send_to(&encoded, addr).await;
}
}
HandshakeResponse::Reject(reject) => {
if let Ok(encoded) = borsh::to_vec(&reject) {
let _ = self.socket.send_to(&encoded, addr).await;
}
}
HandshakeResponse::Fail(_) => {
}
}
break;
}
Ok(())
}
}
pub struct PacedSender {
transport: Arc<UdpTransport>,
pacer: Arc<Pacer>,
estimator: Arc<parking_lot::Mutex<bandwidth_estimator::BandwidthEstimator>>,
}
impl PacedSender {
pub fn new(
transport: Arc<UdpTransport>,
pacer: Arc<Pacer>,
estimator: Arc<parking_lot::Mutex<bandwidth_estimator::BandwidthEstimator>>,
) -> Self {
Self {
transport,
pacer,
estimator,
}
}
pub async fn send(&self, data: &[u8]) -> IoResult<usize> {
let bytes = data.len() as u64;
loop {
if self.pacer.try_consume(bytes) {
break;
}
let wait = self.pacer.time_until_available(bytes);
if wait.is_zero() {
break;
}
tokio::time::sleep(wait).await;
}
self.estimator.lock().on_send(bytes);
self.transport.send(data).await
}
pub async fn send_unpaced(&self, data: &[u8]) -> IoResult<usize> {
self.transport.send(data).await
}
pub fn on_ack(&self, sample: bandwidth_estimator::DeliverySample) {
let mut est = self.estimator.lock();
est.on_ack(sample);
let new_rate = est.pacing_rate();
self.pacer.set_rate(new_rate);
}
pub fn set_rate(&self, rate_bps: u64) {
self.pacer.set_rate(rate_bps);
}
pub fn rate(&self) -> u64 {
self.pacer.rate()
}
}
impl std::fmt::Debug for PacedSender {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PacedSender")
.field("rate_bps", &self.pacer.rate())
.field("pacer_enabled", &self.pacer.is_enabled())
.finish()
}
}
pub struct FastSender {
socket: Arc<UdpSocket>,
session: Arc<AesSession>,
peer_addr: SocketAddr,
}
impl FastSender {
pub fn new(socket: Arc<UdpSocket>, session: Arc<AesSession>, peer_addr: SocketAddr) -> Self {
Self {
socket,
session,
peer_addr,
}
}
#[inline]
pub async fn send(&self, data: &[u8]) -> IoResult<usize> {
let mut buf = Vec::with_capacity(data.len() + 16);
buf.extend_from_slice(data);
self.session
.encrypt_in_place(&[], &mut buf)
.map_err(io::Error::other)?;
self.socket.send_to(&buf, self.peer_addr).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_udp_transport_create() {
let transport = UdpTransport::bind("127.0.0.1:0").await.unwrap();
assert_eq!(transport.buffer_stats().pool_size, 16);
}
#[tokio::test]
async fn test_paced_sender_creation() {
let transport = Arc::new(UdpTransport::bind("127.0.0.1:0").await.unwrap());
let pacer = Arc::new(Pacer::new(1_000_000)); let estimator = Arc::new(parking_lot::Mutex::new(
bandwidth_estimator::BandwidthEstimator::new(),
));
let sender = PacedSender::new(transport, pacer, estimator);
assert_eq!(sender.rate(), 1_000_000);
sender.set_rate(2_000_000);
assert_eq!(sender.rate(), 2_000_000);
}
}