use std::{
net::{IpAddr, Ipv6Addr, SocketAddr},
sync::atomic::{AtomicUsize, Ordering},
};
pub use crate::cmsg::{AsPtr, EcnCodepoint, Source, Transmit};
use imp::LastSendError;
use tracing::warn;
mod cmsg;
#[path = "unix.rs"]
mod imp;
pub use imp::{sync, UdpSocket};
pub mod framed;
pub const BATCH_SIZE_CAP: usize = imp::BATCH_SIZE_CAP;
pub const DEFAULT_BATCH_SIZE: usize = imp::DEFAULT_BATCH_SIZE;
#[derive(Debug)]
pub struct UdpState {
max_gso_segments: AtomicUsize,
gro_segments: usize,
}
impl UdpState {
pub fn new() -> Self {
imp::udp_state()
}
#[inline]
pub fn max_gso_segments(&self) -> usize {
self.max_gso_segments.load(Ordering::Relaxed)
}
#[inline]
pub fn gro_segments(&self) -> usize {
self.gro_segments
}
}
impl Default for UdpState {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Copy, Clone)]
pub struct RecvMeta {
pub addr: SocketAddr,
pub len: usize,
pub stride: usize,
pub ecn: Option<EcnCodepoint>,
pub dst_ip: Option<IpAddr>,
pub dst_local_ip: Option<IpAddr>,
pub ifindex: u32,
}
impl Default for RecvMeta {
fn default() -> Self {
Self {
addr: SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 0),
len: 0,
stride: 0,
ecn: None,
dst_ip: None,
dst_local_ip: None,
ifindex: 0,
}
}
}
const IO_ERROR_LOG_INTERVAL: u64 = 60;
fn log_sendmsg_error<B: AsPtr<u8>>(
last_send_error: LastSendError,
err: impl core::fmt::Debug,
transmit: &Transmit<B>,
) {
if last_send_error.should_log() {
warn!(
"sendmsg error: {:?}, Transmit: {{ destination: {:?}, src_ip: {:?}, enc: {:?}, len: {:?}, segment_size: {:?} }}",
err, transmit.dst, transmit.src, transmit.ecn, transmit.contents.len(), transmit.segment_size);
}
}
#[cfg(test)]
mod tests {
use std::net::Ipv4Addr;
use super::*;
#[test]
fn test_create() {
let s = sync::UdpSocket::bind("0.0.0.0:9909");
assert!(s.is_ok());
}
#[test]
fn test_send_recv() {
let saddr = "0.0.0.0:9901".parse().unwrap();
let a = sync::UdpSocket::bind(saddr).unwrap();
let b = sync::UdpSocket::bind("0.0.0.0:0").unwrap();
let buf = b"hello world";
b.send_to(&buf[..], saddr).unwrap();
let mut r = [0; 1024];
a.recv_from(&mut r).unwrap();
assert_eq!(buf[..], r[..11]);
}
#[test]
fn test_send_recv_msg() {
let saddr = "0.0.0.0:9902".parse().unwrap();
let a = sync::UdpSocket::bind(saddr).unwrap();
let b = sync::UdpSocket::bind("0.0.0.0:0").unwrap();
let send_port = b.local_addr().unwrap().port();
let send_addr = b.local_addr().unwrap().ip();
let buf = b"hello world";
let src = Source::Interface(1);
let tr = Transmit::new(saddr, *buf).src_ip(src);
b.send_msg(&UdpState::new(), tr).unwrap();
let mut r = [0; 1024];
let meta = a.recv_msg(&mut r).unwrap();
assert_eq!(buf[..], r[..11]);
assert_eq!(send_port, meta.addr.port());
assert_eq!(meta.ifindex, 1);
assert!(matches!(
meta.dst_local_ip,
Some(addr) if addr == send_addr || addr == IpAddr::V4(Ipv4Addr::LOCALHOST)
));
}
#[test]
fn test_send_recv_msg_ip() {
let saddr = "0.0.0.0:9903".parse().unwrap();
let a = sync::UdpSocket::bind(saddr).unwrap();
let b = sync::UdpSocket::bind("0.0.0.0:0").unwrap();
let send_port = b.local_addr().unwrap().port();
let send_addr = b.local_addr().unwrap().ip();
let buf = b"hello world";
let src = Source::Ip("0.0.0.0".parse().unwrap());
let tr = Transmit::new(saddr, *buf).src_ip(src);
b.send_msg(&UdpState::new(), tr).unwrap();
let mut r = [0; 1024];
let meta = a.recv_msg(&mut r).unwrap();
assert_eq!(buf[..], r[..11]);
assert_eq!(send_port, meta.addr.port());
assert_eq!(meta.ifindex, 1);
assert!(matches!(
meta.dst_local_ip,
Some(addr) if addr == send_addr || addr == IpAddr::V4(Ipv4Addr::LOCALHOST)
));
}
}