#[cfg(target_os = "windows")]
use std::collections::HashSet;
#[cfg(kvm)]
use kvm_bindings::kvm_fpu;
#[cfg(mshv3)]
use mshv_bindings::FloatingPointUnit;
#[cfg(target_os = "windows")]
use super::Align16;
#[cfg(target_os = "windows")]
use crate::hypervisor::regs::FromWhpRegisterError;
pub(crate) const FP_CONTROL_WORD_DEFAULT: u16 = 0x37f; pub(crate) const MXCSR_DEFAULT: u32 = 0x1f80;
#[derive(Debug, Clone, Copy, PartialEq)]
pub(crate) struct CommonFpu {
pub fpr: [[u8; 16]; 8],
pub fcw: u16,
pub fsw: u16,
pub ftwx: u8,
pub last_opcode: u16,
pub last_ip: u64,
pub last_dp: u64,
pub xmm: [[u8; 16]; 16],
pub mxcsr: u32,
}
impl Default for CommonFpu {
fn default() -> Self {
Self {
fpr: [[0u8; 16]; 8],
fcw: FP_CONTROL_WORD_DEFAULT,
fsw: 0,
ftwx: 0,
last_opcode: 0,
last_ip: 0,
last_dp: 0,
xmm: [[0u8; 16]; 16],
mxcsr: MXCSR_DEFAULT,
}
}
}
#[cfg(kvm)]
impl From<&CommonFpu> for kvm_fpu {
fn from(common_fpu: &CommonFpu) -> Self {
kvm_fpu {
fpr: common_fpu.fpr,
fcw: common_fpu.fcw,
fsw: common_fpu.fsw,
ftwx: common_fpu.ftwx,
pad1: 0,
last_opcode: common_fpu.last_opcode,
last_ip: common_fpu.last_ip,
last_dp: common_fpu.last_dp,
xmm: common_fpu.xmm,
mxcsr: common_fpu.mxcsr,
pad2: 0,
}
}
}
#[cfg(mshv3)]
impl From<&CommonFpu> for FloatingPointUnit {
fn from(common_fpu: &CommonFpu) -> FloatingPointUnit {
FloatingPointUnit {
fpr: common_fpu.fpr,
fcw: common_fpu.fcw,
fsw: common_fpu.fsw,
ftwx: common_fpu.ftwx,
pad1: 0,
last_opcode: common_fpu.last_opcode,
last_ip: common_fpu.last_ip,
last_dp: common_fpu.last_dp,
xmm: common_fpu.xmm,
mxcsr: common_fpu.mxcsr,
pad2: 0,
}
}
}
#[cfg(kvm)]
impl From<&kvm_fpu> for CommonFpu {
fn from(kvm_fpu: &kvm_fpu) -> Self {
Self {
fpr: kvm_fpu.fpr,
fcw: kvm_fpu.fcw,
fsw: kvm_fpu.fsw,
ftwx: kvm_fpu.ftwx,
last_opcode: kvm_fpu.last_opcode,
last_ip: kvm_fpu.last_ip,
last_dp: kvm_fpu.last_dp,
xmm: kvm_fpu.xmm,
mxcsr: kvm_fpu.mxcsr,
}
}
}
#[cfg(mshv3)]
impl From<&FloatingPointUnit> for CommonFpu {
fn from(mshv_fpu: &FloatingPointUnit) -> Self {
Self {
fpr: mshv_fpu.fpr,
fcw: mshv_fpu.fcw,
fsw: mshv_fpu.fsw,
ftwx: mshv_fpu.ftwx,
last_opcode: mshv_fpu.last_opcode,
last_ip: mshv_fpu.last_ip,
last_dp: mshv_fpu.last_dp,
xmm: mshv_fpu.xmm,
mxcsr: mshv_fpu.mxcsr,
}
}
}
#[cfg(target_os = "windows")]
use windows::Win32::System::Hypervisor::*;
#[cfg(target_os = "windows")]
impl From<&CommonFpu> for [(WHV_REGISTER_NAME, Align16<WHV_REGISTER_VALUE>); WHP_FPU_NAMES_LEN] {
fn from(fpu: &CommonFpu) -> Self {
let mut regs: [(WHV_REGISTER_NAME, Align16<WHV_REGISTER_VALUE>); WHP_FPU_NAMES_LEN] =
[Default::default(); WHP_FPU_NAMES_LEN];
let mut idx = 0;
for (i, reg) in fpu.fpr.iter().enumerate() {
let mut value = WHV_REGISTER_VALUE::default();
value.Reg128 = WHV_UINT128 {
Dword: [
u32::from_le_bytes([reg[0], reg[1], reg[2], reg[3]]),
u32::from_le_bytes([reg[4], reg[5], reg[6], reg[7]]),
u32::from_le_bytes([reg[8], reg[9], reg[10], reg[11]]),
u32::from_le_bytes([reg[12], reg[13], reg[14], reg[15]]),
],
};
regs[idx] = (
WHV_REGISTER_NAME(WHvX64RegisterFpMmx0.0 + i as i32),
Align16(value),
);
idx += 1;
}
let mut fp_control_status = WHV_REGISTER_VALUE::default();
fp_control_status.FpControlStatus = WHV_X64_FP_CONTROL_STATUS_REGISTER {
Anonymous: WHV_X64_FP_CONTROL_STATUS_REGISTER_0 {
FpControl: fpu.fcw,
FpStatus: fpu.fsw,
FpTag: fpu.ftwx,
Reserved: 0,
LastFpOp: fpu.last_opcode,
Anonymous: WHV_X64_FP_CONTROL_STATUS_REGISTER_0_0 {
LastFpRip: fpu.last_ip,
},
},
};
regs[idx] = (WHvX64RegisterFpControlStatus, Align16(fp_control_status));
idx += 1;
for (i, reg) in fpu.xmm.iter().enumerate() {
let mut value = WHV_REGISTER_VALUE::default();
value.Reg128 = WHV_UINT128 {
Dword: [
u32::from_le_bytes([reg[0], reg[1], reg[2], reg[3]]),
u32::from_le_bytes([reg[4], reg[5], reg[6], reg[7]]),
u32::from_le_bytes([reg[8], reg[9], reg[10], reg[11]]),
u32::from_le_bytes([reg[12], reg[13], reg[14], reg[15]]),
],
};
regs[idx] = (
WHV_REGISTER_NAME(WHvX64RegisterXmm0.0 + i as i32),
Align16(value),
);
idx += 1;
}
let mut xmm_control_status = WHV_REGISTER_VALUE::default();
xmm_control_status.XmmControlStatus = WHV_X64_XMM_CONTROL_STATUS_REGISTER {
Anonymous: WHV_X64_XMM_CONTROL_STATUS_REGISTER_0 {
XmmStatusControl: fpu.mxcsr,
XmmStatusControlMask: !0,
Anonymous: WHV_X64_XMM_CONTROL_STATUS_REGISTER_0_0 {
LastFpRdp: fpu.last_dp,
},
},
};
regs[idx] = (WHvX64RegisterXmmControlStatus, Align16(xmm_control_status));
regs
}
}
#[cfg(target_os = "windows")]
pub(crate) const WHP_FPU_NAMES_LEN: usize = 26;
#[cfg(target_os = "windows")]
pub(crate) const WHP_FPU_NAMES: [WHV_REGISTER_NAME; WHP_FPU_NAMES_LEN] = [
WHvX64RegisterFpMmx0,
WHvX64RegisterFpMmx1,
WHvX64RegisterFpMmx2,
WHvX64RegisterFpMmx3,
WHvX64RegisterFpMmx4,
WHvX64RegisterFpMmx5,
WHvX64RegisterFpMmx6,
WHvX64RegisterFpMmx7,
WHvX64RegisterFpControlStatus,
WHvX64RegisterXmm0,
WHvX64RegisterXmm1,
WHvX64RegisterXmm2,
WHvX64RegisterXmm3,
WHvX64RegisterXmm4,
WHvX64RegisterXmm5,
WHvX64RegisterXmm6,
WHvX64RegisterXmm7,
WHvX64RegisterXmm8,
WHvX64RegisterXmm9,
WHvX64RegisterXmm10,
WHvX64RegisterXmm11,
WHvX64RegisterXmm12,
WHvX64RegisterXmm13,
WHvX64RegisterXmm14,
WHvX64RegisterXmm15,
WHvX64RegisterXmmControlStatus,
];
#[cfg(target_os = "windows")]
impl TryFrom<&[(WHV_REGISTER_NAME, Align16<WHV_REGISTER_VALUE>)]> for CommonFpu {
type Error = FromWhpRegisterError;
fn try_from(
regs: &[(WHV_REGISTER_NAME, Align16<WHV_REGISTER_VALUE>)],
) -> Result<Self, Self::Error> {
if regs.len() != WHP_FPU_NAMES_LEN {
return Err(FromWhpRegisterError::InvalidLength(regs.len()));
}
let mut fpu = CommonFpu::default();
let mut seen_registers = HashSet::new();
for (name, value) in regs {
let name_id = name.0;
if !seen_registers.insert(name_id) {
return Err(FromWhpRegisterError::DuplicateRegister(name_id));
}
match name_id {
id if (WHvX64RegisterFpMmx0.0..WHvX64RegisterFpMmx0.0 + 8).contains(&id) => {
let idx = (id - WHvX64RegisterFpMmx0.0) as usize;
let dwords = unsafe { value.0.Reg128.Dword };
fpu.fpr[idx] = [
dwords[0].to_le_bytes(),
dwords[1].to_le_bytes(),
dwords[2].to_le_bytes(),
dwords[3].to_le_bytes(),
]
.concat()
.try_into()
.map_err(|_| FromWhpRegisterError::InvalidEncoding)?;
}
id if id == WHvX64RegisterFpControlStatus.0 => {
let control = unsafe { value.0.FpControlStatus.Anonymous };
fpu.fcw = control.FpControl;
fpu.fsw = control.FpStatus;
fpu.ftwx = control.FpTag;
fpu.last_opcode = control.LastFpOp;
fpu.last_ip = unsafe { control.Anonymous.LastFpRip };
}
id if (WHvX64RegisterXmm0.0..WHvX64RegisterXmm0.0 + 16).contains(&id) => {
let idx = (id - WHvX64RegisterXmm0.0) as usize;
let dwords = unsafe { value.0.Reg128.Dword };
fpu.xmm[idx] = [
dwords[0].to_le_bytes(),
dwords[1].to_le_bytes(),
dwords[2].to_le_bytes(),
dwords[3].to_le_bytes(),
]
.concat()
.try_into()
.map_err(|_| FromWhpRegisterError::InvalidEncoding)?;
}
id if id == WHvX64RegisterXmmControlStatus.0 => {
let control = unsafe { value.0.XmmControlStatus.Anonymous };
fpu.mxcsr = control.XmmStatusControl;
fpu.last_dp = unsafe { control.Anonymous.LastFpRdp };
}
_ => {
return Err(FromWhpRegisterError::InvalidRegister(name_id));
}
}
}
let expected_registers: HashSet<i32> = WHP_FPU_NAMES.iter().map(|reg| reg.0).collect();
let missing: HashSet<i32> = expected_registers
.difference(&seen_registers)
.cloned()
.collect();
if !missing.is_empty() {
return Err(FromWhpRegisterError::MissingRegister(missing));
}
Ok(fpu)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_common_fpu() -> CommonFpu {
CommonFpu {
fpr: [
[1u8; 16], [2u8; 16], [3u8; 16], [4u8; 16], [5u8; 16], [6u8; 16], [7u8; 16],
[8u8; 16],
],
fcw: 0x1234,
fsw: 0x5678,
ftwx: 0x9a,
last_opcode: 0xdef0,
last_ip: 0xdeadbeefcafebabe,
last_dp: 0xabad1deaf00dbabe,
xmm: [
[8u8; 16], [9u8; 16], [10u8; 16], [11u8; 16], [12u8; 16], [13u8; 16], [14u8; 16],
[15u8; 16], [16u8; 16], [17u8; 16], [18u8; 16], [19u8; 16], [20u8; 16], [21u8; 16],
[22u8; 16], [23u8; 16],
],
mxcsr: 0x1f80,
}
}
#[cfg(kvm)]
#[test]
fn round_trip_kvm_fpu() {
use kvm_bindings::kvm_fpu;
let original = sample_common_fpu();
let kvm: kvm_fpu = (&original).into();
let round_tripped = CommonFpu::from(&kvm);
assert_eq!(original, round_tripped);
}
#[cfg(mshv3)]
#[test]
fn round_trip_mshv_fpu() {
use mshv_bindings::FloatingPointUnit;
let original = sample_common_fpu();
let mshv: FloatingPointUnit = (&original).into();
let round_tripped = CommonFpu::from(&mshv);
assert_eq!(original, round_tripped);
}
#[cfg(target_os = "windows")]
#[test]
fn round_trip_windows_fpu() {
use windows::Win32::System::Hypervisor::*;
let original = sample_common_fpu();
let windows: [(WHV_REGISTER_NAME, Align16<WHV_REGISTER_VALUE>); WHP_FPU_NAMES_LEN] =
(&original).into();
let round_tripped = CommonFpu::try_from(windows.as_ref()).unwrap();
assert_eq!(original, round_tripped);
let original = sample_common_fpu();
let mut windows: [(WHV_REGISTER_NAME, Align16<WHV_REGISTER_VALUE>); WHP_FPU_NAMES_LEN] =
(&original).into();
windows[0].0 = WHvX64RegisterFpMmx1;
let err = CommonFpu::try_from(windows.as_ref()).unwrap_err();
assert_eq!(
err,
FromWhpRegisterError::DuplicateRegister(WHvX64RegisterFpMmx1.0)
);
let original = sample_common_fpu();
let mut windows: [(WHV_REGISTER_NAME, Align16<WHV_REGISTER_VALUE>); WHP_FPU_NAMES_LEN] =
(&original).into();
windows[0] = (WHvX64RegisterRax, windows[0].1);
let err = CommonFpu::try_from(windows.as_ref()).unwrap_err();
assert_eq!(
err,
FromWhpRegisterError::InvalidRegister(WHvX64RegisterRax.0)
);
}
}