use std::{
net::{IpAddr, Ipv4Addr, SocketAddr},
sync::Mutex,
task::{Poll, ready},
time::Duration,
};
use crate::{copy_common_bind_options, copy_common_udp_options, scenario_executor::{
scenario::{ScenarioAccess, callback_and_continue}, socketopts::{BindOptions, UdpOptions}, types::{DatagramRead, DatagramSocket, DatagramWrite}, utils1::{HandleExt, NEUTRAL_SOCKADDR4, SimpleErr}, utils2::{AddressOrFd, DefragmenterAddChunkResult}
}};
use bytes::BytesMut;
use futures::future::OptionFuture;
use lru::LruCache;
use rhai::{Dynamic, Engine, FnPtr, NativeCallContext};
use tokio::{net::UdpSocket, sync::mpsc::error::TrySendError, time::Instant};
use tracing::{debug, debug_span, error, trace, warn, Instrument};
use crate::scenario_executor::types::Handle;
use std::sync::Arc;
use super::{
types::{BufferFlag, PacketRead, PacketReadResult, PacketWrite, Task},
utils1::RhResult,
utils2::Defragmenter,
};
use crate::scenario_executor::utils1::TaskHandleExt2;
struct VolatileClientInfo {
deadline: Option<Instant>,
removal_notifier: Option<tokio::sync::oneshot::Sender<()>>,
sink: tokio::sync::mpsc::Sender<bytes::Bytes>,
}
impl VolatileClientInfo {
fn dead(&self) -> bool {
self.removal_notifier.is_none()
}
fn terminate(&mut self) {
if let Some(rn) = self.removal_notifier.take() {
let _ = rn.send(());
}
}
}
struct ClientInfo {
addr: SocketAddr,
v: Mutex<VolatileClientInfo>,
}
async fn hangup_monitor(
ci: Arc<ClientInfo>,
mut removal_notifier: tokio::sync::oneshot::Receiver<()>,
) {
debug!(addr=?ci.addr, "Started hangup monitor");
loop {
trace!("hgmon loop");
let (timeout, has_timeout): (OptionFuture<_>, bool) = {
let mut l = ci.v.lock().unwrap();
if l.dead() {
trace!("hgmon dead");
return;
}
let deadline = l.deadline;
let now = Instant::now();
if let Some(ref deadl) = deadline {
if now >= *deadl {
debug!("Hangup monitor expired based on timeout");
l.terminate();
return;
}
}
drop(l);
(
deadline.map(|d| tokio::time::sleep_until(d)).into(),
deadline.is_some(),
)
};
let do_expire = tokio::select! {
biased;
_ret = &mut removal_notifier => {
true
}
_ret = timeout, if has_timeout => {
false
}
};
if do_expire {
debug!("Hangup monitor expired based on removal notifier");
return;
}
}
}
struct UdpSend {
s: Arc<UdpSocket>,
ci: Arc<ClientInfo>,
defragmenter: Defragmenter,
inhibit_send_errors: bool,
}
impl UdpSend {
fn new(
s: Arc<UdpSocket>,
ci: Arc<ClientInfo>,
inhibit_send_errors: bool,
max_send_datagram_size: usize,
) -> Self {
Self {
s,
ci,
defragmenter: Defragmenter::new(max_send_datagram_size),
inhibit_send_errors,
}
}
}
impl PacketWrite for UdpSend {
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut [u8],
flags: super::types::BufferFlags,
) -> std::task::Poll<std::io::Result<()>> {
trace!("poll_write");
let this = self.get_mut();
let data: &[u8] = match this.defragmenter.add_chunk(buf, flags) {
DefragmenterAddChunkResult::DontSendYet => {
return Poll::Ready(Ok(()));
}
DefragmenterAddChunkResult::Continunous(x) => x,
DefragmenterAddChunkResult::SizeLimitExceeded(_x) => {
warn!("Exceeded maximum allowed outgoing datagram size. Closing this session.");
return Poll::Ready(Err(std::io::ErrorKind::InvalidData.into()));
}
};
let inhibit_send_errors = this.inhibit_send_errors;
let addr = this.ci.addr;
{
let v = this.ci.v.lock().unwrap();
if v.dead() {
return Poll::Ready(Err(std::io::ErrorKind::ConnectionAborted.into()));
}
}
let ret = this.s.poll_send_to(cx, data, addr);
match ready!(ret) {
Ok(n) => {
if n != data.len() {
warn!("short UDP send");
}
}
Err(e) => {
this.defragmenter.clear();
if inhibit_send_errors {
warn!("Failed to send to UDP socket: {e}");
} else {
return Poll::Ready(Err(e));
}
}
}
this.defragmenter.clear();
Poll::Ready(Ok(()))
}
}
struct UdpRecv {
recv: tokio::sync::mpsc::Receiver<bytes::Bytes>,
tag_as_text: bool,
}
impl PacketRead for UdpRecv {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut [u8],
) -> std::task::Poll<std::io::Result<PacketReadResult>> {
trace!("poll_read");
let this = self.get_mut();
let flags = if this.tag_as_text {
BufferFlag::Text.into()
} else {
Default::default()
};
let l;
match ready!(this.recv.poll_recv(cx)) {
Some(b) => {
trace!(len = b.len(), "recv");
if b.len() > buf.len() {
warn!("Incoming UDP datagram too big for a supplied buffer");
return Poll::Ready(Err(std::io::ErrorKind::InvalidInput.into()));
}
l = b.len();
buf[..l].copy_from_slice(&b);
}
None => {
debug!("conn abort");
return Poll::Ready(Err(std::io::ErrorKind::ConnectionAborted.into()));
}
}
Poll::Ready(Ok(PacketReadResult {
flags,
buffer_subset: 0..l,
}))
}
}
const fn default_max_send_datagram_size() -> usize {
4096
}
fn udp_server(
ctx: NativeCallContext,
opts: Dynamic,
when_listening: FnPtr,
on_accept: FnPtr,
) -> RhResult<Handle<Task>> {
let original_span = tracing::Span::current();
let span = debug_span!(parent: original_span, "udp_server");
let the_scenario = ctx.get_scenario()?;
debug!(parent: &span, "node created");
#[derive(serde::Deserialize)]
struct Opts {
bind: Option<SocketAddr>,
fd: Option<i32>,
named_fd: Option<String>,
#[serde(default)]
fd_force: bool,
timeout_ms: Option<u64>,
max_clients: Option<usize>,
buffer_size: Option<usize>,
qlen: Option<usize>,
#[serde(default)]
tag_as_text: bool,
#[serde(default)]
backpressure: bool,
#[serde(default)]
inhibit_send_errors: bool,
#[serde(default = "default_max_send_datagram_size")]
max_send_datagram_size: usize,
reuseaddr: Option<bool>,
#[serde(default)]
reuseport: bool,
bind_device: Option<String>,
#[serde(default)]
transparent: bool,
#[serde(default)]
freebind: bool,
only_v6: Option<bool>,
tclass_v6: Option<u32>,
tos_v4: Option<u32>,
ttl: Option<u32>,
cpu_affinity: Option<usize>,
priority: Option<u32>,
recv_buffer_size: Option<usize>,
send_buffer_size: Option<usize>,
mark: Option<u32>,
#[serde(default)]
broadcast: bool,
multicast: Option<IpAddr>,
multicast_interface_addr: Option<Ipv4Addr>,
multicast_interface_index: Option<u32>,
multicast_specific_source: Option<Ipv4Addr>,
multicast_all: Option<bool>,
multicast_loop: Option<bool>,
multicast_ttl: Option<u32>,
}
let opts: Opts = rhai::serde::from_dynamic(&opts)?;
let mut bindopts = BindOptions::new();
let mut udpopts = UdpOptions::new();
copy_common_bind_options!(bindopts, opts);
copy_common_udp_options!(udpopts, opts);
let mut lru: LruCache<SocketAddr, Arc<ClientInfo>> = match opts.max_clients {
None => LruCache::unbounded(),
Some(0) => return Err(ctx.err("max_clients cannot be 0")),
Some(n) => LruCache::new(std::num::NonZeroUsize::new(n).unwrap()),
};
let buffer_size = opts.buffer_size.unwrap_or(4096);
let qlen = opts.qlen.unwrap_or(1);
let backpressure = opts.backpressure;
if buffer_size == 0 {
return Err(ctx.err("Invalid buffer_size 0"));
}
let a = AddressOrFd::interpret(&ctx, &span, opts.bind, opts.fd, opts.named_fd, None)?;
Ok(async move {
debug!("node started");
let mut buf = BytesMut::new();
let mut clients_add_events: usize = 0;
let mut address_to_report = NEUTRAL_SOCKADDR4;
let s = match a {
AddressOrFd::Addr(a) => {
address_to_report = a;
bindopts.bind_udp(a).await?
}
#[cfg(not(unix))]
AddressOrFd::Fd(..) | AddressOrFd::NamedFd(..) => {
error!("Inheriting listeners from parent processes is not supported outside UNIX platforms");
anyhow::bail!("Unsupported feature");
}
#[cfg(unix)]
AddressOrFd::Fd(_) | AddressOrFd::NamedFd(_) => {
bindopts.warn_if_options_set();
use super::unix1::{listen_from_fd, listen_from_fd_named, ListenFromFdType};
let force_addr = opts.fd_force.then_some(ListenFromFdType::Udp);
let assert_addr = Some(ListenFromFdType::Udp);
let ret = match a {
AddressOrFd::Addr(_) => unreachable!(),
AddressOrFd::Fd(fd) => unsafe { listen_from_fd(fd, force_addr, assert_addr) },
AddressOrFd::NamedFd(ref fd) => unsafe {
listen_from_fd_named(fd, force_addr, assert_addr)
},
};
ret?.unwrap_udp()
}
};
udpopts.apply_socket_opts(
&s,
s.local_addr().map(|x| x.is_ipv6()).unwrap_or_else(|_| {
warn!("Failed to determine local address of an UDP socket");
false
}),
)?;
if address_to_report.port() == 0 {
if let Ok(a) = s.local_addr() {
address_to_report = a;
} else {
warn!("Failed to obtain actual listening port");
}
}
callback_and_continue::<(SocketAddr,)>(
the_scenario.clone(),
when_listening,
(address_to_report,),
)
.await;
let s = Arc::new(s);
'main_loop: loop {
trace!("loop");
if clients_add_events == 1024 && opts.max_clients.unwrap_or(4096) >= 4096 {
debug!("vacuum");
let mut ctr = 0;
let dead_clients = Vec::from_iter(
lru.iter()
.filter(|x| x.1.v.lock().unwrap().dead())
.map(|x| *x.0),
);
for x in dead_clients {
if lru.pop(&x).is_some() {
ctr += 1;
}
}
if ctr > 0 {
debug!("Vacuumed {ctr} stale entries");
}
clients_add_events = 0;
}
buf.reserve(buffer_size.saturating_sub(buf.capacity()));
let (b, from_addr) = match s.recv_buf_from(&mut buf).await {
Ok((n, from_addr)) => {
trace!(n, %from_addr, "recv");
let b = buf.split_to(n).freeze();
(b, from_addr)
}
Err(e) => {
error!("Error receiving from udp: {e}");
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
continue 'main_loop;
}
};
let ci :&Arc<ClientInfo> = 'obtaining_entry: loop {
trace!("lookup");
break match lru.get(&from_addr) {
None => {
trace!("not found");
clients_add_events += 1;
let (tx, rx) = tokio::sync::mpsc::channel(qlen);
let (tx2, rx2) = tokio::sync::oneshot::channel();
let ci = Arc::new(ClientInfo {
addr: from_addr,
v: Mutex::new(VolatileClientInfo {
deadline: None,
removal_notifier: Some(tx2),
sink: tx,
}),
});
{
assert!(!ci.v.lock().unwrap().dead());
}
let ci2 = ci.clone();
let ci3 = ci.clone();
if let Some((_, evicted)) = lru.push(from_addr, ci) {
debug!(peeraddr=%evicted.addr, "evicting");
let mut ev = evicted.v.lock().unwrap();
ev.terminate();
}
let udp_send = UdpSend::new(s.clone(), ci2, opts.inhibit_send_errors, opts.max_send_datagram_size);
let udp_recv = UdpRecv {
recv: rx,
tag_as_text: opts.tag_as_text,
};
let hangup =
Some(Box::pin(hangup_monitor(ci3, rx2)) as super::types::Hangup);
let socket = DatagramSocket {
read: Some(DatagramRead {
src: Box::pin(udp_recv),
}),
write: Some(DatagramWrite {
snk: Box::pin(udp_send),
}),
close: hangup,
fd: None,
};
let the_scenario = the_scenario.clone();
let on_accept = on_accept.clone();
tokio::spawn(async move {
let newspan = debug_span!("udp_accept", from=%from_addr);
debug!("accepted");
callback_and_continue::<(Handle<DatagramSocket>, SocketAddr)>(
the_scenario,
on_accept,
(Some(socket).wrap(), from_addr),
)
.instrument(newspan)
.await;
});
lru.get(&from_addr).unwrap()
}
Some(x) => {
let dead = { x.v.lock().unwrap().dead() };
trace!(dead, "found");
if dead {
lru.pop(&from_addr);
continue 'obtaining_entry;
}
x
}
};
};
let mut send_debt = None;
{
let mut v = ci.v.lock().unwrap();
if v.dead() {
warn!("A rare case of a dropped incoming datagram because of timer expiration in an unfortunate moment.");
continue 'main_loop;
}
if let Some(tmo) = opts.timeout_ms {
let deadline = Instant::now() + Duration::from_millis(tmo);
v.deadline = Some(deadline);
}
match v.sink.try_send(b) {
Ok(()) => (),
Err(TrySendError::Closed(_)) => {
v.terminate();
}
Err(TrySendError::Full(b)) => {
if backpressure {
send_debt = Some((v.sink.clone(), b));
} else {
debug!(peer_addr=%from_addr, "dropping a datagram due to handler being too slow")
}
}
}
}
if let Some((sink2, b)) = send_debt {
debug!(peer_addr=%from_addr, "buffer full, sending later");
match sink2.send(b).await {
Ok(()) => (),
Err(_) => {
let mut vv = ci.v.lock().unwrap();
vv.terminate();
}
}
}
}
}
.instrument(span)
.wrap())
}
pub fn register(engine: &mut Engine) {
engine.register_fn("udp_server", udp_server);
}