#![cfg_attr(not(test), no_std)]
extern crate alloc;
use core::slice;
use alloc::vec::Vec;
use core::char::decode_utf16;
use core::fmt;
use core::mem::size_of;
use core::ops::Add;
use windows_sys::core::{PCWSTR, PWSTR};
use windows_sys::Win32::Foundation::UNICODE_STRING;
pub struct OwnedUnicodeString {
unicode_string: UNICODE_STRING,
buffer: Vec<u16>,
}
impl OwnedUnicodeString {
fn is_null_terminated(&self) -> bool {
self.buffer.last() == Some(&0)
}
fn ensure_is_null_terminated(&mut self) {
if !self.is_null_terminated() {
self.buffer.push(0u16);
self.unicode_string.MaximumLength += size_of::<u16>() as u16;
}
}
fn compute_size(&mut self) {
let maximum_length = (self.buffer.len() * size_of::<u16>()) as u16;
let mut count = 0;
if self.is_null_terminated() {
for &value in self.buffer.iter().rev() {
if value == 0 {
count += 1;
} else {
break;
}
}
}
let length= maximum_length - (count * size_of::<u16>()) as u16;
self.unicode_string.Length = length;
self.unicode_string.MaximumLength = maximum_length
}
}
impl From<Vec<u16>> for OwnedUnicodeString {
fn from(mut value: Vec<u16>) -> Self {
let unicode_string = UNICODE_STRING {
Length: 0,
MaximumLength: 0,
Buffer: value.as_mut_ptr(),
};
let mut result = Self {
unicode_string,
buffer: value,
};
result.compute_size();
result
}
}
impl From<&str> for OwnedUnicodeString {
fn from(value: &str) -> Self {
Self::from(value.encode_utf16().collect::<Vec<u16>>())
}
}
impl AsRef<UNICODE_STRING> for OwnedUnicodeString {
fn as_ref(&self) -> &UNICODE_STRING {
&self.unicode_string
}
}
impl Into<PCWSTR> for &mut OwnedUnicodeString {
fn into(self) -> PCWSTR {
self.ensure_is_null_terminated();
self.buffer.as_ptr()
}
}
impl Into<PWSTR> for &mut OwnedUnicodeString{
fn into(self) -> PWSTR {
self.ensure_is_null_terminated();
self.buffer.as_mut_ptr()
}
}
impl fmt::Display for OwnedUnicodeString {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let utf16_slice = unsafe {
slice::from_raw_parts(
self.unicode_string.Buffer,
(self.unicode_string.Length / size_of::<u16>() as u16) as usize
)
};
for utf16 in decode_utf16(utf16_slice.iter().copied()) {
match utf16 {
Ok(ch) => write!(f, "{}", ch)?,
Err(_) => write!(f, "{}", "�")?,
}
}
Ok(())
}
}
impl Add for OwnedUnicodeString {
type Output = OwnedUnicodeString;
fn add(mut self, rhs: Self) -> Self::Output {
let rhs_slice = unsafe {
slice::from_raw_parts(
rhs.unicode_string.Buffer,
(rhs.unicode_string.Length / size_of::<u16>() as u16) as usize
)
};
self.buffer.extend(rhs_slice);
self.compute_size();
self
}
}
impl Add<&str> for OwnedUnicodeString {
type Output = OwnedUnicodeString;
fn add(self, rhs: &str) -> Self::Output {
let other = OwnedUnicodeString::from(rhs);
self + other
}
}
impl PartialEq for OwnedUnicodeString {
fn eq(&self, other: &Self) -> bool {
let self_slice = &self.buffer[..(self.unicode_string.Length / size_of::<u16>() as u16) as usize];
let other_slice = &other.buffer[..(other.unicode_string.Length / size_of::<u16>() as u16) as usize];
self_slice == other_slice
}
}
#[cfg(test)]
mod test_krnlstring {
use alloc::{format, vec};
use super::*;
#[test]
fn test_fmt() {
let owned_unicode = OwnedUnicodeString::from("Hello, world !");
let formated = format!("{}", owned_unicode);
assert_eq!(formated,"Hello, world !");
}
#[test]
fn test_eq() {
let owned_unicode = OwnedUnicodeString::from("Hello, world !");
let same = OwnedUnicodeString::from("Hello, world !");
let result = owned_unicode == same;
assert_eq!(result,true)
}
#[test]
fn test_add() {
let owned_unicode = OwnedUnicodeString::from("Hello, world !");
let other_str: &str = " Bye";
let other = OwnedUnicodeString::from(" !");
let expected1 = OwnedUnicodeString::from("Hello, world ! Bye");
let expected2 = OwnedUnicodeString::from("Hello, world ! Bye !");
let concat1 = owned_unicode + other_str;
let mut result = concat1 == expected1;
assert_eq!(result,true);
let concat2 = concat1 + other;
result = concat2 == expected2;
assert_eq!(result,true);
}
#[test]
fn test_empty_string() {
let owned_unicode = OwnedUnicodeString::from("");
let expected = OwnedUnicodeString::from(Vec::new());
let result = owned_unicode == expected;
assert_eq!(result, true);
}
#[test]
fn test_unicode_characters() {
let unicode_str = "こんにちは"; let owned_unicode = OwnedUnicodeString::from(unicode_str);
let formated = format!("{}", owned_unicode);
assert_eq!(formated, unicode_str);
}
#[test]
fn test_conversion_to_pcwstr_pwstr() {
let mut owned_unicode = OwnedUnicodeString::from("Hello, world!");
let pcwstr: PCWSTR = (&mut owned_unicode).into();
let pwstr: PWSTR = (&mut owned_unicode).into();
unsafe {
assert_eq!(*pcwstr, *pwstr);
}
assert!(owned_unicode.is_null_terminated());
}
#[test]
fn test_add_special_characters() {
let owned_unicode = OwnedUnicodeString::from("Line1\n");
let other = OwnedUnicodeString::from("Line2\tEnd");
let expected = OwnedUnicodeString::from("Line1\nLine2\tEnd");
let result = owned_unicode + other;
assert_eq!(result == expected, true);
}
#[test]
fn test_buffer_overflow_protection() {
let mut owned_unicode = OwnedUnicodeString::from("Test");
owned_unicode.buffer.push(1);
owned_unicode.compute_size();
assert!(owned_unicode.unicode_string.Length <= owned_unicode.unicode_string.MaximumLength);
}
#[test]
fn test_multiple_consecutive_null_characters() {
let mut owned_unicode = OwnedUnicodeString::from("Test");
owned_unicode.buffer.extend(vec![0, 0, 0]);
owned_unicode.compute_size();
let expected_length = (4 * size_of::<u16>()) as u16;
assert_eq!(owned_unicode.unicode_string.Length, expected_length);
}
#[test]
fn test_large_input_handling() {
let large_string = "A".repeat(10000);
let owned_unicode = OwnedUnicodeString::from(large_string.as_str());
assert_eq!(owned_unicode.unicode_string.Length, (10000 * size_of::<u16>()) as u16);
}
#[test]
fn test_equality_case_sensitivity() {
let upper_case = OwnedUnicodeString::from("HELLO");
let lower_case = OwnedUnicodeString::from("hello");
assert_ne!(upper_case == lower_case, true);
}
#[test]
fn test_fmt_invalid_utf16_sequence() {
let mut owned_unicode = OwnedUnicodeString::from("Hello");
owned_unicode.buffer.push(0xD800); owned_unicode.compute_size();
let formated = format!("{}", owned_unicode);
assert_eq!(formated, "Hello�");
}
}