use alloc::boxed::Box;
use core::arch::x86_64::{_fxrstor64, _fxsave64, _xrstor64, _xsave64};
use bitflags::bitflags;
use cfg_if::cfg_if;
use ostd_pod::{FromZeros, IntoBytes};
use spin::Once;
use x86::bits64::segmentation::wrfsbase;
use x86_64::registers::{
control::{Cr0, Cr0Flags},
rflags::RFlags,
xcontrol::XCr0,
};
use crate::{
arch::{
irq::HwIrqLine,
trap::{RawUserContext, TrapFrame},
},
cpu::PrivilegeLevel,
debug,
irq::call_irq_callback_functions,
mm::Vaddr,
user::{ReturnReason, UserContextApi, UserContextApiInternal},
};
cfg_if! {
if #[cfg(feature = "cvm_guest")] {
mod tdx;
use tdx::VirtualizationExceptionHandler;
}
}
#[repr(C)]
#[derive(Clone, Debug, Default)]
pub struct UserContext {
user_context: RawUserContext,
exception: Option<CpuException>,
}
#[expect(missing_docs)]
#[repr(C)]
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub struct GeneralRegs {
pub rax: usize,
pub rbx: usize,
pub rcx: usize,
pub rdx: usize,
pub rsi: usize,
pub rdi: usize,
pub rbp: usize,
pub rsp: usize,
pub r8: usize,
pub r9: usize,
pub r10: usize,
pub r11: usize,
pub r12: usize,
pub r13: usize,
pub r14: usize,
pub r15: usize,
pub rip: usize,
pub rflags: usize,
pub fsbase: usize,
pub gsbase: usize,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum CpuException {
DivisionError,
Debug,
NonMaskableInterrupt,
BreakPoint,
Overflow,
BoundRangeExceeded,
InvalidOpcode,
DeviceNotAvailable,
DoubleFault,
CoprocessorSegmentOverrun,
InvalidTss(SelectorErrorCode),
SegmentNotPresent(SelectorErrorCode),
StackSegmentFault(SelectorErrorCode),
GeneralProtectionFault(Option<SelectorErrorCode>),
PageFault(RawPageFaultInfo),
X87FloatingPointException,
AlignmentCheck,
MachineCheck,
SIMDFloatingPointException,
VirtualizationException,
ControlProtectionException,
HypervisorInjectionException,
VMMCommunicationException,
SecurityException,
Reserved,
}
impl CpuException {
pub(crate) fn new(trap_num: usize, error_code: usize) -> Option<Self> {
let exception = match trap_num {
0 => Self::DivisionError,
1 => Self::Debug,
2 => Self::NonMaskableInterrupt,
3 => Self::BreakPoint,
4 => Self::Overflow,
5 => Self::BoundRangeExceeded,
6 => Self::InvalidOpcode,
7 => Self::DeviceNotAvailable,
8 => {
debug_assert_eq!(error_code, 0);
Self::DoubleFault
}
9 => Self::CoprocessorSegmentOverrun,
10 => Self::InvalidTss(SelectorErrorCode(error_code)),
11 => Self::SegmentNotPresent(SelectorErrorCode(error_code)),
12 => Self::StackSegmentFault(SelectorErrorCode(error_code)),
13 => {
let error_code = if error_code == 0 {
None
} else {
Some(SelectorErrorCode(error_code))
};
Self::GeneralProtectionFault(error_code)
}
14 => {
let page_fault_addr = x86_64::registers::control::Cr2::read_raw() as usize;
Self::PageFault(RawPageFaultInfo {
error_code: PageFaultErrorCode::from_bits(error_code).unwrap(),
addr: page_fault_addr,
})
}
16 => Self::X87FloatingPointException,
17 => Self::AlignmentCheck,
18 => Self::MachineCheck,
19 => Self::SIMDFloatingPointException,
20 => Self::VirtualizationException,
21 => Self::ControlProtectionException,
28 => Self::HypervisorInjectionException,
29 => Self::VMMCommunicationException,
30 => Self::SecurityException,
15 | 22..=27 | 31 => Self::Reserved,
_ => return None,
};
Some(exception)
}
const fn type_(&self) -> CpuExceptionType {
match self {
Self::Debug => CpuExceptionType::FaultOrTrap,
Self::NonMaskableInterrupt => CpuExceptionType::Interrupt,
Self::BreakPoint | Self::Overflow => CpuExceptionType::Trap,
Self::DoubleFault | Self::MachineCheck => CpuExceptionType::Abort,
Self::Reserved => CpuExceptionType::Reserved,
_ => CpuExceptionType::Fault,
}
}
pub(crate) const fn is_cpu_exception(trap_num: usize) -> bool {
trap_num <= 31
}
}
#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)]
pub struct SelectorErrorCode(usize);
impl UserContext {
pub fn general_regs(&self) -> &GeneralRegs {
&self.user_context.general
}
pub fn general_regs_mut(&mut self) -> &mut GeneralRegs {
&mut self.user_context.general
}
pub fn take_exception(&mut self) -> Option<CpuException> {
self.exception.take()
}
pub fn set_tls_pointer(&mut self, tls: usize) {
self.set_fsbase(tls)
}
pub fn tls_pointer(&self) -> usize {
self.fsbase()
}
pub fn activate_tls_pointer(&self) {
unsafe { wrfsbase(self.fsbase() as u64) }
}
}
impl UserContextApiInternal for UserContext {
fn execute<F>(&mut self, mut has_kernel_event: F) -> ReturnReason
where
F: FnMut() -> bool,
{
self.user_context.general.rflags |= (RFlags::INTERRUPT_FLAG | RFlags::ID).bits() as usize;
const SYSCALL_TRAPNUM: usize = 0x100;
loop {
crate::task::scheduler::might_preempt();
self.user_context.run();
let exception =
CpuException::new(self.user_context.trap_num, self.user_context.error_code);
match exception {
#[cfg(feature = "cvm_guest")]
Some(CpuException::VirtualizationException) => {
let ve_handler = VirtualizationExceptionHandler::new();
crate::arch::irq::enable_local();
ve_handler.handle(self);
}
Some(exception) if exception.type_().is_fault_or_trap() => {
crate::arch::irq::enable_local();
self.exception = Some(exception);
return ReturnReason::UserException;
}
Some(exception) => {
panic!(
"cannot handle user CPU exception: {:?}, trapframe: {:?}",
exception,
self.as_trap_frame()
);
}
None if self.user_context.trap_num == SYSCALL_TRAPNUM => {
crate::arch::irq::enable_local();
return ReturnReason::UserSyscall;
}
None => {
call_irq_callback_functions(
&self.as_trap_frame(),
&HwIrqLine::new(self.as_trap_frame().trap_num as u8),
PrivilegeLevel::User,
);
crate::arch::irq::enable_local();
}
}
if has_kernel_event() {
break ReturnReason::KernelEvent;
}
}
}
fn as_trap_frame(&self) -> TrapFrame {
TrapFrame {
rax: self.user_context.general.rax,
rbx: self.user_context.general.rbx,
rcx: self.user_context.general.rcx,
rdx: self.user_context.general.rdx,
rsi: self.user_context.general.rsi,
rdi: self.user_context.general.rdi,
rbp: self.user_context.general.rbp,
rsp: self.user_context.general.rsp,
r8: self.user_context.general.r8,
r9: self.user_context.general.r9,
r10: self.user_context.general.r10,
r11: self.user_context.general.r11,
r12: self.user_context.general.r12,
r13: self.user_context.general.r13,
r14: self.user_context.general.r14,
r15: self.user_context.general.r15,
_pad: 0,
trap_num: self.user_context.trap_num,
error_code: self.user_context.error_code,
rip: self.user_context.general.rip,
cs: 0,
rflags: self.user_context.general.rflags,
}
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum CpuExceptionType {
Fault,
Trap,
FaultOrTrap,
Interrupt,
Abort,
Reserved,
}
impl CpuExceptionType {
pub fn is_fault_or_trap(self) -> bool {
match self {
CpuExceptionType::Trap | CpuExceptionType::Fault | CpuExceptionType::FaultOrTrap => {
true
}
CpuExceptionType::Abort | CpuExceptionType::Interrupt | CpuExceptionType::Reserved => {
false
}
}
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct RawPageFaultInfo {
pub error_code: PageFaultErrorCode,
pub addr: Vaddr,
}
bitflags! {
pub struct PageFaultErrorCode : usize{
const PRESENT = 1 << 0;
const WRITE = 1 << 1;
const USER = 1 << 2;
const RESERVED = 1 << 3;
const INSTRUCTION = 1 << 4;
const PROTECTION = 1 << 5;
const SHADOW_STACK = 1 << 6;
const HLAT = 1 << 7;
const SGX = 1 << 15;
}
}
impl UserContextApi for UserContext {
fn trap_number(&self) -> usize {
self.user_context.trap_num
}
fn trap_error_code(&self) -> usize {
self.user_context.error_code
}
fn set_instruction_pointer(&mut self, ip: usize) {
self.set_rip(ip);
}
fn set_stack_pointer(&mut self, sp: usize) {
self.set_rsp(sp)
}
fn stack_pointer(&self) -> usize {
self.rsp()
}
fn instruction_pointer(&self) -> usize {
self.rip()
}
}
macro_rules! cpu_context_impl_getter_setter {
( $( [ $field: ident, $setter_name: ident] ),*) => {
impl UserContext {
$(
#[doc = concat!("Gets the value of ", stringify!($field))]
#[inline(always)]
pub fn $field(&self) -> usize {
self.user_context.general.$field
}
#[doc = concat!("Sets the value of ", stringify!(field))]
#[inline(always)]
pub fn $setter_name(&mut self, $field: usize) {
self.user_context.general.$field = $field;
}
)*
}
};
}
cpu_context_impl_getter_setter!(
[rax, set_rax],
[rbx, set_rbx],
[rcx, set_rcx],
[rdx, set_rdx],
[rsi, set_rsi],
[rdi, set_rdi],
[rbp, set_rbp],
[rsp, set_rsp],
[r8, set_r8],
[r9, set_r9],
[r10, set_r10],
[r11, set_r11],
[r12, set_r12],
[r13, set_r13],
[r14, set_r14],
[r15, set_r15],
[rip, set_rip],
[rflags, set_rflags],
[fsbase, set_fsbase],
[gsbase, set_gsbase]
);
#[derive(Debug)]
pub struct FpuContext {
xsave_area: Box<XSaveArea>,
area_size: usize,
}
impl FpuContext {
pub fn new() -> Self {
let mut area_size = size_of::<FxSaveArea>();
if let Some(xsave_area_size) = XSAVE_AREA_SIZE.get() {
area_size = area_size.max(*xsave_area_size);
}
Self {
xsave_area: Box::new(XSaveArea::new()),
area_size,
}
}
pub fn save(&mut self) {
let mem_addr = self.as_bytes_mut().as_mut_ptr();
if XSTATE_MAX_FEATURES.is_completed() {
unsafe { _xsave64(mem_addr, XFEATURE_MASK_USER_RESTORE) };
} else {
unsafe { _fxsave64(mem_addr) };
}
debug!("Save FPU context");
}
pub fn load(&mut self) {
let mem_addr = self.as_bytes().as_ptr();
if let Some(xstate_max_features) = XSTATE_MAX_FEATURES.get() {
let rs_mask = XFEATURE_MASK_USER_RESTORE & *xstate_max_features;
unsafe { _xrstor64(mem_addr, rs_mask) };
} else {
unsafe { _fxrstor64(mem_addr) };
}
debug!("Load FPU context");
}
pub fn as_bytes(&self) -> &[u8] {
&self.xsave_area.as_bytes()[..self.area_size]
}
pub fn as_bytes_mut(&mut self) -> &mut [u8] {
&mut self.xsave_area.as_mut_bytes()[..self.area_size]
}
}
impl Default for FpuContext {
fn default() -> Self {
Self::new()
}
}
impl Clone for FpuContext {
fn clone(&self) -> Self {
let mut xsave_area = Box::new(XSaveArea::new());
xsave_area.fxsave_area = self.xsave_area.fxsave_area;
xsave_area.features = self.xsave_area.features;
xsave_area.compaction = self.xsave_area.compaction;
if self.area_size > size_of::<FxSaveArea>() {
let len = self.area_size - size_of::<FxSaveArea>() - 64;
xsave_area.extended_state_area[..len]
.copy_from_slice(&self.xsave_area.extended_state_area[..len]);
}
Self {
xsave_area,
area_size: self.area_size,
}
}
}
#[repr(C)]
#[repr(align(64))]
#[derive(Clone, Copy, Debug, Pod)]
struct XSaveArea {
fxsave_area: FxSaveArea,
features: u64,
compaction: u64,
reserved: [u64; 6],
extended_state_area: [u8; MAX_XSAVE_AREA_SIZE - size_of::<FxSaveArea>() - 64],
}
impl XSaveArea {
fn new() -> Self {
let features = if let Some(xstate_max_features) = XSTATE_MAX_FEATURES.get() {
XCr0::read().bits() & *xstate_max_features
} else {
0
};
let mut xsave_area = Self::new_zeroed();
xsave_area.fxsave_area.control = 0x037F;
xsave_area.fxsave_area.tag = 0;
xsave_area.fxsave_area.mxcsr = 0x1F80;
xsave_area.features = features;
xsave_area
}
}
#[repr(C)]
#[repr(align(16))]
#[derive(Clone, Copy, Debug, Pod)]
struct FxSaveArea {
control: u16, status: u16, tag: u8, reserved1: u8, op: u16, ip: u32, cs: u32, dp: u32, ds: u32, mxcsr: u32, mxcsr_mask: u32, st_space: [u32; 32], xmm_space: [u32; 64], reserved2: [u32; 12], reserved3: [u32; 12], }
static XSTATE_MAX_FEATURES: Once<u64> = Once::new();
const XFEATURE_MASK_USER_RESTORE: u64 = 0b1110_0111;
static XSAVE_AREA_SIZE: Once<usize> = Once::new();
const MAX_XSAVE_AREA_SIZE: usize = 4096;
pub(in crate::arch) fn enable_essential_features() {
use super::extension::{IsaExtensions, has_extensions};
if has_extensions(IsaExtensions::XSAVE) {
XSTATE_MAX_FEATURES.call_once(|| super::cpuid::query_xstate_max_features().unwrap());
XSAVE_AREA_SIZE.call_once(|| {
let xsave_area_size = super::cpuid::query_xsave_area_size().unwrap() as usize;
assert!(xsave_area_size <= MAX_XSAVE_AREA_SIZE);
xsave_area_size
});
}
{
let mut cr0 = Cr0::read();
cr0.remove(Cr0Flags::TASK_SWITCHED | Cr0Flags::EMULATE_COPROCESSOR);
unsafe {
Cr0::write(cr0);
core::arch::asm!("fninit");
}
}
}