use crate::{Error, Result};
use std::ffi::{CString, OsString};
use std::os::windows::ffi::OsStrExt;
use std::ptr;
const CP_ACP: u32 = 0;
#[link(name = "kernel32")]
extern "system" {
fn WideCharToMultiByte(
CodePage: u32,
dwFlags: u32,
lpWideCharStr: *const u16,
cchWideChar: i32,
lpMultiByteStr: *mut u8,
cbMultiByte: i32,
lpDefaultChar: *mut u8,
lpUsedDefaultChar: *mut i32,
) -> i32;
fn GetACP() -> u32;
}
#[rustversion::since(1.74)]
mod util {
use std::ffi::{CString, OsStr, OsString};
use std::iter::Map;
use std::slice;
pub fn os_str_iter<F>(s: &OsStr, f: F) -> Map<slice::Iter<u8>, F>
where
F: Fn(&u8) -> u8,
{
s.as_encoded_bytes().iter().map(f)
}
pub unsafe fn ascii_os_string_into_c_string(s: OsString) -> CString {
CString::from_vec_unchecked(s.into_encoded_bytes())
}
pub fn vec_from_os_string(capacity: usize, s: OsString) -> Vec<u8> {
let mut v = s.into_encoded_bytes();
if v.capacity() < capacity {
v.reserve_exact(capacity - v.capacity());
}
v
}
}
#[rustversion::before(1.74)]
mod util {
use std::ffi::{CString, OsStr, OsString};
use std::os::windows::ffi::OsStrExt;
pub fn os_str_iter<F>(s: &OsStr, _: F) -> std::os::windows::ffi::EncodeWide
where
F: Fn(&u8) -> u8,
{
s.encode_wide()
}
pub unsafe fn ascii_os_string_into_c_string(s: OsString) -> CString {
let mut vec = Vec::with_capacity(s.len() + 1);
for c in s.encode_wide() {
vec.push(c as u8);
}
CString::from_vec_unchecked(vec)
}
pub fn vec_from_os_string(capacity: usize, _: OsString) -> Vec<u8> {
Vec::with_capacity(capacity)
}
}
pub fn os_string_into_ansi_c_string(s: OsString, name: &str) -> Result<CString> {
let mut contains_nul = false;
let mut ascii_only = true;
for c in util::os_str_iter(&s, |b| *b) {
if c == 0 {
contains_nul = true;
} else if c > 127 {
ascii_only = false;
}
}
if contains_nul {
return Err(Error::invalid_argument(format!(
"{} cannot contain nul characters",
name
)));
}
if ascii_only {
return Ok(unsafe { util::ascii_os_string_into_c_string(s) });
}
let wide_chars: Vec<u16> = s.as_os_str().encode_wide().collect();
let mut used_default_char = 0;
let len = unsafe {
WideCharToMultiByte(
CP_ACP,
0,
wide_chars.as_ptr(),
wide_chars.len() as i32,
ptr::null_mut(),
0,
ptr::null_mut(),
&mut used_default_char,
)
};
if used_default_char != 0 {
return Err(Error::invalid_argument(format!(
"{} cannot contain characters incompatible with the Windows ANSI code page {}",
name,
unsafe { GetACP() }
)));
}
let mut vec = util::vec_from_os_string(len as usize + 1, s);
unsafe {
let len = WideCharToMultiByte(
CP_ACP,
0,
wide_chars.as_ptr(),
wide_chars.len() as i32,
vec.as_mut_ptr(),
len,
ptr::null_mut(),
ptr::null_mut(),
);
vec.set_len(len as usize);
}
Ok(unsafe { CString::from_vec_unchecked(vec) })
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_os_string_into_ansi_c_string() {
let mut data = Vec::<(&str, &[u8])>::new();
data.push(("Hello", b"Hello"));
match unsafe { GetACP() } {
932 => {
data.push(("こんにちは", b"\x82\xb1\x82\xf1\x82\xc9\x82\xbf\x82\xcd"));
}
1252 => {
data.push(("Grüß Gott", b"Gr\xfc\xdf Gott"));
}
_ => (),
}
for data in data {
assert_eq!(
os_string_into_ansi_c_string(data.0.into(), "dummy").unwrap(),
CString::new(data.1).unwrap()
);
}
os_string_into_ansi_c_string("crab 🦀".into(), "dummy").unwrap_err();
}
}