#![cfg_attr(not(unix), allow(dead_code))]
use crate::node::session_wire::FSP_HEADER_SIZE;
use crate::node::wire::ESTABLISHED_HEADER_SIZE;
use crate::transport::udp::socket::AsyncUdpSocket;
#[cfg(not(target_os = "macos"))]
use crossbeam_channel::{Receiver, SendError, Sender, TrySendError, bounded};
use ring::aead::{Aad, LessSafeKey, Nonce};
#[cfg(target_os = "macos")]
use std::collections::{BTreeMap, HashMap, VecDeque};
use std::net::SocketAddr;
#[cfg(unix)]
use std::os::unix::io::AsRawFd;
use std::sync::Arc;
use std::sync::OnceLock;
#[cfg(target_os = "macos")]
use std::sync::{Condvar, Mutex};
use tracing::{debug, trace, warn};
pub(crate) struct FmpSendJob {
pub cipher: LessSafeKey,
pub counter: u64,
pub wire_buf: Vec<u8>,
pub fsp_seal: Option<FspSealJob>,
pub socket: AsyncUdpSocket,
pub dest_addr: SocketAddr,
#[cfg(any(target_os = "linux", target_os = "macos"))]
pub connected_socket:
Option<std::sync::Arc<crate::transport::udp::connected_peer::ConnectedPeerSocket>>,
pub drop_on_backpressure: bool,
pub queued_at: Option<std::time::Instant>,
}
pub(crate) struct FspSealJob {
pub cipher: LessSafeKey,
pub counter: u64,
pub aad_offset: usize,
pub plaintext_offset: usize,
}
struct QueuedFmpSendJob {
job: FmpSendJob,
#[cfg(target_os = "macos")]
macos_flow: Option<Arc<MacSequencedSendFlow>>,
#[cfg(target_os = "macos")]
macos_seq: u64,
}
impl QueuedFmpSendJob {
#[allow(dead_code)] fn direct(job: FmpSendJob) -> Self {
Self {
job,
#[cfg(target_os = "macos")]
macos_flow: None,
#[cfg(target_os = "macos")]
macos_seq: 0,
}
}
#[cfg(target_os = "macos")]
fn macos_sequenced(job: FmpSendJob, macos_flow: Arc<MacSequencedSendFlow>) -> Self {
let macos_seq = macos_flow.reserve_seq();
Self {
job,
macos_flow: Some(macos_flow),
macos_seq,
}
}
}
const WORKER_CHANNEL_CAP: usize = 1024;
#[cfg(target_os = "macos")]
struct MacWorkerSender {
inner: Arc<MacWorkerQueueInner>,
}
#[cfg(target_os = "macos")]
struct MacWorkerReceiver {
inner: Arc<MacWorkerQueueInner>,
}
#[cfg(target_os = "macos")]
struct MacWorkerQueueInner {
state: Mutex<MacWorkerQueueState>,
not_empty: Condvar,
not_full: Condvar,
cap: usize,
}
#[cfg(target_os = "macos")]
#[derive(Default)]
struct MacWorkerQueueState {
queue: VecDeque<QueuedFmpSendJob>,
waiting: bool,
closed: bool,
}
#[cfg(target_os = "macos")]
enum MacWorkerTryPushError {
Full(Box<QueuedFmpSendJob>),
Closed,
}
#[cfg(target_os = "macos")]
struct MacWorkerPushError;
#[cfg(target_os = "macos")]
fn mac_worker_channel(cap: usize) -> (MacWorkerSender, MacWorkerReceiver) {
let inner = Arc::new(MacWorkerQueueInner {
state: Mutex::new(MacWorkerQueueState {
queue: VecDeque::with_capacity(cap),
waiting: false,
closed: false,
}),
not_empty: Condvar::new(),
not_full: Condvar::new(),
cap,
});
(
MacWorkerSender {
inner: Arc::clone(&inner),
},
MacWorkerReceiver { inner },
)
}
#[cfg(target_os = "macos")]
impl MacWorkerSender {
fn try_push(&self, job: QueuedFmpSendJob) -> Result<(), MacWorkerTryPushError> {
let mut state = self
.inner
.state
.lock()
.expect("encrypt worker queue poisoned");
if state.closed {
drop(job);
return Err(MacWorkerTryPushError::Closed);
}
if state.queue.len() >= self.inner.cap {
return Err(MacWorkerTryPushError::Full(Box::new(job)));
}
let was_empty = state.queue.is_empty();
let should_notify = was_empty && state.waiting;
state.queue.push_back(job);
drop(state);
if should_notify {
self.inner.not_empty.notify_one();
}
Ok(())
}
fn push_blocking(&self, job: QueuedFmpSendJob) -> Result<(), MacWorkerPushError> {
let mut state = self
.inner
.state
.lock()
.expect("encrypt worker queue poisoned");
loop {
if state.closed {
drop(job);
return Err(MacWorkerPushError);
}
if state.queue.len() < self.inner.cap {
let was_empty = state.queue.is_empty();
let should_notify = was_empty && state.waiting;
state.queue.push_back(job);
drop(state);
if should_notify {
self.inner.not_empty.notify_one();
}
return Ok(());
}
state = self
.inner
.not_full
.wait(state)
.expect("encrypt worker queue poisoned");
}
}
}
#[cfg(target_os = "macos")]
impl Drop for MacWorkerSender {
fn drop(&mut self) {
let mut state = self
.inner
.state
.lock()
.expect("encrypt worker queue poisoned");
state.closed = true;
drop(state);
self.inner.not_empty.notify_all();
self.inner.not_full.notify_all();
}
}
#[cfg(target_os = "macos")]
impl MacWorkerReceiver {
fn recv_batch(&self, batch: &mut Vec<QueuedFmpSendJob>, max: usize) -> bool {
debug_assert!(batch.is_empty());
let mut state = self
.inner
.state
.lock()
.expect("encrypt worker queue poisoned");
loop {
while let Some(job) = state.queue.pop_front() {
batch.push(job);
if batch.len() >= max {
break;
}
}
if !batch.is_empty() {
self.inner.not_full.notify_one();
return true;
}
if state.closed {
return false;
}
state.waiting = true;
state = self
.inner
.not_empty
.wait(state)
.expect("encrypt worker queue poisoned");
state.waiting = false;
}
}
}
#[cfg(target_os = "macos")]
type WorkerSender = MacWorkerSender;
#[cfg(not(target_os = "macos"))]
type WorkerSender = Sender<QueuedFmpSendJob>;
#[derive(Clone)]
pub(crate) struct EncryptWorkerPool {
senders: Arc<[WorkerSender]>,
#[cfg(target_os = "macos")]
macos_senders: Arc<MacSequencedSendFlows>,
#[cfg(target_os = "macos")]
next_worker: Arc<std::sync::atomic::AtomicUsize>,
}
impl EncryptWorkerPool {
pub fn spawn(n: usize) -> Self {
let n = n.max(1);
let mut senders = Vec::with_capacity(n);
for i in 0..n {
#[cfg(target_os = "macos")]
{
let (tx, rx) = mac_worker_channel(WORKER_CHANNEL_CAP);
std::thread::Builder::new()
.name(format!("fips-encrypt-{i}"))
.spawn(move || run_worker_macos(i, rx))
.expect("failed to spawn fips-encrypt OS thread");
senders.push(tx);
}
#[cfg(not(target_os = "macos"))]
{
let (tx, rx) = bounded::<QueuedFmpSendJob>(WORKER_CHANNEL_CAP);
std::thread::Builder::new()
.name(format!("fips-encrypt-{i}"))
.spawn(move || run_worker(i, rx))
.expect("failed to spawn fips-encrypt OS thread");
senders.push(tx);
}
}
Self {
senders: senders.into(),
#[cfg(target_os = "macos")]
macos_senders: Arc::new(MacSequencedSendFlows::default()),
#[cfg(target_os = "macos")]
next_worker: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
}
}
pub fn dispatch(&self, job: FmpSendJob) {
if self.senders.is_empty() {
debug!("EncryptWorkerPool has no workers; dropping job");
return;
}
let (idx, job) = self.prepare_dispatch(job);
self.dispatch_to_worker(idx, job);
}
#[cfg(target_os = "macos")]
fn prepare_dispatch(&self, job: FmpSendJob) -> (usize, QueuedFmpSendJob) {
if !macos_ordered_sender_enabled() {
use std::hash::{Hash, Hasher};
let key = MacSendFlowKey {
socket_fd: job.socket.as_raw_fd(),
connected_fd: job.connected_socket.as_ref().map(|s| s.as_raw_fd()),
dest_addr: job.dest_addr,
};
let mut h = std::collections::hash_map::DefaultHasher::new();
key.hash(&mut h);
let idx = (h.finish() as usize) % self.senders.len();
return (idx, QueuedFmpSendJob::direct(job));
}
let flow = self.macos_senders.flow_for(&job);
let ticket = self
.next_worker
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
/ macos_worker_stride();
let idx = ticket % self.senders.len();
(idx, QueuedFmpSendJob::macos_sequenced(job, flow))
}
#[cfg(not(target_os = "macos"))]
fn prepare_dispatch(&self, job: FmpSendJob) -> (usize, QueuedFmpSendJob) {
use std::hash::{Hash, Hasher};
let mut h = std::collections::hash_map::DefaultHasher::new();
job.dest_addr.hash(&mut h);
let idx = (h.finish() as usize) % self.senders.len();
(idx, QueuedFmpSendJob::direct(job))
}
#[cfg(target_os = "macos")]
fn dispatch_to_worker(&self, idx: usize, job: QueuedFmpSendJob) {
match self.senders[idx].try_push(job) {
Ok(()) => {}
Err(MacWorkerTryPushError::Full(job)) => {
static FULL_COUNT: std::sync::atomic::AtomicU64 =
std::sync::atomic::AtomicU64::new(0);
let n = FULL_COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
if n < 8 || n.is_multiple_of(10000) {
warn!(
worker = idx,
full_events = n + 1,
"EncryptWorker channel full; applying outbound backpressure"
);
}
if let Err(MacWorkerPushError) = self.senders[idx].push_blocking(*job) {
debug!(worker = idx, "EncryptWorker thread gone; dropping job");
}
}
Err(MacWorkerTryPushError::Closed) => {
debug!(worker = idx, "EncryptWorker thread gone; dropping job");
}
}
}
#[cfg(not(target_os = "macos"))]
fn dispatch_to_worker(&self, idx: usize, job: QueuedFmpSendJob) {
match self.senders[idx].try_send(job) {
Ok(()) => {}
Err(TrySendError::Full(job)) => {
static FULL_COUNT: std::sync::atomic::AtomicU64 =
std::sync::atomic::AtomicU64::new(0);
let n = FULL_COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
if n < 8 || n.is_multiple_of(10000) {
warn!(
worker = idx,
full_events = n + 1,
"EncryptWorker channel full; applying outbound backpressure"
);
}
if let Err(SendError(_)) = self.senders[idx].send(job) {
debug!(worker = idx, "EncryptWorker thread gone; dropping job");
}
}
Err(TrySendError::Disconnected(_)) => {
debug!(worker = idx, "EncryptWorker thread gone; dropping job");
}
}
}
}
#[cfg(target_os = "macos")]
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
struct MacSendFlowKey {
socket_fd: std::os::unix::io::RawFd,
connected_fd: Option<std::os::unix::io::RawFd>,
dest_addr: SocketAddr,
}
#[cfg(target_os = "macos")]
#[derive(Default)]
struct MacSequencedSendFlows {
flows: Mutex<HashMap<MacSendFlowKey, Arc<MacSequencedSendFlow>>>,
last_prune_ms: std::sync::atomic::AtomicU64,
}
#[cfg(target_os = "macos")]
impl MacSequencedSendFlows {
fn flow_for(&self, job: &FmpSendJob) -> Arc<MacSequencedSendFlow> {
let now_ms = mac_now_ms();
let key = MacSendFlowKey {
socket_fd: job.socket.as_raw_fd(),
connected_fd: job.connected_socket.as_ref().map(|s| s.as_raw_fd()),
dest_addr: job.dest_addr,
};
let mut flows = self.flows.lock().expect("mac send flow map poisoned");
self.prune_idle_locked(&mut flows, now_ms);
if let Some(flow) = flows.get(&key) {
flow.mark_used(now_ms);
return Arc::clone(flow);
}
let flow = MacSequencedSendFlow::spawn(
key,
job.socket.clone(),
job.connected_socket.clone(),
job.dest_addr,
now_ms,
);
flows.insert(key, Arc::clone(&flow));
flow
}
fn prune_idle_locked(
&self,
flows: &mut HashMap<MacSendFlowKey, Arc<MacSequencedSendFlow>>,
now_ms: u64,
) {
let last = self
.last_prune_ms
.load(std::sync::atomic::Ordering::Relaxed);
if now_ms.saturating_sub(last) < 10_000 {
return;
}
if self
.last_prune_ms
.compare_exchange(
last,
now_ms,
std::sync::atomic::Ordering::Relaxed,
std::sync::atomic::Ordering::Relaxed,
)
.is_err()
{
return;
}
let idle_ms = mac_send_flow_idle_ms();
flows.retain(|_, flow| {
if flow.is_idle(now_ms, idle_ms) {
flow.close();
false
} else {
true
}
});
}
}
#[cfg(target_os = "macos")]
fn macos_ordered_sender_enabled() -> bool {
static VALUE: OnceLock<bool> = OnceLock::new();
*VALUE.get_or_init(|| {
std::env::var("FIPS_MACOS_ORDERED_SENDER")
.ok()
.map(|raw| {
!matches!(
raw.trim().to_ascii_lowercase().as_str(),
"0" | "false" | "no" | "off"
)
})
.unwrap_or(false)
})
}
#[cfg(target_os = "macos")]
fn macos_worker_stride() -> usize {
static VALUE: OnceLock<usize> = OnceLock::new();
*VALUE.get_or_init(|| {
std::env::var("FIPS_MACOS_WORKER_STRIDE")
.ok()
.and_then(|raw| raw.trim().parse::<usize>().ok())
.unwrap_or(1)
.clamp(1, 64)
})
}
#[cfg(target_os = "macos")]
fn macos_worker_batch_size() -> usize {
static VALUE: OnceLock<usize> = OnceLock::new();
*VALUE.get_or_init(|| {
std::env::var("FIPS_MACOS_WORKER_BATCH")
.ok()
.and_then(|raw| raw.trim().parse::<usize>().ok())
.unwrap_or(8)
.clamp(1, 64)
})
}
#[cfg(target_os = "macos")]
fn mac_send_flow_idle_ms() -> u64 {
static VALUE: OnceLock<u64> = OnceLock::new();
*VALUE.get_or_init(|| {
std::env::var("FIPS_MACOS_SEND_FLOW_IDLE_MS")
.ok()
.and_then(|raw| raw.trim().parse::<u64>().ok())
.unwrap_or(120_000)
.max(10_000)
})
}
#[cfg(target_os = "macos")]
fn mac_now_ms() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|duration| duration.as_millis() as u64)
.unwrap_or(0)
}
#[cfg(target_os = "macos")]
struct MacSequencedSendFlow {
key: MacSendFlowKey,
socket: AsyncUdpSocket,
connected_socket:
Option<std::sync::Arc<crate::transport::udp::connected_peer::ConnectedPeerSocket>>,
dest_addr: SocketAddr,
next_seq: std::sync::atomic::AtomicU64,
last_used_ms: std::sync::atomic::AtomicU64,
state: Mutex<MacSendFlowState>,
ready_cv: Condvar,
space_cv: Condvar,
}
#[cfg(target_os = "macos")]
#[derive(Default)]
struct MacSendFlowState {
next_send_seq: u64,
pending: BTreeMap<u64, MacSendItem>,
closed: bool,
}
#[cfg(target_os = "macos")]
struct MacCompletionGroup {
flow: Arc<MacSequencedSendFlow>,
items: Vec<(u64, MacSendItem)>,
}
#[cfg(target_os = "macos")]
enum MacSendItem {
Packet {
packet: Vec<u8>,
drop_on_backpressure: bool,
},
Skip,
}
#[cfg(target_os = "macos")]
impl MacSequencedSendFlow {
fn spawn(
key: MacSendFlowKey,
socket: AsyncUdpSocket,
connected_socket: Option<
std::sync::Arc<crate::transport::udp::connected_peer::ConnectedPeerSocket>,
>,
dest_addr: SocketAddr,
now_ms: u64,
) -> Arc<Self> {
let flow = Arc::new(Self {
key,
socket,
connected_socket,
dest_addr,
next_seq: std::sync::atomic::AtomicU64::new(0),
last_used_ms: std::sync::atomic::AtomicU64::new(now_ms),
state: Mutex::new(MacSendFlowState::default()),
ready_cv: Condvar::new(),
space_cv: Condvar::new(),
});
let thread_flow = Arc::clone(&flow);
std::thread::Builder::new()
.name(format!("fips-mac-send-{}", key.socket_fd))
.spawn(move || thread_flow.run())
.expect("failed to spawn fips macOS send thread");
flow
}
fn reserve_seq(&self) -> u64 {
self.next_seq
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
}
fn mark_used(&self, now_ms: u64) {
self.last_used_ms
.store(now_ms, std::sync::atomic::Ordering::Relaxed);
}
fn is_idle(&self, now_ms: u64, idle_ms: u64) -> bool {
let last_used = self.last_used_ms.load(std::sync::atomic::Ordering::Relaxed);
if now_ms.saturating_sub(last_used) < idle_ms {
return false;
}
let state = self.state.lock().expect("mac send flow state poisoned");
state.pending.is_empty()
&& state.next_send_seq == self.next_seq.load(std::sync::atomic::Ordering::Relaxed)
}
fn close(&self) {
let mut state = self.state.lock().expect("mac send flow state poisoned");
state.closed = true;
drop(state);
self.ready_cv.notify_one();
self.space_cv.notify_all();
}
fn complete_many(&self, items: Vec<(u64, MacSendItem)>) {
const PENDING_CAP: usize = 4096;
if items.is_empty() {
return;
}
let mut state = self.state.lock().expect("mac send flow state poisoned");
if state.closed {
return;
}
let mut wakes_sender = false;
for (seq, item) in items {
while state.pending.len() >= PENDING_CAP && seq != state.next_send_seq && !wakes_sender
{
state = self
.space_cv
.wait(state)
.expect("mac send flow state poisoned");
}
if seq == state.next_send_seq {
wakes_sender = true;
}
state.pending.insert(seq, item);
}
drop(state);
if wakes_sender {
self.ready_cv.notify_one();
}
}
fn run(self: Arc<Self>) {
trace!(
socket_fd = self.key.socket_fd,
connected_fd = ?self.key.connected_fd,
dest = %self.dest_addr,
"macOS ordered UDP sender starting"
);
let (fd, connected) = match self.connected_socket.as_ref() {
Some(socket) => (socket.as_raw_fd(), true),
None => (self.socket.as_raw_fd(), false),
};
let mut backpressure = SendBackpressurePacer::default();
let mut rate_pacer = MacSendRatePacer::default();
loop {
let item = {
let mut state = self.state.lock().expect("mac send flow state poisoned");
loop {
let next = state.next_send_seq;
if let Some(item) = state.pending.remove(&next) {
state.next_send_seq = next.wrapping_add(1);
self.space_cv.notify_one();
break item;
}
if state.closed {
return;
}
state = self
.ready_cv
.wait(state)
.expect("mac send flow state poisoned");
}
};
match item {
MacSendItem::Packet {
packet,
drop_on_backpressure,
} => {
let _t = crate::perf_profile::Timer::start(crate::perf_profile::Stage::UdpSend);
rate_pacer.pace(packet.len());
if let Err(err) = send_one_with_backpressure(
fd,
connected,
&self.dest_addr,
&packet,
&mut backpressure,
drop_on_backpressure,
) {
debug!(
socket_fd = self.key.socket_fd,
connected_fd = ?self.key.connected_fd,
dest = %self.dest_addr,
error = %err,
"macOS ordered UDP send failed"
);
}
}
MacSendItem::Skip => {}
}
}
}
}
#[cfg(target_os = "macos")]
fn push_mac_completion(
groups: &mut Vec<MacCompletionGroup>,
flow: Arc<MacSequencedSendFlow>,
seq: u64,
item: MacSendItem,
) {
if let Some(group) = groups
.iter_mut()
.find(|group| Arc::ptr_eq(&group.flow, &flow))
{
group.items.push((seq, item));
} else {
groups.push(MacCompletionGroup {
flow,
items: vec![(seq, item)],
});
}
}
#[cfg(not(target_os = "macos"))]
fn run_worker(idx: usize, rx: Receiver<QueuedFmpSendJob>) {
trace!(worker = idx, "FMP encrypt worker thread starting");
const BATCH_SIZE: usize = 32;
let mut batch: Vec<QueuedFmpSendJob> = Vec::with_capacity(BATCH_SIZE);
loop {
let first = match rx.recv() {
Ok(j) => j,
Err(_) => break, };
batch.push(first);
while batch.len() < BATCH_SIZE {
match rx.try_recv() {
Ok(j) => batch.push(j),
Err(_) => break,
}
}
if let Err(err) = flush_batch_sync(&mut batch) {
debug!(worker = idx, error = %err, "FMP encrypt worker batch flush failed");
}
}
trace!(worker = idx, "FMP encrypt worker thread exiting");
}
#[cfg(target_os = "macos")]
fn run_worker_macos(idx: usize, rx: MacWorkerReceiver) {
trace!(worker = idx, "FMP encrypt worker thread starting");
let batch_size = macos_worker_batch_size();
let mut batch: Vec<QueuedFmpSendJob> = Vec::with_capacity(batch_size);
while rx.recv_batch(&mut batch, batch_size) {
if let Err(err) = flush_batch_sync(&mut batch) {
debug!(worker = idx, error = %err, "FMP encrypt worker batch flush failed");
batch.clear();
}
}
trace!(worker = idx, "FMP encrypt worker thread exiting");
}
fn flush_batch_sync(
batch: &mut Vec<QueuedFmpSendJob>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
if batch.is_empty() {
return Ok(());
}
let _t = crate::perf_profile::Timer::start(crate::perf_profile::Stage::FmpEncrypt);
#[cfg(unix)]
struct EncryptedGroup {
socket: AsyncUdpSocket,
#[cfg(any(target_os = "linux", target_os = "macos"))]
connected_socket:
Option<std::sync::Arc<crate::transport::udp::connected_peer::ConnectedPeerSocket>>,
dest_addr: SocketAddr,
wire_packets: Vec<Vec<u8>>,
drop_on_backpressure: bool,
}
#[cfg(unix)]
let mut groups: Vec<EncryptedGroup> = Vec::with_capacity(1);
#[cfg(target_os = "macos")]
let mut macos_completions: Vec<MacCompletionGroup> = Vec::with_capacity(1);
for queued in batch.drain(..) {
#[cfg(target_os = "macos")]
let QueuedFmpSendJob {
job,
macos_flow,
macos_seq,
} = queued;
#[cfg(not(target_os = "macos"))]
let QueuedFmpSendJob { job } = queued;
let FmpSendJob {
cipher,
counter,
mut wire_buf,
fsp_seal,
socket,
dest_addr,
#[cfg(any(target_os = "linux", target_os = "macos"))]
connected_socket,
drop_on_backpressure,
queued_at,
} = job;
crate::perf_profile::record_since(
crate::perf_profile::Stage::FmpWorkerQueueWait,
queued_at,
);
if let Some(fsp) = fsp_seal {
if fsp.aad_offset + FSP_HEADER_SIZE > fsp.plaintext_offset
|| fsp.plaintext_offset > wire_buf.len()
{
#[cfg(target_os = "macos")]
if let Some(flow) = macos_flow.as_ref() {
push_mac_completion(
&mut macos_completions,
Arc::clone(flow),
macos_seq,
MacSendItem::Skip,
);
}
continue;
}
let mut nonce_bytes = [0u8; 12];
nonce_bytes[4..12].copy_from_slice(&fsp.counter.to_le_bytes());
let nonce = Nonce::assume_unique_for_key(nonce_bytes);
let (prefix, plaintext_slice) = wire_buf.split_at_mut(fsp.plaintext_offset);
let aad = &prefix[fsp.aad_offset..fsp.aad_offset + FSP_HEADER_SIZE];
let tag =
match fsp
.cipher
.seal_in_place_separate_tag(nonce, Aad::from(aad), plaintext_slice)
{
Ok(tag) => tag,
Err(_) => {
#[cfg(target_os = "macos")]
if let Some(flow) = macos_flow.as_ref() {
push_mac_completion(
&mut macos_completions,
Arc::clone(flow),
macos_seq,
MacSendItem::Skip,
);
}
continue;
}
};
wire_buf.extend_from_slice(tag.as_ref());
}
let mut nonce_bytes = [0u8; 12];
nonce_bytes[4..12].copy_from_slice(&counter.to_le_bytes());
let nonce = Nonce::assume_unique_for_key(nonce_bytes);
let (header_slice, plaintext_slice) = wire_buf.split_at_mut(ESTABLISHED_HEADER_SIZE);
let tag = match cipher.seal_in_place_separate_tag(
nonce,
Aad::from(&*header_slice),
plaintext_slice,
) {
Ok(tag) => tag,
Err(_) => {
#[cfg(target_os = "macos")]
if let Some(flow) = macos_flow {
push_mac_completion(&mut macos_completions, flow, macos_seq, MacSendItem::Skip);
}
continue;
}
};
wire_buf.extend_from_slice(tag.as_ref());
#[cfg(target_os = "macos")]
if let Some(flow) = macos_flow {
push_mac_completion(
&mut macos_completions,
flow,
macos_seq,
MacSendItem::Packet {
packet: wire_buf,
drop_on_backpressure,
},
);
continue;
}
#[cfg(unix)]
{
let socket_fd = socket.as_raw_fd();
#[cfg(any(target_os = "linux", target_os = "macos"))]
let connected_fd = connected_socket.as_ref().map(|s| s.as_raw_fd());
let matched = groups.iter_mut().position(|g| {
if g.dest_addr != dest_addr {
return false;
}
if g.socket.as_raw_fd() != socket_fd {
return false;
}
#[cfg(any(target_os = "linux", target_os = "macos"))]
{
if g.connected_socket.as_ref().map(|s| s.as_raw_fd()) != connected_fd {
return false;
}
}
true
});
if let Some(idx) = matched {
groups[idx].wire_packets.push(wire_buf);
groups[idx].drop_on_backpressure &= drop_on_backpressure;
} else {
groups.push(EncryptedGroup {
socket,
#[cfg(any(target_os = "linux", target_os = "macos"))]
connected_socket,
dest_addr,
wire_packets: vec![wire_buf],
drop_on_backpressure,
});
}
}
#[cfg(not(unix))]
{
let _ = (socket, dest_addr, wire_buf, drop_on_backpressure);
}
}
#[cfg(target_os = "macos")]
for group in macos_completions {
group.flow.complete_many(group.items);
}
drop(_t);
let _t2 = crate::perf_profile::Timer::start(crate::perf_profile::Stage::UdpSend);
#[cfg(target_os = "linux")]
for group in groups {
let mut backpressure = SendBackpressurePacer::default();
let EncryptedGroup {
socket,
connected_socket,
dest_addr,
wire_packets,
drop_on_backpressure: _,
} = group;
let (fd, connected) = match connected_socket.as_ref() {
Some(s) => (s.as_raw_fd(), true),
None => (socket.as_raw_fd(), false),
};
if !GSO_DISABLED.load(std::sync::atomic::Ordering::Relaxed)
&& gso_eligible_sizes(&wire_packets)
{
match send_batch_gso(fd, &wire_packets, dest_addr, connected) {
Ok(()) => {
record_udp_send_path(connected, wire_packets.len() as u64);
continue;
}
Err(err)
if err.kind() == std::io::ErrorKind::InvalidInput
|| err.raw_os_error() == Some(libc::EOPNOTSUPP)
|| err.raw_os_error() == Some(libc::ENOPROTOOPT) =>
{
GSO_DISABLED.store(true, std::sync::atomic::Ordering::Relaxed);
warn!(
error = %err,
"UDP_GSO refused by kernel; falling back to sendmmsg for life of process"
);
}
Err(err) if is_send_backpressure(&err) => {
}
Err(err) => {
return Err(format!("sendmsg+UDP_GSO failed: {err}").into());
}
}
}
let mut sent = 0usize;
while sent < wire_packets.len() {
let n = match send_batch_raw(fd, &wire_packets[sent..], dest_addr, connected) {
Ok(n) => n,
Err(err) if is_send_backpressure(&err) => {
backpressure.pause(&err);
continue;
}
Err(err) => {
return Err(format!("sendmmsg(2) failed: {err}").into());
}
};
if n == 0 {
break;
}
sent += n;
backpressure.record_success();
record_udp_send_path(connected, n as u64);
}
}
#[cfg(all(unix, not(target_os = "linux")))]
for group in groups {
let mut backpressure = SendBackpressurePacer::default();
#[cfg(target_os = "macos")]
let (fd, connected) = match group.connected_socket.as_ref() {
Some(s) => (s.as_raw_fd(), true),
None => (group.socket.as_raw_fd(), false),
};
#[cfg(not(target_os = "macos"))]
let (fd, connected) = (group.socket.as_raw_fd(), false);
for data in &group.wire_packets {
if let Err(err) = send_one_with_backpressure(
fd,
connected,
&group.dest_addr,
data,
&mut backpressure,
group.drop_on_backpressure,
) {
if group.drop_on_backpressure && is_send_backpressure(&err) {
continue;
}
return Err(format!("sendto failed: {err}").into());
}
}
}
Ok(())
}
#[cfg(all(test, unix))]
fn flush_direct_batch_sync(
batch: &mut Vec<FmpSendJob>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let mut queued: Vec<QueuedFmpSendJob> = batch.drain(..).map(QueuedFmpSendJob::direct).collect();
flush_batch_sync(&mut queued)
}
fn record_udp_send_path(connected: bool, count: u64) {
let event = if connected {
crate::perf_profile::Event::UdpSendConnected
} else {
crate::perf_profile::Event::UdpSendWildcard
};
crate::perf_profile::record_event_count(event, count);
}
fn is_send_backpressure(err: &std::io::Error) -> bool {
err.kind() == std::io::ErrorKind::WouldBlock
|| err.raw_os_error().is_some_and(raw_send_backpressure_code)
}
#[cfg(unix)]
fn raw_send_backpressure_code(code: i32) -> bool {
code == libc::ENOBUFS || code == libc::ENOMEM
}
#[cfg(windows)]
fn raw_send_backpressure_code(code: i32) -> bool {
const WSAENOBUFS: i32 = 10055;
const ERROR_NOT_ENOUGH_MEMORY: i32 = 8;
code == WSAENOBUFS || code == ERROR_NOT_ENOUGH_MEMORY
}
#[cfg(not(any(unix, windows)))]
fn raw_send_backpressure_code(_code: i32) -> bool {
false
}
#[derive(Default)]
struct SendBackpressurePacer {
consecutive_full: u32,
full_since_sleep: u32,
}
impl SendBackpressurePacer {
fn record_success(&mut self) {
self.consecutive_full = 0;
self.full_since_sleep = 0;
}
fn pause(&mut self, err: &std::io::Error) -> bool {
crate::perf_profile::record_event(crate::perf_profile::Event::UdpSendBackpressure);
if err.kind() == std::io::ErrorKind::WouldBlock {
self.consecutive_full = 0;
self.full_since_sleep = 0;
std::thread::yield_now();
return false;
}
static SEND_BACKPRESSURE_COUNT: std::sync::atomic::AtomicU64 =
std::sync::atomic::AtomicU64::new(0);
let n = SEND_BACKPRESSURE_COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
if n < 8 || n.is_multiple_of(100_000) {
warn!(
error = %err,
events = n + 1,
"UDP send queue full; applying kernel backpressure"
);
}
self.consecutive_full = self.consecutive_full.saturating_add(1);
self.full_since_sleep = self.full_since_sleep.saturating_add(1);
let drop_after = send_backpressure_drop_after();
if drop_after > 0 && self.consecutive_full >= drop_after {
self.consecutive_full = 0;
self.full_since_sleep = 0;
return true;
}
let sleep_after = send_backpressure_sleep_after();
if sleep_after > 0 && self.full_since_sleep >= sleep_after {
self.full_since_sleep = 0;
crate::perf_profile::record_event(crate::perf_profile::Event::UdpSendBackpressureSleep);
std::thread::sleep(std::time::Duration::from_micros(
send_backpressure_sleep_micros(),
));
} else {
std::thread::yield_now();
}
false
}
}
fn send_backpressure_sleep_after() -> u32 {
static VALUE: OnceLock<u32> = OnceLock::new();
*VALUE.get_or_init(|| {
std::env::var("FIPS_SEND_BACKPRESSURE_SLEEP_AFTER")
.ok()
.and_then(|raw| raw.trim().parse::<u32>().ok())
.unwrap_or(default_send_backpressure_sleep_after())
})
}
fn send_backpressure_sleep_micros() -> u64 {
static VALUE: OnceLock<u64> = OnceLock::new();
*VALUE.get_or_init(|| {
std::env::var("FIPS_SEND_BACKPRESSURE_SLEEP_MICROS")
.ok()
.and_then(|raw| raw.trim().parse::<u64>().ok())
.unwrap_or(default_send_backpressure_sleep_micros())
.max(1)
})
}
fn send_backpressure_drop_after() -> u32 {
static VALUE: OnceLock<u32> = OnceLock::new();
*VALUE.get_or_init(|| {
std::env::var("FIPS_SEND_BACKPRESSURE_DROP_AFTER")
.ok()
.and_then(|raw| raw.trim().parse::<u32>().ok())
.unwrap_or(default_send_backpressure_drop_after())
})
}
#[cfg(target_os = "macos")]
fn default_send_backpressure_sleep_after() -> u32 {
4
}
#[cfg(not(target_os = "macos"))]
fn default_send_backpressure_sleep_after() -> u32 {
0
}
#[cfg(target_os = "macos")]
fn default_send_backpressure_sleep_micros() -> u64 {
100
}
#[cfg(not(target_os = "macos"))]
fn default_send_backpressure_sleep_micros() -> u64 {
1
}
#[cfg(target_os = "macos")]
fn default_send_backpressure_drop_after() -> u32 {
256
}
#[cfg(not(target_os = "macos"))]
fn default_send_backpressure_drop_after() -> u32 {
0
}
#[cfg(all(unix, not(target_os = "linux")))]
fn record_udp_send_backpressure_drop(err: &std::io::Error) {
static SEND_BACKPRESSURE_DROP_COUNT: std::sync::atomic::AtomicU64 =
std::sync::atomic::AtomicU64::new(0);
let n = SEND_BACKPRESSURE_DROP_COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
if n < 8 || n.is_multiple_of(100_000) {
warn!(
error = %err,
drops = n + 1,
"UDP send queue full; dropping bulk data packet"
);
}
}
#[cfg(target_os = "macos")]
struct MacSendRatePacer {
bytes_per_sec: f64,
burst_bytes: f64,
credit_bytes: f64,
last: std::time::Instant,
}
#[cfg(target_os = "macos")]
impl Default for MacSendRatePacer {
fn default() -> Self {
let mbps = std::env::var("FIPS_MACOS_SEND_PACE_MBPS")
.ok()
.and_then(|raw| raw.trim().parse::<f64>().ok())
.unwrap_or(0.0);
let bytes_per_sec = if mbps.is_finite() && mbps > 0.0 {
mbps * 1_000_000.0 / 8.0
} else {
0.0
};
let burst_bytes = std::env::var("FIPS_MACOS_SEND_PACE_BURST_BYTES")
.ok()
.and_then(|raw| raw.trim().parse::<f64>().ok())
.filter(|value| value.is_finite() && *value > 0.0)
.unwrap_or(64.0 * 1024.0);
Self {
bytes_per_sec,
burst_bytes,
credit_bytes: burst_bytes,
last: std::time::Instant::now(),
}
}
}
#[cfg(target_os = "macos")]
impl MacSendRatePacer {
fn pace(&mut self, bytes: usize) {
if self.bytes_per_sec <= 0.0 || bytes == 0 {
return;
}
let needed = bytes as f64;
let now = std::time::Instant::now();
let elapsed = now.saturating_duration_since(self.last).as_secs_f64();
self.credit_bytes =
(self.credit_bytes + elapsed * self.bytes_per_sec).min(self.burst_bytes);
self.last = now;
if self.credit_bytes >= needed {
self.credit_bytes -= needed;
return;
}
let wait_secs = (needed - self.credit_bytes) / self.bytes_per_sec;
self.credit_bytes = 0.0;
let deadline = now + std::time::Duration::from_secs_f64(wait_secs);
let spin_window = std::time::Duration::from_micros(75);
loop {
let now = std::time::Instant::now();
if now >= deadline {
self.last = now;
break;
}
let remaining = deadline - now;
if remaining > spin_window {
std::thread::sleep(remaining - spin_window);
} else {
std::hint::spin_loop();
}
}
}
}
#[cfg(target_os = "linux")]
static GSO_DISABLED: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
#[cfg(target_os = "linux")]
fn gso_eligible_sizes(packets: &[Vec<u8>]) -> bool {
if packets.len() < 2 {
return false;
}
let seg = packets[0].len();
if seg == 0 {
return false;
}
for p in &packets[..packets.len() - 1] {
if p.len() != seg {
return false;
}
}
packets[packets.len() - 1].len() <= seg
}
#[cfg(target_os = "linux")]
fn send_batch_gso(
fd: std::os::unix::io::RawFd,
packets: &[Vec<u8>],
dest: SocketAddr,
connected: bool,
) -> std::io::Result<()> {
debug_assert!(!packets.is_empty());
const MAX_BATCH: usize = 64;
let n = packets.len().min(MAX_BATCH);
if n == 0 {
return Ok(());
}
let seg_size = packets[0].len() as u16;
let sa: socket2::SockAddr = dest.into();
let mut iovs: [libc::iovec; MAX_BATCH] = unsafe { std::mem::zeroed() };
for (i, data) in packets[..n].iter().enumerate() {
iovs[i].iov_base = data.as_ptr() as *mut libc::c_void;
iovs[i].iov_len = data.len();
}
let mut storage: libc::sockaddr_storage = unsafe { std::mem::zeroed() };
let sa_len = sa.len();
if !connected {
unsafe {
std::ptr::copy_nonoverlapping(
sa.as_ptr() as *const u8,
&mut storage as *mut _ as *mut u8,
sa_len as usize,
);
}
}
let cmsg_space = unsafe { libc::CMSG_SPACE(std::mem::size_of::<u16>() as u32) as usize };
let mut cmsg_buf = [0u8; 64];
debug_assert!(cmsg_space <= cmsg_buf.len());
let mut msg: libc::msghdr = unsafe { std::mem::zeroed() };
if connected {
msg.msg_name = std::ptr::null_mut();
msg.msg_namelen = 0;
} else {
msg.msg_name = &mut storage as *mut _ as *mut libc::c_void;
msg.msg_namelen = sa_len;
}
msg.msg_iov = iovs.as_mut_ptr();
msg.msg_iovlen = n as _;
msg.msg_control = cmsg_buf.as_mut_ptr() as *mut libc::c_void;
msg.msg_controllen = cmsg_space as _;
unsafe {
let cmsg = libc::CMSG_FIRSTHDR(&msg);
if cmsg.is_null() {
return Err(std::io::Error::other("CMSG_FIRSTHDR returned null"));
}
(*cmsg).cmsg_level = libc::IPPROTO_UDP as _;
(*cmsg).cmsg_type = libc::UDP_SEGMENT as _;
(*cmsg).cmsg_len = libc::CMSG_LEN(std::mem::size_of::<u16>() as u32) as _;
let data = libc::CMSG_DATA(cmsg) as *mut u16;
*data = seg_size;
}
let r = unsafe { libc::sendmsg(fd, &msg, 0) };
if r < 0 {
Err(std::io::Error::last_os_error())
} else {
Ok(())
}
}
#[cfg(target_os = "linux")]
fn send_batch_raw(
fd: std::os::unix::io::RawFd,
packets: &[Vec<u8>],
dest: SocketAddr,
connected: bool,
) -> std::io::Result<usize> {
const MAX_BATCH: usize = 32;
let n = packets.len().min(MAX_BATCH);
if n == 0 {
return Ok(0);
}
let mut iovs: [libc::iovec; MAX_BATCH] = unsafe { std::mem::zeroed() };
let mut storage: libc::sockaddr_storage = unsafe { std::mem::zeroed() };
let mut storage_len: libc::socklen_t = 0;
let mut msgs: [libc::mmsghdr; MAX_BATCH] = unsafe { std::mem::zeroed() };
if !connected {
let sa: socket2::SockAddr = dest.into();
let sa_len = sa.len();
unsafe {
std::ptr::copy_nonoverlapping(
sa.as_ptr() as *const u8,
&mut storage as *mut _ as *mut u8,
sa_len as usize,
);
}
storage_len = sa_len;
}
for i in 0..n {
let data = &packets[i];
iovs[i].iov_base = data.as_ptr() as *mut libc::c_void;
iovs[i].iov_len = data.len();
msgs[i].msg_hdr.msg_iov = &mut iovs[i];
msgs[i].msg_hdr.msg_iovlen = 1 as _;
if connected {
msgs[i].msg_hdr.msg_name = std::ptr::null_mut();
msgs[i].msg_hdr.msg_namelen = 0;
} else {
msgs[i].msg_hdr.msg_name = &mut storage as *mut _ as *mut libc::c_void;
msgs[i].msg_hdr.msg_namelen = storage_len;
}
}
let r = unsafe { libc::sendmmsg(fd, msgs.as_mut_ptr(), n as libc::c_uint, 0) };
if r < 0 {
Err(std::io::Error::last_os_error())
} else {
Ok(r as usize)
}
}
#[cfg(all(test, unix))]
mod unix_tests {
use super::*;
use crate::transport::udp::socket::UdpRawSocket;
use ring::aead::{LessSafeKey, UnboundKey};
use std::net::UdpSocket;
fn test_cipher(byte: u8) -> LessSafeKey {
let key_bytes = [byte; 32];
let unbound =
UnboundKey::new(&ring::aead::CHACHA20_POLY1305, &key_bytes).expect("build key");
LessSafeKey::new(unbound)
}
#[test]
fn fsp_preseal_runs_before_outer_fmp_seal() {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_io()
.build()
.expect("tokio rt");
rt.block_on(async {
let recv = UdpSocket::bind("127.0.0.1:0").expect("bind recv");
recv.set_read_timeout(Some(std::time::Duration::from_millis(500)))
.expect("set_read_timeout");
let recv_addr = recv.local_addr().expect("recv local_addr");
let raw = UdpRawSocket::open("127.0.0.1:0".parse().unwrap(), 1 << 20, 1 << 20)
.expect("open send socket");
let send_sock = raw.into_async().expect("into_async");
let fmp_cipher = test_cipher(1);
let fsp_cipher = test_cipher(2);
let fmp_counter = 11;
let fsp_counter = 22;
let fmp_header = [0xA5; ESTABLISHED_HEADER_SIZE];
let fsp_header = [0x5A; FSP_HEADER_SIZE];
let fsp_plaintext = b"inner payload";
let mut wire_buf = Vec::with_capacity(
ESTABLISHED_HEADER_SIZE
+ FSP_HEADER_SIZE
+ fsp_plaintext.len()
+ crate::noise::TAG_SIZE
+ crate::noise::TAG_SIZE,
);
wire_buf.extend_from_slice(&fmp_header);
let fsp_aad_offset = wire_buf.len();
wire_buf.extend_from_slice(&fsp_header);
let fsp_plaintext_offset = wire_buf.len();
wire_buf.extend_from_slice(fsp_plaintext);
let expected_wire_len = ESTABLISHED_HEADER_SIZE
+ FSP_HEADER_SIZE
+ fsp_plaintext.len()
+ crate::noise::TAG_SIZE
+ crate::noise::TAG_SIZE;
let mut batch = vec![FmpSendJob {
cipher: fmp_cipher.clone(),
counter: fmp_counter,
wire_buf,
fsp_seal: Some(FspSealJob {
cipher: fsp_cipher.clone(),
counter: fsp_counter,
aad_offset: fsp_aad_offset,
plaintext_offset: fsp_plaintext_offset,
}),
socket: send_sock,
dest_addr: recv_addr,
#[cfg(any(target_os = "linux", target_os = "macos"))]
connected_socket: None,
drop_on_backpressure: true,
queued_at: None,
}];
flush_direct_batch_sync(&mut batch).expect("flush ok");
assert!(batch.is_empty(), "flush must drain the batch");
let mut buf = [0u8; 256];
let (len, _) = recv.recv_from(&mut buf).expect("recv");
assert_eq!(len, expected_wire_len);
assert_eq!(&buf[..ESTABLISHED_HEADER_SIZE], &fmp_header);
let outer_plaintext = crate::noise::open(
Some(&fmp_cipher),
fmp_counter,
&fmp_header,
&buf[ESTABLISHED_HEADER_SIZE..len],
)
.expect("outer open");
assert_eq!(&outer_plaintext[..FSP_HEADER_SIZE], &fsp_header);
let inner_plaintext = crate::noise::open(
Some(&fsp_cipher),
fsp_counter,
&outer_plaintext[..FSP_HEADER_SIZE],
&outer_plaintext[FSP_HEADER_SIZE..],
)
.expect("inner open");
assert_eq!(inner_plaintext, fsp_plaintext);
});
}
}
#[cfg(all(test, target_os = "linux"))]
mod tests {
use super::*;
fn pkt(bytes: usize) -> Vec<u8> {
vec![0u8; bytes]
}
#[test]
fn gso_eligible_rejects_single_packet() {
assert!(!gso_eligible_sizes(&[pkt(1500)]));
}
#[test]
fn gso_eligible_accepts_uniform_batch() {
let batch: Vec<_> = (0..18).map(|_| pkt(1500)).collect();
assert!(gso_eligible_sizes(&batch));
}
#[test]
fn gso_eligible_accepts_short_trailer() {
let mut batch: Vec<_> = (0..18).map(|_| pkt(1500)).collect();
batch.push(pkt(900)); assert!(gso_eligible_sizes(&batch));
}
#[test]
fn gso_eligible_rejects_mixed_sizes() {
let mut batch: Vec<_> = (0..18).map(|_| pkt(1500)).collect();
batch[3] = pkt(800); batch.push(pkt(1500));
assert!(!gso_eligible_sizes(&batch));
}
#[test]
fn gso_roundtrip_loopback() {
use std::net::UdpSocket;
use std::os::unix::io::AsRawFd;
let recv_sock = UdpSocket::bind("127.0.0.1:0").expect("bind recv");
let recv_addr = recv_sock.local_addr().expect("recv local_addr");
recv_sock
.set_read_timeout(Some(std::time::Duration::from_millis(500)))
.expect("set_read_timeout");
let send_sock = UdpSocket::bind("127.0.0.1:0").expect("bind send");
const SEG: usize = 200;
const N: usize = 18;
let mut batch: Vec<Vec<u8>> = Vec::with_capacity(N);
for i in 0..N {
let mut buf = vec![0u8; SEG];
buf[0] = i as u8;
batch.push(buf);
}
let r = send_batch_gso(
send_sock.as_raw_fd(),
&batch,
recv_addr,
false,
);
match r {
Ok(()) => {} Err(err)
if err.raw_os_error() == Some(libc::EOPNOTSUPP)
|| err.raw_os_error() == Some(libc::ENOPROTOOPT)
|| err.kind() == std::io::ErrorKind::InvalidInput =>
{
eprintln!(
"gso_roundtrip_loopback: kernel doesn't support UDP_GSO ({err}); skipping"
);
return;
}
Err(err) => panic!("send_batch_gso failed: {err}"),
}
let mut recv_buf = [0u8; SEG + 32];
for i in 0..N {
let (len, _from) = recv_sock
.recv_from(&mut recv_buf)
.unwrap_or_else(|e| panic!("recv {i}: {e}"));
assert_eq!(len, SEG, "datagram {i} has wrong length");
assert_eq!(
recv_buf[0], i as u8,
"datagram {i} arrived out of order or with wrong stamp"
);
}
}
#[test]
fn sendmmsg_uniform_dest_roundtrip() {
use std::net::UdpSocket;
use std::os::unix::io::AsRawFd;
let recv_sock = UdpSocket::bind("127.0.0.1:0").expect("bind recv");
let recv_addr = recv_sock.local_addr().unwrap();
recv_sock
.set_read_timeout(Some(std::time::Duration::from_millis(500)))
.expect("set_read_timeout");
let send_sock = UdpSocket::bind("127.0.0.1:0").expect("bind send");
send_sock.set_nonblocking(true).unwrap();
let packets: Vec<Vec<u8>> = (0..4)
.map(|i| {
let mut v = vec![0u8; 16];
v[0] = i as u8;
v
})
.collect();
let n =
send_batch_raw(send_sock.as_raw_fd(), &packets, recv_addr, false).expect("sendmmsg ok");
assert_eq!(n, 4);
let mut buf = [0u8; 64];
let mut stamps: Vec<u8> = Vec::new();
for _ in 0..4 {
let (len, _) = recv_sock.recv_from(&mut buf).expect("recv");
assert_eq!(len, 16);
stamps.push(buf[0]);
}
stamps.sort();
assert_eq!(stamps, vec![0, 1, 2, 3]);
}
#[test]
fn flush_batch_routes_each_target_separately() {
use crate::transport::udp::socket::UdpRawSocket;
use ring::aead::{LessSafeKey, UnboundKey};
use std::net::UdpSocket;
let rt = tokio::runtime::Builder::new_current_thread()
.enable_io()
.build()
.expect("tokio rt");
rt.block_on(async {
let recv_a = UdpSocket::bind("127.0.0.1:0").expect("bind recv_a");
let recv_b = UdpSocket::bind("127.0.0.1:0").expect("bind recv_b");
for s in [&recv_a, &recv_b] {
s.set_read_timeout(Some(std::time::Duration::from_millis(500)))
.expect("set_read_timeout");
}
let addr_a = recv_a.local_addr().unwrap();
let addr_b = recv_b.local_addr().unwrap();
let raw = UdpRawSocket::open("127.0.0.1:0".parse().unwrap(), 1 << 20, 1 << 20)
.expect("open send socket");
let send_sock = raw.into_async().expect("into_async");
let key_bytes = [0u8; 32];
let unbound = UnboundKey::new(&ring::aead::CHACHA20_POLY1305, &key_bytes)
.expect("build unbound key");
let cipher = LessSafeKey::new(unbound);
const A_PLAINTEXT: usize = 32;
const B_PLAINTEXT: usize = 64;
const A_WIRE: usize = 16 + A_PLAINTEXT + 16; const B_WIRE: usize = 16 + B_PLAINTEXT + 16;
fn make_job(
socket: crate::transport::udp::socket::AsyncUdpSocket,
cipher: &LessSafeKey,
counter: u64,
dest: SocketAddr,
plaintext_size: usize,
) -> FmpSendJob {
let mut wire_buf = Vec::with_capacity(16 + plaintext_size + 16);
wire_buf.extend_from_slice(&[0u8; 16]);
wire_buf.extend_from_slice(&vec![0u8; plaintext_size]);
FmpSendJob {
cipher: cipher.clone(),
counter,
wire_buf,
fsp_seal: None,
socket,
dest_addr: dest,
#[cfg(any(target_os = "linux", target_os = "macos"))]
connected_socket: None,
drop_on_backpressure: true,
queued_at: None,
}
}
let mut batch = vec![
make_job(send_sock.clone(), &cipher, 1, addr_a, A_PLAINTEXT),
make_job(send_sock.clone(), &cipher, 2, addr_b, B_PLAINTEXT),
make_job(send_sock.clone(), &cipher, 3, addr_a, A_PLAINTEXT),
];
flush_direct_batch_sync(&mut batch).expect("flush ok");
assert!(batch.is_empty(), "flush must drain the batch");
let mut buf = [0u8; 256];
for i in 0..2 {
let (len, _) = recv_a.recv_from(&mut buf).expect("recv_a");
assert_eq!(
len, A_WIRE,
"recv_a packet {i} has wrong length: got {len}, expected {A_WIRE}"
);
}
let (len, _) = recv_b.recv_from(&mut buf).expect("recv_b");
assert_eq!(
len, B_WIRE,
"recv_b packet has wrong length: got {len}, expected {B_WIRE}"
);
for (name, sock) in [("recv_a", &recv_a), ("recv_b", &recv_b)] {
sock.set_read_timeout(Some(std::time::Duration::from_millis(50)))
.unwrap();
let leftover = sock.recv_from(&mut buf);
assert!(
leftover.is_err(),
"{name} got unexpected extra packet: {:?}",
leftover
);
}
});
}
}
#[cfg(all(unix, not(target_os = "linux")))]
fn send_connected_raw(fd: std::os::unix::io::RawFd, data: &[u8]) -> std::io::Result<usize> {
let r = unsafe { libc::send(fd, data.as_ptr() as *const libc::c_void, data.len(), 0) };
if r < 0 {
Err(std::io::Error::last_os_error())
} else {
Ok(r as usize)
}
}
#[cfg(all(unix, not(target_os = "linux")))]
fn send_one_with_backpressure(
fd: std::os::unix::io::RawFd,
connected: bool,
dest: &SocketAddr,
data: &[u8],
backpressure: &mut SendBackpressurePacer,
drop_on_backpressure: bool,
) -> std::io::Result<()> {
loop {
let result = if connected {
send_connected_raw(fd, data)
} else {
send_one_raw(fd, data, dest)
};
match result {
Ok(_) => {
backpressure.record_success();
record_udp_send_path(connected, 1);
return Ok(());
}
Err(err) if is_send_backpressure(&err) => {
if backpressure.pause(&err) && drop_on_backpressure {
record_udp_send_backpressure_drop(&err);
return Err(err);
}
}
Err(err) => return Err(err),
}
}
}
#[cfg(all(unix, not(target_os = "linux")))]
fn send_one_raw(
fd: std::os::unix::io::RawFd,
data: &[u8],
dest: &SocketAddr,
) -> std::io::Result<usize> {
let sa: socket2::SockAddr = (*dest).into();
let r = unsafe {
libc::sendto(
fd,
data.as_ptr() as *const libc::c_void,
data.len(),
0,
sa.as_ptr() as *const libc::sockaddr,
sa.len(),
)
};
if r < 0 {
Err(std::io::Error::last_os_error())
} else {
Ok(r as usize)
}
}