use std::{
net::SocketAddr,
task::{ready, Poll},
};
use crate::scenario_executor::{
types::{DatagramRead, DatagramSocket, DatagramWrite},
utils1::{SimpleErr, ToNeutralAddress},
utils2::AddressOrFd,
};
use futures::FutureExt;
use rhai::{Dynamic, Engine, NativeCallContext};
use tokio::{io::ReadBuf, net::UdpSocket};
#[allow(unused)]
use tracing::{debug, debug_span, error, info, warn};
use crate::scenario_executor::types::Handle;
use std::sync::{Arc, RwLock};
use super::{
types::{BufferFlag, PacketRead, PacketReadResult, PacketWrite},
utils1::RhResult,
utils2::{Defragmenter, DefragmenterAddChunkResult},
};
struct UdpAddrInner {
target_address: SocketAddr,
address_change_counter: u32,
}
struct UdpInner {
s: UdpSocket,
peer: RwLock<UdpAddrInner>,
}
struct UdpSend {
s: Arc<UdpInner>,
sendto_mode: bool,
degragmenter: Defragmenter,
inhibit_send_errors: bool,
}
fn new_udp_endpoint(
s: UdpSocket,
toaddr: SocketAddr,
sendto_mode: bool,
allow_other_addresses: bool,
redirect_to_last_seen_address: bool,
connect_to_first_seen_address: bool,
tag_as_text: bool,
inhibit_send_errors: bool,
max_send_datagram_size: usize,
) -> (UdpSend, UdpRecv) {
let inner = Arc::new(UdpInner {
s,
peer: RwLock::new(UdpAddrInner {
target_address: toaddr,
address_change_counter: 0,
}),
});
(
UdpSend {
s: inner.clone(),
sendto_mode,
degragmenter: Defragmenter::new(max_send_datagram_size),
inhibit_send_errors,
},
UdpRecv {
s: inner,
sendto_mode,
allow_other_addresses,
redirect_to_last_seen_address,
connect_to_first_seen_address,
tag_as_text,
},
)
}
impl PacketWrite for UdpSend {
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut [u8],
flags: super::types::BufferFlags,
) -> std::task::Poll<std::io::Result<()>> {
let this = self.get_mut();
let data: &[u8] = match this.degragmenter.add_chunk(buf, flags) {
DefragmenterAddChunkResult::DontSendYet => {
return Poll::Ready(Ok(()));
}
DefragmenterAddChunkResult::Continunous(x) => x,
DefragmenterAddChunkResult::SizeLimitExceeded(_x) => {
warn!("Exceeded maximum allowed outgoing datagram size. Closing this session.");
return Poll::Ready(Err(std::io::ErrorKind::InvalidData.into()));
}
};
let mut inhibit_send_errors = this.inhibit_send_errors;
let ret = if !this.sendto_mode {
this.s.s.poll_send(cx, data)
} else {
let addr = this.s.peer.read().unwrap().target_address;
if addr.ip().is_unspecified() {
inhibit_send_errors = true;
}
this.s.s.poll_send_to(cx, data, addr)
};
match ready!(ret) {
Ok(n) => {
if n != data.len() {
warn!("short UDP send");
}
}
Err(e) => {
this.degragmenter.clear();
if inhibit_send_errors {
warn!("Failed to send to UDP socket: {e}");
} else {
return Poll::Ready(Err(e));
}
}
}
this.degragmenter.clear();
Poll::Ready(Ok(()))
}
}
#[derive(Clone)]
struct UdpRecv {
s: Arc<UdpInner>,
sendto_mode: bool,
allow_other_addresses: bool,
redirect_to_last_seen_address: bool,
connect_to_first_seen_address: bool,
tag_as_text: bool,
}
impl PacketRead for UdpRecv {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut [u8],
) -> std::task::Poll<std::io::Result<PacketReadResult>> {
let this = self.get_mut();
let flags = if this.tag_as_text {
BufferFlag::Text.into()
} else {
Default::default()
};
if !this.sendto_mode {
let mut rb = ReadBuf::new(buf);
ready!(this.s.s.poll_recv(cx, &mut rb))?;
return Poll::Ready(Ok(PacketReadResult {
flags,
buffer_subset: 0..(rb.filled().len()),
}));
}
loop {
let mut rb = ReadBuf::new(buf);
let from: SocketAddr = ready!(this.s.s.poll_recv_from(cx, &mut rb))?;
let savedaddr = this.s.peer.read().unwrap();
if savedaddr.target_address != from {
if !this.allow_other_addresses {
info!("Ignored incoming UDP datagram from a foreign address: {from}");
continue;
}
if this.redirect_to_last_seen_address {
drop(savedaddr);
let mut savedaddr = this.s.peer.write().unwrap();
savedaddr.target_address = from;
savedaddr.address_change_counter += 1;
info!(
"Updated UDP peer address to {from} (number of address changes: {})",
savedaddr.address_change_counter
);
if this.connect_to_first_seen_address {
match this.s.s.connect(from).now_or_never() {
Some(Ok(())) => {
this.sendto_mode = false;
}
Some(Err(e)) => return Poll::Ready(Err(e)),
None => panic!(
"UDP connect to specific address not completed immeidately somehow"
),
}
}
}
}
return Poll::Ready(Ok(PacketReadResult {
flags,
buffer_subset: 0..(rb.filled().len()),
}));
}
}
}
const fn default_max_send_datagram_size() -> usize {
4096
}
fn udp_socket(ctx: NativeCallContext, opts: Dynamic) -> RhResult<Handle<DatagramSocket>> {
let original_span = tracing::Span::current();
let span = debug_span!(parent: original_span, "udp_socket");
debug!(parent: &span, "node created");
#[derive(serde::Deserialize)]
struct Opts {
addr: SocketAddr,
fd: Option<i32>,
named_fd: Option<String>,
#[serde(default)]
fd_force: bool,
bind: Option<SocketAddr>,
#[serde(default)]
sendto_mode: bool,
#[serde(default)]
allow_other_addresses: bool,
#[serde(default)]
redirect_to_last_seen_address: bool,
#[serde(default)]
connect_to_first_seen_address: bool,
#[serde(default)]
tag_as_text: bool,
#[serde(default)]
inhibit_send_errors: bool,
#[serde(default = "default_max_send_datagram_size")]
max_send_datagram_size: usize,
}
let opts: Opts = rhai::serde::from_dynamic(&opts)?;
let to_addr = opts.addr;
let bind_addr = opts.bind.unwrap_or(to_addr.to_neutral_address());
let a = AddressOrFd::interpret(
&ctx,
&span,
opts.bind,
opts.fd,
opts.named_fd,
Some(bind_addr),
)?;
let s = match a {
AddressOrFd::Addr(a) => {
let Some(Ok(s)) = UdpSocket::bind(a).now_or_never() else {
return Err(ctx.err("Failed to bind UDP socket"));
};
s
}
#[cfg(not(unix))]
AddressOrFd::Fd(..) | AddressOrFd::NamedFd(..) => {
error!("Inheriting listeners from parent processes is not supported outside UNIX platforms");
return Err(ctx.err("Unsupported feature"));
}
#[cfg(unix)]
AddressOrFd::Fd(_) | AddressOrFd::NamedFd(_) => {
use super::unix1::{listen_from_fd, listen_from_fd_named, ListenFromFdType};
let force_addr = opts.fd_force.then_some(ListenFromFdType::Udp);
let assert_addr = Some(ListenFromFdType::Udp);
let ret = match a {
AddressOrFd::Addr(_) => unreachable!(),
AddressOrFd::Fd(fd) => unsafe { listen_from_fd(fd, force_addr, assert_addr) },
AddressOrFd::NamedFd(ref fd) => unsafe {
listen_from_fd_named(fd, force_addr, assert_addr)
},
};
let Ok(s) = ret else {
return Err(ctx.err("Failed to get UDP socket"));
};
s.unwrap_udp()
}
};
#[allow(unused_assignments)]
let mut fd = None;
#[cfg(unix)]
{
use std::os::fd::AsRawFd;
fd = Some(
unsafe { super::types::SocketFd::new(s.as_raw_fd()) },
);
}
if !opts.sendto_mode {
match s.connect(to_addr).now_or_never() {
Some(Ok(())) => (),
_ => return Err(ctx.err("Failed to connect UDP socket")),
}
}
let (us, ur) = new_udp_endpoint(
s,
to_addr,
opts.sendto_mode,
opts.allow_other_addresses,
opts.redirect_to_last_seen_address,
opts.connect_to_first_seen_address,
opts.tag_as_text,
opts.inhibit_send_errors,
opts.max_send_datagram_size,
);
let s = DatagramSocket {
read: Some(DatagramRead { src: Box::pin(ur) }),
write: Some(DatagramWrite { snk: Box::pin(us) }),
close: None,
fd,
};
debug!(s=?s, "created");
Ok(s.wrap())
}
pub fn register(engine: &mut Engine) {
engine.register_fn("udp_socket", udp_socket);
}