use super::{pkru, sys};
use anyhow::{Context, Result};
use std::sync::OnceLock;
pub fn is_supported() -> bool {
cfg!(target_os = "linux") && cfg!(target_arch = "x86_64") && pkru::has_cpuid_bit_set()
}
pub fn keys(max: usize) -> &'static [ProtectionKey] {
let keys = KEYS.get_or_init(|| {
let mut allocated = vec![];
if is_supported() {
while allocated.len() < max {
if let Ok(key_id) = sys::pkey_alloc(0, 0) {
debug_assert!(key_id < 16);
allocated.push(ProtectionKey {
id: key_id,
stripe: allocated.len().try_into().unwrap(),
});
} else {
break;
}
}
}
allocated
});
&keys[..keys.len().min(max)]
}
static KEYS: OnceLock<Vec<ProtectionKey>> = OnceLock::new();
pub fn allow(mask: ProtectionMask) {
let previous = if log::log_enabled!(log::Level::Trace) {
pkru::read()
} else {
0
};
pkru::write(mask.0);
log::trace!("PKRU change: {:#034b} => {:#034b}", previous, pkru::read());
}
pub fn current_mask() -> ProtectionMask {
ProtectionMask(pkru::read())
}
#[derive(Clone, Copy, Debug)]
pub struct ProtectionKey {
id: u32,
stripe: u32,
}
impl ProtectionKey {
pub fn protect(&self, region: &mut [u8]) -> Result<()> {
let addr = region.as_mut_ptr() as usize;
let len = region.len();
let prot = sys::PROT_NONE;
sys::pkey_mprotect(addr, len, prot, self.id).with_context(|| {
format!(
"failed to mark region with pkey (addr = {addr:#x}, len = {len}, prot = {prot:#b})"
)
})
}
pub fn as_stripe(&self) -> usize {
self.stripe as usize
}
}
pub struct ProtectionMask(u32);
impl ProtectionMask {
#[inline]
pub fn all() -> Self {
Self(pkru::ALLOW_ACCESS)
}
#[inline]
pub fn zero() -> Self {
Self(pkru::DISABLE_ACCESS ^ 0b11)
}
#[inline]
pub fn or(self, pkey: ProtectionKey) -> Self {
let mask = pkru::DISABLE_ACCESS ^ 0b11 << (pkey.id * 2);
Self(self.0 & mask)
}
}
#[cfg(test)]
macro_rules! skip_if_mpk_unavailable {
() => {
if !crate::mpk::is_supported() {
println!("> mpk is not supported: ignoring test");
return;
}
};
}
#[cfg(test)]
pub(crate) use skip_if_mpk_unavailable;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn check_is_supported() {
println!("is pku supported = {}", is_supported());
if std::env::var("WASMTIME_TEST_FORCE_MPK").is_ok() {
assert!(is_supported());
}
}
#[test]
fn check_initialized_keys() {
if is_supported() {
assert!(!keys(15).is_empty())
}
}
#[test]
fn check_invalid_mark() {
skip_if_mpk_unavailable!();
let pkey = keys(15)[0];
let unaligned_region = unsafe {
let addr = 1 as *mut u8; let len = 1;
std::slice::from_raw_parts_mut(addr, len)
};
let result = pkey.protect(unaligned_region);
assert!(result.is_err());
assert_eq!(
result.unwrap_err().to_string(),
"failed to mark region with pkey (addr = 0x1, len = 1, prot = 0b0)"
);
}
#[test]
fn check_masking() {
skip_if_mpk_unavailable!();
let original = pkru::read();
allow(ProtectionMask::all());
assert_eq!(0, pkru::read());
allow(ProtectionMask::all().or(ProtectionKey { id: 5, stripe: 0 }));
assert_eq!(0, pkru::read());
allow(ProtectionMask::zero());
assert_eq!(0b11111111_11111111_11111111_11111100, pkru::read());
allow(ProtectionMask::zero().or(ProtectionKey { id: 5, stripe: 0 }));
assert_eq!(0b11111111_11111111_11110011_11111100, pkru::read());
pkru::write(original);
}
}