use std::collections::{BTreeMap, BTreeSet};
use std::net::{IpAddr, Ipv4Addr};
use std::time::{Duration, Instant};
use network_interface::{NetworkInterface, NetworkInterfaceConfig};
use serde::{Deserialize, Serialize};
#[cfg(feature = "python")]
use pyo3::prelude::*;
use tracing::info;
use crate::controller::context::ControllerCtx;
use crate::{LoopMethod, py_json_methods};
use super::*;
use deimos_shared::peripherals::PeripheralId;
use deimos_shared::{CONTROLLER_RX_PORT, PERIPHERAL_RX_PORT};
fn broadcast_from_addr_and_mask(addr: Ipv4Addr, netmask: Ipv4Addr) -> Ipv4Addr {
Ipv4Addr::from(u32::from(addr) | !u32::from(netmask))
}
fn dedup_broadcast_targets(targets: impl IntoIterator<Item = Ipv4Addr>) -> Vec<Ipv4Addr> {
BTreeSet::from_iter(targets).into_iter().collect()
}
pub fn possible_broadcast_targets() -> Vec<Ipv4Addr> {
let mut targets = BTreeSet::from([Ipv4Addr::BROADCAST]);
let interfaces = match NetworkInterface::show() {
Ok(interfaces) => interfaces,
Err(_) => return targets.into_iter().collect(),
};
for interface in interfaces {
for addr in interface.addr {
let ip = match addr.ip() {
IpAddr::V4(ip) if !ip.is_loopback() && !ip.is_unspecified() => ip,
_ => continue,
};
let target = match addr.broadcast() {
Some(IpAddr::V4(broadcast)) if broadcast != ip && !broadcast.is_unspecified() => {
Some(broadcast)
}
_ => match addr.netmask() {
Some(IpAddr::V4(netmask)) => {
let broadcast = broadcast_from_addr_and_mask(ip, netmask);
(broadcast != ip && !broadcast.is_unspecified()).then_some(broadcast)
}
_ => None,
},
};
if let Some(target) = target {
targets.insert(target);
}
}
}
targets.into_iter().collect()
}
#[derive(Serialize, Deserialize, Default)]
#[cfg_attr(feature = "python", pyclass)]
pub struct UdpSocket {
#[serde(default, skip_serializing_if = "Vec::is_empty")]
broadcast_targets: Vec<Ipv4Addr>,
#[serde(skip)]
socket: Option<std::net::UdpSocket>,
#[serde(skip)]
nonblocking: bool,
#[serde(skip)]
addrs: BTreeMap<PeripheralId, std::net::SocketAddr>,
#[serde(skip)]
pids: BTreeMap<std::net::SocketAddr, PeripheralId>,
#[serde(skip)]
addr_tokens: BTreeMap<std::net::SocketAddr, SocketAddrToken>,
#[serde(skip)]
token_addrs: BTreeMap<SocketAddrToken, std::net::SocketAddr>,
#[serde(skip)]
next_addr_token: SocketAddrToken,
}
impl UdpSocket {
pub fn new() -> Self {
Self {
broadcast_targets: Vec::new(),
socket: None,
nonblocking: false,
addrs: BTreeMap::new(),
pids: BTreeMap::new(),
addr_tokens: BTreeMap::new(),
token_addrs: BTreeMap::new(),
next_addr_token: 0,
}
}
pub fn with_broadcast_targets(targets: Vec<Ipv4Addr>) -> Self {
Self {
broadcast_targets: dedup_broadcast_targets(targets),
..Self::new()
}
}
fn broadcast_targets(&self) -> Vec<Ipv4Addr> {
if self.broadcast_targets.is_empty() {
possible_broadcast_targets()
} else {
self.broadcast_targets.clone()
}
}
}
py_json_methods!(
UdpSocket,
Socket,
#[new]
fn py_new() -> PyResult<Self> {
Ok(Self::new())
},
#[staticmethod]
#[pyo3(name = "with_broadcast_targets")]
fn py_with_broadcast_targets(targets: Vec<String>) -> PyResult<Self> {
let targets = targets
.into_iter()
.map(|target| {
target.parse::<Ipv4Addr>().map_err(|e| {
pyo3::exceptions::PyValueError::new_err(format!(
"Invalid IPv4 broadcast target `{target}`: {e}"
))
})
})
.collect::<PyResult<Vec<_>>>()?;
Ok(Self::with_broadcast_targets(targets))
},
#[staticmethod]
#[pyo3(name = "possible_broadcast_targets")]
fn py_possible_broadcast_targets() -> Vec<String> {
possible_broadcast_targets()
.into_iter()
.map(|addr| addr.to_string())
.collect()
}
);
#[typetag::serde]
impl Socket for UdpSocket {
fn is_open(&self) -> bool {
self.socket.is_some()
}
fn open(&mut self, ctx: &ControllerCtx) -> Result<(), String> {
if self.socket.is_some() {
self.close();
}
self.nonblocking = false;
let addr = format!("0.0.0.0:{CONTROLLER_RX_PORT}");
let socket = std::net::UdpSocket::bind(&addr)
.map_err(|e| format!("Unable to bind UDP socket: {e}"))?;
if matches!(ctx.loop_method, LoopMethod::Performant) {
socket
.set_nonblocking(true)
.map_err(|e| format!("Failed to set socket to nonblocking mode: {e}"))?;
self.nonblocking = true;
}
self.socket = Some(socket);
info!("Opened controller UDP socket at {addr:?}");
Ok(())
}
fn close(&mut self) {
self.socket = None;
self.nonblocking = false;
self.addrs.clear();
self.pids.clear();
self.addr_tokens.clear();
self.token_addrs.clear();
self.next_addr_token = 0;
info!("Closed controller UDP socket at 0.0.0.0:{CONTROLLER_RX_PORT}");
}
fn send(&mut self, id: PeripheralId, msg: &[u8]) -> Result<(), String> {
let addr = *self
.addrs
.get(&id)
.ok_or(format!("Peripheral not present in address map: {id:?}"))?;
let sock = self
.socket
.as_mut()
.ok_or("Unable to send before socket is bound".to_string())?;
sock.send_to(msg, addr)
.map_err(|e| format!("Failed to send UDP packet: {e}"))?;
Ok(())
}
fn recv(&mut self, buf: &mut [u8], timeout: Duration) -> Option<SocketPacketMeta> {
let sock = self.socket.as_mut()?;
if timeout.is_zero() {
if !self.nonblocking {
let _ = sock.set_read_timeout(Some(Duration::from_nanos(1)));
}
} else {
let _ = sock.set_read_timeout(Some(timeout)); }
let (size, addr, time) = match sock.recv_from(buf) {
Ok((size, addr)) => {
let now = Instant::now();
if addr.port() != PERIPHERAL_RX_PORT {
return None;
}
(size, addr, now)
}
Err(_) => return None, };
let token = match self.addr_tokens.get(&addr).copied() {
Some(token) => token,
None => {
let token = self.next_addr_token;
self.next_addr_token = self.next_addr_token.wrapping_add(1);
self.addr_tokens.insert(addr, token);
self.token_addrs.insert(token, addr);
token
}
};
let pid = self.pids.get(&addr).copied();
Some(SocketPacketMeta {
pid,
token,
time,
size,
})
}
fn broadcast(&mut self, msg: &[u8]) -> Result<(), String> {
let targets = self.broadcast_targets();
let sock = self
.socket
.as_mut()
.ok_or("Unable to send before socket is bound".to_string())?;
sock.set_broadcast(true)
.map_err(|e| format!("Unable to set UDP socket to broadcast mode: {e}"))?;
let mut sent_any = false;
let mut errors = Vec::new();
for target in targets {
match sock.send_to(msg, (target, PERIPHERAL_RX_PORT)) {
Ok(_) => sent_any = true,
Err(e) => errors.push(format!("{target}: {e}")),
}
}
sock.set_broadcast(false)
.map_err(|e| format!("Unable to set UDP socket to unicast mode: {e}"))?;
if sent_any {
Ok(())
} else {
Err(format!(
"Failed to send UDP packet on any broadcast target: {}",
errors.join("; ")
))
}
}
fn update_map(&mut self, id: PeripheralId, token: SocketAddrToken) -> Result<(), String> {
if let Some(addr) = self.token_addrs.get(&token).copied() {
self.addrs.insert(id, addr);
self.pids.insert(addr, id);
if self.addrs.len() != self.pids.len() {
return Err(format!(
"Duplicate addresses or peripheral IDs detected.\nAddress map: {:?}\nPeripheral ID map: {:?}",
&self.addrs, &self.pids
));
}
} else {
return Err(format!("Unknown address token {token}"));
}
Ok(())
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn broadcast_address_uses_interface_mask() {
let addr = Ipv4Addr::new(10, 42, 7, 9);
let netmask = Ipv4Addr::new(255, 255, 255, 0);
assert_eq!(
broadcast_from_addr_and_mask(addr, netmask),
Ipv4Addr::new(10, 42, 7, 255)
);
}
#[test]
fn explicit_broadcast_targets_are_deduplicated() {
let socket = UdpSocket::with_broadcast_targets(vec![
Ipv4Addr::new(192, 168, 1, 255),
Ipv4Addr::new(192, 168, 1, 255),
Ipv4Addr::new(10, 0, 0, 255),
]);
assert_eq!(
socket.broadcast_targets(),
vec![
Ipv4Addr::new(10, 0, 0, 255),
Ipv4Addr::new(192, 168, 1, 255),
]
);
}
#[test]
fn possible_broadcast_targets_include_limited_broadcast() {
assert!(possible_broadcast_targets().contains(&Ipv4Addr::BROADCAST));
}
}