use std::collections::HashMap;
use std::io::Write;
use std::path::Path;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex, RwLock};
use windows::Win32::System::Hypervisor::*;
use windows::Win32::System::Memory::{
VirtualAlloc, VirtualFree, MEM_COMMIT, MEM_RELEASE, MEM_RESERVE, PAGE_READWRITE,
};
use super::boot::{
self, BOOT_PARAMS_ADDR, GDT_ADDR, GDT_CODE_SELECTOR, GDT_DATA_SELECTOR, PML4_ADDR,
};
use crate::config::{VmConfig, VmHandle, VmState};
use crate::driver::{VmDriver, VmError};
const COM1_DATA: u16 = 0x3F8;
const COM1_LSR: u16 = 0x3FD;
struct GuestMemory {
ptr: *mut u8,
size: usize,
}
unsafe impl Send for GuestMemory {}
unsafe impl Sync for GuestMemory {}
impl GuestMemory {
fn allocate(size: usize) -> Result<Self, VmError> {
let ptr = unsafe { VirtualAlloc(None, size, MEM_COMMIT | MEM_RESERVE, PAGE_READWRITE) };
if ptr.is_null() {
return Err(VmError::Hypervisor(format!(
"VirtualAlloc failed for {} MB",
size / (1024 * 1024)
)));
}
Ok(Self {
ptr: ptr as *mut u8,
size,
})
}
unsafe fn as_mut_slice(&mut self) -> &mut [u8] {
std::slice::from_raw_parts_mut(self.ptr, self.size)
}
fn as_ptr(&self) -> *const u8 {
self.ptr
}
}
impl Drop for GuestMemory {
fn drop(&mut self) {
if !self.ptr.is_null() {
unsafe {
let _ = VirtualFree(self.ptr as *mut _, 0, MEM_RELEASE);
}
}
}
}
struct WhpVm {
partition: WHV_PARTITION_HANDLE,
_memory: GuestMemory,
state: Arc<RwLock<VmState>>,
resume_state: Option<VmState>,
stop_flag: Arc<AtomicBool>,
vcpu_thread: Option<std::thread::JoinHandle<()>>,
serial_log: std::path::PathBuf,
}
#[derive(Clone, Copy)]
struct SendablePartition(WHV_PARTITION_HANDLE);
unsafe impl Send for SendablePartition {}
pub struct WhpDriver {
vms: Mutex<HashMap<String, WhpVm>>,
}
impl Default for WhpDriver {
fn default() -> Self {
Self::new()
}
}
impl WhpDriver {
pub fn new() -> Self {
Self {
vms: Mutex::new(HashMap::new()),
}
}
pub fn is_available() -> bool {
let mut capability: WHV_CAPABILITY = unsafe { std::mem::zeroed() };
let result = unsafe {
WHvGetCapability(
WHvCapabilityCodeHypervisorPresent,
&mut capability as *mut _ as *mut std::ffi::c_void,
std::mem::size_of::<WHV_CAPABILITY>() as u32,
None,
)
};
result.is_ok() && unsafe { capability.HypervisorPresent.as_bool() }
}
}
impl VmDriver for WhpDriver {
fn boot(&self, config: &VmConfig) -> Result<VmHandle, VmError> {
let name = &config.name;
let memory_bytes = config.memory_mb * 1024 * 1024;
tracing::info!(
vm = %name,
cpus = config.cpus,
memory_mb = config.memory_mb,
"booting VM via WHP"
);
if !Self::is_available() {
return Err(VmError::Hypervisor(
"Windows Hypervisor Platform is not available. \
Enable Hyper-V in Windows Features."
.into(),
));
}
let partition = unsafe { WHvCreatePartition() }.map_err(|e| VmError::BootFailed {
name: name.clone(),
detail: format!("WHvCreatePartition failed: {e}"),
})?;
let mut prop: WHV_PARTITION_PROPERTY = unsafe { std::mem::zeroed() };
prop.ProcessorCount = config.cpus as u32;
unsafe {
WHvSetPartitionProperty(
partition,
WHvPartitionPropertyCodeProcessorCount,
&prop as *const _ as *const std::ffi::c_void,
std::mem::size_of::<WHV_PARTITION_PROPERTY>() as u32,
)
}
.map_err(|e| {
unsafe {
let _ = WHvDeletePartition(partition);
}
VmError::BootFailed {
name: name.clone(),
detail: format!("WHvSetPartitionProperty(ProcessorCount) failed: {e}"),
}
})?;
unsafe { WHvSetupPartition(partition) }.map_err(|e| {
unsafe {
let _ = WHvDeletePartition(partition);
}
VmError::BootFailed {
name: name.clone(),
detail: format!("WHvSetupPartition failed: {e}"),
}
})?;
let mut memory = GuestMemory::allocate(memory_bytes).inspect_err(|_| unsafe {
let _ = WHvDeletePartition(partition);
})?;
unsafe {
WHvMapGpaRange(
partition,
memory.as_ptr() as *const std::ffi::c_void,
0, memory_bytes as u64,
WHvMapGpaRangeFlagRead | WHvMapGpaRangeFlagWrite | WHvMapGpaRangeFlagExecute,
)
}
.map_err(|e| {
unsafe {
let _ = WHvDeletePartition(partition);
}
VmError::BootFailed {
name: name.clone(),
detail: format!("WHvMapGpaRange failed: {e}"),
}
})?;
let mem = unsafe { memory.as_mut_slice() };
boot::setup_page_tables(mem, config.memory_mb);
boot::setup_gdt(mem);
let default_cmdline = if config.root_disk.is_some() {
"console=ttyS0 root=/dev/vda1 rw"
} else {
"console=ttyS0"
};
let cmdline = config.cmdline.as_deref().unwrap_or(default_cmdline);
let entry_point = boot::load_kernel(
mem,
&config.kernel,
config.initramfs.as_deref(),
cmdline,
config.memory_mb,
)
.map_err(|mut e| {
if let VmError::BootFailed { ref mut name, .. } = e {
*name = config.name.clone();
}
unsafe {
let _ = WHvDeletePartition(partition);
}
e
})?;
unsafe { WHvCreateVirtualProcessor(partition, 0, 0) }.map_err(|e| {
unsafe {
let _ = WHvDeletePartition(partition);
}
VmError::BootFailed {
name: name.clone(),
detail: format!("WHvCreateVirtualProcessor failed: {e}"),
}
})?;
setup_initial_registers(partition, entry_point).inspect_err(|_| unsafe {
let _ = WHvDeleteVirtualProcessor(partition, 0);
let _ = WHvDeletePartition(partition);
})?;
let state = Arc::new(RwLock::new(VmState::Running));
let stop_flag = Arc::new(AtomicBool::new(false));
let serial_log = config.serial_log.clone();
let sendable = SendablePartition(partition);
let state_clone = Arc::clone(&state);
let stop_clone = Arc::clone(&stop_flag);
let log_path = serial_log.clone();
let vm_name = name.clone();
let vcpu_thread = std::thread::Builder::new()
.name(format!("vcpu-{}", name))
.spawn(move || {
vcpu_loop(sendable, state_clone, stop_clone, &log_path, &vm_name);
})
.map_err(|e| {
unsafe {
let _ = WHvDeleteVirtualProcessor(partition, 0);
let _ = WHvDeletePartition(partition);
}
VmError::BootFailed {
name: name.clone(),
detail: format!("failed to spawn vCPU thread: {e}"),
}
})?;
let vm = WhpVm {
partition,
_memory: memory,
state: Arc::clone(&state),
resume_state: None,
stop_flag: Arc::clone(&stop_flag),
vcpu_thread: Some(vcpu_thread),
serial_log: serial_log.clone(),
};
{
let mut vms = self
.vms
.lock()
.map_err(|e| VmError::Hypervisor(format!("VM lock poisoned: {e}")))?;
vms.insert(name.clone(), vm);
}
Ok(VmHandle {
name: name.clone(),
namespace: config.namespace.clone(),
state: VmState::Running,
process: None, serial_log,
machine_id: None,
})
}
fn stop(&self, handle: &VmHandle) -> Result<(), VmError> {
let mut vms = self
.vms
.lock()
.map_err(|e| VmError::Hypervisor(format!("VM lock poisoned: {e}")))?;
let vm = vms.get_mut(&handle.name).ok_or_else(|| VmError::NotFound {
name: handle.name.clone(),
})?;
vm.stop_flag.store(true, Ordering::Release);
let _ = unsafe { WHvCancelRunVirtualProcessor(vm.partition, 0, 0) };
if let Some(thread) = vm.vcpu_thread.take() {
let _ = thread.join();
}
if let Ok(mut state) = vm.state.write() {
*state = VmState::Stopped;
}
unsafe {
let _ = WHvDeleteVirtualProcessor(vm.partition, 0);
let _ = WHvDeletePartition(vm.partition);
}
vms.remove(&handle.name);
tracing::info!(vm = %handle.name, "VM stopped");
Ok(())
}
fn kill(&self, handle: &VmHandle) -> Result<(), VmError> {
self.stop(handle)
}
fn state(&self, handle: &VmHandle) -> Result<VmState, VmError> {
let vms = self
.vms
.lock()
.map_err(|e| VmError::Hypervisor(format!("VM lock poisoned: {e}")))?;
match vms.get(&handle.name) {
Some(vm) => {
let state = vm
.state
.read()
.map_err(|e| VmError::Hypervisor(format!("state lock poisoned: {e}")))?;
Ok(state.clone())
}
None => Ok(VmState::Stopped),
}
}
fn pause(&self, handle: &VmHandle) -> Result<(), VmError> {
let mut vms = self
.vms
.lock()
.map_err(|e| VmError::Hypervisor(format!("VM lock poisoned: {e}")))?;
let vm = vms.get_mut(&handle.name).ok_or_else(|| VmError::NotFound {
name: handle.name.clone(),
})?;
let current_state = vm
.state
.read()
.map_err(|e| VmError::Hypervisor(format!("state lock poisoned: {e}")))?
.clone();
if !current_state.is_running() {
return Err(VmError::Hypervisor("can only pause a running VM".into()));
}
vm.stop_flag.store(true, Ordering::Release);
let _ = unsafe { WHvCancelRunVirtualProcessor(vm.partition, 0, 0) };
if let Some(thread) = vm.vcpu_thread.take() {
let _ = thread.join();
}
vm.resume_state = Some(current_state);
if let Ok(mut state) = vm.state.write() {
*state = VmState::Paused;
}
tracing::info!(vm = %handle.name, "VM paused");
Ok(())
}
fn resume(&self, handle: &VmHandle) -> Result<(), VmError> {
let mut vms = self
.vms
.lock()
.map_err(|e| VmError::Hypervisor(format!("VM lock poisoned: {e}")))?;
let vm = vms.get_mut(&handle.name).ok_or_else(|| VmError::NotFound {
name: handle.name.clone(),
})?;
let current_state = vm
.state
.read()
.map_err(|e| VmError::Hypervisor(format!("state lock poisoned: {e}")))?
.clone();
if current_state != VmState::Paused {
return Err(VmError::Hypervisor("can only resume a paused VM".into()));
}
if vm.vcpu_thread.is_some() {
return Err(VmError::Hypervisor(format!(
"VM '{}' already has an active vCPU thread",
handle.name
)));
}
vm.stop_flag.store(false, Ordering::Release);
let resumed_state = vm.resume_state.clone().unwrap_or(VmState::Running);
let sendable = SendablePartition(vm.partition);
let state_clone = Arc::clone(&vm.state);
let stop_clone = Arc::clone(&vm.stop_flag);
let log_path = vm.serial_log.clone();
let vm_name = handle.name.clone();
let thread = std::thread::Builder::new()
.name(format!("vcpu-{}", handle.name))
.spawn(move || {
vcpu_loop(sendable, state_clone, stop_clone, &log_path, &vm_name);
})
.map_err(|e| VmError::Hypervisor(format!("failed to spawn vCPU thread: {e}")))?;
vm.vcpu_thread = Some(thread);
vm.resume_state = None;
if let Ok(mut state) = vm.state.write() {
*state = resumed_state;
}
tracing::info!(vm = %handle.name, "VM resumed");
Ok(())
}
}
impl Drop for WhpVm {
fn drop(&mut self) {
self.stop_flag.store(true, Ordering::Release);
let _ = unsafe { WHvCancelRunVirtualProcessor(self.partition, 0, 0) };
if let Some(thread) = self.vcpu_thread.take() {
let _ = thread.join();
}
unsafe {
let _ = WHvDeleteVirtualProcessor(self.partition, 0);
let _ = WHvDeletePartition(self.partition);
}
}
}
fn setup_initial_registers(
partition: WHV_PARTITION_HANDLE,
entry_point: u64,
) -> Result<(), VmError> {
let reg_names = [
WHV_REGISTER_NAME(WHvX64RegisterRip.0),
WHV_REGISTER_NAME(WHvX64RegisterRsp.0),
WHV_REGISTER_NAME(WHvX64RegisterRflags.0),
WHV_REGISTER_NAME(WHvX64RegisterCr0.0),
WHV_REGISTER_NAME(WHvX64RegisterCr3.0),
WHV_REGISTER_NAME(WHvX64RegisterCr4.0),
WHV_REGISTER_NAME(WHvX64RegisterEfer.0),
WHV_REGISTER_NAME(WHvX64RegisterRsi.0), ];
let reg_values = [
WHV_REGISTER_VALUE { Reg64: entry_point },
WHV_REGISTER_VALUE {
Reg64: BOOT_PARAMS_ADDR - 0x10,
},
WHV_REGISTER_VALUE { Reg64: 0x2 },
WHV_REGISTER_VALUE { Reg64: 0x8001_0001 },
WHV_REGISTER_VALUE { Reg64: PML4_ADDR },
WHV_REGISTER_VALUE { Reg64: 0x20 },
WHV_REGISTER_VALUE { Reg64: 0xD00 },
WHV_REGISTER_VALUE {
Reg64: BOOT_PARAMS_ADDR,
},
];
unsafe {
WHvSetVirtualProcessorRegisters(
partition,
0, reg_names.as_ptr(),
reg_names.len() as u32,
reg_values.as_ptr(),
)
}
.map_err(|e| VmError::BootFailed {
name: String::new(),
detail: format!("WHvSetVirtualProcessorRegisters (general) failed: {e}"),
})?;
set_segment_registers(partition)?;
Ok(())
}
fn set_segment_registers(partition: WHV_PARTITION_HANDLE) -> Result<(), VmError> {
let code_segment = WHV_X64_SEGMENT_REGISTER {
Base: 0,
Limit: 0xFFFF_FFFF,
Selector: GDT_CODE_SELECTOR,
Anonymous: WHV_X64_SEGMENT_REGISTER_0 {
Attributes: 0xA09B, },
};
let data_segment = WHV_X64_SEGMENT_REGISTER {
Base: 0,
Limit: 0xFFFF_FFFF,
Selector: GDT_DATA_SELECTOR,
Anonymous: WHV_X64_SEGMENT_REGISTER_0 {
Attributes: 0xC093, },
};
let gdt_table = WHV_X64_TABLE_REGISTER {
Pad: [0; 3],
Base: GDT_ADDR,
Limit: 23, };
let names = [WHV_REGISTER_NAME(WHvX64RegisterCs.0)];
let values = [WHV_REGISTER_VALUE {
Segment: code_segment,
}];
unsafe { WHvSetVirtualProcessorRegisters(partition, 0, names.as_ptr(), 1, values.as_ptr()) }
.map_err(|e| VmError::BootFailed {
name: String::new(),
detail: format!("failed to set CS: {e}"),
})?;
for reg in [WHvX64RegisterDs, WHvX64RegisterEs, WHvX64RegisterSs] {
let names = [WHV_REGISTER_NAME(reg.0)];
let values = [WHV_REGISTER_VALUE {
Segment: data_segment,
}];
unsafe {
WHvSetVirtualProcessorRegisters(partition, 0, names.as_ptr(), 1, values.as_ptr())
}
.map_err(|e| VmError::BootFailed {
name: String::new(),
detail: format!("failed to set segment register: {e}"),
})?;
}
let names = [WHV_REGISTER_NAME(WHvX64RegisterGdtr.0)];
let values = [WHV_REGISTER_VALUE { Table: gdt_table }];
unsafe { WHvSetVirtualProcessorRegisters(partition, 0, names.as_ptr(), 1, values.as_ptr()) }
.map_err(|e| VmError::BootFailed {
name: String::new(),
detail: format!("failed to set GDTR: {e}"),
})?;
Ok(())
}
fn vcpu_loop(
partition: SendablePartition,
state: Arc<RwLock<VmState>>,
stop_flag: Arc<AtomicBool>,
serial_log_path: &Path,
vm_name: &str,
) {
let partition = partition.0;
let mut serial_file = match std::fs::File::create(serial_log_path) {
Ok(f) => f,
Err(e) => {
tracing::error!(vm = %vm_name, "failed to create serial log: {e}");
update_state(
&state,
VmState::Failed {
reason: format!("failed to create serial log: {e}"),
},
);
return;
}
};
let mut serial_buffer = String::new();
loop {
if stop_flag.load(Ordering::Acquire) {
break;
}
let mut exit_context: WHV_RUN_VP_EXIT_CONTEXT = unsafe { std::mem::zeroed() };
let result = unsafe {
WHvRunVirtualProcessor(
partition,
0,
&mut exit_context as *mut _ as *mut std::ffi::c_void,
std::mem::size_of::<WHV_RUN_VP_EXIT_CONTEXT>() as u32,
)
};
if let Err(e) = result {
tracing::error!(vm = %vm_name, "WHvRunVirtualProcessor failed: {e}");
update_state(
&state,
VmState::Failed {
reason: format!("vCPU execution failed: {e}"),
},
);
break;
}
let exit_reason = exit_context.ExitReason;
if exit_reason == WHvRunVpExitReasonX64IoPortAccess {
let io = unsafe { &exit_context.Anonymous.IoPortAccess };
handle_io_port(
partition,
&exit_context,
io,
&mut serial_file,
&mut serial_buffer,
&state,
vm_name,
);
} else if exit_reason == WHvRunVpExitReasonX64Halt {
let rflags = exit_context.VpContext.Rflags;
if rflags & 0x200 != 0 {
std::thread::sleep(std::time::Duration::from_millis(10));
advance_rip(partition, &exit_context);
} else {
tracing::info!(vm = %vm_name, "VM halted (interrupts disabled)");
update_state(&state, VmState::Stopped);
break;
}
} else if exit_reason == WHvRunVpExitReasonCanceled {
tracing::debug!(vm = %vm_name, "vCPU execution cancelled");
break;
} else if exit_reason == WHvRunVpExitReasonMemoryAccess {
let gpa = unsafe { exit_context.Anonymous.MemoryAccess.Gpa };
let reason = format!("unmapped memory access at GPA 0x{:x}", gpa);
tracing::error!(vm = %vm_name, "{}", reason);
update_state(&state, VmState::Failed { reason });
break;
} else {
tracing::debug!(
vm = %vm_name,
exit_reason = exit_reason.0,
"unhandled VM exit"
);
advance_rip(partition, &exit_context);
}
}
}
fn handle_io_port(
partition: WHV_PARTITION_HANDLE,
exit_context: &WHV_RUN_VP_EXIT_CONTEXT,
io: &WHV_X64_IO_PORT_ACCESS_CONTEXT,
serial_file: &mut std::fs::File,
serial_buffer: &mut String,
state: &Arc<RwLock<VmState>>,
vm_name: &str,
) {
let port = io.PortNumber;
let is_write = unsafe { io.AccessInfo.AsUINT32 } & 1 != 0;
if is_write && port == COM1_DATA {
let byte = (io.Rax & 0xFF) as u8;
let _ = serial_file.write_all(&[byte]);
let _ = serial_file.flush();
if byte.is_ascii() {
serial_buffer.push(byte as char);
if let Some(pos) = serial_buffer.find(crate::config::READY_MARKER) {
let after = &serial_buffer[pos + crate::config::READY_MARKER.len()..];
if let Some(ip) = after.split_whitespace().next() {
let ip = ip.trim().to_string();
if !ip.is_empty() {
tracing::info!(vm = %vm_name, ip = %ip, "VM ready");
update_state(state, VmState::Ready { ip });
serial_buffer.clear();
}
}
}
}
} else if !is_write && port == COM1_LSR {
set_rax(partition, 0x60);
} else if !is_write {
set_rax(partition, 0);
}
advance_rip(partition, exit_context);
}
fn advance_rip(partition: WHV_PARTITION_HANDLE, exit_context: &WHV_RUN_VP_EXIT_CONTEXT) {
let new_rip = exit_context.VpContext.Rip + (exit_context.VpContext._bitfield & 0xF) as u64;
let names = [WHV_REGISTER_NAME(WHvX64RegisterRip.0)];
let values = [WHV_REGISTER_VALUE { Reg64: new_rip }];
let _ = unsafe {
WHvSetVirtualProcessorRegisters(partition, 0, names.as_ptr(), 1, values.as_ptr())
};
}
fn set_rax(partition: WHV_PARTITION_HANDLE, value: u64) {
let names = [WHV_REGISTER_NAME(WHvX64RegisterRax.0)];
let values = [WHV_REGISTER_VALUE { Reg64: value }];
let _ = unsafe {
WHvSetVirtualProcessorRegisters(partition, 0, names.as_ptr(), 1, values.as_ptr())
};
}
fn update_state(state: &Arc<RwLock<VmState>>, new_state: VmState) {
if let Ok(mut s) = state.write() {
*s = new_state;
}
}