use std::ffi::OsStr;
use std::mem::MaybeUninit;
use std::os::windows::ffi::OsStrExt;
pub struct WindowsString<const STACK_BUFFER_SIZE: usize> {
heap: Option<Vec<u16>>,
stack: MaybeUninit<[u16; STACK_BUFFER_SIZE]>,
}
impl<const STACK_BUFFER_SIZE: usize> WindowsString<STACK_BUFFER_SIZE> {
pub fn new<S>(s: S) -> std::io::Result<Self>
where
S: AsRef<OsStr>,
{
let mut rv = Self {
heap: None,
stack: MaybeUninit::uninit(),
};
rv.convert_and_store(s.as_ref())?;
Ok(rv)
}
pub fn as_wide(&self) -> *const u16 {
if self.heap.is_some() {
unsafe { self.heap.as_ref().map(|v| v.as_ptr()).unwrap_unchecked() }
} else {
self.stack.as_ptr() as *const u16
}
}
fn convert_and_store(&mut self, s: &OsStr) -> std::io::Result<()> {
if s.len() + 1 > STACK_BUFFER_SIZE {
return self.use_heap(s);
}
self.use_stack(s)
}
fn use_heap(&mut self, s: &OsStr) -> std::io::Result<()> {
let mut capacity = s.len() + 1;
loop {
let mut buffer = Vec::with_capacity(capacity);
capacity = buffer.capacity(); let mut encoder = s.encode_wide();
let mut p = buffer.as_mut_ptr() as *mut u16;
let base = p as *const u16;
let mut finished = false;
for _ in 0..capacity {
if let Some(c) = encoder.next() {
#[cfg(not(feature = "skip_null_check"))]
{
if c == 0 {
return Err(Self::no_nuls());
}
}
unsafe { *p = c };
p = unsafe { p.add(1) };
} else {
unsafe { *p = 0 };
finished = true;
let stored = unsafe { p.offset_from(base) } + 1;
unsafe { buffer.set_len(stored as usize) };
self.heap = Some(buffer);
break;
}
}
if finished {
break;
}
capacity *= 2;
}
Ok(())
}
fn use_stack(&mut self, s: &OsStr) -> std::io::Result<()> {
let mut encoder = s.encode_wide();
let mut p = self.stack.as_mut_ptr() as *mut u16;
let mut finished = false;
for _ in 0..STACK_BUFFER_SIZE {
if let Some(c) = encoder.next() {
#[cfg(not(feature = "skip_null_check"))]
{
if c == 0 {
return Err(Self::no_nuls());
}
}
unsafe { *p = c };
p = unsafe { p.add(1) };
} else {
unsafe { *p = 0 };
finished = true;
break;
}
}
if !finished {
return self.use_heap(s);
}
Ok(())
}
#[cfg(not(feature = "skip_null_check"))]
fn no_nuls() -> std::io::Error {
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"strings passed to WinAPI cannot contain NULs",
)
}
}