use bit_field::BitField;
use crate::{
instructions::segmentation::{Segment, CS},
structures::paging::{
page::{NotGiantPageSize, PageRange},
Page, PageSize, Size2MiB, Size4KiB,
},
PrivilegeLevel, VirtAddr,
};
use core::{arch::asm, cmp, convert::TryFrom, fmt};
#[inline]
pub fn flush(addr: VirtAddr) {
unsafe {
asm!("invlpg [{}]", in(reg) addr.as_u64(), options(nostack, preserves_flags));
}
}
#[inline]
pub fn flush_all() {
use crate::registers::control::Cr3;
let (frame, flags) = Cr3::read();
unsafe { Cr3::write(frame, flags) }
}
#[derive(Debug)]
pub enum InvPcidCommand {
Address(VirtAddr, Pcid),
Single(Pcid),
All,
AllExceptGlobal,
}
#[deprecated = "please use `InvPcidCommand` instead"]
#[doc(hidden)]
pub type InvPicdCommand = InvPcidCommand;
#[repr(C)]
#[derive(Debug)]
struct InvpcidDescriptor {
pcid: u64,
address: u64,
}
#[repr(transparent)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Pcid(u16);
impl Pcid {
pub const fn new(pcid: u16) -> Result<Pcid, PcidTooBig> {
if pcid >= 4096 {
Err(PcidTooBig(pcid))
} else {
Ok(Pcid(pcid))
}
}
pub const fn value(&self) -> u16 {
self.0
}
}
#[derive(Debug)]
pub struct PcidTooBig(u16);
impl fmt::Display for PcidTooBig {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "PCID should be < 4096, got {}", self.0)
}
}
#[inline]
pub unsafe fn flush_pcid(command: InvPcidCommand) {
let mut desc = InvpcidDescriptor {
pcid: 0,
address: 0,
};
let kind: u64;
match command {
InvPcidCommand::Address(addr, pcid) => {
kind = 0;
desc.pcid = pcid.value().into();
desc.address = addr.as_u64()
}
InvPcidCommand::Single(pcid) => {
kind = 1;
desc.pcid = pcid.0.into()
}
InvPcidCommand::All => kind = 2,
InvPcidCommand::AllExceptGlobal => kind = 3,
}
unsafe {
asm!("invpcid {0}, [{1}]", in(reg) kind, in(reg) &desc, options(nostack, preserves_flags));
}
}
#[derive(Debug, Clone, Copy)]
pub struct Invlpgb {
invlpgb_count_max: u16,
tlb_flush_nested: bool,
nasid: u32,
}
impl Invlpgb {
pub fn new() -> Option<Self> {
let cs = CS::get_reg();
assert_eq!(cs.rpl(), PrivilegeLevel::Ring0);
#[allow(unused_unsafe)]
let cpuid = unsafe { core::arch::x86_64::__cpuid(0x8000_0008) };
if !cpuid.ebx.get_bit(3) {
return None;
}
let tlb_flush_nested = cpuid.ebx.get_bit(21);
let invlpgb_count_max = cpuid.edx.get_bits(0..=15) as u16;
#[allow(unused_unsafe)]
let cpuid = unsafe { core::arch::x86_64::__cpuid(0x8000_000a) };
let nasid = cpuid.ebx;
Some(Self {
tlb_flush_nested,
invlpgb_count_max,
nasid,
})
}
#[inline]
pub fn invlpgb_count_max(&self) -> u16 {
self.invlpgb_count_max
}
#[inline]
pub fn tlb_flush_nested(&self) -> bool {
self.tlb_flush_nested
}
#[inline]
pub fn nasid(&self) -> u32 {
self.nasid
}
pub fn build(&self) -> InvlpgbFlushBuilder<'_> {
InvlpgbFlushBuilder {
invlpgb: self,
page_range: None,
pcid: None,
asid: None,
include_global: false,
final_translation_only: false,
include_nested_translations: false,
}
}
#[inline]
pub fn tlbsync(&self) {
unsafe {
asm!("tlbsync", options(nomem, preserves_flags));
}
}
}
#[derive(Debug, Clone)]
#[must_use]
pub struct InvlpgbFlushBuilder<'a, S = Size4KiB>
where
S: NotGiantPageSize,
{
invlpgb: &'a Invlpgb,
page_range: Option<PageRange<S>>,
pcid: Option<Pcid>,
asid: Option<u16>,
include_global: bool,
final_translation_only: bool,
include_nested_translations: bool,
}
impl<'a, S> InvlpgbFlushBuilder<'a, S>
where
S: NotGiantPageSize,
{
pub fn pages<T>(self, page_range: PageRange<T>) -> InvlpgbFlushBuilder<'a, T>
where
T: NotGiantPageSize,
{
InvlpgbFlushBuilder {
invlpgb: self.invlpgb,
page_range: Some(page_range),
pcid: self.pcid,
asid: self.asid,
include_global: self.include_global,
final_translation_only: self.final_translation_only,
include_nested_translations: self.include_nested_translations,
}
}
pub unsafe fn pcid(&mut self, pcid: Pcid) -> &mut Self {
self.pcid = Some(pcid);
self
}
pub unsafe fn asid(&mut self, asid: u16) -> Result<&mut Self, AsidOutOfRangeError> {
if u32::from(asid) >= self.invlpgb.nasid {
return Err(AsidOutOfRangeError {
asid,
nasid: self.invlpgb.nasid,
});
}
self.asid = Some(asid);
Ok(self)
}
pub fn include_global(&mut self) -> &mut Self {
self.include_global = true;
self
}
pub fn final_translation_only(&mut self) -> &mut Self {
self.final_translation_only = true;
self
}
pub fn include_nested_translations(mut self) -> Self {
assert!(
self.invlpgb.tlb_flush_nested,
"flushing all nested translations is not supported"
);
self.include_nested_translations = true;
self
}
pub fn flush(&self) {
if let Some(mut pages) = self.page_range {
while !pages.is_empty() {
let count = Page::<S>::steps_between_impl(&pages.start, &pages.end).0;
let second_half_start =
Page::<S>::containing_address(VirtAddr::new(0xffff_8000_0000_0000));
let count = if pages.start < second_half_start {
let count_to_second_half =
Page::steps_between_impl(&pages.start, &second_half_start).0;
cmp::min(count, count_to_second_half)
} else {
count
};
let count = u16::try_from(count).unwrap_or(u16::MAX);
let count = cmp::min(count, self.invlpgb.invlpgb_count_max);
unsafe {
flush_broadcast(
Some((pages.start, count)),
self.pcid,
self.asid,
self.include_global,
self.final_translation_only,
self.include_nested_translations,
);
}
let inc_count = cmp::max(count, 1);
pages.start =
Page::forward_checked_impl(pages.start, usize::from(inc_count)).unwrap();
}
} else {
unsafe {
flush_broadcast::<S>(
None,
self.pcid,
self.asid,
self.include_global,
self.final_translation_only,
self.include_nested_translations,
);
}
}
}
}
#[derive(Debug)]
pub struct AsidOutOfRangeError {
pub asid: u16,
pub nasid: u32,
}
impl fmt::Display for AsidOutOfRangeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{} is out of the range of available ASIDS ({})",
self.asid, self.nasid
)
}
}
#[inline]
unsafe fn flush_broadcast<S>(
va_and_count: Option<(Page<S>, u16)>,
pcid: Option<Pcid>,
asid: Option<u16>,
include_global: bool,
final_translation_only: bool,
include_nested_translations: bool,
) where
S: NotGiantPageSize,
{
let mut rax = 0;
let mut ecx = 0;
let mut edx = 0;
if let Some((va, count)) = va_and_count {
rax.set_bit(0, true);
rax.set_bits(12.., va.start_address().as_u64().get_bits(12..));
ecx.set_bits(0..=15, u32::from(count));
ecx.set_bit(31, S::SIZE == Size2MiB::SIZE);
}
if let Some(pcid) = pcid {
rax.set_bit(1, true);
edx.set_bits(16..=27, u32::from(pcid.value()));
}
if let Some(asid) = asid {
rax.set_bit(2, true);
edx.set_bits(0..=15, u32::from(asid));
}
rax.set_bit(3, include_global);
rax.set_bit(4, final_translation_only);
rax.set_bit(5, include_nested_translations);
unsafe {
asm!(
"invlpgb",
in("rax") rax,
in("ecx") ecx,
in("edx") edx,
options(nostack, preserves_flags),
);
}
}