use std::{
collections::HashMap,
io,
net::{SocketAddr, ToSocketAddrs},
sync::Arc,
task::Poll,
};
use bytes::BytesMut;
use rand::seq::IteratorRandom;
use rasi::{
executor::spawn,
net::UdpSocket,
syscall::{global_network, Network},
};
use futures::{SinkExt, Stream, StreamExt};
use crate::utils::ReadBuf;
#[derive(Debug)]
pub struct PathInfo {
pub from: SocketAddr,
pub to: SocketAddr,
}
impl PathInfo {
pub fn reverse(self) -> Self {
Self {
from: self.to,
to: self.from,
}
}
}
pub struct UdpGroup {
sockets: HashMap<SocketAddr, Arc<UdpSocket>>,
max_recv_buf_len: u16,
}
impl UdpGroup {
pub async fn bind<A: ToSocketAddrs>(laddrs: A) -> io::Result<Self> {
Self::bind_with(laddrs, global_network()).await
}
pub async fn bind_with<A: ToSocketAddrs>(
laddrs: A,
syscall: &'static dyn Network,
) -> io::Result<Self> {
let laddrs = laddrs.to_socket_addrs()?.collect::<Vec<_>>();
let mut sockets = HashMap::new();
for laddr in laddrs {
let socket = Arc::new(UdpSocket::bind_with([laddr].as_slice(), syscall).await?);
let laddr = socket.local_addr()?;
sockets.insert(laddr, socket);
}
Ok(Self {
sockets,
max_recv_buf_len: 2048,
})
}
pub fn with_max_recv_buf_len(mut self, len: u16) -> Self {
assert!(len > 0, "sets max_recv_buf_len to zero");
self.max_recv_buf_len = len;
self
}
pub fn split(self) -> (Sender, Receiver) {
let sockets = self.sockets.values().cloned().collect::<Vec<_>>();
let (sender, receiver) = futures::channel::mpsc::channel(0);
for socket in sockets {
spawn(Self::recv_loop(
socket,
sender.clone(),
self.max_recv_buf_len as usize,
));
}
(Sender::new(self.sockets), Receiver::new(receiver))
}
pub fn local_addrs(&self) -> impl Iterator<Item = &SocketAddr> {
self.sockets.keys()
}
async fn recv_loop(
socket: Arc<UdpSocket>,
mut sender: futures::channel::mpsc::Sender<UdpGroupData>,
max_recv_buf_len: usize,
) {
let laddr = socket.local_addr().unwrap();
loop {
let mut read_buf = ReadBuf::with_capacity(max_recv_buf_len);
match socket.recv_from(read_buf.chunk_mut()).await {
Ok((read_size, raddr)) => {
log::trace!(
"UdpGroup recv_from, raddr={:?}, read_size={}",
raddr,
read_size
);
let data = UdpGroupData {
result: Ok((read_buf.into_bytes_mut(Some(read_size)), raddr)),
to: laddr,
};
if sender.send(data).await.is_err() {
log::trace!("socket({:?}) in udp group, stop recv loop", laddr);
}
}
Err(err) => {
log::error!(
"socket({:?}) in udp group, shutdown with error: {}",
laddr,
err
);
}
}
}
}
}
struct UdpGroupData {
result: io::Result<(BytesMut, SocketAddr)>,
to: SocketAddr,
}
pub struct Receiver {
inner: futures::channel::mpsc::Receiver<UdpGroupData>,
}
impl Receiver {
fn new(inner: futures::channel::mpsc::Receiver<UdpGroupData>) -> Self {
Self { inner }
}
}
impl Stream for Receiver {
type Item = io::Result<(BytesMut, PathInfo)>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
match self.inner.poll_next_unpin(cx) {
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(udp_group_data)) => {
Poll::Ready(Some(udp_group_data.result.map(|(buf, raddr)| {
(
buf,
PathInfo {
from: raddr,
to: udp_group_data.to,
},
)
})))
}
Poll::Pending => Poll::Pending,
}
}
}
pub struct Sender {
sockets: Arc<HashMap<SocketAddr, Arc<UdpSocket>>>,
}
impl Clone for Sender {
fn clone(&self) -> Self {
Self {
sockets: self.sockets.clone(),
}
}
}
impl Sender {
fn new(sockets: HashMap<SocketAddr, Arc<UdpSocket>>) -> Self {
Self {
sockets: Arc::new(sockets),
}
}
pub async fn send_to(&self, buf: &[u8], raddr: SocketAddr) -> io::Result<usize> {
let socket = self
.sockets
.values()
.choose(&mut rand::thread_rng())
.unwrap()
.clone();
socket.send_to(buf, raddr).await
}
pub async fn send_to_on_path(&self, buf: &[u8], path_info: PathInfo) -> io::Result<usize> {
if let Some(socket) = self.sockets.get(&path_info.from) {
socket.send_to(buf, path_info.to).await
} else {
Err(io::Error::new(
io::ErrorKind::NotFound,
format!("Socket bound to {:?} is not in the group.", path_info.from),
))
}
}
}
#[cfg(test)]
mod tests {
use bytes::Bytes;
use futures::TryStreamExt;
use rasi_default::{executor::register_futures_executor_with_pool_size, net::MioNetwork};
use super::*;
use std::sync::OnceLock;
static INIT: OnceLock<Box<dyn rasi::syscall::Network>> = OnceLock::new();
fn get_syscall() -> &'static dyn rasi::syscall::Network {
INIT.get_or_init(|| {
register_futures_executor_with_pool_size(10).unwrap();
Box::new(MioNetwork::default())
})
.as_ref()
}
#[futures_test::test]
async fn test_udp_group_echo() {
let syscall = get_syscall();
let addrs: Vec<SocketAddr> = ["127.0.0.1:0".parse().unwrap()].repeat(4);
let (client_sender, mut client_receiver) = UdpGroup::bind_with(addrs.as_slice(), syscall)
.await
.unwrap()
.split();
let server = UdpGroup::bind_with(addrs.as_slice(), syscall)
.await
.unwrap();
let raddrs = server.local_addrs().cloned().collect::<Vec<_>>();
let (server_sender, mut server_receiver) = server.split();
let random_raddr = raddrs
.iter()
.choose(&mut rand::thread_rng())
.cloned()
.unwrap();
client_sender
.send_to(b"hello world", random_raddr)
.await
.unwrap();
let (buf, path_info) = server_receiver.try_next().await.unwrap().unwrap();
let buf = buf.freeze();
assert_eq!(buf, Bytes::from_static(b"hello world"));
server_sender
.send_to_on_path(b"hello world", path_info.reverse())
.await
.unwrap();
let (buf, _) = client_receiver.try_next().await.unwrap().unwrap();
let buf = buf.freeze();
assert_eq!(buf, Bytes::from_static(b"hello world"));
}
}