ostd 0.17.2

Rust OS framework that facilitates the development of and innovation in OS kernels
Documentation
// SPDX-License-Identifier: MPL-2.0 OR MIT
//
// The original source code is from [trapframe-rs](https://github.com/rcore-os/trapframe-rs),
// which is released under the following license:
//
// SPDX-License-Identifier: MIT
//
// Copyright (c) 2020 - 2024 Runji Wang
//
// We make the following new changes:
// * Implement the `trap_handler` of Asterinas.
//
// These changes are released under the following license:
//
// SPDX-License-Identifier: MPL-2.0

//! Handles trap.

pub(super) mod gdt;
mod idt;
mod syscall;

use cfg_if::cfg_if;
use spin::Once;

use super::cpu::context::GeneralRegs;
use crate::{
    arch::{
        cpu::context::CpuException,
        irq::{HwIrqLine, disable_local, enable_local},
    },
    cpu::PrivilegeLevel,
    ex_table::ExTable,
    irq::call_irq_callback_functions,
    mm::MAX_USERSPACE_VADDR,
};

cfg_if! {
    if #[cfg(feature = "cvm_guest")] {
        use tdx_guest::{tdcall, handle_virtual_exception};
        use crate::arch::tdx_guest::TrapFrameWrapper;
    }
}

/// Trap frame of kernel interrupt
///
/// # Trap handler
///
/// You need to define a handler function like this:
///
/// ```
/// // SAFETY: The name does not collide with other symbols.
/// #[unsafe(no_mangle)]
/// extern "sysv64" fn trap_handler(tf: &mut TrapFrame) {
///     match tf.trap_num {
///         3 => {
///             println!("TRAP: BreakPoint");
///             tf.rip += 1;
///         }
///         _ => panic!("TRAP: {:#x?}", tf),
///     }
/// }
/// ```
#[expect(missing_docs)]
#[repr(C)]
#[derive(Clone, Copy, Debug, Default)]
pub struct TrapFrame {
    // Pushed by 'trap.S'
    pub rax: usize,
    pub rbx: usize,
    pub rcx: usize,
    pub rdx: usize,
    pub rsi: usize,
    pub rdi: usize,
    pub rbp: usize,
    pub rsp: usize,
    pub r8: usize,
    pub r9: usize,
    pub r10: usize,
    pub r11: usize,
    pub r12: usize,
    pub r13: usize,
    pub r14: usize,
    pub r15: usize,
    pub _pad: usize,

    pub trap_num: usize,
    pub error_code: usize,

    // Pushed by CPU
    pub rip: usize,
    pub cs: usize,
    pub rflags: usize,
}

/// Initializes interrupt handling on x86_64.
///
/// This function will:
/// - Switch to a new, CPU-local [GDT].
/// - Switch to a new, CPU-local [TSS].
/// - Switch to a new, global [IDT].
/// - Enable the [`syscall`] instruction.
///
/// [GDT]: https://wiki.osdev.org/GDT
/// [IDT]: https://wiki.osdev.org/IDT
/// [TSS]: https://wiki.osdev.org/Task_State_Segment
/// [`syscall`]: https://www.felixcloutier.com/x86/syscall
///
/// # Safety
///
/// On the current CPU, this function must be called
/// - only once and
/// - before any trap can occur.
pub(crate) unsafe fn init_on_cpu() {
    // SAFETY: Since there's no traps, no preemption can occur.
    unsafe { gdt::init_on_cpu() };

    idt::init_on_cpu();

    // SAFETY: `gdt::init_on_cpu` has been called before.
    unsafe { syscall::init_on_cpu() };
}

/// Userspace context.
#[repr(C)]
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub(super) struct RawUserContext {
    pub(super) general: GeneralRegs,
    pub(super) trap_num: usize,
    pub(super) error_code: usize,
}

/// Handle traps (only from kernel).
// SAFETY: The name does not collide with other symbols.
#[unsafe(no_mangle)]
unsafe extern "sysv64" fn trap_handler(f: &mut TrapFrame) {
    fn enable_local_if(cond: bool) {
        if cond {
            enable_local();
        }
    }

    fn disable_local_if(cond: bool) {
        if cond {
            disable_local();
        }
    }

    // The IRQ state before trapping. We need to ensure that the IRQ state
    // during exception handling is consistent with the state before the trap.
    let was_irq_enabled =
        f.rflags as u64 & x86_64::registers::rflags::RFlags::INTERRUPT_FLAG.bits() > 0;

    let cpu_exception = CpuException::new(f.trap_num, f.error_code);
    match cpu_exception {
        #[cfg(feature = "cvm_guest")]
        Some(CpuException::VirtualizationException) => {
            let ve_info = tdcall::get_veinfo().expect("#VE handler: fail to get VE info\n");
            // We need to enable interrupts only after `tdcall::get_veinfo` is called
            // to avoid nested `#VE`s.
            enable_local_if(was_irq_enabled);
            let mut trapframe_wrapper = TrapFrameWrapper(&mut *f);
            handle_virtual_exception(&mut trapframe_wrapper, &ve_info);
            *f = *trapframe_wrapper.0;
            disable_local_if(was_irq_enabled);
        }
        Some(CpuException::PageFault(raw_page_fault_info)) => {
            enable_local_if(was_irq_enabled);
            // The actual user space implementation should be responsible
            // for providing mechanism to treat the 0 virtual address.
            if (0..MAX_USERSPACE_VADDR).contains(&raw_page_fault_info.addr) {
                handle_user_page_fault(f, cpu_exception.as_ref().unwrap());
            } else {
                panic!(
                    "Cannot handle kernel page fault: {:#x?}; trapframe: {:#x?}",
                    raw_page_fault_info, f
                );
            }
            disable_local_if(was_irq_enabled);
        }
        Some(exception) => {
            enable_local_if(was_irq_enabled);
            panic!(
                "Cannot handle kernel CPU exception: {:#x?}; trapframe: {:#x?}",
                exception, f
            );
        }
        None => {
            call_irq_callback_functions(
                f,
                &HwIrqLine::new(f.trap_num as u8),
                PrivilegeLevel::Kernel,
            );
        }
    }
}

#[expect(clippy::type_complexity)]
static USER_PAGE_FAULT_HANDLER: Once<fn(&CpuException) -> core::result::Result<(), ()>> =
    Once::new();

/// Injects a custom handler for page faults that occur in the kernel and
/// are caused by user-space address.
pub fn inject_user_page_fault_handler(
    handler: fn(info: &CpuException) -> core::result::Result<(), ()>,
) {
    USER_PAGE_FAULT_HANDLER.call_once(|| handler);
}

/// Handles page fault from user space.
fn handle_user_page_fault(f: &mut TrapFrame, exception: &CpuException) {
    let handler = USER_PAGE_FAULT_HANDLER
        .get()
        .expect("a page fault handler is missing");

    let res = handler(exception);
    // Copying bytes by bytes can recover directly
    // if handling the page fault successfully.
    if res.is_ok() {
        return;
    }

    // Use the exception table to recover to normal execution.
    if let Some(addr) = ExTable::find_recovery_inst_addr(f.rip) {
        f.rip = addr;
    } else {
        panic!("Cannot handle user page fault; trapframe: {:#x?}", f);
    }
}