use std::collections::HashMap;
use std::io;
use std::os::fd::{AsRawFd, FromRawFd, OwnedFd, RawFd};
use std::sync::{Arc, Mutex, RwLock};
use std::time::Instant;
use crate::config::VmSocketEndpoint;
const MIN_FRAME_SIZE: usize = 14;
const MAX_FRAME_SIZE: usize = 1518;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct MacAddress([u8; 6]);
impl MacAddress {
pub fn from_bytes(bytes: &[u8]) -> Option<MacAddress> {
if bytes.len() < 6 {
return None;
}
let mut mac = [0u8; 6];
mac.copy_from_slice(&bytes[..6]);
Some(MacAddress(mac))
}
pub fn is_broadcast(&self) -> bool {
self.0 == [0xff, 0xff, 0xff, 0xff, 0xff, 0xff]
}
pub fn is_multicast(&self) -> bool {
self.0[0] & 0x01 != 0
}
}
impl std::fmt::Display for MacAddress {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{:02x}:{:02x}:{:02x}:{:02x}:{:02x}:{:02x}",
self.0[0], self.0[1], self.0[2], self.0[3], self.0[4], self.0[5]
)
}
}
#[derive(Debug)]
struct SwitchPort {
fd: OwnedFd,
}
impl AsRawFd for SwitchPort {
fn as_raw_fd(&self) -> RawFd {
self.fd.as_raw_fd()
}
}
type MacTable = HashMap<MacAddress, (usize, Instant)>;
pub struct NetworkSwitch {
networks: Arc<Mutex<HashMap<String, Vec<SwitchPort>>>>,
mac_tables: Arc<RwLock<HashMap<String, MacTable>>>,
running: Arc<std::sync::atomic::AtomicBool>,
worker: Mutex<Option<std::thread::JoinHandle<()>>>,
}
impl NetworkSwitch {
pub fn new() -> Self {
Self {
networks: Arc::new(Mutex::new(HashMap::new())),
mac_tables: Arc::new(RwLock::new(HashMap::new())),
running: Arc::new(std::sync::atomic::AtomicBool::new(false)),
worker: Mutex::new(None),
}
}
pub fn add_port(&self, network_id: &str, _label: &str) -> io::Result<VmSocketEndpoint> {
let (switch_fd, vm_fd) = create_socketpair()?;
let port = SwitchPort { fd: switch_fd };
let mut networks = self
.networks
.lock()
.map_err(|e| io::Error::other(format!("lock poisoned: {}", e)))?;
networks
.entry(network_id.to_string())
.or_default()
.push(port);
let mut mac_tables = self
.mac_tables
.write()
.map_err(|e| io::Error::other(format!("lock poisoned: {}", e)))?;
mac_tables.entry(network_id.to_string()).or_default();
Ok(VmSocketEndpoint::new(vm_fd))
}
pub fn start(&self) -> io::Result<()> {
use std::sync::atomic::Ordering;
if self.running.load(Ordering::Relaxed) {
return Ok(());
}
self.running.store(true, Ordering::SeqCst);
let networks = Arc::clone(&self.networks);
let mac_tables = Arc::clone(&self.mac_tables);
let running = Arc::clone(&self.running);
let handle = std::thread::Builder::new()
.name("network-switch".to_string())
.spawn(move || {
forwarding_loop(&networks, &mac_tables, &running);
})?;
let mut worker = self
.worker
.lock()
.map_err(|e| io::Error::other(format!("lock poisoned: {}", e)))?;
*worker = Some(handle);
Ok(())
}
pub fn stop(&self) {
self.running
.store(false, std::sync::atomic::Ordering::SeqCst);
match self.worker.lock() {
Ok(mut worker) => {
if let Some(handle) = worker.take() {
if let Err(e) = handle.join() {
tracing::error!("network switch thread panicked: {:?}", e);
}
}
}
Err(e) => {
tracing::error!("network switch worker lock poisoned during stop: {}", e);
}
}
}
}
impl Default for NetworkSwitch {
fn default() -> Self {
Self::new()
}
}
impl Drop for NetworkSwitch {
fn drop(&mut self) {
self.stop();
}
}
fn create_socketpair() -> io::Result<(OwnedFd, OwnedFd)> {
let mut fds = [0i32; 2];
let ret = unsafe { libc::socketpair(libc::AF_UNIX, libc::SOCK_DGRAM, 0, fds.as_mut_ptr()) };
if ret != 0 {
return Err(io::Error::last_os_error());
}
let switch_fd = unsafe { OwnedFd::from_raw_fd(fds[0]) };
let vm_fd = unsafe { OwnedFd::from_raw_fd(fds[1]) };
unsafe {
let flags = libc::fcntl(switch_fd.as_raw_fd(), libc::F_GETFL);
if flags == -1 {
return Err(io::Error::last_os_error());
}
if libc::fcntl(
switch_fd.as_raw_fd(),
libc::F_SETFL,
flags | libc::O_NONBLOCK,
) == -1
{
return Err(io::Error::last_os_error());
}
}
Ok((switch_fd, vm_fd))
}
fn forwarding_loop(
networks: &Mutex<HashMap<String, Vec<SwitchPort>>>,
mac_tables: &RwLock<HashMap<String, MacTable>>,
running: &std::sync::atomic::AtomicBool,
) {
use std::sync::atomic::Ordering;
const MAC_AGE_INTERVAL_SECS: u64 = 30;
const MAC_ENTRY_LIFETIME_SECS: u64 = 120;
let mut buf = [0u8; MAX_FRAME_SIZE];
let mut last_aged = Instant::now();
while running.load(Ordering::Relaxed) {
if last_aged.elapsed().as_secs() >= MAC_AGE_INTERVAL_SECS {
if let Ok(mut tables) = mac_tables.write() {
for table in tables.values_mut() {
table.retain(|_mac, (_port, ts)| {
ts.elapsed().as_secs() < MAC_ENTRY_LIFETIME_SECS
});
}
}
last_aged = Instant::now();
}
let nets = match networks.lock() {
Ok(n) => n,
Err(_) => break, };
let mut pollfds: Vec<libc::pollfd> = Vec::new();
let mut fd_map: Vec<(String, usize)> = Vec::new();
let mut port_fds: HashMap<String, Vec<RawFd>> = HashMap::new();
for (net_id, ports) in nets.iter() {
let fds: Vec<RawFd> = ports.iter().map(|p| p.fd.as_raw_fd()).collect();
port_fds.insert(net_id.clone(), fds);
for (idx, port) in ports.iter().enumerate() {
pollfds.push(libc::pollfd {
fd: port.fd.as_raw_fd(),
events: libc::POLLIN,
revents: 0,
});
fd_map.push((net_id.clone(), idx));
}
}
drop(nets);
if pollfds.is_empty() {
std::thread::sleep(std::time::Duration::from_millis(50));
continue;
}
let ready = unsafe { libc::poll(pollfds.as_mut_ptr(), pollfds.len() as libc::nfds_t, 50) };
if ready <= 0 {
continue;
}
for (i, pfd) in pollfds.iter().enumerate() {
if pfd.revents & libc::POLLIN == 0 {
continue;
}
let (ref net_id, src_port_idx) = fd_map[i];
let n =
unsafe { libc::recv(pfd.fd, buf.as_mut_ptr() as *mut libc::c_void, buf.len(), 0) };
if n < MIN_FRAME_SIZE as isize {
continue;
}
let frame = &buf[..n as usize];
let dst_mac = match MacAddress::from_bytes(&frame[0..6]) {
Some(m) => m,
None => continue,
};
let src_mac = match MacAddress::from_bytes(&frame[6..12]) {
Some(m) => m,
None => continue,
};
if let Ok(mut tables) = mac_tables.write() {
if let Some(table) = tables.get_mut(net_id.as_str()) {
table.insert(src_mac, (src_port_idx, Instant::now()));
}
}
let fds = match port_fds.get(net_id.as_str()) {
Some(f) => f,
None => continue,
};
if dst_mac.is_broadcast() || dst_mac.is_multicast() {
for (idx, &fd) in fds.iter().enumerate() {
if idx == src_port_idx {
continue;
}
send_frame(fd, frame);
}
} else {
let dst_port = match mac_tables.read() {
Ok(tables) => tables
.get(net_id.as_str())
.and_then(|t| t.get(&dst_mac))
.map(|(port_idx, _ts)| *port_idx),
Err(_) => {
tracing::error!("MAC table read lock poisoned, flooding frame");
None
}
};
if let Some(dst_idx) = dst_port {
if dst_idx != src_port_idx && dst_idx < fds.len() {
send_frame(fds[dst_idx], frame);
}
} else {
for (idx, &fd) in fds.iter().enumerate() {
if idx == src_port_idx {
continue;
}
send_frame(fd, frame);
}
}
}
}
}
running.store(false, Ordering::SeqCst);
}
fn send_frame(fd: RawFd, frame: &[u8]) {
let sent = unsafe {
libc::send(
fd,
frame.as_ptr() as *const libc::c_void,
frame.len(),
libc::MSG_DONTWAIT,
)
};
if sent < 0 {
let err = std::io::Error::last_os_error();
tracing::debug!(fd = fd, error = %err, "dropping frame because send failed");
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn mac_from_bytes_valid() {
let mac = MacAddress::from_bytes(&[0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]).expect("valid MAC");
assert_eq!(mac.0, [0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]);
}
#[test]
fn mac_from_bytes_too_short() {
assert!(MacAddress::from_bytes(&[0xaa, 0xbb]).is_none());
}
#[test]
fn mac_from_bytes_empty() {
assert!(MacAddress::from_bytes(&[]).is_none());
}
#[test]
fn mac_from_bytes_extra_bytes_ignored() {
let mac = MacAddress::from_bytes(&[1, 2, 3, 4, 5, 6, 7, 8]).expect("invalid MAC");
assert_eq!(mac.0, [1, 2, 3, 4, 5, 6]);
}
#[test]
fn mac_broadcast() {
let mac = MacAddress([0xff, 0xff, 0xff, 0xff, 0xff, 0xff]);
assert!(mac.is_broadcast());
assert!(mac.is_multicast()); }
#[test]
fn mac_not_broadcast() {
let mac = MacAddress([0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]);
assert!(!mac.is_broadcast());
}
#[test]
fn mac_multicast() {
let mac = MacAddress([0x01, 0x00, 0x5e, 0x00, 0x00, 0x01]);
assert!(mac.is_multicast());
assert!(!mac.is_broadcast());
}
#[test]
fn mac_unicast() {
let mac = MacAddress([0x02, 0x42, 0xac, 0x11, 0x00, 0x02]);
assert!(!mac.is_multicast());
assert!(!mac.is_broadcast());
}
#[test]
fn mac_display_format() {
let mac = MacAddress([0x02, 0x42, 0xac, 0x11, 0x00, 0x02]);
assert_eq!(format!("{}", mac), "02:42:ac:11:00:02");
}
#[test]
fn mac_display_zero() {
let mac = MacAddress([0, 0, 0, 0, 0, 0]);
assert_eq!(format!("{}", mac), "00:00:00:00:00:00");
}
#[test]
fn switch_add_port_returns_fd() {
let switch = NetworkSwitch::new();
let vm_fd = switch.add_port("net0", "web").expect("add port");
assert!(vm_fd.as_raw_fd() >= 0);
}
#[test]
fn switch_add_multiple_ports_same_network() {
let switch = NetworkSwitch::new();
let fd1 = switch.add_port("net0", "web").expect("add web port");
let fd2 = switch.add_port("net0", "db").expect("add db port");
assert_ne!(fd1.as_raw_fd(), fd2.as_raw_fd());
}
#[test]
fn switch_add_ports_different_networks() {
let switch = NetworkSwitch::new();
let fd1 = switch
.add_port("frontend", "web")
.expect("add frontend port");
let fd2 = switch.add_port("backend", "db").expect("add backend port");
assert_ne!(fd1.as_raw_fd(), fd2.as_raw_fd());
}
#[test]
fn switch_frame_delivery_same_network() {
let switch = NetworkSwitch::new();
let fd1 = switch.add_port("net0", "sender").expect("add sender port");
let fd2 = switch
.add_port("net0", "receiver")
.expect("add receiver port");
switch.start().expect("start switch");
let mut frame = [0u8; 14];
frame[0..6].copy_from_slice(&[0xff, 0xff, 0xff, 0xff, 0xff, 0xff]);
frame[6..12].copy_from_slice(&[0x02, 0x00, 0x00, 0x00, 0x00, 0x01]);
frame[12..14].copy_from_slice(&[0x08, 0x00]);
let sent = unsafe {
libc::send(
fd1.as_raw_fd(),
frame.as_ptr() as *const libc::c_void,
frame.len(),
0,
)
};
assert_eq!(sent, 14);
std::thread::sleep(std::time::Duration::from_millis(200));
let mut buf = vec![0u8; 1518];
let recvd = unsafe {
libc::recv(
fd2.as_raw_fd(),
buf.as_mut_ptr() as *mut libc::c_void,
buf.len(),
libc::MSG_DONTWAIT,
)
};
assert_eq!(
recvd, 14,
"broadcast frame should be forwarded to the other port"
);
assert_eq!(&buf[..14], &frame[..14]);
switch.stop();
}
#[test]
fn switch_no_cross_network_forwarding() {
let switch = NetworkSwitch::new();
let fd1 = switch.add_port("net-a", "sender").expect("add sender port");
let fd2 = switch
.add_port("net-b", "isolated")
.expect("add isolated port");
switch.start().expect("start switch");
let mut frame = [0u8; 14];
frame[0..6].copy_from_slice(&[0xff, 0xff, 0xff, 0xff, 0xff, 0xff]);
frame[6..12].copy_from_slice(&[0x02, 0x00, 0x00, 0x00, 0x00, 0x01]);
frame[12..14].copy_from_slice(&[0x08, 0x00]);
unsafe {
libc::send(
fd1.as_raw_fd(),
frame.as_ptr() as *const libc::c_void,
frame.len(),
0,
);
}
std::thread::sleep(std::time::Duration::from_millis(200));
let mut buf = vec![0u8; 1518];
let recvd = unsafe {
libc::recv(
fd2.as_raw_fd(),
buf.as_mut_ptr() as *mut libc::c_void,
buf.len(),
libc::MSG_DONTWAIT,
)
};
assert!(recvd <= 0, "frame should NOT cross network boundaries");
switch.stop();
}
}