#![doc = include_str!("../README.md")]
use std::ops::Deref;
use std::sync::Arc;
#[cfg(all(target_arch = "x86", not(target_env = "sgx"), target_feature = "sse"))]
use ::core::arch::x86 as arch;
#[cfg(all(target_arch = "x86_64", not(target_env = "sgx")))]
use ::core::arch::x86_64 as arch;
#[derive(Default)]
pub struct ProtectionKeys {
handle: Option<libc::c_int>,
}
impl ProtectionKeys {
pub fn is_supported() -> bool {
is_ospke_supported()
}
pub fn new(require_protected: bool) -> Result<Arc<Self>, ProtectionError> {
#[inline(always)]
fn stub(require_protected: bool) -> Result<Arc<ProtectionKeys>, ProtectionError> {
if require_protected {
Err(ProtectionError::Unsupported)
} else {
log::error!(
"Protection keys are not supported by this CPU or OS. \
Skipping keystore memory protection"
);
Ok(Arc::new(ProtectionKeys { handle: None }))
}
}
#[cfg(not(target_os = "linux"))]
#[allow(clippy::needless_return)]
return stub(require_protected);
#[cfg(target_os = "linux")]
{
if !is_ospke_supported() {
return stub(require_protected);
}
let pkey = unsafe { libc::syscall(libc::SYS_pkey_alloc, 0usize, PKEY_DISABLE_ACCESS) };
if pkey < 0 && !require_protected {
log::error!("Protection keys allocation failed");
Ok(Arc::new(Self { handle: None }))
} else if pkey < 0 {
Err(ProtectionError::PkeyAllocationFailed(
std::io::Error::last_os_error(),
))
} else {
Ok(Arc::new(Self {
handle: Some(pkey as libc::c_int),
}))
}
}
}
pub fn make_region<T>(
self: &Arc<Self>,
initial: T,
) -> Result<Arc<ProtectedRegion<T>>, ProtectionError>
where
T: Sized,
{
ProtectedRegion::new(self, initial)
}
pub fn is_empty(&self) -> bool {
self.handle.is_none()
}
fn set(&self, rights: usize) {
#[cfg(any(not(target_arch = "x86_64"), not(target_os = "linux")))]
let _unused = rights;
#[cfg(all(target_arch = "x86_64", target_os = "linux"))]
if let Some(handle) = self.handle {
unsafe {
let eax = (rights << (2 * handle as usize)) as u32;
std::arch::asm!(
".byte 0x0f, 0x01, 0xef",
in("eax") eax,
in("ecx") 0,
in("edx") 0,
options(nomem, preserves_flags, nostack)
)
}
}
}
}
#[cfg(target_os = "linux")]
impl Drop for ProtectionKeys {
fn drop(&mut self) {
let handle = match self.handle {
Some(handle) => handle as usize,
None => return,
};
if unsafe { libc::syscall(libc::SYS_pkey_free, handle) } < 0 {
log::error!("failed to free pkey: {}", std::io::Error::last_os_error());
}
}
}
pub struct ProtectedRegion<T> {
pkey: Arc<ProtectionKeys>,
ptr: *mut libc::c_void,
_marker: std::marker::PhantomData<T>,
}
impl<T> ProtectedRegion<T> {
const _ASSERT: () = assert!(std::mem::size_of::<T>() <= PAGE_SIZE);
fn new(pkey: &Arc<ProtectionKeys>, initial: T) -> Result<Arc<Self>, ProtectionError>
where
T: Sized,
{
let ptr = unsafe {
libc::mmap(
std::ptr::null_mut(),
PAGE_SIZE,
libc::PROT_NONE,
libc::MAP_ANON | libc::MAP_PRIVATE,
-1,
0,
)
};
if ptr == libc::MAP_FAILED {
return Err(ProtectionError::MMapFailed(std::io::Error::last_os_error()));
}
#[cfg(not(target_os = "linux"))]
{
let res = unsafe { libc::mprotect(ptr, PAGE_SIZE, libc::PROT_READ | libc::PROT_WRITE) };
if res < 0 {
return Err(ProtectionError::MProtectFailed(
std::io::Error::last_os_error(),
));
}
}
#[cfg(target_os = "linux")]
{
let res = unsafe {
libc::syscall(
libc::SYS_pkey_mprotect,
ptr as usize,
PAGE_SIZE,
libc::PROT_READ | libc::PROT_WRITE,
pkey.handle.unwrap_or(-1),
)
};
if res < 0 {
return Err(ProtectionError::MProtectFailed(
std::io::Error::last_os_error(),
));
}
}
pkey.set(0);
unsafe { (ptr as *mut T).write(initial) };
pkey.set(PKEY_DISABLE_ACCESS);
Ok(Arc::new(Self {
pkey: pkey.clone(),
ptr,
_marker: std::marker::PhantomData::default(),
}))
}
pub fn lock(&'_ self) -> ProtectedRegionGuard<'_, T> {
ProtectedRegionGuard::new(self)
}
}
impl<T> Drop for ProtectedRegion<T> {
fn drop(&mut self) {
self.pkey.set(0);
unsafe { std::ptr::drop_in_place(self.ptr as *mut T) };
self.pkey.set(PKEY_DISABLE_ACCESS);
if unsafe { libc::munmap(self.ptr, PAGE_SIZE) } < 0 {
log::error!("failed to unmap file: {}", std::io::Error::last_os_error());
}
}
}
unsafe impl<T: Sync> Sync for ProtectedRegion<T> {}
unsafe impl<T> Send for ProtectedRegion<T> {}
pub struct ProtectedRegionGuard<'a, T> {
region: &'a ProtectedRegion<T>,
_marker: std::marker::PhantomData<*const u8>,
}
impl<'a, T> ProtectedRegionGuard<'a, T> {
fn new(region: &'a ProtectedRegion<T>) -> Self {
region.pkey.set(0);
Self {
region,
_marker: Default::default(),
}
}
}
impl<T> Deref for ProtectedRegionGuard<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
unsafe { &*(self.region.ptr as *const T) }
}
}
impl<T> Drop for ProtectedRegionGuard<'_, T> {
fn drop(&mut self) {
self.region.pkey.set(PKEY_DISABLE_ACCESS);
}
}
#[cfg(target_arch = "x86_64")]
fn is_ospke_supported() -> bool {
const EAX_VENDOR_INFO: u32 = 0x0;
const EAX_STRUCTURED_EXTENDED_FEATURE_INFO: u32 = 0x7;
const OSPKE_BIT: u32 = 0b10000;
struct CpuIdResult {
eax: u32,
ecx: u32,
}
fn cpuid_count(eax: u32, ecx: u32) -> CpuIdResult {
let result = unsafe { arch::__cpuid_count(eax, ecx) };
CpuIdResult {
eax: result.eax,
ecx: result.ecx,
}
}
let vendor_leaf = cpuid_count(EAX_VENDOR_INFO, 0);
if vendor_leaf.eax < EAX_STRUCTURED_EXTENDED_FEATURE_INFO {
return false;
}
let info = cpuid_count(EAX_STRUCTURED_EXTENDED_FEATURE_INFO, 0);
info.ecx & OSPKE_BIT == OSPKE_BIT
}
#[cfg(not(target_arch = "x86_64"))]
fn is_ospke_supported() -> bool {
false
}
const PKEY_DISABLE_ACCESS: usize = 1;
const PAGE_SIZE: usize = 4096;
#[derive(Debug, thiserror::Error)]
pub enum ProtectionError {
#[error("Protection keys are not supported by this CPU")]
Unsupported,
#[error("Failed to allocate protection keys")]
PkeyAllocationFailed(#[source] std::io::Error),
#[error("Failed to map memory")]
MMapFailed(#[source] std::io::Error),
#[error("Failed to protect memory")]
MProtectFailed(#[source] std::io::Error),
}
#[cfg(test)]
mod tests {
use super::*;
struct TestStruct {
test: bool,
value: u32,
}
impl Drop for TestStruct {
fn drop(&mut self) {
println!("dropped {}", self.value);
}
}
#[test]
fn test_protected_region() {
let pkey = ProtectionKeys::new(false).unwrap();
{
let region = pkey
.make_region(TestStruct {
test: true,
value: 123,
})
.unwrap();
let guard = region.lock();
println!("{}, {}", guard.test, guard.value);
}
}
}