use core::ffi::c_void;
use core::ptr::NonNull;
use super::device::Device;
use super::error::{status, KmError, KmResult, NtStatus};
use super::irp::{Irp, IrpMajorFunction};
use super::string::UnicodeString;
#[repr(C)]
pub struct DriverObjectRaw {
pub type_: i16,
pub size: i16,
pub device_object: *mut c_void,
pub flags: u32,
pub driver_start: *mut c_void,
pub driver_size: u32,
pub driver_section: *mut c_void,
pub driver_extension: *mut DriverExtensionRaw,
pub driver_name: UnicodeStringRaw,
pub hardware_database: *mut UnicodeStringRaw,
pub fast_io_dispatch: *mut c_void,
pub driver_init: *mut c_void,
pub driver_start_io: *mut c_void,
pub driver_unload: Option<DriverUnload>,
pub major_function: [Option<DriverDispatch>; 28],
}
#[repr(C)]
pub struct DriverExtensionRaw {
pub driver_object: *mut DriverObjectRaw,
pub add_device: *mut c_void,
pub count: u32,
pub service_key_name: UnicodeStringRaw,
}
#[repr(C)]
#[derive(Clone, Copy)]
pub struct UnicodeStringRaw {
pub length: u16,
pub maximum_length: u16,
pub buffer: *mut u16,
}
pub type DriverEntry = unsafe extern "system" fn(
driver_object: *mut DriverObjectRaw,
registry_path: *const UnicodeStringRaw,
) -> NtStatus;
pub type DriverUnload = unsafe extern "system" fn(driver_object: *mut DriverObjectRaw);
pub type DriverDispatch = unsafe extern "system" fn(
device_object: *mut c_void,
irp: *mut c_void,
) -> NtStatus;
pub struct Driver {
raw: NonNull<DriverObjectRaw>,
}
impl Driver {
pub unsafe fn from_raw(ptr: *mut DriverObjectRaw) -> Option<Self> {
NonNull::new(ptr).map(|raw| Self { raw })
}
pub fn as_raw(&self) -> *mut DriverObjectRaw {
self.raw.as_ptr()
}
pub fn set_unload(&mut self, unload: DriverUnload) {
unsafe {
(*self.raw.as_ptr()).driver_unload = Some(unload);
}
}
pub fn set_major_function(&mut self, function: IrpMajorFunction, handler: DriverDispatch) {
let index = function as usize;
if index < 28 {
unsafe {
(*self.raw.as_ptr()).major_function[index] = Some(handler);
}
}
}
pub fn set_all_major_functions(&mut self, handler: DriverDispatch) {
unsafe {
for i in 0..28 {
(*self.raw.as_ptr()).major_function[i] = Some(handler);
}
}
}
pub fn name(&self) -> &[u16] {
unsafe {
let name = &(*self.raw.as_ptr()).driver_name;
if name.buffer.is_null() || name.length == 0 {
return &[];
}
core::slice::from_raw_parts(name.buffer, (name.length / 2) as usize)
}
}
pub fn start_address(&self) -> *mut c_void {
unsafe { (*self.raw.as_ptr()).driver_start }
}
pub fn size(&self) -> u32 {
unsafe { (*self.raw.as_ptr()).driver_size }
}
pub fn create_device(
&mut self,
name: &UnicodeString,
device_type: u32,
characteristics: u32,
exclusive: bool,
) -> KmResult<Device> {
let mut device_object: *mut c_void = core::ptr::null_mut();
let status = unsafe {
IoCreateDevice(
self.raw.as_ptr(),
0, name.as_ptr() as *const _,
device_type,
characteristics,
if exclusive { 1 } else { 0 },
&mut device_object,
)
};
if !status::nt_success(status) {
return Err(KmError::DeviceCreationFailed {
reason: "IoCreateDevice failed",
});
}
unsafe { Device::from_raw(device_object) }
.ok_or(KmError::DeviceCreationFailed {
reason: "device object is null",
})
}
}
pub struct DriverBuilder {
driver: Driver,
}
impl DriverBuilder {
pub unsafe fn new(ptr: *mut DriverObjectRaw) -> Option<Self> {
unsafe { Driver::from_raw(ptr) }.map(|driver| Self { driver })
}
pub fn unload(mut self, handler: DriverUnload) -> Self {
self.driver.set_unload(handler);
self
}
pub fn create(mut self, handler: DriverDispatch) -> Self {
self.driver.set_major_function(IrpMajorFunction::Create, handler);
self
}
pub fn close(mut self, handler: DriverDispatch) -> Self {
self.driver.set_major_function(IrpMajorFunction::Close, handler);
self
}
pub fn device_control(mut self, handler: DriverDispatch) -> Self {
self.driver.set_major_function(IrpMajorFunction::DeviceControl, handler);
self
}
pub fn read(mut self, handler: DriverDispatch) -> Self {
self.driver.set_major_function(IrpMajorFunction::Read, handler);
self
}
pub fn write(mut self, handler: DriverDispatch) -> Self {
self.driver.set_major_function(IrpMajorFunction::Write, handler);
self
}
pub fn major_function(mut self, function: IrpMajorFunction, handler: DriverDispatch) -> Self {
self.driver.set_major_function(function, handler);
self
}
pub fn build(self) -> Driver {
self.driver
}
}
pub trait DriverImpl {
fn init(driver: &mut Driver, registry_path: &UnicodeString) -> KmResult<()>;
fn unload(driver: &Driver);
fn create(_device: *mut c_void, _irp: &mut Irp) -> NtStatus {
status::STATUS_SUCCESS
}
fn close(_device: *mut c_void, _irp: &mut Irp) -> NtStatus {
status::STATUS_SUCCESS
}
fn device_control(_device: *mut c_void, _irp: &mut Irp) -> NtStatus {
status::STATUS_NOT_IMPLEMENTED
}
fn read(_device: *mut c_void, _irp: &mut Irp) -> NtStatus {
status::STATUS_NOT_IMPLEMENTED
}
fn write(_device: *mut c_void, _irp: &mut Irp) -> NtStatus {
status::STATUS_NOT_IMPLEMENTED
}
}
#[macro_export]
macro_rules! driver_entry {
($impl_type:ty) => {
#[no_mangle]
pub unsafe extern "system" fn DriverEntry(
driver_object: *mut $crate::km::driver::DriverObjectRaw,
registry_path: *const $crate::km::driver::UnicodeStringRaw,
) -> $crate::km::error::NtStatus {
unsafe { __driver_entry_impl::<$impl_type>(driver_object, registry_path) }
}
unsafe fn __driver_entry_impl<T: $crate::km::driver::DriverImpl>(
driver_object: *mut $crate::km::driver::DriverObjectRaw,
registry_path: *const $crate::km::driver::UnicodeStringRaw,
) -> $crate::km::error::NtStatus {
use $crate::km::driver::DriverImpl;
use $crate::km::error::status;
let Some(mut driver) = (unsafe { $crate::km::Driver::from_raw(driver_object) }) else {
return status::STATUS_INVALID_PARAMETER;
};
driver.set_unload(__driver_unload::<T>);
driver.set_major_function($crate::km::IrpMajorFunction::Create, __dispatch_create::<T>);
driver.set_major_function($crate::km::IrpMajorFunction::Close, __dispatch_close::<T>);
driver.set_major_function($crate::km::IrpMajorFunction::DeviceControl, __dispatch_device_control::<T>);
driver.set_major_function($crate::km::IrpMajorFunction::Read, __dispatch_read::<T>);
driver.set_major_function($crate::km::IrpMajorFunction::Write, __dispatch_write::<T>);
if registry_path.is_null() {
return status::STATUS_INVALID_PARAMETER;
}
let reg_string = $crate::km::UnicodeString::empty();
match T::init(&mut driver, ®_string) {
Ok(()) => status::STATUS_SUCCESS,
Err(e) => e.to_ntstatus(),
}
}
unsafe extern "system" fn __driver_unload<T: $crate::km::driver::DriverImpl>(
driver_object: *mut $crate::km::driver::DriverObjectRaw
) {
if let Some(driver) = unsafe { $crate::km::Driver::from_raw(driver_object) } {
T::unload(&driver);
}
}
unsafe extern "system" fn __dispatch_create<T: $crate::km::driver::DriverImpl>(
device: *mut core::ffi::c_void,
irp: *mut core::ffi::c_void,
) -> $crate::km::error::NtStatus {
if let Some(mut irp_wrapper) = unsafe { $crate::km::Irp::from_raw(irp) } {
let status = T::create(device, &mut irp_wrapper);
irp_wrapper.complete(status);
status
} else {
$crate::km::error::status::STATUS_INVALID_PARAMETER
}
}
unsafe extern "system" fn __dispatch_close<T: $crate::km::driver::DriverImpl>(
device: *mut core::ffi::c_void,
irp: *mut core::ffi::c_void,
) -> $crate::km::error::NtStatus {
if let Some(mut irp_wrapper) = unsafe { $crate::km::Irp::from_raw(irp) } {
let status = T::close(device, &mut irp_wrapper);
irp_wrapper.complete(status);
status
} else {
$crate::km::error::status::STATUS_INVALID_PARAMETER
}
}
unsafe extern "system" fn __dispatch_device_control<T: $crate::km::driver::DriverImpl>(
device: *mut core::ffi::c_void,
irp: *mut core::ffi::c_void,
) -> $crate::km::error::NtStatus {
if let Some(mut irp_wrapper) = unsafe { $crate::km::Irp::from_raw(irp) } {
let status = T::device_control(device, &mut irp_wrapper);
irp_wrapper.complete(status);
status
} else {
$crate::km::error::status::STATUS_INVALID_PARAMETER
}
}
unsafe extern "system" fn __dispatch_read<T: $crate::km::driver::DriverImpl>(
device: *mut core::ffi::c_void,
irp: *mut core::ffi::c_void,
) -> $crate::km::error::NtStatus {
if let Some(mut irp_wrapper) = unsafe { $crate::km::Irp::from_raw(irp) } {
let status = T::read(device, &mut irp_wrapper);
irp_wrapper.complete(status);
status
} else {
$crate::km::error::status::STATUS_INVALID_PARAMETER
}
}
unsafe extern "system" fn __dispatch_write<T: $crate::km::driver::DriverImpl>(
device: *mut core::ffi::c_void,
irp: *mut core::ffi::c_void,
) -> $crate::km::error::NtStatus {
if let Some(mut irp_wrapper) = unsafe { $crate::km::Irp::from_raw(irp) } {
let status = T::write(device, &mut irp_wrapper);
irp_wrapper.complete(status);
status
} else {
$crate::km::error::status::STATUS_INVALID_PARAMETER
}
}
};
}
extern "system" {
fn IoCreateDevice(
DriverObject: *mut DriverObjectRaw,
DeviceExtensionSize: u32,
DeviceName: *const c_void,
DeviceType: u32,
DeviceCharacteristics: u32,
Exclusive: u8,
DeviceObject: *mut *mut c_void,
) -> NtStatus;
}