use crate::device::{Device, DeviceOperations};
use crate::io_request::IoRequest;
use alloc::{boxed::Box, vec::Vec};
use moon_wire_codec::{parse_header, Codec, CodecError};
use wdk_sys::{
NTSTATUS, STATUS_INVALID_BUFFER_SIZE, STATUS_INVALID_DEVICE_REQUEST, STATUS_INVALID_PARAMETER,
};
pub struct Empty;
impl Codec for Empty {
fn encode(&self, _out: &mut Vec<u8>) {}
fn decode(_input: &[u8]) -> Result<(Self, usize), CodecError> {
Ok((Empty, 0))
}
}
pub trait IoctlHandler {
fn on_ioctl(&self, device: &Device, request: &mut IoRequest) -> Result<(), NTSTATUS>;
}
pub struct IoctlOpMux {
pub control_code: u32,
pub handlers: Vec<(u16, Box<dyn IoctlHandler>)>,
}
impl IoctlOpMux {
pub fn new(control_code: u32) -> Self {
Self {
control_code,
handlers: Vec::new(),
}
}
pub fn add_handler(&mut self, op: u16, h: Box<dyn IoctlHandler>) {
self.handlers.push((op, h));
}
pub fn add_typed_handler<Req, Resp, F>(&mut self, op: u16, handler: F)
where
Req: Codec + 'static,
Resp: Codec + 'static,
F: Fn(Req) -> Result<Resp, NTSTATUS> + 'static,
{
self.handlers
.push((op, Box::new(TypedIoctlHandler::new(op, handler))));
}
}
impl DeviceOperations for IoctlOpMux {
fn create(&self, _device: &Device, request: &mut IoRequest) -> Result<(), NTSTATUS> {
request.complete(Ok(0));
Ok(())
}
fn close(&self, _device: &Device, request: &mut IoRequest) -> Result<(), NTSTATUS> {
request.complete(Ok(0));
Ok(())
}
fn cleanup(&self, _device: &Device, request: &mut IoRequest) -> Result<(), NTSTATUS> {
request.complete(Ok(0));
Ok(())
}
fn others(&self, device: &Device, request: &mut IoRequest) -> Result<(), NTSTATUS> {
let code = request.control_code();
if code != self.control_code {
return Err(STATUS_INVALID_DEVICE_REQUEST);
}
let in_len = request.input_buffer_length() as usize;
if in_len < 12 {
return Err(STATUS_INVALID_PARAMETER);
}
let buff = request.system_buffer();
let in_slice = unsafe { core::slice::from_raw_parts(buff as *const u8, in_len) };
let Ok((op, _)) = parse_header(in_slice) else {
return Err(STATUS_INVALID_PARAMETER);
};
let op = op as u16;
for (k, h) in &self.handlers {
if *k == op {
match h.on_ioctl(device, request) {
Ok(_) => return Ok(()),
Err(e) => return Err(e),
}
}
}
Err(STATUS_INVALID_DEVICE_REQUEST)
}
}
pub struct TypedIoctlHandler<Req, Resp, F> {
handler: F,
opcode: u16,
_req: core::marker::PhantomData<Req>,
_resp: core::marker::PhantomData<Resp>,
}
impl<Req, Resp, F> TypedIoctlHandler<Req, Resp, F> {
pub fn new(opcode: u16, handler: F) -> Self {
Self {
handler,
opcode,
_req: core::marker::PhantomData,
_resp: core::marker::PhantomData,
}
}
}
impl<Req, Resp, F> IoctlHandler for TypedIoctlHandler<Req, Resp, F>
where
Req: Codec,
Resp: Codec,
F: Fn(Req) -> Result<Resp, NTSTATUS>,
{
fn on_ioctl(&self, _device: &Device, request: &mut IoRequest) -> Result<(), NTSTATUS> {
let buff = request.system_buffer();
let in_len = request.input_buffer_length() as usize;
let out_len = request.output_buffer_length() as usize;
if in_len < 12 {
return Err(STATUS_INVALID_PARAMETER);
}
let in_slice = unsafe { core::slice::from_raw_parts(buff as *const u8, in_len) };
let Ok((_, req)) = moon_wire_codec::decode_frame::<Req>(in_slice) else {
return Err(STATUS_INVALID_PARAMETER);
};
let resp = (self.handler)(req)?;
let frame = moon_wire_codec::encode_frame(self.opcode, &resp);
if frame.len() > out_len {
return Err(STATUS_INVALID_BUFFER_SIZE);
}
unsafe {
core::ptr::copy_nonoverlapping(frame.as_ptr(), buff as *mut u8, frame.len());
}
request.complete(Ok(frame.len()));
Ok(())
}
}