use std::collections::HashMap;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::path::{Path, PathBuf};
use std::sync::Mutex;
use std::{fs, io};
use crate::error::{VmRuntimeError, VmRuntimeResult};
pub const DEFAULT_SOCKET_DIR: &str = "/var/run/microvm/vsocks";
pub const DEFAULT_FIRST_CID: u32 = 3;
pub const DEFAULT_LAST_CID: u32 = 0xFFFF_FFFE;
#[derive(Debug, Clone)]
pub struct VsockConfig {
pub socket_dir: PathBuf,
pub first_cid: u32,
pub last_cid: u32,
}
impl Default for VsockConfig {
fn default() -> Self {
Self {
socket_dir: PathBuf::from(DEFAULT_SOCKET_DIR),
first_cid: DEFAULT_FIRST_CID,
last_cid: DEFAULT_LAST_CID,
}
}
}
impl VsockConfig {
pub fn from_env() -> Self {
let defaults = Self::default();
let socket_dir = std::env::var("MICROVM_VSOCK_DIR")
.map(PathBuf::from)
.unwrap_or(defaults.socket_dir);
let first_cid = std::env::var("MICROVM_VSOCK_FIRST_CID")
.ok()
.and_then(|v| v.parse::<u32>().ok())
.unwrap_or(defaults.first_cid);
let last_cid = std::env::var("MICROVM_VSOCK_LAST_CID")
.ok()
.and_then(|v| v.parse::<u32>().ok())
.unwrap_or(defaults.last_cid);
Self {
socket_dir,
first_cid,
last_cid,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct VmVsock {
pub cid: u32,
pub uds_path: PathBuf,
}
#[derive(Debug)]
pub struct VsockManager {
config: VsockConfig,
allocations: Mutex<HashMap<String, u32>>,
}
impl VsockManager {
pub fn new(config: VsockConfig) -> Self {
Self {
config,
allocations: Mutex::new(HashMap::new()),
}
}
pub fn from_env() -> Self {
Self::new(VsockConfig::from_env())
}
pub fn config(&self) -> &VsockConfig {
&self.config
}
pub fn attach(&self, vm_id: &str) -> VmRuntimeResult<VmVsock> {
let mut allocations = self
.allocations
.lock()
.map_err(|_| VmRuntimeError::StatePoisoned)?;
let cid = match allocations.get(vm_id) {
Some(existing) => *existing,
None => {
let used: std::collections::HashSet<u32> = allocations.values().copied().collect();
let cid = find_free_cid(vm_id, self.config.first_cid, self.config.last_cid, &used)?;
allocations.insert(vm_id.to_owned(), cid);
cid
}
};
let uds_path = self.uds_path_for(vm_id);
ensure_parent_dir(&uds_path)?;
Ok(VmVsock { cid, uds_path })
}
pub fn detach(&self, vm_id: &str) -> VmRuntimeResult<()> {
let mut allocations = self
.allocations
.lock()
.map_err(|_| VmRuntimeError::StatePoisoned)?;
allocations.remove(vm_id);
drop(allocations);
let uds_path = self.uds_path_for(vm_id);
match fs::remove_file(&uds_path) {
Ok(()) => Ok(()),
Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(()),
Err(err) => Err(VmRuntimeError::Unsupported(format!(
"failed to remove vsock uds {}: {err}",
uds_path.display()
))),
}
}
pub fn ensure_uds_parent(&self, uds_path: &Path) -> VmRuntimeResult<()> {
ensure_parent_dir(uds_path)
}
pub fn uds_path_for(&self, vm_id: &str) -> PathBuf {
self.config
.socket_dir
.join(safe_vm_id(vm_id))
.join("vsock.uds")
}
}
fn safe_vm_id(vm_id: &str) -> String {
vm_id
.chars()
.map(|c| {
if c.is_ascii_alphanumeric() || c == '-' || c == '_' {
c
} else {
'_'
}
})
.collect()
}
fn ensure_parent_dir(path: &Path) -> VmRuntimeResult<()> {
let parent = path.parent().ok_or_else(|| {
VmRuntimeError::Unsupported(format!(
"vsock uds path has no parent directory: {}",
path.display()
))
})?;
fs::create_dir_all(parent).map_err(|e| {
VmRuntimeError::Unsupported(format!(
"failed to create vsock uds parent {}: {e}",
parent.display()
))
})
}
fn find_free_cid(
vm_id: &str,
first: u32,
last: u32,
used: &std::collections::HashSet<u32>,
) -> VmRuntimeResult<u32> {
if first > last {
return Err(VmRuntimeError::Unsupported(format!(
"vsock cid range is empty: first={first}, last={last}"
)));
}
let span = u64::from(last - first) + 1;
let capacity = usize::try_from(span).unwrap_or(usize::MAX);
if used.len() >= capacity {
return Err(VmRuntimeError::Unsupported(format!(
"vsock cid range exhausted: first={first}, last={last}"
)));
}
let mut hasher = DefaultHasher::new();
vm_id.hash(&mut hasher);
let h = hasher.finish();
let offset = u32::try_from(h % span).unwrap_or(0);
let start = first.saturating_add(offset);
let mut candidate = start;
loop {
if !used.contains(&candidate) {
return Ok(candidate);
}
candidate = if candidate == last {
first
} else {
candidate + 1
};
if candidate == start {
return Err(VmRuntimeError::Unsupported(format!(
"vsock cid range exhausted: first={first}, last={last}"
)));
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashSet;
use tempfile::TempDir;
fn manager_in(dir: &TempDir) -> VsockManager {
VsockManager::new(VsockConfig {
socket_dir: dir.path().to_path_buf(),
first_cid: DEFAULT_FIRST_CID,
last_cid: DEFAULT_LAST_CID,
})
}
#[test]
fn defaults_are_sane() {
let cfg = VsockConfig::default();
assert_eq!(cfg.first_cid, 3);
assert_eq!(cfg.last_cid, 0xFFFF_FFFE);
assert_eq!(cfg.socket_dir, PathBuf::from(DEFAULT_SOCKET_DIR));
}
#[test]
fn cid_is_deterministic_across_managers() {
let tmp = TempDir::new().unwrap();
let m1 = manager_in(&tmp);
let m2 = manager_in(&tmp);
let a = m1.attach("vm-deterministic").unwrap();
let b = m2.attach("vm-deterministic").unwrap();
assert_eq!(a.cid, b.cid);
assert_eq!(a.uds_path, b.uds_path);
}
#[test]
fn attach_is_idempotent() {
let tmp = TempDir::new().unwrap();
let mgr = manager_in(&tmp);
let first = mgr.attach("vm-x").unwrap();
let second = mgr.attach("vm-x").unwrap();
let third = mgr.attach("vm-x").unwrap();
assert_eq!(first, second);
assert_eq!(second, third);
let allocations = mgr.allocations.lock().unwrap();
assert_eq!(allocations.len(), 1);
}
#[test]
fn distinct_vm_ids_get_distinct_cids() {
let tmp = TempDir::new().unwrap();
let mgr = manager_in(&tmp);
let a = mgr.attach("alpha").unwrap();
let b = mgr.attach("beta").unwrap();
let c = mgr.attach("gamma").unwrap();
let set: HashSet<u32> = [a.cid, b.cid, c.cid].into_iter().collect();
assert_eq!(set.len(), 3);
}
#[test]
fn cid_collision_linear_probes() {
let tmp = TempDir::new().unwrap();
let mgr = VsockManager::new(VsockConfig {
socket_dir: tmp.path().to_path_buf(),
first_cid: 3,
last_cid: 4,
});
let a = mgr.attach("vm-a").unwrap();
let b = mgr.attach("vm-b").unwrap();
assert_ne!(a.cid, b.cid);
let exhausted = mgr.attach("vm-c");
assert!(
matches!(exhausted, Err(VmRuntimeError::Unsupported(ref msg)) if msg.contains("exhausted"))
);
}
#[test]
fn detach_releases_cid_for_future_allocation() {
let tmp = TempDir::new().unwrap();
let mgr = VsockManager::new(VsockConfig {
socket_dir: tmp.path().to_path_buf(),
first_cid: 3,
last_cid: 4,
});
let _a = mgr.attach("vm-a").unwrap();
let _b = mgr.attach("vm-b").unwrap();
assert!(mgr.attach("vm-c").is_err());
mgr.detach("vm-a").unwrap();
let c = mgr.attach("vm-c").unwrap();
assert!(c.cid == 3 || c.cid == 4);
}
#[test]
fn detach_is_idempotent_when_unknown() {
let tmp = TempDir::new().unwrap();
let mgr = manager_in(&tmp);
mgr.detach("never-attached").unwrap();
mgr.detach("never-attached").unwrap();
}
#[test]
fn detach_removes_socket_file_but_leaves_parent_dir() {
let tmp = TempDir::new().unwrap();
let mgr = manager_in(&tmp);
let vm = mgr.attach("vm-cleanup").unwrap();
std::fs::write(&vm.uds_path, b"").unwrap();
assert!(vm.uds_path.exists());
mgr.detach("vm-cleanup").unwrap();
assert!(!vm.uds_path.exists());
assert!(vm.uds_path.parent().unwrap().exists());
}
#[test]
fn safe_vm_id_strips_unsafe_chars() {
assert_eq!(safe_vm_id("a/b/c"), "a_b_c");
assert_eq!(safe_vm_id("normal-id_42"), "normal-id_42");
assert_eq!(safe_vm_id("../etc/passwd"), "___etc_passwd");
assert_eq!(safe_vm_id("with space"), "with_space");
}
#[test]
fn uds_path_uses_sanitised_vm_id() {
let tmp = TempDir::new().unwrap();
let mgr = manager_in(&tmp);
let path = mgr.uds_path_for("a/b/c");
assert_eq!(path, tmp.path().join("a_b_c").join("vsock.uds"));
}
#[test]
fn ensure_uds_parent_creates_missing_dirs() {
let tmp = TempDir::new().unwrap();
let mgr = manager_in(&tmp);
let deep = tmp.path().join("a").join("b").join("c").join("vsock.uds");
assert!(!deep.parent().unwrap().exists());
mgr.ensure_uds_parent(&deep).unwrap();
assert!(deep.parent().unwrap().exists());
}
#[test]
fn ensure_uds_parent_is_idempotent() {
let tmp = TempDir::new().unwrap();
let mgr = manager_in(&tmp);
let p = tmp.path().join("a").join("vsock.uds");
mgr.ensure_uds_parent(&p).unwrap();
mgr.ensure_uds_parent(&p).unwrap();
mgr.ensure_uds_parent(&p).unwrap();
assert!(p.parent().unwrap().exists());
}
#[test]
fn ensure_uds_parent_simulates_fc_v16_restore_race() {
let tmp = TempDir::new().unwrap();
let mgr = manager_in(&tmp);
let vm = mgr.attach("seed").unwrap();
let parent = vm.uds_path.parent().unwrap();
assert!(parent.exists());
std::fs::remove_dir_all(parent).unwrap();
assert!(!parent.exists());
mgr.ensure_uds_parent(&vm.uds_path).unwrap();
assert!(parent.exists());
}
#[test]
fn attach_creates_parent_dir() {
let tmp = TempDir::new().unwrap();
let mgr = manager_in(&tmp);
let vm = mgr.attach("vm-fresh").unwrap();
assert!(vm.uds_path.parent().unwrap().exists());
assert!(!vm.uds_path.exists());
}
#[test]
fn cid_range_validation() {
let tmp = TempDir::new().unwrap();
let mgr = VsockManager::new(VsockConfig {
socket_dir: tmp.path().to_path_buf(),
first_cid: 100,
last_cid: 50,
});
let err = mgr.attach("vm").unwrap_err();
assert!(matches!(err, VmRuntimeError::Unsupported(_)));
}
#[test]
fn cid_lands_inside_configured_range() {
let tmp = TempDir::new().unwrap();
let mgr = VsockManager::new(VsockConfig {
socket_dir: tmp.path().to_path_buf(),
first_cid: 100,
last_cid: 200,
});
for i in 0..50 {
let vm = mgr.attach(&format!("vm-{i}")).unwrap();
assert!(
vm.cid >= 100 && vm.cid <= 200,
"cid {} outside [100, 200]",
vm.cid
);
}
}
#[test]
fn linear_probe_when_hash_collides_directly() {
let tmp = TempDir::new().unwrap();
let mgr = VsockManager::new(VsockConfig {
socket_dir: tmp.path().to_path_buf(),
first_cid: 10,
last_cid: 13,
});
let mut seen = HashSet::new();
for i in 0..4 {
let vm = mgr.attach(&format!("probe-{i}")).unwrap();
assert!(seen.insert(vm.cid), "duplicate CID {}", vm.cid);
}
assert_eq!(seen, HashSet::from([10, 11, 12, 13]));
}
}