use core::cell::UnsafeCell;
use core::ffi::c_void;
use uefi::Status;
use uefi_raw::protocol::block::{BlockIoProtocol, Lba};
use vck_common::{VckError, VckResult};
use crate::hook::BlockIoHookEngine;
type ReadBlocksFn = unsafe extern "efiapi" fn(
this: *const BlockIoProtocol,
media_id: u32,
lba: Lba,
buffer_size: usize,
buffer: *mut c_void,
) -> Status;
type WriteBlocksFn = unsafe extern "efiapi" fn(
this: *mut BlockIoProtocol,
media_id: u32,
lba: Lba,
buffer_size: usize,
buffer: *const c_void,
) -> Status;
#[derive(Clone, Copy)]
struct HookEntry {
protocol: *const BlockIoProtocol,
original_read: ReadBlocksFn,
original_write: WriteBlocksFn,
engine: *const BlockIoHookEngine,
}
const MAX_HOOKS: usize = 8;
struct HookTable {
entries: UnsafeCell<[Option<HookEntry>; MAX_HOOKS]>,
}
unsafe impl Sync for HookTable {}
static HOOK_TABLE: HookTable = HookTable {
entries: UnsafeCell::new([None; MAX_HOOKS]),
};
unsafe fn find_entry(protocol: *const BlockIoProtocol) -> Option<HookEntry> {
let entries = &*HOOK_TABLE.entries.get();
entries
.iter()
.flatten()
.find(|e| e.protocol == protocol)
.copied()
}
pub struct BlockIoHook {
protocol: *mut BlockIoProtocol,
}
impl BlockIoHook {
#[allow(clippy::not_unsafe_ptr_arg_deref)]
pub fn install(
protocol: *mut BlockIoProtocol,
engine: *const BlockIoHookEngine,
) -> VckResult<Self> {
unsafe {
let original_read = (*protocol).read_blocks;
let original_write = (*protocol).write_blocks;
let entries = &mut *HOOK_TABLE.entries.get();
let slot = entries
.iter_mut()
.find(|e| e.is_none())
.ok_or(VckError::Io(alloc::string::String::from(
"Block IO hook table is full",
)))?;
*slot = Some(HookEntry {
protocol: protocol as *const BlockIoProtocol,
original_read,
original_write,
engine,
});
(*protocol).read_blocks = hooked_read_blocks;
(*protocol).write_blocks = hooked_write_blocks;
}
Ok(Self { protocol })
}
pub fn uninstall(self) -> VckResult<()> {
unsafe {
if let Some(entry) = find_entry(self.protocol as *const BlockIoProtocol) {
(*self.protocol).read_blocks = entry.original_read;
(*self.protocol).write_blocks = entry.original_write;
let entries = &mut *HOOK_TABLE.entries.get();
for slot in entries.iter_mut() {
if matches!(slot, Some(e) if core::ptr::eq(e.protocol, self.protocol)) {
*slot = None;
}
}
}
}
Ok(())
}
}
pub unsafe extern "efiapi" fn hooked_read_blocks(
this: *const BlockIoProtocol,
media_id: u32,
lba: Lba,
buffer_size: usize,
buffer: *mut c_void,
) -> Status {
let Some(entry) = find_entry(this) else {
return Status::DEVICE_ERROR;
};
let status = (entry.original_read)(this, media_id, lba, buffer_size, buffer);
if status != Status::SUCCESS {
return status;
}
let media = (*this).media;
if media.is_null() || buffer.is_null() || buffer_size == 0 {
return status;
}
let block_size = (*media).block_size as usize;
if block_size != 0 && !entry.engine.is_null() {
let buf = core::slice::from_raw_parts_mut(buffer as *mut u8, buffer_size);
let engine = &*entry.engine;
if engine.decrypt_after_read(lba, block_size, buf).is_err() {
return Status::DEVICE_ERROR;
}
}
status
}
pub unsafe extern "efiapi" fn hooked_write_blocks(
this: *mut BlockIoProtocol,
media_id: u32,
lba: Lba,
buffer_size: usize,
buffer: *const c_void,
) -> Status {
let Some(entry) = find_entry(this as *const BlockIoProtocol) else {
return Status::DEVICE_ERROR;
};
let media = (*this).media;
if media.is_null() || buffer.is_null() || buffer_size == 0 {
return (entry.original_write)(this, media_id, lba, buffer_size, buffer);
}
let block_size = (*media).block_size as usize;
if block_size == 0 || entry.engine.is_null() {
return (entry.original_write)(this, media_id, lba, buffer_size, buffer);
}
let src = core::slice::from_raw_parts(buffer as *const u8, buffer_size);
let engine = &*entry.engine;
match engine.encrypt_before_write(lba, block_size, src) {
Err(_) => Status::DEVICE_ERROR,
Ok(encrypted) => (entry.original_write)(
this,
media_id,
lba,
encrypted.len(),
encrypted.as_ptr() as *const c_void,
),
}
}