use core::cell::UnsafeCell;
use core::ffi::c_void;
use uefi::Status;
use uefi_raw::protocol::block::{BlockIoProtocol, Lba};
use vck_common::{VckError, VckResult};
use crate::hook::BlockIoHookEngine;
type ReadBlocksFn = unsafe extern "efiapi" fn(
this: *const BlockIoProtocol,
media_id: u32,
lba: Lba,
buffer_size: usize,
buffer: *mut c_void,
) -> Status;
#[derive(Clone, Copy)]
struct HookEntry {
protocol: *const BlockIoProtocol,
original: ReadBlocksFn,
engine: *const BlockIoHookEngine,
}
const MAX_HOOKS: usize = 8;
struct HookTable {
entries: UnsafeCell<[Option<HookEntry>; MAX_HOOKS]>,
}
unsafe impl Sync for HookTable {}
static HOOK_TABLE: HookTable = HookTable {
entries: UnsafeCell::new([None; MAX_HOOKS]),
};
unsafe fn find_entry(protocol: *const BlockIoProtocol) -> Option<HookEntry> {
let entries = &*HOOK_TABLE.entries.get();
entries
.iter()
.flatten()
.find(|e| e.protocol == protocol)
.copied()
}
pub struct BlockIoHook {
protocol: *mut BlockIoProtocol,
}
impl BlockIoHook {
#[allow(clippy::not_unsafe_ptr_arg_deref)]
pub fn install(
protocol: *mut BlockIoProtocol,
engine: *const BlockIoHookEngine,
) -> VckResult<Self> {
unsafe {
let original = (*protocol).read_blocks;
let entries = &mut *HOOK_TABLE.entries.get();
let slot = entries
.iter_mut()
.find(|e| e.is_none())
.ok_or(VckError::Io(alloc::string::String::from(
"Block IO hook table is full",
)))?;
*slot = Some(HookEntry {
protocol: protocol as *const BlockIoProtocol,
original,
engine,
});
(*protocol).read_blocks = hooked_read_blocks;
}
Ok(Self { protocol })
}
pub fn uninstall(self) -> VckResult<()> {
unsafe {
if let Some(entry) = find_entry(self.protocol as *const BlockIoProtocol) {
(*self.protocol).read_blocks = entry.original;
let entries = &mut *HOOK_TABLE.entries.get();
for slot in entries.iter_mut() {
if matches!(slot, Some(e) if core::ptr::eq(e.protocol, self.protocol)) {
*slot = None;
}
}
}
}
Ok(())
}
}
pub unsafe extern "efiapi" fn hooked_read_blocks(
this: *const BlockIoProtocol,
media_id: u32,
lba: Lba,
buffer_size: usize,
buffer: *mut c_void,
) -> Status {
let Some(entry) = find_entry(this) else {
return Status::DEVICE_ERROR;
};
let status = (entry.original)(this, media_id, lba, buffer_size, buffer);
if status != Status::SUCCESS {
return status;
}
let media = (*this).media;
if media.is_null() || buffer.is_null() || buffer_size == 0 {
return status;
}
let block_size = (*media).block_size as usize;
if block_size != 0 && !entry.engine.is_null() {
let buf = core::slice::from_raw_parts_mut(buffer as *mut u8, buffer_size);
let engine = &*entry.engine;
if engine.decrypt_after_read(lba, block_size, buf).is_err() {
return Status::DEVICE_ERROR;
}
}
status
}