use std::{
cell::RefCell,
collections::{BTreeMap, HashMap, VecDeque, hash_map::Entry},
io::ErrorKind,
net::SocketAddr,
os::unix::io::AsRawFd,
rc::Rc,
time::{Duration, Instant},
};
use mio::{Interest, Registry, Token, net::UdpSocket, unix::SourceFd};
use sozu_command::{
logging::ansi_palette,
proto::command::{
Cluster, LoadBalancingAlgorithms, LoadMetric, RequestUdpFrontend, UdpAffinityKey,
UdpListenerConfig, UpdateUdpListenerConfig, WorkerRequest, WorkerResponse,
request::RequestType,
},
};
use crate::metrics::names;
use crate::{
CachedTags, ListenerError, ListenerHandler, Protocol, ProxyError, ProxySession,
SessionIsToBeClosed,
backends::BackendMap,
pool::Pool,
protocol::udp::{
CloseReason, ClusterConfig, ConfigEvent, DropReason, FlowId, ManagerInput, MetricEvent,
Output, UdpManager,
},
server::{SessionManager, TIMER},
socket::{udp_bind, udp_connect},
sozu_command::{ready::Ready, state::ClusterId},
};
mod health;
pub use health::UdpHealthChecker;
macro_rules! log_context {
($self:expr) => {{
let (open, reset, grey, gray, white) = ansi_palette();
format!(
"[- - - -]\t{open}UDP{reset}\t{grey}Listener{reset}({gray}token{reset}={white}{token}{reset}, {gray}address{reset}={white}{address}{reset})\t >>>",
open = open,
reset = reset,
grey = grey,
gray = gray,
white = white,
token = $self.listener_token.0,
address = $self.address,
)
}};
}
macro_rules! log_module_context {
() => {{
let (open, reset, _, _, _) = sozu_command::logging::ansi_palette();
format!("{open}UDP{reset}\t >>>", open = open, reset = reset)
}};
}
macro_rules! log_flow_context {
($flow:expr, $client:expr, $backend:expr) => {{
let (open, reset, grey, gray, white) = sozu_command::logging::ansi_palette();
format!(
"[- - - -]\t{open}UDP-FLOW{reset}\t{grey}Flow{reset}({gray}id{reset}={white}{id}{reset}, {gray}client{reset}={white}{client}{reset}, {gray}backend{reset}={white}{backend:?}{reset})\t >>>",
open = open,
reset = reset,
grey = grey,
gray = gray,
white = white,
id = $flow,
client = $client,
backend = $backend,
)
}};
}
const UPSTREAM_WRITE_QUEUE_CAP: usize = 64;
const CLIENT_WRITE_QUEUE_CAP: usize = 256;
enum SendOutcome {
Sent,
WouldBlock,
Dropped,
}
struct WriteQueue {
queue: VecDeque<(SocketAddr, Vec<u8>)>,
cap: usize,
}
impl WriteQueue {
fn new(cap: usize) -> Self {
WriteQueue {
queue: VecDeque::new(),
cap,
}
}
fn is_empty(&self) -> bool {
self.queue.is_empty()
}
#[cfg(test)]
fn len(&self) -> usize {
self.queue.len()
}
#[must_use]
fn push(&mut self, dst: SocketAddr, payload: Vec<u8>) -> bool {
if self.queue.len() >= self.cap {
return false;
}
self.queue.push_back((dst, payload));
debug_assert!(
self.queue.len() <= self.cap,
"WriteQueue overran its cap: len {} > cap {}",
self.queue.len(),
self.cap,
);
true
}
fn drain<F: FnMut(&SocketAddr, &[u8]) -> SendOutcome>(&mut self, mut send: F) -> bool {
while let Some((dst, payload)) = self.queue.front() {
match send(dst, payload) {
SendOutcome::Sent | SendOutcome::Dropped => {
self.queue.pop_front();
}
SendOutcome::WouldBlock => break,
}
}
self.queue.is_empty()
}
}
pub struct UdpListener {
active: SessionIsToBeClosed,
address: SocketAddr,
cluster_id: Option<String>,
config: UdpListenerConfig,
socket: Option<UdpSocket>,
tags: BTreeMap<String, CachedTags>,
token: Token,
}
impl ListenerHandler for UdpListener {
fn get_addr(&self) -> &SocketAddr {
&self.address
}
fn get_tags(&self, key: &str) -> Option<&CachedTags> {
self.tags.get(key)
}
fn set_tags(&mut self, key: String, tags: Option<BTreeMap<String, String>>) {
match tags {
Some(tags) => self.tags.insert(key, CachedTags::new(tags)),
None => self.tags.remove(&key),
};
}
fn protocol(&self) -> Protocol {
Protocol::UDP
}
fn public_address(&self) -> SocketAddr {
self.config
.public_address
.map(|addr| addr.into())
.unwrap_or(self.address)
}
}
impl UdpListener {
fn new(config: UdpListenerConfig, token: Token) -> Result<UdpListener, ListenerError> {
Ok(UdpListener {
cluster_id: None,
socket: None,
token,
address: config.address.into(),
config,
active: false,
tags: BTreeMap::new(),
})
}
pub fn activate(
&mut self,
registry: &Registry,
udp_socket: Option<UdpSocket>,
) -> Result<Token, ProxyError> {
if self.active {
return Ok(self.token);
}
let mut socket = match udp_socket {
Some(socket) => socket,
None => {
let address: SocketAddr = self.config.address.into();
udp_bind(address).map_err(|e| ProxyError::BindToSocket(address, e))?
}
};
registry
.register(&mut socket, self.token, Interest::READABLE)
.map_err(ProxyError::RegisterListener)?;
self.socket = Some(socket);
self.active = true;
Ok(self.token)
}
pub fn update_config(&mut self, patch: &UpdateUdpListenerConfig) {
if let Some(v) = patch.public_address {
self.config.public_address = Some(v);
}
if let Some(v) = patch.front_timeout {
self.config.front_timeout = v;
}
if let Some(v) = patch.back_timeout {
self.config.back_timeout = v;
}
if let Some(v) = patch.max_rx_datagram_size {
self.config.max_rx_datagram_size = v;
}
if let Some(v) = patch.max_flows {
self.config.max_flows = v;
}
}
}
pub struct UdpProxy {
fronts: HashMap<String, Token>,
backends: Rc<RefCell<BackendMap>>,
listeners: HashMap<Token, Rc<RefCell<UdpListener>>>,
listener_sessions: HashMap<Token, Rc<RefCell<UdpListenerSession>>>,
managers: HashMap<Token, Rc<RefCell<UdpManager>>>,
cluster_for_listener: HashMap<Token, ClusterId>,
cluster_udp_config: HashMap<ClusterId, sozu_command::proto::command::UdpClusterConfig>,
registry: Registry,
sessions: Rc<RefCell<SessionManager>>,
#[allow(dead_code)]
pool: Rc<RefCell<Pool>>,
hash_seed: u64,
max_connections: usize,
buffer_size: usize,
health: UdpHealthChecker,
}
impl UdpProxy {
pub fn new(
registry: Registry,
sessions: Rc<RefCell<SessionManager>>,
pool: Rc<RefCell<Pool>>,
backends: Rc<RefCell<BackendMap>>,
max_connections: usize,
buffer_size: usize,
) -> UdpProxy {
let hash_seed = crate::load_balancing::DEFAULT_HASH_SEED;
UdpProxy {
backends,
listeners: HashMap::new(),
listener_sessions: HashMap::new(),
managers: HashMap::new(),
cluster_for_listener: HashMap::new(),
cluster_udp_config: HashMap::new(),
fronts: HashMap::new(),
registry,
sessions,
pool,
hash_seed,
max_connections,
buffer_size,
health: UdpHealthChecker::new(),
}
}
pub fn health_poll(&mut self) {
let registry = self.registry.try_clone();
if let Ok(registry) = registry {
self.health.poll(&self.backends, ®istry);
}
}
pub fn health_ready(&mut self, token: Token) {
self.health.ready(token);
}
pub fn health_owns_token(&self, token: Token) -> bool {
self.health.owns_token(token)
}
pub fn add_listener(
&mut self,
config: UdpListenerConfig,
token: Token,
) -> Result<Token, ProxyError> {
match self.listeners.entry(token) {
Entry::Vacant(entry) => {
let mut config = config;
let max_flows = effective_max_flows(config.max_flows, self.max_connections);
let max_rx = clamp_max_rx(config.max_rx_datagram_size as usize, self.buffer_size);
config.max_rx_datagram_size = max_rx as u32;
let front = Duration::from_secs(u64::from(config.front_timeout));
let back = Duration::from_secs(u64::from(config.back_timeout));
let listener = UdpListener::new(config, token).map_err(ProxyError::AddListener)?;
entry.insert(Rc::new(RefCell::new(listener)));
let cluster_cfg = ClusterConfig {
front_timeout: front,
back_timeout: back,
..Default::default()
};
self.managers.insert(
token,
Rc::new(RefCell::new(UdpManager::new(
cluster_cfg,
max_flows,
max_rx,
self.hash_seed,
))),
);
Ok(token)
}
_ => Err(ProxyError::ListenerAlreadyPresent),
}
}
pub fn remove_listener(&mut self, address: SocketAddr) -> SessionIsToBeClosed {
let len = self.listeners.len();
let mut removed_tokens = Vec::new();
self.listeners.retain(|token, l| {
if l.borrow().address == address {
removed_tokens.push(*token);
false
} else {
true
}
});
let now = Instant::now();
for token in removed_tokens {
self.cluster_for_listener.remove(&token);
if let Some(session) = self.listener_sessions.remove(&token) {
session.borrow_mut().close_all_flows(now);
}
self.managers.remove(&token);
}
self.listeners.len() < len
}
pub fn activate_listener(
&self,
addr: &SocketAddr,
udp_socket: Option<UdpSocket>,
) -> Result<Token, ProxyError> {
let listener = self
.listeners
.values()
.find(|listener| listener.borrow().address == *addr)
.ok_or(ProxyError::NoListenerFound(*addr))?;
listener.borrow_mut().activate(&self.registry, udp_socket)
}
pub fn build_session(&mut self, token: Token) -> Option<Rc<RefCell<UdpListenerSession>>> {
let listener = self.listeners.get(&token)?.clone();
let manager = self.managers.get(&token)?.clone();
let registry = self.registry.try_clone().ok()?;
let session = Rc::new(RefCell::new(UdpListenerSession::new(
listener,
manager,
self.backends.clone(),
registry,
self.sessions.clone(),
token,
)));
self.listener_sessions.insert(token, session.clone());
Some(session)
}
pub fn give_back_listeners(&mut self) -> Vec<(SocketAddr, UdpSocket)> {
self.listeners
.values()
.filter_map(|listener| {
let mut owned = listener.borrow_mut();
if let Some(socket) = owned.socket.take() {
owned.active = false;
return Some((owned.address, socket));
}
None
})
.collect()
}
pub fn give_back_listener(
&mut self,
address: SocketAddr,
) -> Result<(Token, UdpSocket), ProxyError> {
let listener = self
.listeners
.values()
.find(|listener| listener.borrow().address == address)
.ok_or(ProxyError::NoListenerFound(address))?;
let (token, taken) = {
let mut owned = listener.borrow_mut();
let taken = owned.socket.take().ok_or(ProxyError::UnactivatedListener)?;
owned.active = false;
(owned.token, taken)
};
if let Some(session) = self.listener_sessions.remove(&token) {
session.borrow_mut().close_all_flows(Instant::now());
}
Ok((token, taken))
}
pub fn update_listener(&mut self, patch: UpdateUdpListenerConfig) -> Result<(), ProxyError> {
let address: SocketAddr = patch.address.into();
let listener = self
.listeners
.values()
.find(|l| l.borrow().address == address)
.ok_or(ProxyError::NoListenerFound(address))?;
{
let mut l = listener.borrow_mut();
l.update_config(&patch);
l.config.max_rx_datagram_size =
clamp_max_rx(l.config.max_rx_datagram_size as usize, self.buffer_size) as u32;
}
if let Some(token) = self
.listeners
.iter()
.find(|(_, l)| l.borrow().address == address)
.map(|(t, _)| *t)
&& let Some(mgr) = self.managers.get(&token)
{
let now = Instant::now();
let (cfg, max_flows, max_rx) = {
let l = listener.borrow();
(
self.cluster_config_for(&l, token),
effective_max_flows(l.config.max_flows, self.max_connections),
clamp_max_rx(l.config.max_rx_datagram_size as usize, self.buffer_size),
)
};
{
let mut m = mgr.borrow_mut();
m.handle_input(ManagerInput::Config(ConfigEvent::SetCluster(cfg)), now);
m.handle_input(
ManagerInput::Config(ConfigEvent::SetMaxFlows(max_flows)),
now,
);
m.handle_input(
ManagerInput::Config(ConfigEvent::SetMaxRxDatagramSize(max_rx)),
now,
);
}
if let Some(session) = self.listener_sessions.get(&token) {
session.borrow_mut().resize_recv_buf(max_rx);
}
}
Ok(())
}
pub fn add_udp_front(&mut self, front: RequestUdpFrontend) -> Result<(), ProxyError> {
let address = front.address.into();
let token = {
let mut listener = self
.listeners
.values()
.find(|l| l.borrow().address == address)
.ok_or(ProxyError::NoListenerFound(address))?
.borrow_mut();
self.fronts
.insert(front.cluster_id.to_string(), listener.token);
listener.set_tags(address.to_string(), Some(front.tags));
listener.cluster_id = Some(front.cluster_id.clone());
listener.token
};
self.cluster_for_listener
.insert(token, front.cluster_id.clone());
if let Some(mgr) = self.managers.get(&token) {
let listener = self.listeners.get(&token).unwrap();
let cfg = {
let l = listener.borrow();
self.cluster_config_for(&l, token)
};
mgr.borrow_mut().handle_input(
ManagerInput::Config(ConfigEvent::SetCluster(cfg)),
Instant::now(),
);
}
Ok(())
}
pub fn remove_udp_front(&mut self, front: RequestUdpFrontend) -> Result<(), ProxyError> {
let address = front.address.into();
let token = {
let mut listener = match self
.listeners
.values()
.find(|l| l.borrow().address == address)
{
Some(l) => l.borrow_mut(),
None => return Err(ProxyError::NoListenerFound(address)),
};
listener.set_tags(address.to_string(), None);
if let Some(cluster_id) = listener.cluster_id.take() {
self.fronts.remove(&cluster_id);
}
listener.token
};
self.cluster_for_listener.remove(&token);
if let Some(mgr) = self.managers.get(&token) {
mgr.borrow_mut().handle_input(
ManagerInput::Config(ConfigEvent::SetCluster(ClusterConfig::default())),
Instant::now(),
);
}
Ok(())
}
fn cluster_config_for(&self, listener: &UdpListener, _token: Token) -> ClusterConfig {
let cluster = listener.cluster_id.clone().unwrap_or_default();
let mut cfg = ClusterConfig {
cluster: cluster.clone(),
front_timeout: Duration::from_secs(u64::from(listener.config.front_timeout)),
back_timeout: Duration::from_secs(u64::from(listener.config.back_timeout)),
..Default::default()
};
if let Some(udp) = self.cluster_udp_config.get(&cluster) {
apply_udp_knobs(&mut cfg, udp);
}
cfg
}
fn apply_cluster(&mut self, cluster: &Cluster) {
self.backends
.borrow_mut()
.set_load_balancing_policy_for_cluster(
&cluster.cluster_id,
LoadBalancingAlgorithms::try_from(cluster.load_balancing).unwrap_or_default(),
cluster
.load_metric
.and_then(|n| LoadMetric::try_from(n).ok()),
);
let health_settings = cluster.udp.as_ref().and_then(|udp| {
udp.health.as_ref().and_then(|h| {
let mode = h
.mode
.and_then(|m| sozu_command::proto::command::UdpHealthMode::try_from(m).ok());
match mode {
Some(sozu_command::proto::command::UdpHealthMode::HealthOff) => None,
_ => Some(health::UdpHealthSettings::from_proto(h)),
}
})
});
self.health
.set_cluster(&cluster.cluster_id, health_settings, &self.registry);
match &cluster.udp {
Some(udp) => {
self.cluster_udp_config
.insert(cluster.cluster_id.clone(), udp.clone());
}
None => {
self.cluster_udp_config.remove(&cluster.cluster_id);
}
}
let now = Instant::now();
let tokens: Vec<Token> = self
.cluster_for_listener
.iter()
.filter(|(_, c)| **c == cluster.cluster_id)
.map(|(t, _)| *t)
.collect();
for token in tokens {
let Some(listener) = self.listeners.get(&token) else {
continue;
};
let cfg = {
let l = listener.borrow();
self.cluster_config_for(&l, token)
};
if let Some(mgr) = self.managers.get(&token) {
mgr.borrow_mut()
.handle_input(ManagerInput::Config(ConfigEvent::SetCluster(cfg)), now);
}
}
}
pub fn notify(&mut self, message: WorkerRequest) -> WorkerResponse {
let request_type = match message.content.request_type {
Some(t) => t,
None => return WorkerResponse::error(message.id, "Empty request"),
};
match request_type {
RequestType::AddUdpFrontend(front) => match self.add_udp_front(front) {
Ok(()) => WorkerResponse::ok(message.id),
Err(err) => WorkerResponse::error(message.id, err),
},
RequestType::RemoveUdpFrontend(front) => match self.remove_udp_front(front) {
Ok(()) => WorkerResponse::ok(message.id),
Err(err) => WorkerResponse::error(message.id, err),
},
RequestType::AddCluster(cluster) => {
self.apply_cluster(&cluster);
WorkerResponse::ok(message.id)
}
RequestType::RemoveCluster(cluster_id) => {
let tokens: Vec<Token> = self
.cluster_for_listener
.iter()
.filter(|(_, c)| **c == cluster_id)
.map(|(t, _)| *t)
.collect();
for token in tokens {
if let Some(mgr) = self.managers.get(&token) {
mgr.borrow_mut().handle_input(
ManagerInput::Config(ConfigEvent::SetCluster(ClusterConfig::default())),
Instant::now(),
);
}
}
self.cluster_udp_config.remove(&cluster_id);
self.health.remove_cluster(&cluster_id, &self.registry);
WorkerResponse::ok(message.id)
}
RequestType::SoftStop(_) => {
info!(
"{} {} processing soft shutdown",
log_module_context!(),
message.id
);
let now = Instant::now();
for mgr in self.managers.values() {
mgr.borrow_mut()
.handle_input(ManagerInput::Config(ConfigEvent::Drain), now);
}
for session in self.listener_sessions.values() {
session.borrow_mut().close_all_flows(now);
}
self.listener_sessions.clear();
let listeners: HashMap<_, _> = self.listeners.drain().collect();
for (_, l) in listeners.iter() {
l.borrow_mut()
.socket
.take()
.map(|mut sock| self.registry.deregister(&mut sock));
}
WorkerResponse::processing(message.id)
}
RequestType::HardStop(_) => {
info!("{} {} hard shutdown", log_module_context!(), message.id);
let now = Instant::now();
for session in self.listener_sessions.values() {
session.borrow_mut().close_all_flows(now);
}
self.listener_sessions.clear();
let mut listeners: HashMap<_, _> = self.listeners.drain().collect();
for (_, l) in listeners.drain() {
l.borrow_mut()
.socket
.take()
.map(|mut sock| self.registry.deregister(&mut sock));
}
self.managers.clear();
WorkerResponse::ok(message.id)
}
RequestType::Status(_) => {
info!("{} {} status", log_module_context!(), message.id);
WorkerResponse::ok(message.id)
}
RequestType::RemoveListener(remove) => {
if !self.remove_listener(remove.address.into()) {
WorkerResponse::error(
message.id,
format!("no UDP listener to remove at address {:?}", remove.address),
)
} else {
WorkerResponse::ok(message.id)
}
}
command => {
debug!(
"{} {} unsupported message for UDP proxy, ignoring {:?}",
log_module_context!(),
message.id,
command
);
WorkerResponse::error(message.id, "unsupported message")
}
}
}
}
fn apply_udp_knobs(cfg: &mut ClusterConfig, udp: &sozu_command::proto::command::UdpClusterConfig) {
cfg.affinity_with_port = matches!(
udp.affinity_key
.and_then(|k| UdpAffinityKey::try_from(k).ok()),
Some(UdpAffinityKey::SourceIpPort)
);
cfg.responses = udp.responses.unwrap_or(0);
cfg.requests = udp.requests.unwrap_or(0);
cfg.send_proxy_protocol = udp.send_proxy_protocol.unwrap_or(false);
cfg.proxy_protocol_every_datagram = udp.proxy_protocol_every_datagram.unwrap_or(false);
}
const DEFAULT_AUTO_MAX_FLOWS: usize = 1024;
fn effective_max_flows(configured: u32, slab_headroom: usize) -> usize {
if configured != 0 {
return configured as usize;
}
let auto = {
#[cfg(unix)]
{
let mut limit = libc::rlimit {
rlim_cur: 0,
rlim_max: 0,
};
let ret = unsafe { libc::getrlimit(libc::RLIMIT_NOFILE, &mut limit) };
if ret == 0 && limit.rlim_cur > 0 {
let soft = limit.rlim_cur;
((soft.saturating_mul(7)) / 10).max(1) as usize
} else {
DEFAULT_AUTO_MAX_FLOWS
}
}
#[cfg(not(unix))]
{
DEFAULT_AUTO_MAX_FLOWS
}
};
if slab_headroom == 0 {
auto
} else {
auto.min(slab_headroom).max(1)
}
}
fn clamp_max_rx(configured: usize, buffer_size: usize) -> usize {
if buffer_size == 0 {
configured
} else {
configured.min(buffer_size)
}
}
pub struct UdpListenerSession {
listener: Rc<RefCell<UdpListener>>,
manager: Rc<RefCell<UdpManager>>,
backends: Rc<RefCell<BackendMap>>,
registry: Registry,
sessions: Rc<RefCell<SessionManager>>,
listener_token: Token,
address: SocketAddr,
upstream_sockets: HashMap<Token, UdpSocket>,
upstream_write_queues: HashMap<Token, WriteQueue>,
client_write_queue: WriteQueue,
upstream_to_flow: HashMap<Token, FlowId>,
flow_to_upstream: HashMap<FlowId, Token>,
flow_started: HashMap<FlowId, Instant>,
flow_endpoints: HashMap<FlowId, (SocketAddr, Option<SocketAddr>)>,
in_flight_client: Option<SocketAddr>,
in_flight_flow: Option<FlowId>,
client_key_to_flow: HashMap<SocketAddr, FlowId>,
recv_buf: Vec<u8>,
timer_handle: Option<crate::timer::Timeout>,
}
impl UdpListenerSession {
#[allow(clippy::too_many_arguments)]
pub fn new(
listener: Rc<RefCell<UdpListener>>,
manager: Rc<RefCell<UdpManager>>,
backends: Rc<RefCell<BackendMap>>,
registry: Registry,
sessions: Rc<RefCell<SessionManager>>,
listener_token: Token,
) -> UdpListenerSession {
let (address, max_rx) = {
let l = listener.borrow();
(l.address, l.config.max_rx_datagram_size as usize)
};
UdpListenerSession {
listener,
manager,
backends,
registry,
sessions,
listener_token,
address,
upstream_sockets: HashMap::new(),
upstream_write_queues: HashMap::new(),
client_write_queue: WriteQueue::new(CLIENT_WRITE_QUEUE_CAP),
upstream_to_flow: HashMap::new(),
flow_to_upstream: HashMap::new(),
flow_started: HashMap::new(),
flow_endpoints: HashMap::new(),
in_flight_client: None,
in_flight_flow: None,
client_key_to_flow: HashMap::new(),
recv_buf: vec![0u8; max_rx.saturating_add(1).max(1)],
timer_handle: None,
}
}
fn resize_recv_buf(&mut self, max_rx: usize) {
self.recv_buf.resize(max_rx.saturating_add(1).max(1), 0u8);
}
fn client_key(&self, src: SocketAddr) -> SocketAddr {
let with_port = self.manager.borrow().affinity_with_port();
if with_port {
src
} else {
let mut s = src;
s.set_port(0);
s
}
}
fn ingest_client(&mut self, now: Instant) {
loop {
let result = {
let listener = self.listener.borrow();
let Some(socket) = listener.socket.as_ref() else {
return;
};
socket.recv_from(&mut self.recv_buf)
};
match result {
Ok((len, src)) => {
self.in_flight_client = Some(src);
self.in_flight_flow = None;
let len = len.min(self.recv_buf.len());
let payload: &[u8] = &self.recv_buf[..len];
let mgr = self.manager.clone();
mgr.borrow_mut()
.handle_input(ManagerInput::ClientDatagram { src, payload }, now);
self.drain_outputs(now);
self.in_flight_client = None;
}
Err(ref e) if e.kind() == ErrorKind::WouldBlock => break,
Err(ref e) if e.kind() == ErrorKind::Interrupted => continue,
Err(e) => {
debug!(
"{} recv_from error on UDP listener: {}",
log_context!(self),
e
);
break;
}
}
}
}
fn ingest_upstream(&mut self, upstream_token: Token, now: Instant) {
let Some(&flow) = self.upstream_to_flow.get(&upstream_token) else {
return;
};
loop {
let result = {
let Some(socket) = self.upstream_sockets.get(&upstream_token) else {
return;
};
socket.recv(&mut self.recv_buf)
};
match result {
Ok(len) => {
let len = len.min(self.recv_buf.len());
let payload: &[u8] = &self.recv_buf[..len];
let mgr = self.manager.clone();
mgr.borrow_mut()
.handle_input(ManagerInput::BackendDatagram { flow, payload }, now);
self.drain_outputs(now);
}
Err(ref e) if e.kind() == ErrorKind::WouldBlock => break,
Err(ref e) if e.kind() == ErrorKind::Interrupted => continue,
Err(e) => {
debug!(
"{} recv error on upstream socket: {}",
log_context!(self),
e
);
break;
}
}
}
}
fn drain_outputs(&mut self, now: Instant) {
let mgr = self.manager.clone();
loop {
let out = mgr.borrow_mut().poll_output();
let Some(out) = out else { break };
match out {
Output::SelectBackend { flow, cluster, key } => {
self.on_select_backend(flow, &cluster, key, now)
}
Output::OpenUpstream { flow, backend } => self.on_open_upstream(flow, backend, now),
Output::SendToBackend(transmit) => self.on_send_to_backend(transmit),
Output::SendToClient(transmit) => self.on_send_to_client(transmit),
Output::ArmTimer(deadline) => self.arm_timer(deadline, now),
Output::Metric(ev) => Self::record_metric(ev),
Output::CloseFlow(flow) => self.on_close_flow(flow),
Output::Drop(reason) => Self::record_drop(reason),
}
}
}
fn on_select_backend(&mut self, flow: FlowId, cluster: &str, key: u64, now: Instant) {
let resolved = self
.backends
.borrow_mut()
.backend_from_cluster_id_with_key(cluster, Some(key));
match resolved {
Ok((backend, addr)) => {
self.manager.borrow_mut().handle_input(
ManagerInput::BackendResolved {
flow,
backend,
addr,
},
now,
);
}
Err(e) => {
debug!(
"{} no backend for cluster {}: {}; aborting flow {}",
log_context!(self),
cluster,
e,
flow
);
incr!(names::udp::DROPPED_NO_BACKEND);
self.manager
.borrow_mut()
.abort_flow(flow, now, CloseReason::Aborted);
}
}
}
fn on_open_upstream(&mut self, flow: FlowId, backend: SocketAddr, now: Instant) {
let mut socket = match udp_connect(backend) {
Ok(socket) => socket,
Err(e) => {
warn!(
"{} could not open upstream socket to {}: {}; shedding flow {}",
log_context!(self),
backend,
e,
flow
);
incr!(names::udp::FLOWS_SHED);
self.manager
.borrow_mut()
.abort_flow(flow, now, CloseReason::Aborted);
return;
}
};
let upstream_token = {
let mut s = self.sessions.borrow_mut();
let listener_session = s.slab[self.listener_token.0].clone();
let entry = s.slab.vacant_entry();
let token = Token(entry.key());
entry.insert(listener_session);
token
};
if let Err(e) = self
.registry
.register(&mut socket, upstream_token, Interest::READABLE)
{
error!(
"{} could not register upstream socket: {}",
log_context!(self),
e
);
self.sessions.borrow_mut().slab.try_remove(upstream_token.0);
self.manager
.borrow_mut()
.abort_flow(flow, now, CloseReason::Aborted);
return;
}
self.upstream_sockets.insert(upstream_token, socket);
self.upstream_to_flow.insert(upstream_token, flow);
self.flow_to_upstream.insert(flow, upstream_token);
debug_assert_eq!(
self.upstream_to_flow.get(&upstream_token),
Some(&flow),
"upstream_to_flow must map the new token back to its flow"
);
debug_assert_eq!(
self.flow_to_upstream.get(&flow),
Some(&upstream_token),
"flow_to_upstream must map the flow back to its upstream token"
);
self.flow_started.insert(flow, Instant::now());
let client = self.in_flight_client.unwrap_or(self.address);
self.flow_endpoints.insert(flow, (client, Some(backend)));
if let Some(src) = self.in_flight_client {
let key = self.client_key(src);
self.client_key_to_flow.insert(key, flow);
debug_assert!(
self.flow_to_upstream.contains_key(&flow),
"client_key_to_flow points at flow {flow} with no live upstream token"
);
}
self.in_flight_flow = Some(flow);
}
fn on_send_to_backend(&mut self, transmit: crate::protocol::udp::Transmit) {
let flow = self.in_flight_flow.or_else(|| {
self.in_flight_client
.map(|src| self.client_key(src))
.and_then(|key| self.client_key_to_flow.get(&key).copied())
});
let token = flow.and_then(|f| self.flow_to_upstream.get(&f).copied());
let Some(token) = token else {
incr!(names::udp::DROPPED_UNKNOWN_FLOW);
return;
};
let Some(socket) = self.upstream_sockets.get(&token) else {
incr!(names::udp::DROPPED_UNKNOWN_FLOW);
return;
};
if let Some(q) = self.upstream_write_queues.get_mut(&token)
&& !q.is_empty()
{
if !q.push(transmit.dst, transmit.payload) {
debug!("{} upstream write queue full, dropping", log_context!(self));
incr!(names::udp::DROPPED_WQ_FULL);
}
return;
}
match socket.send(&transmit.payload) {
Ok(_) => {}
Err(ref e) if e.kind() == ErrorKind::WouldBlock => {
let q = self
.upstream_write_queues
.entry(token)
.or_insert_with(|| WriteQueue::new(UPSTREAM_WRITE_QUEUE_CAP));
if q.push(transmit.dst, transmit.payload) {
self.arm_upstream_writable(token);
} else {
debug!("{} upstream write queue full, dropping", log_context!(self));
incr!(names::udp::DROPPED_WQ_FULL);
}
}
Err(e) => {
debug!("{} upstream send error: {}", log_context!(self), e);
incr!(names::udp::DROPPED_SEND_ERROR);
}
}
}
fn arm_upstream_writable(&mut self, token: Token) {
if let Some(socket) = self.upstream_sockets.get_mut(&token)
&& let Err(e) =
self.registry
.reregister(socket, token, Interest::READABLE | Interest::WRITABLE)
{
debug!(
"{} could not arm WRITABLE on upstream socket: {}",
log_context!(self),
e
);
}
}
fn disarm_upstream_writable(&mut self, token: Token) {
if let Some(socket) = self.upstream_sockets.get_mut(&token)
&& let Err(e) = self.registry.reregister(socket, token, Interest::READABLE)
{
debug!(
"{} could not disarm WRITABLE on upstream socket: {}",
log_context!(self),
e
);
}
}
fn drain_upstream_queue(&mut self, token: Token) {
let Some(mut queue) = self.upstream_write_queues.remove(&token) else {
return;
};
let socket = self.upstream_sockets.get(&token);
let Some(socket) = socket else {
return;
};
let emptied = queue.drain(|_dst, payload| match socket.send(payload) {
Ok(_) => SendOutcome::Sent,
Err(ref e) if e.kind() == ErrorKind::WouldBlock => SendOutcome::WouldBlock,
Err(_) => SendOutcome::Dropped,
});
if emptied {
self.disarm_upstream_writable(token);
} else {
self.upstream_write_queues.insert(token, queue);
}
}
fn on_send_to_client(&mut self, transmit: crate::protocol::udp::Transmit) {
if !self.client_write_queue.is_empty() {
if !self.client_write_queue.push(transmit.dst, transmit.payload) {
debug!("{} client write queue full, dropping", log_context!(self));
incr!(names::udp::DROPPED_WQ_FULL);
}
return;
}
let send_result = {
let listener = self.listener.borrow();
let Some(socket) = listener.socket.as_ref() else {
return;
};
socket.send_to(&transmit.payload, transmit.dst)
};
match send_result {
Ok(_) => {}
Err(ref e) if e.kind() == ErrorKind::WouldBlock => {
if self.client_write_queue.push(transmit.dst, transmit.payload) {
self.arm_client_writable();
} else {
debug!("{} client write queue full, dropping", log_context!(self));
incr!(names::udp::DROPPED_WQ_FULL);
}
}
Err(e) => {
debug!("{} client send_to error: {}", log_context!(self), e);
incr!(names::udp::DROPPED_SEND_ERROR);
}
}
}
fn arm_client_writable(&mut self) {
let listener = self.listener.borrow();
let fd = match listener.socket.as_ref() {
Some(socket) => socket.as_raw_fd(),
None => return,
};
if let Err(e) = self.registry.reregister(
&mut SourceFd(&fd),
self.listener_token,
Interest::READABLE | Interest::WRITABLE,
) {
debug!(
"{} could not arm WRITABLE on listener socket: {}",
log_context!(self),
e
);
}
}
fn disarm_client_writable(&mut self) {
let listener = self.listener.borrow();
let fd = match listener.socket.as_ref() {
Some(socket) => socket.as_raw_fd(),
None => return,
};
if let Err(e) =
self.registry
.reregister(&mut SourceFd(&fd), self.listener_token, Interest::READABLE)
{
debug!(
"{} could not disarm WRITABLE on listener socket: {}",
log_context!(self),
e
);
}
}
fn drain_client_queue(&mut self) {
let mut queue = std::mem::replace(&mut self.client_write_queue, WriteQueue::new(0));
let emptied = {
let listener = self.listener.borrow();
let Some(socket) = listener.socket.as_ref() else {
self.client_write_queue = WriteQueue::new(CLIENT_WRITE_QUEUE_CAP);
return;
};
queue.drain(|dst, payload| match socket.send_to(payload, *dst) {
Ok(_) => SendOutcome::Sent,
Err(ref e) if e.kind() == ErrorKind::WouldBlock => SendOutcome::WouldBlock,
Err(_) => SendOutcome::Dropped,
})
};
queue.cap = CLIENT_WRITE_QUEUE_CAP;
self.client_write_queue = queue;
if emptied {
self.disarm_client_writable();
}
}
fn arm_timer(&mut self, deadline: Instant, now: Instant) {
let delay = deadline.saturating_duration_since(now);
TIMER.with(|timer| {
let mut timer = timer.borrow_mut();
if let Some(old) = self.timer_handle.take() {
let _ = timer.cancel_timeout(&old);
}
self.timer_handle = Some(timer.set_timeout(delay, self.listener_token));
});
}
fn on_close_flow(&mut self, flow: FlowId) {
if let Some(token) = self.flow_to_upstream.remove(&flow) {
if let Some(mut socket) = self.upstream_sockets.remove(&token) {
if let Err(e) = self.registry.deregister(&mut socket) {
debug!("{} deregister upstream on close: {}", log_context!(self), e);
}
}
self.upstream_write_queues.remove(&token);
self.upstream_to_flow.remove(&token);
self.sessions.borrow_mut().slab.try_remove(token.0);
debug_assert!(
!self.upstream_to_flow.values().any(|&f| f == flow),
"on_close_flow left upstream_to_flow referencing closed flow {flow}"
);
debug_assert!(
!self.flow_to_upstream.contains_key(&flow),
"on_close_flow left flow_to_upstream entry for closed flow {flow}"
);
}
if let Some(started) = self.flow_started.remove(&flow) {
let duration = started.elapsed();
time!(names::udp::FLOW_DURATION, duration.as_millis());
}
let (client, backend) = self
.flow_endpoints
.remove(&flow)
.unwrap_or((self.address, None));
let key = self.client_key(client);
if self.client_key_to_flow.get(&key) == Some(&flow) {
self.client_key_to_flow.remove(&key);
}
debug_assert!(
!self.client_key_to_flow.values().any(|&f| f == flow),
"on_close_flow left client_key_to_flow referencing closed flow {flow}"
);
info!("{} flow closed", log_flow_context!(flow, client, backend));
}
fn record_metric(ev: MetricEvent) {
match ev {
MetricEvent::FlowCreated => {
incr!(names::udp::FLOWS_CREATED);
gauge_add!(names::udp::ACTIVE_FLOWS, 1);
}
MetricEvent::FlowEvicted => {
incr!(names::udp::FLOWS_EVICTED);
gauge_add!(names::udp::ACTIVE_FLOWS, -1);
}
MetricEvent::FlowShed => {
incr!(names::udp::FLOWS_SHED);
}
MetricEvent::DatagramIn(bytes) => {
incr!(names::udp::DATAGRAMS_IN);
count!(names::udp::BYTES_IN, bytes as i64);
}
MetricEvent::DatagramOut(bytes) => {
incr!(names::udp::DATAGRAMS_OUT);
count!(names::udp::BYTES_OUT, bytes as i64);
}
MetricEvent::DatagramDropped(reason) => Self::record_drop(reason),
}
}
fn record_drop(reason: DropReason) {
incr!(names::udp::DATAGRAMS_DROPPED);
match reason {
DropReason::Invalid => incr!(names::udp::DROPPED_INVALID),
DropReason::Truncated => incr!(names::udp::DROPPED_TRUNCATED),
DropReason::NoBackend => incr!(names::udp::DROPPED_NO_BACKEND),
DropReason::Shed => incr!(names::udp::DROPPED_SHED),
DropReason::UnknownFlow => incr!(names::udp::DROPPED_UNKNOWN_FLOW),
}
}
pub fn close_all_flows(&mut self, now: Instant) {
self.manager.borrow_mut().close_all(now);
self.drain_outputs(now);
if let Some(handle) = self.timer_handle.take() {
TIMER.with(|timer| {
let _ = timer.borrow_mut().cancel_timeout(&handle);
});
}
}
}
impl ProxySession for UdpListenerSession {
fn protocol(&self) -> Protocol {
Protocol::UDPListen
}
fn update_readiness(&mut self, token: Token, events: Ready) {
if events.is_writable() {
if token == self.listener_token {
self.drain_client_queue();
} else if self.upstream_to_flow.contains_key(&token) {
self.drain_upstream_queue(token);
}
}
if !events.is_readable() {
return;
}
let now = Instant::now();
if token == self.listener_token {
self.ingest_client(now);
} else if self.upstream_to_flow.contains_key(&token) {
self.ingest_upstream(token, now);
}
}
fn ready(&mut self, _session: Rc<RefCell<dyn ProxySession>>) -> SessionIsToBeClosed {
false
}
fn timeout(&mut self, token: Token) -> SessionIsToBeClosed {
if token == self.listener_token {
let now = Instant::now();
self.manager.borrow_mut().handle_timeout(now);
self.drain_outputs(now);
}
false
}
fn close(&mut self) {
self.close_all_flows(Instant::now());
self.upstream_write_queues.clear();
self.client_write_queue = WriteQueue::new(CLIENT_WRITE_QUEUE_CAP);
self.upstream_to_flow.clear();
self.flow_to_upstream.clear();
let mut listener = self.listener.borrow_mut();
if let Some(socket) = listener.socket.as_ref() {
let fd = socket.as_raw_fd();
let _ = self.registry.deregister(&mut SourceFd(&fd));
}
listener.active = false;
}
fn last_event(&self) -> Instant {
Instant::now()
}
fn print_session(&self) {
error!(
"{} UDP listener session: {} active flows, {} upstream sockets",
log_context!(self),
self.manager.borrow().flow_count(),
self.upstream_sockets.len(),
);
}
fn frontend_token(&self) -> Token {
self.listener_token
}
fn shutting_down(&mut self) -> SessionIsToBeClosed {
false
}
fn cluster_id(&self) -> Option<String> {
self.listener.borrow().cluster_id.clone()
}
}
#[allow(unused_imports)]
pub(crate) use {log_context, log_flow_context, log_module_context};
#[cfg(test)]
mod tests {
use super::*;
use std::cell::Cell;
#[test]
fn effective_max_flows_explicit_value_is_used() {
assert_eq!(effective_max_flows(42, 0), 42);
assert_eq!(effective_max_flows(42, 10), 42);
}
#[test]
fn effective_max_flows_auto_is_positive() {
assert!(effective_max_flows(0, 0) >= 1);
}
#[test]
fn effective_max_flows_auto_is_clamped_to_slab_headroom() {
assert!(effective_max_flows(0, 4) <= 4);
assert!(effective_max_flows(0, 4) >= 1);
}
#[test]
fn clamp_max_rx_respects_buffer_size() {
assert_eq!(clamp_max_rx(u32::MAX as usize, 16_384), 16_384);
assert_eq!(clamp_max_rx(1_024, 16_384), 1_024);
assert_eq!(clamp_max_rx(u32::MAX as usize, 0), u32::MAX as usize);
}
fn addr(port: u16) -> SocketAddr {
SocketAddr::from(([127, 0, 0, 1], port))
}
struct FakeSocket {
script: RefCell<VecDeque<SendOutcome>>,
consumed: RefCell<Vec<Vec<u8>>>,
default: Cell<bool>, }
impl FakeSocket {
fn new(script: Vec<SendOutcome>, default_sent: bool) -> Self {
FakeSocket {
script: RefCell::new(script.into()),
consumed: RefCell::new(Vec::new()),
default: Cell::new(default_sent),
}
}
fn send(&self, payload: &[u8]) -> SendOutcome {
let outcome = self.script.borrow_mut().pop_front().unwrap_or({
if self.default.get() {
SendOutcome::Sent
} else {
SendOutcome::WouldBlock
}
});
if matches!(outcome, SendOutcome::Sent | SendOutcome::Dropped) {
self.consumed.borrow_mut().push(payload.to_vec());
}
outcome
}
}
#[test]
fn write_queue_push_until_full_then_drops() {
let mut q = WriteQueue::new(2);
assert!(q.is_empty());
assert!(q.push(addr(1), vec![1]));
assert!(q.push(addr(2), vec![2]));
assert_eq!(q.len(), 2);
assert!(!q.push(addr(3), vec![3]));
assert_eq!(q.len(), 2);
}
#[test]
fn write_queue_drains_in_fifo_order_on_writable() {
let mut q = WriteQueue::new(8);
for i in 0..4u8 {
assert!(q.push(addr(i as u16), vec![i]));
}
let sock = FakeSocket::new(vec![], true);
let emptied = q.drain(|_dst, payload| sock.send(payload));
assert!(emptied);
assert!(q.is_empty());
assert_eq!(
*sock.consumed.borrow(),
vec![vec![0u8], vec![1u8], vec![2u8], vec![3u8]]
);
}
#[test]
fn write_queue_stops_on_wouldblock_and_resumes() {
let mut q = WriteQueue::new(8);
for i in 0..3u8 {
assert!(q.push(addr(i as u16), vec![i]));
}
let sock = FakeSocket::new(vec![SendOutcome::Sent, SendOutcome::WouldBlock], false);
let emptied = q.drain(|_dst, payload| sock.send(payload));
assert!(!emptied);
assert_eq!(q.len(), 2);
assert_eq!(*sock.consumed.borrow(), vec![vec![0u8]]);
assert_eq!(q.queue.front().unwrap().1, vec![1u8]);
let sock2 = FakeSocket::new(vec![], true);
let emptied2 = q.drain(|_dst, payload| sock2.send(payload));
assert!(emptied2);
assert!(q.is_empty());
assert_eq!(*sock2.consumed.borrow(), vec![vec![1u8], vec![2u8]]);
}
#[test]
fn write_queue_hard_error_drops_one_and_continues() {
let mut q = WriteQueue::new(8);
for i in 0..3u8 {
assert!(q.push(addr(i as u16), vec![i]));
}
let sock = FakeSocket::new(
vec![SendOutcome::Sent, SendOutcome::Dropped, SendOutcome::Sent],
true,
);
let emptied = q.drain(|_dst, payload| sock.send(payload));
assert!(emptied);
assert!(q.is_empty());
assert_eq!(
*sock.consumed.borrow(),
vec![vec![0u8], vec![1u8], vec![2u8]]
);
}
#[test]
fn write_queue_empties_cleanly_when_already_empty() {
let mut q = WriteQueue::new(4);
let sock = FakeSocket::new(vec![], true);
let emptied = q.drain(|_dst, payload| sock.send(payload));
assert!(emptied);
assert!(q.is_empty());
assert!(sock.consumed.borrow().is_empty());
}
}