use super::{JvmError, char_encoding_generic::*};
use crate::windows_sys;
use std::{
borrow::Cow,
convert::TryInto,
ffi::{CStr, c_int, c_uint},
io,
mem::MaybeUninit,
ptr,
};
type WSize = c_int;
type WCodepage = c_uint;
const MAX_INPUT_LEN: usize = 1048576;
pub(super) fn str_to_cstr_win32(
s: Cow<str>,
needed_codepage: WCodepage,
) -> Result<Cow<'static, CStr>, JvmError> {
if s.len() > MAX_INPUT_LEN {
return Err(JvmError::OptStringTooLong {
opt_string: s.into_owned(),
});
}
fn convert_error(s: Cow<str>) -> JvmError {
JvmError::OptStringTranscodeFailure {
opt_string: s.into_owned(),
error: io::Error::last_os_error(),
}
}
let s_utf16: Vec<u16> = s.encode_utf16().collect();
let s_utf16_len: WSize = s_utf16
.len()
.try_into()
.expect("UTF-16 form of input string is too long");
let conversion_flags = match needed_codepage {
42
| 50220
| 50221
| 50222
| 50225
| 50227
| 50229
| 54936
| 57002..=57011
| 65000
| 65001 => 0,
_ => windows_sys::WC_COMPOSITECHECK | windows_sys::WC_NO_BEST_FIT_CHARS,
};
let mut is_non_representable: Option<MaybeUninit<_>> = match needed_codepage {
windows_sys::CP_UTF7 | windows_sys::CP_UTF8 => None,
_ => Some(MaybeUninit::uninit()),
};
let required_buffer_space = unsafe {
windows_sys::WideCharToMultiByte(
needed_codepage,
conversion_flags,
s_utf16.as_ptr(),
s_utf16_len,
ptr::null_mut(),
0,
ptr::null(),
match &mut is_non_representable {
Some(x) => x.as_mut_ptr(),
None => ptr::null_mut(),
},
)
};
if required_buffer_space == 0 {
drop(s_utf16);
return Err(convert_error(s));
}
if let Some(is_non_representable) = is_non_representable {
let is_non_representable = unsafe { is_non_representable.assume_init() };
if is_non_representable != 0 {
drop(s_utf16);
return Err(JvmError::OptStringNotRepresentable {
opt_string: s.into_owned(),
});
}
}
let required_buffer_space_usize: usize = required_buffer_space as _;
let required_buffer_space_usize_with_nul: usize = required_buffer_space_usize + 1;
let mut output = Vec::<u8>::with_capacity(required_buffer_space_usize_with_nul);
let used_buffer_space = unsafe {
windows_sys::WideCharToMultiByte(
needed_codepage,
conversion_flags,
s_utf16.as_ptr(),
s_utf16_len,
output.as_mut_ptr(),
required_buffer_space,
ptr::null(),
ptr::null_mut(),
)
};
drop(s_utf16);
if used_buffer_space == 0 {
drop(output);
return Err(convert_error(s));
}
let used_buffer_space_usize: usize = used_buffer_space as usize;
unsafe {
output.set_len(used_buffer_space_usize);
}
unsafe { bytes_to_cstr(Cow::Owned(output), Some(s)) }
}
pub(super) fn str_to_cstr_win32_default_codepage(s: Cow<str>) -> Result<Cow<CStr>, JvmError> {
let needed_codepage = unsafe { windows_sys::GetACP() };
if needed_codepage == windows_sys::CP_UTF8 {
return utf8_to_cstr(s);
}
str_to_cstr_win32(s, needed_codepage)
}
#[cfg(test)]
fn codepage_to_string_win32(
codepage_string: impl AsRef<[u8]>,
codepage: WCodepage,
max_expected_utf16_len: WSize,
) -> io::Result<String> {
let codepage_string_slice = codepage_string.as_ref();
let codepage_string_slice_len: WSize = codepage_string_slice
.len()
.try_into()
.expect("`codepage_string`'s length is too large to transcode with Win32");
let mut buf = Vec::<u16>::with_capacity(
max_expected_utf16_len
.try_into()
.expect("expected_utf16_len is negative or exceeds address space"),
);
let utf16_units_transcoded = unsafe {
windows_sys::MultiByteToWideChar(
codepage,
0,
codepage_string_slice.as_ptr() as *const _,
codepage_string_slice_len,
buf.as_mut_ptr(),
max_expected_utf16_len,
)
};
if utf16_units_transcoded == 0 {
return Err(io::Error::last_os_error());
}
unsafe {
buf.set_len(utf16_units_transcoded as _);
}
drop(codepage_string);
let string =
String::from_utf16(buf.as_slice()).expect("`MultiByteToWideChar` generated invalid UTF-16");
Ok(string)
}
#[test]
fn test() {
use assert_matches::assert_matches;
{
let result = str_to_cstr_win32("Hello, world 😎".into(), windows_sys::CP_UTF8).unwrap();
assert_eq!(
result.to_bytes_with_nul(),
b"Hello, world \xf0\x9f\x98\x8e\0"
);
assert_matches!(result, Cow::Owned(_));
}
{
let result = str_to_cstr_win32("Hello, world 😎\0".into(), windows_sys::CP_UTF8).unwrap();
assert_eq!(
result.to_bytes_with_nul(),
b"Hello, world \xf0\x9f\x98\x8e\0"
);
}
{
let result = str_to_cstr_win32("Hello, world 😎".into(), 1252).unwrap_err();
let error_string = assert_matches!(result, JvmError::OptStringNotRepresentable { opt_string } => opt_string);
assert_eq!(error_string, "Hello, world 😎");
}
{
let result = str_to_cstr_win32("Hello, world™".into(), 1252).unwrap();
assert_eq!(result.to_bytes_with_nul(), b"Hello, world\x99\0");
assert_matches!(result, Cow::Owned(_));
}
}
#[test]
fn test_overflow() {
use assert_matches::assert_matches;
#[track_caller]
fn check_and_clear_error_opt_string(expected_opt_string: &str, error: &mut JvmError) {
if let Some(actual_opt_string) = error.opt_string_mut() {
if actual_opt_string != expected_opt_string {
panic!("opt_string was mangled in moving it to an error");
}
*actual_opt_string = String::new();
}
}
#[track_caller]
fn expect_success(
expected_opt_string: &str,
result: Result<Cow<'static, CStr>, JvmError>,
) -> Cow<'static, CStr> {
match result {
Ok(ok) => ok,
Err(mut error) => {
check_and_clear_error_opt_string(expected_opt_string, &mut error);
panic!("unexpected transcoding failure: {}", error)
}
}
}
#[track_caller]
fn expect_successful_roundtrip(
expected_opt_string: &str,
result: Result<Cow<'static, CStr>, JvmError>,
) -> Cow<'static, CStr> {
let string = expect_success(expected_opt_string, result);
assert!(
expected_opt_string.as_bytes() == string.to_bytes(),
"opt_string was transcoded successfully but mangled"
);
string
}
#[track_caller]
fn expect_opt_string_too_long(
expected_opt_string: &str,
result: Result<Cow<'static, CStr>, JvmError>,
) {
let mut error = match result {
Err(err) => err,
Ok(ok) => {
assert!(
expected_opt_string.as_bytes() == ok.to_bytes(),
"transcoding unexpectedly succeeded and resulted in mangled output"
);
panic!("transcoding unexpectedly succeeded")
}
};
check_and_clear_error_opt_string(expected_opt_string, &mut error);
assert_matches!(error, JvmError::OptStringTooLong { .. });
}
{
let string = vec![b'H'; MAX_INPUT_LEN.checked_add(1).unwrap()];
let mut string = String::from_utf8(string).unwrap();
expect_opt_string_too_long(
&string,
str_to_cstr_win32(string.as_str().into(), windows_sys::CP_UTF8),
);
assert_eq!(string.pop(), Some('H'));
expect_successful_roundtrip(
&string,
str_to_cstr_win32(string.as_str().into(), windows_sys::CP_UTF8),
);
}
{
let string_byte_pairs = vec![u16::from_be(0xdfbf); MAX_INPUT_LEN / 2];
let string: &str =
std::str::from_utf8(bytemuck::cast_slice(string_byte_pairs.as_slice())).unwrap();
expect_successful_roundtrip(
string,
str_to_cstr_win32(string.into(), windows_sys::CP_UTF8),
);
{
let result = expect_success(
string,
str_to_cstr_win32(string.into(), windows_sys::CP_UTF7),
);
let result: String = codepage_to_string_win32(
result.to_bytes(),
windows_sys::CP_UTF7,
(string.len() / 2).try_into().unwrap(),
)
.unwrap();
assert!(result == string, "didn't roundtrip via UTF-7");
}
}
{
let string_byte_pairs = vec![u16::from_be(0xc2ae); MAX_INPUT_LEN / 2];
let string: &str =
std::str::from_utf8(bytemuck::cast_slice(string_byte_pairs.as_slice())).unwrap();
let result = expect_success(string, str_to_cstr_win32(string.into(), 1252));
assert!(
result.to_bytes().iter().all(|byte| *byte == 0xae),
"string didn't transcode to Windows-1252 properly"
);
}
}