pub use crate::registers::segmentation::SegmentSelector;
use crate::structures::tss::{InvalidIoMap, TaskStateSegment};
use crate::PrivilegeLevel;
use bit_field::BitField;
use bitflags::bitflags;
use core::{cmp, fmt, mem};
#[cfg(doc)]
use crate::registers::segmentation::{Segment, CS, SS};
#[cfg(all(feature = "instructions", target_arch = "x86_64"))]
use core::sync::atomic::{AtomicU64 as EntryValue, Ordering};
#[cfg(not(all(feature = "instructions", target_arch = "x86_64")))]
use u64 as EntryValue;
#[repr(transparent)]
pub struct Entry(EntryValue);
impl Entry {
const fn new(raw: u64) -> Self {
#[cfg(all(feature = "instructions", target_arch = "x86_64"))]
let raw = EntryValue::new(raw);
Self(raw)
}
pub fn raw(&self) -> u64 {
#[cfg(all(feature = "instructions", target_arch = "x86_64"))]
let raw = self.0.load(Ordering::SeqCst);
#[cfg(not(all(feature = "instructions", target_arch = "x86_64")))]
let raw = self.0;
raw
}
}
impl Clone for Entry {
fn clone(&self) -> Self {
Self::new(self.raw())
}
}
impl PartialEq for Entry {
fn eq(&self, other: &Self) -> bool {
self.raw() == other.raw()
}
}
impl Eq for Entry {}
impl fmt::Debug for Entry {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Entry({:#018x})", self.raw())
}
}
#[derive(Debug, Clone)]
pub struct GlobalDescriptorTable<const MAX: usize = 8> {
table: [Entry; MAX],
len: usize,
}
impl GlobalDescriptorTable {
pub const fn new() -> Self {
Self::empty()
}
}
impl Default for GlobalDescriptorTable {
#[inline]
fn default() -> Self {
Self::new()
}
}
impl<const MAX: usize> GlobalDescriptorTable<MAX> {
#[inline]
pub const fn empty() -> Self {
assert!(MAX > 0, "A GDT cannot have 0 entries");
assert!(MAX <= (1 << 13), "A GDT can only have at most 2^13 entries");
#[allow(clippy::declare_interior_mutable_const)]
const NULL: Entry = Entry::new(0);
Self {
table: [NULL; MAX],
len: 1,
}
}
#[cfg_attr(
not(all(feature = "instructions", target_arch = "x86_64")),
allow(rustdoc::broken_intra_doc_links)
)]
#[inline]
pub const fn from_raw_entries(slice: &[u64]) -> Self {
let len = slice.len();
let mut table = Self::empty().table;
let mut idx = 0;
assert!(len > 0, "cannot initialize GDT with empty slice");
assert!(slice[0] == 0, "first GDT entry must be zero");
assert!(
len <= MAX,
"cannot initialize GDT with slice exceeding the maximum length"
);
while idx < len {
table[idx] = Entry::new(slice[idx]);
idx += 1;
}
Self { table, len }
}
#[inline]
pub fn entries(&self) -> &[Entry] {
&self.table[..self.len]
}
#[inline]
#[rustversion::attr(since(1.83), const)]
pub fn append(&mut self, entry: Descriptor) -> SegmentSelector {
let index = match entry {
Descriptor::UserSegment(value) => {
if self.len > self.table.len().saturating_sub(1) {
panic!("GDT full")
}
self.push(value)
}
Descriptor::SystemSegment(value_low, value_high) => {
if self.len > self.table.len().saturating_sub(2) {
panic!("GDT requires two free spaces to hold a SystemSegment")
}
let index = self.push(value_low);
self.push(value_high);
index
}
};
SegmentSelector::new(index as u16, entry.dpl())
}
#[cfg(all(feature = "instructions", target_arch = "x86_64"))]
#[inline]
pub fn load(&'static self) {
unsafe { self.load_unsafe() };
}
#[cfg(all(feature = "instructions", target_arch = "x86_64"))]
#[inline]
pub unsafe fn load_unsafe(&self) {
use crate::instructions::tables::lgdt;
unsafe {
lgdt(&self.pointer());
}
}
#[inline]
#[rustversion::attr(since(1.83), const)]
fn push(&mut self, value: u64) -> usize {
let index = self.len;
self.table[index] = Entry::new(value);
self.len += 1;
index
}
pub const fn limit(&self) -> u16 {
use core::mem::size_of;
(self.len * size_of::<u64>() - 1) as u16
}
#[cfg(all(feature = "instructions", target_arch = "x86_64"))]
fn pointer(&self) -> super::DescriptorTablePointer {
super::DescriptorTablePointer {
base: crate::VirtAddr::new(self.table.as_ptr() as u64),
limit: self.limit(),
}
}
}
#[derive(Debug, Clone, Copy)]
pub enum Descriptor {
UserSegment(u64),
SystemSegment(u64, u64),
}
bitflags! {
#[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Clone, Copy)]
pub struct DescriptorFlags: u64 {
const ACCESSED = 1 << 40;
const WRITABLE = 1 << 41;
const CONFORMING = 1 << 42;
const EXECUTABLE = 1 << 43;
const USER_SEGMENT = 1 << 44;
const DPL_RING_3 = 3 << 45;
const PRESENT = 1 << 47;
const AVAILABLE = 1 << 52;
const LONG_MODE = 1 << 53;
const DEFAULT_SIZE = 1 << 54;
const GRANULARITY = 1 << 55;
const LIMIT_0_15 = 0xFFFF;
const LIMIT_16_19 = 0xF << 48;
const BASE_0_23 = 0xFF_FFFF << 16;
const BASE_24_31 = 0xFF << 56;
}
}
impl DescriptorFlags {
const COMMON: Self = Self::from_bits_truncate(
Self::USER_SEGMENT.bits()
| Self::PRESENT.bits()
| Self::WRITABLE.bits()
| Self::ACCESSED.bits()
| Self::LIMIT_0_15.bits()
| Self::LIMIT_16_19.bits()
| Self::GRANULARITY.bits(),
);
pub const KERNEL_DATA: Self =
Self::from_bits_truncate(Self::COMMON.bits() | Self::DEFAULT_SIZE.bits());
pub const KERNEL_CODE32: Self = Self::from_bits_truncate(
Self::COMMON.bits() | Self::EXECUTABLE.bits() | Self::DEFAULT_SIZE.bits(),
);
pub const KERNEL_CODE64: Self = Self::from_bits_truncate(
Self::COMMON.bits() | Self::EXECUTABLE.bits() | Self::LONG_MODE.bits(),
);
pub const USER_DATA: Self =
Self::from_bits_truncate(Self::KERNEL_DATA.bits() | Self::DPL_RING_3.bits());
pub const USER_CODE32: Self =
Self::from_bits_truncate(Self::KERNEL_CODE32.bits() | Self::DPL_RING_3.bits());
pub const USER_CODE64: Self =
Self::from_bits_truncate(Self::KERNEL_CODE64.bits() | Self::DPL_RING_3.bits());
}
impl Descriptor {
#[inline]
pub const fn dpl(self) -> PrivilegeLevel {
let value_low = match self {
Descriptor::UserSegment(v) => v,
Descriptor::SystemSegment(v, _) => v,
};
let dpl = (value_low & DescriptorFlags::DPL_RING_3.bits()) >> 45;
PrivilegeLevel::from_u16(dpl as u16)
}
#[inline]
pub const fn kernel_code_segment() -> Descriptor {
Descriptor::UserSegment(DescriptorFlags::KERNEL_CODE64.bits())
}
#[inline]
pub const fn kernel_data_segment() -> Descriptor {
Descriptor::UserSegment(DescriptorFlags::KERNEL_DATA.bits())
}
#[inline]
pub const fn user_data_segment() -> Descriptor {
Descriptor::UserSegment(DescriptorFlags::USER_DATA.bits())
}
#[inline]
pub const fn user_code_segment() -> Descriptor {
Descriptor::UserSegment(DescriptorFlags::USER_CODE64.bits())
}
#[inline]
pub fn tss_segment(tss: &'static TaskStateSegment) -> Descriptor {
unsafe { Self::tss_segment_unchecked(tss) }
}
#[inline]
pub unsafe fn tss_segment_unchecked(tss: *const TaskStateSegment) -> Descriptor {
unsafe { Self::tss_segment_raw(tss, 0) }
}
pub fn tss_segment_with_iomap(
tss: &'static TaskStateSegment,
iomap: &'static [u8],
) -> Result<Descriptor, InvalidIoMap> {
if iomap.len() > 8193 {
return Err(InvalidIoMap::TooLong { len: iomap.len() });
}
let iomap_addr = iomap.as_ptr() as usize;
let tss_addr = tss as *const _ as usize;
if tss_addr > iomap_addr {
return Err(InvalidIoMap::IoMapBeforeTss);
}
let base = iomap_addr - tss_addr;
if base > 0xdfff {
return Err(InvalidIoMap::TooFarFromTss { distance: base });
}
let last_byte = *iomap.last().unwrap_or(&0xff);
if last_byte != 0xff {
return Err(InvalidIoMap::InvalidTerminatingByte { byte: last_byte });
}
if tss.iomap_base != base as u16 {
return Err(InvalidIoMap::InvalidBase {
expected: base as u16,
got: tss.iomap_base,
});
}
Ok(unsafe { Self::tss_segment_raw(tss, iomap.len() as u16) })
}
unsafe fn tss_segment_raw(tss: *const TaskStateSegment, iomap_size: u16) -> Descriptor {
use self::DescriptorFlags as Flags;
let ptr = tss as u64;
let mut low = Flags::PRESENT.bits();
low.set_bits(16..40, ptr.get_bits(0..24));
low.set_bits(56..64, ptr.get_bits(24..32));
let iomap_limit = u64::from(unsafe { (*tss).iomap_base }) + u64::from(iomap_size);
low.set_bits(
0..16,
cmp::max(mem::size_of::<TaskStateSegment>() as u64, iomap_limit) - 1,
);
low.set_bits(40..44, 0b1001);
let mut high = 0;
high.set_bits(0..32, ptr.get_bits(32..64));
Descriptor::SystemSegment(low, high)
}
}
#[cfg(test)]
mod tests {
use super::DescriptorFlags as Flags;
use super::*;
#[test]
#[rustfmt::skip]
pub fn linux_kernel_defaults() {
assert_eq!(Flags::KERNEL_CODE64.bits(), 0x00af9b000000ffff);
assert_eq!(Flags::KERNEL_CODE32.bits(), 0x00cf9b000000ffff);
assert_eq!(Flags::KERNEL_DATA.bits(), 0x00cf93000000ffff);
assert_eq!(Flags::USER_CODE64.bits(), 0x00affb000000ffff);
assert_eq!(Flags::USER_CODE32.bits(), 0x00cffb000000ffff);
assert_eq!(Flags::USER_DATA.bits(), 0x00cff3000000ffff);
}
fn make_six_entry_gdt() -> GlobalDescriptorTable {
let mut gdt = GlobalDescriptorTable::new();
gdt.append(Descriptor::kernel_code_segment());
gdt.append(Descriptor::kernel_data_segment());
gdt.append(Descriptor::UserSegment(DescriptorFlags::USER_CODE32.bits()));
gdt.append(Descriptor::user_data_segment());
gdt.append(Descriptor::user_code_segment());
assert_eq!(gdt.len, 6);
gdt
}
static TSS: TaskStateSegment = TaskStateSegment::new();
fn make_full_gdt() -> GlobalDescriptorTable {
let mut gdt = make_six_entry_gdt();
gdt.append(Descriptor::tss_segment(&TSS));
assert_eq!(gdt.len, 8);
gdt
}
#[test]
pub fn push_max_segments() {
let mut gdt = make_six_entry_gdt();
gdt.append(Descriptor::user_data_segment());
assert_eq!(gdt.len, 7);
gdt.append(Descriptor::user_data_segment());
assert_eq!(gdt.len, 8);
let _ = make_full_gdt();
}
#[test]
#[should_panic]
pub fn panic_user_segment() {
let mut gdt = make_full_gdt();
gdt.append(Descriptor::user_data_segment());
}
#[test]
#[should_panic]
pub fn panic_system_segment() {
let mut gdt = make_six_entry_gdt();
gdt.append(Descriptor::user_data_segment());
gdt.append(Descriptor::tss_segment(&TSS));
}
#[test]
pub fn from_entries() {
let raw = [0, Flags::KERNEL_CODE64.bits(), Flags::KERNEL_DATA.bits()];
let gdt = GlobalDescriptorTable::<3>::from_raw_entries(&raw);
assert_eq!(gdt.table.len(), 3);
assert_eq!(gdt.entries().len(), 3);
}
#[test]
pub fn descriptor_dpl() {
assert_eq!(
Descriptor::kernel_code_segment().dpl(),
PrivilegeLevel::Ring0
);
assert_eq!(
Descriptor::kernel_data_segment().dpl(),
PrivilegeLevel::Ring0
);
assert_eq!(Descriptor::user_code_segment().dpl(), PrivilegeLevel::Ring3);
assert_eq!(Descriptor::user_code_segment().dpl(), PrivilegeLevel::Ring3);
}
}