moon-driver-utils 0.1.0

Windows Kernel Utils
extern crate alloc;

use core::{
    ffi::{c_void, CStr},
    slice,
};

use alloc::{string::String, vec::Vec};
use wdk::println;
use wdk_sys::{
    ntddk::{wcslen, wcsrchr},
    UNICODE_STRING,
};

pub struct SafeUnicodeString {
    pub unicode_string: UNICODE_STRING,
    _buffer: Vec<u16>,
}

impl SafeUnicodeString {
    pub fn as_ptr(&mut self) -> *mut UNICODE_STRING {
        &mut self.unicode_string as _
    }

    pub fn as_ref(&self) -> &UNICODE_STRING {
        &self.unicode_string
    }
}

pub fn string_to_u16_slice(input: &str) -> Vec<u16> {
    let utf16_iter = input.encode_utf16();
    let utf16_vec: Vec<u16> = utf16_iter.collect();
    utf16_vec
}

pub fn u16_slice_to_unicode_string(s: &[u16]) -> UNICODE_STRING {
    let len = s.len();

    let n = if len > 0 && s[len - 1] == 0 {
        len - 1
    } else {
        len
    };

    UNICODE_STRING {
        Length: (n * 2) as u16,
        MaximumLength: (len * 2) as u16,
        Buffer: s.as_ptr() as _,
    }
}

pub fn u16_slice_to_string(s: &[u16]) -> String {
    match String::from_utf16(s) {
        Ok(s) => return s,
        Err(_) => {
            println!("from utf16 error");
        }
    }

    String::new()
}

pub fn copy_unicode_string(dst: &mut UNICODE_STRING, src: &UNICODE_STRING) {
    dst.Buffer = src.Buffer;
    dst.Length = src.Length;
    dst.MaximumLength = src.MaximumLength;
}

pub fn get_file_name_from_path(path: &UNICODE_STRING, file_name: &mut UNICODE_STRING) {
    unsafe {
        let last_back_slash = wcsrchr(path.Buffer, 0x5c) as *mut c_void;

        if !last_back_slash.is_null() {
            file_name.Buffer = last_back_slash.add(2) as _;
            file_name.Length = (wcslen(last_back_slash.add(1) as _) * 2) as _;
            file_name.MaximumLength = file_name.Length + 2;
        } else {
            copy_unicode_string(file_name, path);
        }
    }
}

pub fn str_to_unicode_string(s: &str) -> SafeUnicodeString {
    let mut wide: Vec<u16> = s.encode_utf16().chain(Some(0)).collect(); // 包含 NULL 终止符

    let unicode_string = UNICODE_STRING {
        Length: ((wide.len() - 1) * 2) as u16,  // 不包含 NULL
        MaximumLength: (wide.len() * 2) as u16, // 包含 NULL
        Buffer: wide.as_mut_ptr(),
    };

    SafeUnicodeString {
        unicode_string,
        _buffer: wide, // 持有缓冲区
    }
}

pub fn unicode_string_to_string(s: &UNICODE_STRING) -> String {
    let buffer_slice = unsafe { slice::from_raw_parts(s.Buffer, s.Length as usize / 2) };
    u16_slice_to_string(buffer_slice)
}

pub fn cstr_to_rust_str(cstr_ptr: *mut u8) -> String {
    unsafe {
        let c_str = CStr::from_ptr(cstr_ptr as _);
        c_str.to_string_lossy().into_owned()
    }
}