use super::TrapCode;
use core::fmt;
use core::num::NonZeroU8;
use core::str::FromStr;
#[cfg(feature = "enable-serde")]
use serde_derive::{Deserialize, Serialize};
#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)]
pub enum Endianness {
Little,
Big,
}
#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)]
#[repr(u8)]
#[expect(missing_docs, reason = "self-describing variants")]
#[rustfmt::skip]
pub enum AliasRegion {
Heap = 0b01,
Table = 0b10,
Vmctx = 0b11,
}
impl AliasRegion {
const fn from_bits(bits: u8) -> Option<Self> {
match bits {
0b00 => None,
0b01 => Some(Self::Heap),
0b10 => Some(Self::Table),
0b11 => Some(Self::Vmctx),
_ => panic!("invalid alias region bits"),
}
}
const fn to_bits(region: Option<Self>) -> u8 {
match region {
None => 0b00,
Some(r) => r as u8,
}
}
}
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
#[cfg_attr(feature = "enable-serde", derive(Serialize, Deserialize))]
pub struct MemFlags {
bits: u16,
}
const BIT_ALIGNED: u16 = 1 << 0;
const BIT_READONLY: u16 = 1 << 1;
const BIT_LITTLE_ENDIAN: u16 = 1 << 2;
const BIT_BIG_ENDIAN: u16 = 1 << 3;
const BIT_CHECKED: u16 = 1 << 4;
const MASK_ALIAS_REGION: u16 = 0b11 << ALIAS_REGION_OFFSET;
const ALIAS_REGION_OFFSET: u16 = 5;
const MASK_TRAP_CODE: u16 = 0b1111_1111 << TRAP_CODE_OFFSET;
const TRAP_CODE_OFFSET: u16 = 7;
const BIT_CAN_MOVE: u16 = 1 << 15;
impl MemFlags {
pub const fn new() -> Self {
Self { bits: 0 }.with_trap_code(Some(TrapCode::HEAP_OUT_OF_BOUNDS))
}
pub const fn trusted() -> Self {
Self::new().with_notrap().with_aligned()
}
const fn read_bit(self, bit: u16) -> bool {
self.bits & bit != 0
}
const fn with_bit(mut self, bit: u16) -> Self {
self.bits |= bit;
self
}
pub const fn alias_region(self) -> Option<AliasRegion> {
AliasRegion::from_bits(((self.bits & MASK_ALIAS_REGION) >> ALIAS_REGION_OFFSET) as u8)
}
pub const fn with_alias_region(mut self, region: Option<AliasRegion>) -> Self {
let bits = AliasRegion::to_bits(region);
self.bits &= !MASK_ALIAS_REGION;
self.bits |= (bits as u16) << ALIAS_REGION_OFFSET;
self
}
pub fn set_alias_region(&mut self, region: Option<AliasRegion>) {
*self = self.with_alias_region(region);
}
pub fn set_by_name(&mut self, name: &str) -> Result<bool, &'static str> {
*self = match name {
"notrap" => self.with_trap_code(None),
"aligned" => self.with_aligned(),
"readonly" => self.with_readonly(),
"little" => {
if self.read_bit(BIT_BIG_ENDIAN) {
return Err("cannot set both big and little endian bits");
}
self.with_endianness(Endianness::Little)
}
"big" => {
if self.read_bit(BIT_LITTLE_ENDIAN) {
return Err("cannot set both big and little endian bits");
}
self.with_endianness(Endianness::Big)
}
"heap" => {
if self.alias_region().is_some() {
return Err("cannot set more than one alias region");
}
self.with_alias_region(Some(AliasRegion::Heap))
}
"table" => {
if self.alias_region().is_some() {
return Err("cannot set more than one alias region");
}
self.with_alias_region(Some(AliasRegion::Table))
}
"vmctx" => {
if self.alias_region().is_some() {
return Err("cannot set more than one alias region");
}
self.with_alias_region(Some(AliasRegion::Vmctx))
}
"checked" => self.with_checked(),
"can_move" => self.with_can_move(),
other => match TrapCode::from_str(other) {
Ok(code) => self.with_trap_code(Some(code)),
Err(()) => return Ok(false),
},
};
Ok(true)
}
pub const fn endianness(self, native_endianness: Endianness) -> Endianness {
if self.read_bit(BIT_LITTLE_ENDIAN) {
Endianness::Little
} else if self.read_bit(BIT_BIG_ENDIAN) {
Endianness::Big
} else {
native_endianness
}
}
pub const fn explicit_endianness(self) -> Option<Endianness> {
if self.read_bit(BIT_LITTLE_ENDIAN) {
Some(Endianness::Little)
} else if self.read_bit(BIT_BIG_ENDIAN) {
Some(Endianness::Big)
} else {
None
}
}
pub fn set_endianness(&mut self, endianness: Endianness) {
*self = self.with_endianness(endianness);
}
pub const fn with_endianness(self, endianness: Endianness) -> Self {
let res = match endianness {
Endianness::Little => self.with_bit(BIT_LITTLE_ENDIAN),
Endianness::Big => self.with_bit(BIT_BIG_ENDIAN),
};
assert!(!(res.read_bit(BIT_LITTLE_ENDIAN) && res.read_bit(BIT_BIG_ENDIAN)));
res
}
pub const fn notrap(self) -> bool {
self.trap_code().is_none()
}
pub fn set_notrap(&mut self) {
*self = self.with_notrap();
}
pub const fn with_notrap(self) -> Self {
self.with_trap_code(None)
}
pub const fn can_move(self) -> bool {
self.read_bit(BIT_CAN_MOVE)
}
pub const fn set_can_move(&mut self) {
*self = self.with_can_move();
}
pub const fn with_can_move(self) -> Self {
self.with_bit(BIT_CAN_MOVE)
}
pub const fn aligned(self) -> bool {
self.read_bit(BIT_ALIGNED)
}
pub fn set_aligned(&mut self) {
*self = self.with_aligned();
}
pub const fn with_aligned(self) -> Self {
self.with_bit(BIT_ALIGNED)
}
pub const fn readonly(self) -> bool {
self.read_bit(BIT_READONLY)
}
pub fn set_readonly(&mut self) {
*self = self.with_readonly();
}
pub const fn with_readonly(self) -> Self {
self.with_bit(BIT_READONLY)
}
pub const fn checked(self) -> bool {
self.read_bit(BIT_CHECKED)
}
pub fn set_checked(&mut self) {
*self = self.with_checked();
}
pub const fn with_checked(self) -> Self {
self.with_bit(BIT_CHECKED)
}
pub const fn trap_code(self) -> Option<TrapCode> {
let byte = ((self.bits & MASK_TRAP_CODE) >> TRAP_CODE_OFFSET) as u8;
match NonZeroU8::new(byte) {
Some(code) => Some(TrapCode::from_raw(code)),
None => None,
}
}
pub const fn with_trap_code(mut self, code: Option<TrapCode>) -> Self {
let bits = match code {
Some(code) => code.as_raw().get() as u16,
None => 0,
};
self.bits &= !MASK_TRAP_CODE;
self.bits |= bits << TRAP_CODE_OFFSET;
self
}
}
impl fmt::Display for MemFlags {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self.trap_code() {
None => write!(f, " notrap")?,
Some(TrapCode::HEAP_OUT_OF_BOUNDS) => {}
Some(t) => write!(f, " {t}")?,
}
if self.aligned() {
write!(f, " aligned")?;
}
if self.readonly() {
write!(f, " readonly")?;
}
if self.can_move() {
write!(f, " can_move")?;
}
if self.read_bit(BIT_BIG_ENDIAN) {
write!(f, " big")?;
}
if self.read_bit(BIT_LITTLE_ENDIAN) {
write!(f, " little")?;
}
if self.checked() {
write!(f, " checked")?;
}
match self.alias_region() {
None => {}
Some(AliasRegion::Heap) => write!(f, " heap")?,
Some(AliasRegion::Table) => write!(f, " table")?,
Some(AliasRegion::Vmctx) => write!(f, " vmctx")?,
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn roundtrip_traps() {
for trap in TrapCode::non_user_traps().iter().copied() {
let flags = MemFlags::new().with_trap_code(Some(trap));
assert_eq!(flags.trap_code(), Some(trap));
}
let flags = MemFlags::new().with_trap_code(None);
assert_eq!(flags.trap_code(), None);
}
#[test]
fn cannot_set_big_and_little() {
let mut big = MemFlags::new().with_endianness(Endianness::Big);
assert!(big.set_by_name("little").is_err());
let mut little = MemFlags::new().with_endianness(Endianness::Little);
assert!(little.set_by_name("big").is_err());
}
#[test]
fn only_one_region() {
let mut big = MemFlags::new().with_alias_region(Some(AliasRegion::Heap));
assert!(big.set_by_name("table").is_err());
assert!(big.set_by_name("vmctx").is_err());
let mut big = MemFlags::new().with_alias_region(Some(AliasRegion::Table));
assert!(big.set_by_name("heap").is_err());
assert!(big.set_by_name("vmctx").is_err());
let mut big = MemFlags::new().with_alias_region(Some(AliasRegion::Vmctx));
assert!(big.set_by_name("heap").is_err());
assert!(big.set_by_name("table").is_err());
}
}