use std::collections::HashSet;
use std::net::Ipv4Addr;
use std::sync::Mutex;
use serde::{Deserialize, Serialize};
use tracing::{debug, info};
use crate::error::{Result, VmmError};
const fn default_prefix_len() -> u8 {
16
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NetworkAllocation {
pub tap_name: String,
pub ip_address: Ipv4Addr,
#[serde(default = "default_prefix_len")]
pub prefix_len: u8,
pub gateway: Ipv4Addr,
pub mac_address: String,
pub dns_servers: Vec<String>,
}
impl NetworkAllocation {
pub fn netmask(&self) -> Ipv4Addr {
let p = self.prefix_len.min(32);
if p == 0 {
Ipv4Addr::UNSPECIFIED
} else {
Ipv4Addr::from(!0u32 << (32 - p))
}
}
}
pub struct NetworkManager {
base: Ipv4Addr,
prefix_len: u8,
gateway: Ipv4Addr,
dns: Vec<String>,
allocated: Mutex<HashSet<u32>>,
}
impl NetworkManager {
pub fn new(cidr: &str, gateway: &str, dns: Vec<String>) -> Result<Self> {
let (base, prefix_len) = parse_cidr(cidr)?;
if !(1..=30).contains(&prefix_len) {
return Err(VmmError::Network(format!(
"prefix length {prefix_len} out of range 1–30"
)));
}
let gateway = gateway
.parse::<Ipv4Addr>()
.map_err(|e| VmmError::Network(format!("invalid gateway: {e}")))?;
Ok(Self {
base,
prefix_len,
gateway,
dns,
allocated: Mutex::new(HashSet::new()),
})
}
pub fn allocate(&self, vm_id: &str) -> Result<NetworkAllocation> {
let ip = self.next_ip()?;
let tap_name = tap_name_from_ip(ip);
let mac = mac_from_vm_id(vm_id);
info!(vm_id, tap = %tap_name, ip = %ip, "allocating network");
#[cfg(target_os = "linux")]
if let Err(e) = self.create_tap(&tap_name, ip) {
self.allocated.lock().unwrap().remove(&u32::from(ip));
return Err(e);
}
Ok(NetworkAllocation {
tap_name,
ip_address: ip,
prefix_len: self.prefix_len,
gateway: self.gateway,
mac_address: mac,
dns_servers: self.dns.clone(),
})
}
pub fn release(&self, alloc: &NetworkAllocation) {
let ip_int = u32::from(alloc.ip_address);
self.allocated.lock().unwrap().remove(&ip_int);
debug!(tap = %alloc.tap_name, ip = %alloc.ip_address, "releasing network");
#[cfg(target_os = "linux")]
destroy_tap(&alloc.tap_name);
}
fn next_ip(&self) -> Result<Ipv4Addr> {
let host_bits = 32 - u32::from(self.prefix_len);
let mask = !((1u32 << host_bits) - 1);
let host_max = (1u32 << host_bits) - 2;
let network_base = u32::from(self.base) & mask;
let mut allocated = self.allocated.lock().unwrap();
for offset in 2..=host_max {
let candidate = network_base + offset;
if !allocated.contains(&candidate) {
allocated.insert(candidate);
return Ok(Ipv4Addr::from(candidate));
}
}
Err(VmmError::Network("IP pool exhausted".into()))
}
#[cfg(target_os = "linux")]
fn create_tap(&self, tap_name: &str, ip: Ipv4Addr) -> Result<()> {
use std::os::fd::FromRawFd;
use std::os::unix::io::AsRawFd;
destroy_tap(tap_name);
let name_bytes = tap_name.as_bytes();
if name_bytes.len() >= libc::IFNAMSIZ {
return Err(VmmError::Network(format!("TAP name too long: {tap_name}")));
}
let tun = std::fs::OpenOptions::new()
.read(true)
.write(true)
.open("/dev/net/tun")
.map_err(|e| VmmError::Network(format!("open /dev/net/tun: {e}")))?;
let mut ifr = new_ifreq(name_bytes);
ifr.ifr_ifru.ifru_flags = (libc::IFF_TAP | libc::IFF_NO_PI) as i16;
const TUNSETIFF: libc::c_ulong = 0x400454ca;
const TUNSETPERSIST: libc::c_ulong = 0x400454cb;
if unsafe { libc::ioctl(tun.as_raw_fd(), TUNSETIFF as _, &ifr) } < 0 {
return Err(VmmError::Network(format!(
"TUNSETIFF {tap_name}: {}",
std::io::Error::last_os_error()
)));
}
if unsafe { libc::ioctl(tun.as_raw_fd(), TUNSETPERSIST as _, 1i32) } < 0 {
return Err(VmmError::Network(format!(
"TUNSETPERSIST {tap_name}: {}",
std::io::Error::last_os_error()
)));
}
drop(tun);
let sock = unsafe { libc::socket(libc::AF_INET, libc::SOCK_DGRAM, 0) };
if sock < 0 {
destroy_tap(tap_name);
return Err(VmmError::Network(format!(
"socket: {}",
std::io::Error::last_os_error()
)));
}
let sock = unsafe { std::os::fd::OwnedFd::from_raw_fd(sock) };
if unsafe { libc::ioctl(sock.as_raw_fd(), libc::SIOCGIFFLAGS as _, &ifr) } < 0 {
destroy_tap(tap_name);
return Err(VmmError::Network(format!(
"SIOCGIFFLAGS {tap_name}: {}",
std::io::Error::last_os_error()
)));
}
unsafe { ifr.ifr_ifru.ifru_flags |= libc::IFF_UP as i16 };
if unsafe { libc::ioctl(sock.as_raw_fd(), libc::SIOCSIFFLAGS as _, &ifr) } < 0 {
destroy_tap(tap_name);
return Err(VmmError::Network(format!(
"SIOCSIFFLAGS UP {tap_name}: {}",
std::io::Error::last_os_error()
)));
}
if let Err(e) = (|| -> Result<()> {
set_ifaddr(
&sock,
&ifr,
libc::SIOCSIFADDR,
self.gateway,
tap_name,
"SIOCSIFADDR",
)?;
set_ifaddr(
&sock,
&ifr,
libc::SIOCSIFDSTADDR,
ip,
tap_name,
"SIOCSIFDSTADDR",
)?;
set_ifaddr(
&sock,
&ifr,
libc::SIOCSIFNETMASK,
Ipv4Addr::BROADCAST, tap_name,
"SIOCSIFNETMASK",
)?;
Ok(())
})() {
destroy_tap(tap_name);
return Err(e);
}
Ok(())
}
}
#[cfg(target_os = "linux")]
fn new_ifreq(name_bytes: &[u8]) -> libc::ifreq {
let mut ifr: libc::ifreq = unsafe { std::mem::zeroed() };
unsafe {
std::ptr::copy_nonoverlapping(
name_bytes.as_ptr(),
ifr.ifr_name.as_mut_ptr().cast::<u8>(),
name_bytes.len(),
);
}
ifr
}
#[cfg(target_os = "linux")]
fn set_ifaddr(
sock: &std::os::fd::OwnedFd,
ifr: &libc::ifreq,
request: libc::c_ulong,
addr: Ipv4Addr,
tap_name: &str,
label: &str,
) -> Result<()> {
use std::os::unix::io::AsRawFd;
let mut req = *ifr;
let mut addr_in: libc::sockaddr_in = unsafe { std::mem::zeroed() };
addr_in.sin_family = libc::AF_INET as libc::sa_family_t;
addr_in.sin_addr.s_addr = u32::from(addr).to_be();
unsafe {
std::ptr::copy_nonoverlapping(
(&raw const addr_in).cast::<u8>(),
(&raw mut req.ifr_ifru).cast::<u8>(),
std::mem::size_of::<libc::sockaddr_in>(),
);
}
if unsafe { libc::ioctl(sock.as_raw_fd(), request as _, &req) } < 0 {
return Err(VmmError::Network(format!(
"{label} {tap_name} {addr}: {}",
std::io::Error::last_os_error()
)));
}
Ok(())
}
#[cfg(target_os = "linux")]
fn destroy_tap(tap_name: &str) {
use std::os::unix::io::AsRawFd;
let name_bytes = tap_name.as_bytes();
if name_bytes.len() >= libc::IFNAMSIZ {
return;
}
if let Ok(tun) = std::fs::OpenOptions::new()
.read(true)
.write(true)
.open("/dev/net/tun")
{
let mut ifr = new_ifreq(name_bytes);
ifr.ifr_ifru.ifru_flags = (libc::IFF_TAP | libc::IFF_NO_PI) as i16;
const TUNSETIFF: libc::c_ulong = 0x400454ca;
const TUNSETPERSIST: libc::c_ulong = 0x400454cb;
if unsafe { libc::ioctl(tun.as_raw_fd(), TUNSETIFF as _, &ifr) } >= 0 {
let _ = unsafe { libc::ioctl(tun.as_raw_fd(), TUNSETPERSIST as _, 0i32) };
}
drop(tun);
}
if std::path::Path::new(&format!("/sys/class/net/{tap_name}")).exists() {
match std::process::Command::new("/usr/sbin/ip")
.args(["link", "delete", tap_name])
.output()
{
Ok(o) if !o.status.success() => {
tracing::warn!(
tap = tap_name,
stderr = %String::from_utf8_lossy(&o.stderr),
"ip link delete failed"
);
}
Err(e) => tracing::warn!(tap = tap_name, error = %e, "ip link delete failed"),
_ => {}
}
}
}
fn parse_cidr(cidr: &str) -> Result<(Ipv4Addr, u8)> {
let parts: Vec<&str> = cidr.split('/').collect();
if parts.len() != 2 {
return Err(VmmError::Network(format!("invalid CIDR: {cidr}")));
}
let addr = parts[0]
.parse::<Ipv4Addr>()
.map_err(|e| VmmError::Network(format!("invalid CIDR address: {e}")))?;
let prefix: u8 = parts[1]
.parse()
.map_err(|e| VmmError::Network(format!("invalid prefix length: {e}")))?;
Ok((addr, prefix))
}
fn tap_name_from_ip(ip: Ipv4Addr) -> String {
let octets = ip.octets();
format!("vmtap{}-{}", octets[2], octets[3])
}
fn mac_from_vm_id(vm_id: &str) -> String {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
vm_id.hash(&mut hasher);
let h = hasher.finish();
let b: [u8; 6] = [
0x02 | (((h >> 40) & 0xfe) as u8),
((h >> 32) & 0xff) as u8,
((h >> 24) & 0xff) as u8,
((h >> 16) & 0xff) as u8,
((h >> 8) & 0xff) as u8,
(h & 0xff) as u8,
];
format!(
"{:02x}:{:02x}:{:02x}:{:02x}:{:02x}:{:02x}",
b[0], b[1], b[2], b[3], b[4], b[5]
)
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(target_os = "linux")]
fn is_root() -> bool {
std::fs::read_to_string("/proc/self/status")
.map(|s| {
s.lines()
.find(|l| l.starts_with("Uid:"))
.and_then(|l| l.split_whitespace().nth(2))
.map(|uid| uid == "0")
.unwrap_or(false)
})
.unwrap_or(false)
}
#[test]
fn test_allocate_sequential_ips() {
#[cfg(target_os = "linux")]
if !is_root() {
eprintln!("SKIP test_allocate_sequential_ips — requires root (TAP creation)");
return;
}
let mgr = NetworkManager::new("172.20.0.0/16", "172.20.0.1", vec![]).unwrap();
let a1 = mgr.allocate("vm-1").unwrap();
let a2 = mgr.allocate("vm-2").unwrap();
assert_ne!(a1.ip_address, a2.ip_address);
}
#[test]
fn test_release_returns_ip_to_pool() {
#[cfg(target_os = "linux")]
if !is_root() {
eprintln!("SKIP test_release_returns_ip_to_pool — requires root (TAP creation)");
return;
}
let mgr = NetworkManager::new("172.20.0.0/16", "172.20.0.1", vec![]).unwrap();
let a1 = mgr.allocate("vm-1").unwrap();
let first_ip = a1.ip_address;
mgr.release(&a1);
let a2 = mgr.allocate("vm-1").unwrap();
assert_eq!(a2.ip_address, first_ip);
}
#[test]
fn test_mac_deterministic() {
assert_eq!(mac_from_vm_id("abc"), mac_from_vm_id("abc"));
assert_ne!(mac_from_vm_id("abc"), mac_from_vm_id("xyz"));
}
#[test]
fn test_invalid_prefix_len_rejected() {
assert!(NetworkManager::new("10.0.0.0/0", "10.0.0.1", vec![]).is_err());
assert!(NetworkManager::new("10.0.0.0/31", "10.0.0.1", vec![]).is_err());
assert!(NetworkManager::new("10.0.0.0/32", "10.0.0.1", vec![]).is_err());
assert!(NetworkManager::new("10.0.0.0/24", "10.0.0.1", vec![]).is_ok());
}
#[test]
fn test_next_ip_respects_subnet_boundary() {
#[cfg(target_os = "linux")]
if !is_root() {
eprintln!("SKIP test_next_ip_respects_subnet_boundary — requires root (TAP creation)");
return;
}
let mgr = NetworkManager::new("10.0.0.0/30", "10.0.0.1", vec![]).unwrap();
let a = mgr.allocate("vm-1").unwrap();
assert_eq!(a.ip_address, "10.0.0.2".parse::<Ipv4Addr>().unwrap());
assert!(mgr.allocate("vm-2").is_err());
}
#[test]
fn test_pool_exhaustion_on_slash29() {
#[cfg(target_os = "linux")]
if !is_root() {
eprintln!("SKIP test_pool_exhaustion_on_slash29 — requires root (TAP creation)");
return;
}
let mgr = NetworkManager::new("10.0.0.0/29", "10.0.0.1", vec![]).unwrap();
for i in 0..5 {
mgr.allocate(&format!("vm-{i}")).unwrap();
}
assert!(mgr.allocate("vm-overflow").is_err());
}
#[test]
fn test_mac_unicast_and_locally_administered_bits() {
let mac = mac_from_vm_id("test-vm");
let first_byte = u8::from_str_radix(&mac[..2], 16).unwrap();
assert_eq!(
first_byte & 0x02,
0x02,
"locally administered bit must be set"
);
assert_eq!(first_byte & 0x01, 0x00, "multicast bit must be clear");
}
#[test]
fn test_netmask_slash0() {
let alloc = NetworkAllocation {
tap_name: String::new(),
ip_address: Ipv4Addr::UNSPECIFIED,
prefix_len: 0,
gateway: Ipv4Addr::UNSPECIFIED,
mac_address: String::new(),
dns_servers: vec![],
};
assert_eq!(alloc.netmask(), Ipv4Addr::UNSPECIFIED);
}
#[test]
fn test_netmask_slash8() {
let alloc = NetworkAllocation {
tap_name: String::new(),
ip_address: Ipv4Addr::UNSPECIFIED,
prefix_len: 8,
gateway: Ipv4Addr::UNSPECIFIED,
mac_address: String::new(),
dns_servers: vec![],
};
assert_eq!(alloc.netmask(), Ipv4Addr::new(255, 0, 0, 0));
}
#[test]
fn test_netmask_slash16() {
let alloc = NetworkAllocation {
tap_name: String::new(),
ip_address: Ipv4Addr::UNSPECIFIED,
prefix_len: 16,
gateway: Ipv4Addr::UNSPECIFIED,
mac_address: String::new(),
dns_servers: vec![],
};
assert_eq!(alloc.netmask(), Ipv4Addr::new(255, 255, 0, 0));
}
#[test]
fn test_netmask_slash24() {
let alloc = NetworkAllocation {
tap_name: String::new(),
ip_address: Ipv4Addr::UNSPECIFIED,
prefix_len: 24,
gateway: Ipv4Addr::UNSPECIFIED,
mac_address: String::new(),
dns_servers: vec![],
};
assert_eq!(alloc.netmask(), Ipv4Addr::new(255, 255, 255, 0));
}
#[test]
fn test_netmask_slash30() {
let alloc = NetworkAllocation {
tap_name: String::new(),
ip_address: Ipv4Addr::UNSPECIFIED,
prefix_len: 30,
gateway: Ipv4Addr::UNSPECIFIED,
mac_address: String::new(),
dns_servers: vec![],
};
assert_eq!(alloc.netmask(), Ipv4Addr::new(255, 255, 255, 252));
}
#[test]
fn test_netmask_slash32() {
let alloc = NetworkAllocation {
tap_name: String::new(),
ip_address: Ipv4Addr::UNSPECIFIED,
prefix_len: 32,
gateway: Ipv4Addr::UNSPECIFIED,
mac_address: String::new(),
dns_servers: vec![],
};
assert_eq!(alloc.netmask(), Ipv4Addr::BROADCAST);
}
#[test]
fn test_netmask_out_of_range_clamps_to_32() {
let alloc = NetworkAllocation {
tap_name: String::new(),
ip_address: Ipv4Addr::UNSPECIFIED,
prefix_len: 33,
gateway: Ipv4Addr::UNSPECIFIED,
mac_address: String::new(),
dns_servers: vec![],
};
assert_eq!(alloc.netmask(), Ipv4Addr::BROADCAST);
}
#[test]
fn test_tap_name_encodes_last_two_octets() {
let ip: Ipv4Addr = "172.20.3.17".parse().unwrap();
assert_eq!(tap_name_from_ip(ip), "vmtap3-17");
let ip2: Ipv4Addr = "10.0.255.1".parse().unwrap();
assert_eq!(tap_name_from_ip(ip2), "vmtap255-1");
}
}