moon-driver-core 0.1.3

Windows WDK driver core helpers: device, IOCTL, IRP, symlink
use crate::device::{Device, DeviceOperations};
use crate::io_request::IoRequest;
use alloc::{boxed::Box, vec::Vec};
use moon_wire_codec::{parse_header, Codec, CodecError};
use wdk_sys::{
    NTSTATUS, STATUS_INVALID_BUFFER_SIZE, STATUS_INVALID_DEVICE_REQUEST, STATUS_INVALID_PARAMETER,
};

pub struct Empty;

impl Codec for Empty {
    fn encode(&self, _out: &mut Vec<u8>) {}
    fn decode(_input: &[u8]) -> Result<(Self, usize), CodecError> {
        Ok((Empty, 0))
    }
}

pub trait IoctlHandler {
    fn on_ioctl(&self, device: &Device, request: &mut IoRequest) -> Result<(), NTSTATUS>;
}

pub struct IoctlOpMux {
    pub control_code: u32,
    pub handlers: Vec<(u16, Box<dyn IoctlHandler>)>,
}
impl IoctlOpMux {
    pub fn new(control_code: u32) -> Self {
        Self {
            control_code,
            handlers: Vec::new(),
        }
    }
    pub fn add_handler(&mut self, op: u16, h: Box<dyn IoctlHandler>) {
        self.handlers.push((op, h));
    }
    pub fn add_typed_handler<Req, Resp, F>(&mut self, op: u16, handler: F)
    where
        Req: Codec + 'static,
        Resp: Codec + 'static,
        F: Fn(Req) -> Result<Resp, NTSTATUS> + 'static,
    {
        self.handlers
            .push((op, Box::new(TypedIoctlHandler::new(op, handler))));
    }
}

impl DeviceOperations for IoctlOpMux {
    fn create(&self, _device: &Device, request: &mut IoRequest) -> Result<(), NTSTATUS> {
        request.complete(Ok(0));
        Ok(())
    }
    fn close(&self, _device: &Device, request: &mut IoRequest) -> Result<(), NTSTATUS> {
        request.complete(Ok(0));
        Ok(())
    }
    fn cleanup(&self, _device: &Device, request: &mut IoRequest) -> Result<(), NTSTATUS> {
        request.complete(Ok(0));
        Ok(())
    }
    fn others(&self, device: &Device, request: &mut IoRequest) -> Result<(), NTSTATUS> {
        let code = request.control_code();
        if code != self.control_code {
            return Err(STATUS_INVALID_DEVICE_REQUEST);
        }
        let in_len = request.input_buffer_length() as usize;
        if in_len < 12 {
            return Err(STATUS_INVALID_PARAMETER);
        }
        let buff = request.system_buffer();
        let in_slice = unsafe { core::slice::from_raw_parts(buff as *const u8, in_len) };
        let Ok((op, _)) = parse_header(in_slice) else {
            return Err(STATUS_INVALID_PARAMETER);
        };
        let op = op as u16;
        for (k, h) in &self.handlers {
            if *k == op {
                match h.on_ioctl(device, request) {
                    Ok(_) => return Ok(()),
                    Err(e) => return Err(e),
                }
            }
        }
        Err(STATUS_INVALID_DEVICE_REQUEST)
    }
}

pub struct TypedIoctlHandler<Req, Resp, F> {
    handler: F,
    opcode: u16,
    _req: core::marker::PhantomData<Req>,
    _resp: core::marker::PhantomData<Resp>,
}

impl<Req, Resp, F> TypedIoctlHandler<Req, Resp, F> {
    pub fn new(opcode: u16, handler: F) -> Self {
        Self {
            handler,
            opcode,
            _req: core::marker::PhantomData,
            _resp: core::marker::PhantomData,
        }
    }
}

impl<Req, Resp, F> IoctlHandler for TypedIoctlHandler<Req, Resp, F>
where
    Req: Codec,
    Resp: Codec,
    F: Fn(Req) -> Result<Resp, NTSTATUS>,
{
    fn on_ioctl(&self, _device: &Device, request: &mut IoRequest) -> Result<(), NTSTATUS> {
        let buff = request.system_buffer();
        let in_len = request.input_buffer_length() as usize;
        let out_len = request.output_buffer_length() as usize;

        if in_len < 12 {
            return Err(STATUS_INVALID_PARAMETER);
        }
        let in_slice = unsafe { core::slice::from_raw_parts(buff as *const u8, in_len) };
        let Ok((_, req)) = moon_wire_codec::decode_frame::<Req>(in_slice) else {
            return Err(STATUS_INVALID_PARAMETER);
        };

        let resp = (self.handler)(req)?;

        let frame = moon_wire_codec::encode_frame(self.opcode, &resp);
        if frame.len() > out_len {
            return Err(STATUS_INVALID_BUFFER_SIZE);
        }
        unsafe {
            core::ptr::copy_nonoverlapping(frame.as_ptr(), buff as *mut u8, frame.len());
        }
        request.complete(Ok(frame.len()));
        Ok(())
    }
}