use core::ops::{Deref, DerefMut};
use ax_memory_addr::VirtAddr;
use x86_64::{
registers::{
control::Cr2,
model_specific::{Efer, EferFlags, KernelGsBase, LStar, SFMask, Star},
rflags::RFlags,
},
structures::idt::ExceptionVector,
};
use super::{
TrapFrame,
asm::{read_thread_pointer, write_thread_pointer},
gdt,
trap::{IRQ_VECTOR_END, IRQ_VECTOR_START, LEGACY_SYSCALL_VECTOR, err_code_to_flags},
};
pub use crate::uspace_common::{ExceptionKind, ReturnReason};
#[derive(Debug, Clone, Copy)]
#[repr(C)]
pub struct UserContext {
tf: TrapFrame,
pub fs_base: u64,
pub gs_base: u64,
}
impl UserContext {
pub fn new(entry: usize, ustack_top: VirtAddr, arg0: usize) -> Self {
use x86_64::registers::rflags::RFlags;
Self {
tf: TrapFrame {
rdi: arg0 as _,
rip: entry as _,
cs: gdt::UCODE64.0 as _,
rflags: RFlags::INTERRUPT_FLAG.bits(), rsp: ustack_top.as_usize() as _,
ss: gdt::UDATA.0 as _,
..Default::default()
},
fs_base: 0,
gs_base: 0,
}
}
pub const fn tls(&self) -> usize {
self.fs_base as _
}
pub const fn set_tls(&mut self, tls_area: usize) {
self.fs_base = tls_area as _;
}
pub fn run(&mut self) -> ReturnReason {
extern "C" {
fn enter_user(uctx: &mut UserContext);
}
assert_eq!(self.cs, gdt::UCODE64.0 as _);
assert_eq!(self.ss, gdt::UDATA.0 as _);
crate::asm::disable_irqs();
let kernel_fs_base = read_thread_pointer();
unsafe { write_thread_pointer(self.fs_base as _) };
KernelGsBase::write(x86_64::VirtAddr::new_truncate(self.gs_base));
unsafe { enter_user(self) };
self.gs_base = KernelGsBase::read().as_u64();
self.fs_base = read_thread_pointer() as _;
unsafe { write_thread_pointer(kernel_fs_base) };
let cr2 = Cr2::read().unwrap().as_u64() as usize;
let vector = self.vector as u8;
const PAGE_FAULT_VECTOR: u8 = ExceptionVector::Page as u8;
let ret = match (vector, err_code_to_flags(self.error_code)) {
(PAGE_FAULT_VECTOR, Ok(flags)) => ReturnReason::PageFault(va!(cr2), flags),
(LEGACY_SYSCALL_VECTOR, _) => ReturnReason::Syscall,
(IRQ_VECTOR_START..=IRQ_VECTOR_END, _) => {
crate::trap::irq_handler(vector as _);
ReturnReason::Interrupt
}
_ => ReturnReason::Exception(ExceptionInfo {
vector,
error_code: self.error_code,
cr2,
}),
};
crate::asm::enable_irqs();
ret
}
}
impl Deref for UserContext {
type Target = TrapFrame;
fn deref(&self) -> &Self::Target {
&self.tf
}
}
impl DerefMut for UserContext {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.tf
}
}
#[derive(Debug, Clone, Copy)]
pub struct ExceptionInfo {
pub vector: u8,
pub error_code: u64,
pub cr2: usize,
}
impl ExceptionInfo {
pub fn kind(&self) -> ExceptionKind {
match ExceptionVector::try_from(self.vector) {
Ok(ExceptionVector::Breakpoint) => ExceptionKind::Breakpoint,
Ok(ExceptionVector::InvalidOpcode) => ExceptionKind::IllegalInstruction,
_ => ExceptionKind::Other,
}
}
}
pub(super) fn init_syscall() {
extern "C" {
fn syscall_entry();
}
LStar::write(x86_64::VirtAddr::new_truncate(
syscall_entry as *const () as usize as _,
));
Star::write(gdt::UCODE64, gdt::UDATA, gdt::KCODE64, gdt::KDATA).unwrap();
SFMask::write(
RFlags::TRAP_FLAG
| RFlags::INTERRUPT_FLAG
| RFlags::DIRECTION_FLAG
| RFlags::IOPL_LOW
| RFlags::IOPL_HIGH
| RFlags::NESTED_TASK
| RFlags::ALIGNMENT_CHECK,
); unsafe {
Efer::update(|efer| *efer |= EferFlags::SYSTEM_CALL_EXTENSIONS);
}
}