use std::{
io,
net::{IpAddr, Ipv4Addr, SocketAddr},
};
use mio::net::UdpSocket;
use log::{debug, error, info, trace, warn};
use socket2::{Domain, Protocol, SockAddr, Socket, Type};
use bytes::{Bytes, BytesMut};
use crate::{
network::util::{
get_local_multicast_ip_addrs, get_local_multicast_locators, get_local_unicast_locators,
},
structure::locator::Locator,
};
const MAX_MESSAGE_SIZE: usize = 64 * 1024; const MESSAGE_BUFFER_ALLOCATION_CHUNK: usize = 256 * 1024;
#[derive(Debug)]
pub struct UDPListener {
socket: UdpSocket,
receive_buffer: BytesMut,
multicast_group: Option<Ipv4Addr>,
}
impl Drop for UDPListener {
fn drop(&mut self) {
if let Some(mcg) = self.multicast_group {
self
.socket
.leave_multicast_v4(&mcg, &Ipv4Addr::UNSPECIFIED)
.unwrap_or_else(|e| {
error!("leave_multicast_group: {:?}", e);
});
}
}
}
impl UDPListener {
fn new_listening_socket(host: &str, port: u16, reuse_addr: bool) -> io::Result<UdpSocket> {
let raw_socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?;
if reuse_addr {
raw_socket.set_reuse_address(true)?;
}
#[cfg(not(any(target_os = "solaris", target_os = "illumos", windows)))]
{
if reuse_addr {
raw_socket.set_reuse_port(true)?;
}
}
let address = SocketAddr::new(
host
.parse()
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?,
port,
);
if let Err(e) = raw_socket.bind(&SockAddr::from(address)) {
info!("new_socket - cannot bind socket: {:?}", e);
return Err(e);
}
let std_socket = std::net::UdpSocket::from(raw_socket);
std_socket
.set_nonblocking(true)
.expect("Failed to set std socket to non blocking.");
let mio_socket = UdpSocket::from_socket(std_socket).expect("Unable to create mio socket");
info!(
"UDPListener: new socket with address {:?}",
mio_socket.local_addr()
);
Ok(mio_socket)
}
pub fn to_locator_address(&self) -> io::Result<Vec<Locator>> {
let local_port = self.socket.local_addr()?.port();
match self.multicast_group {
Some(_ipv4_addr) => Ok(get_local_multicast_locators(local_port)),
None => Ok(get_local_unicast_locators(local_port)),
}
}
pub fn new_unicast(host: &str, port: u16) -> io::Result<Self> {
let mio_socket = Self::new_listening_socket(host, port, false)?;
Ok(Self {
socket: mio_socket,
receive_buffer: BytesMut::with_capacity(MESSAGE_BUFFER_ALLOCATION_CHUNK),
multicast_group: None,
})
}
pub fn new_multicast(host: &str, port: u16, multicast_group: Ipv4Addr) -> io::Result<Self> {
if !multicast_group.is_multicast() {
return io::Result::Err(io::Error::new(
io::ErrorKind::Other,
"Not a multicast address",
));
}
let mio_socket = Self::new_listening_socket(host, port, true)?;
for multicast_if_ipaddr in get_local_multicast_ip_addrs()? {
match multicast_if_ipaddr {
IpAddr::V4(a) => mio_socket
.join_multicast_v4(&multicast_group, &a)
.unwrap_or_else(|e| {
warn!(
"join_multicast_v4 failed: {:?}. multicast_group [{:?}] interface [{:?}]",
e, multicast_group, a
);
}),
IpAddr::V6(_a) => error!("UDPListener::new_multicast() not implemented for IpV6"), }
}
Ok(Self {
socket: mio_socket,
receive_buffer: BytesMut::with_capacity(MESSAGE_BUFFER_ALLOCATION_CHUNK),
multicast_group: Some(multicast_group),
})
}
pub fn mio_socket(&mut self) -> &mut UdpSocket {
&mut self.socket
}
#[cfg(test)]
pub fn port(&self) -> u16 {
match self.socket.local_addr() {
Ok(add) => add.port(),
_ => 0,
}
}
#[cfg(test)]
pub fn get_message(&self) -> Vec<u8> {
let mut message: Vec<u8> = vec![];
let mut buf: [u8; MAX_MESSAGE_SIZE] = [0; MAX_MESSAGE_SIZE];
match self.socket.recv(&mut buf) {
Ok(nbytes) => {
message = buf[..nbytes].to_vec();
return message;
}
Err(e) => {
debug!("UDPListener::get_message failed: {:?}", e);
}
};
message
}
fn ensure_receive_buffer_capacity(&mut self) {
if self.receive_buffer.capacity() < MAX_MESSAGE_SIZE {
self.receive_buffer = BytesMut::with_capacity(MESSAGE_BUFFER_ALLOCATION_CHUNK);
debug!("ensure_receive_buffer_capacity - reallocated receive_buffer");
}
unsafe {
self.receive_buffer.set_len(MAX_MESSAGE_SIZE);
}
trace!(
"ensure_receive_buffer_capacity - {} bytes left",
self.receive_buffer.capacity()
);
}
pub fn messages(&mut self) -> Vec<Bytes> {
let mut messages = Vec::with_capacity(4); self.ensure_receive_buffer_capacity();
while let Ok(nbytes) = self.socket.recv(&mut self.receive_buffer) {
self.receive_buffer.truncate(nbytes);
while self.receive_buffer.len() % 4 != 0 {
self.receive_buffer.extend_from_slice(&[0xCC]); }
let mut message = self.receive_buffer.split_to(self.receive_buffer.len());
self.ensure_receive_buffer_capacity();
message.truncate(nbytes); messages.push(Bytes::from(message)); }
messages
}
#[cfg(test)] pub fn leave_multicast(&self, address: &Ipv4Addr) -> io::Result<()> {
if address.is_multicast() {
return self
.socket
.leave_multicast_v4(address, &Ipv4Addr::UNSPECIFIED);
}
io::Result::Err(io::Error::new(
io::ErrorKind::Other,
"Not a multicast address",
))
}
}
#[cfg(test)]
mod tests {
use std::{thread, time};
use super::*;
use crate::network::udp_sender::*;
#[test]
fn udpl_single_address() {
let listener = UDPListener::new_unicast("127.0.0.1", 10001).unwrap();
let sender = UDPSender::new_with_random_port().expect("failed to create UDPSender");
let data: Vec<u8> = vec![0, 1, 2, 3, 4];
let addrs = vec![SocketAddr::new("127.0.0.1".parse().unwrap(), 10001)];
sender.send_to_all(&data, &addrs);
let rec_data = listener.get_message();
assert_eq!(rec_data.len(), 5); assert_eq!(rec_data, data);
}
#[test]
fn udpl_multicast_address() {
let listener =
UDPListener::new_multicast("0.0.0.0", 10002, Ipv4Addr::new(239, 255, 0, 1)).unwrap();
let sender = UDPSender::new_with_random_port().unwrap();
let data: Vec<u8> = vec![2, 4, 6];
sender
.send_multicast(&data, Ipv4Addr::new(239, 255, 0, 1), 10002)
.expect("Failed to send multicast");
thread::sleep(time::Duration::from_secs(1));
let rec_data = listener.get_message();
listener
.leave_multicast(&Ipv4Addr::new(239, 255, 0, 1))
.unwrap();
assert_eq!(rec_data.len(), 3);
assert_eq!(rec_data, data);
}
}