use std::{cell::Cell, env};
use encoding_rs::{Encoding, WINDOWS_1252};
use tracing::instrument;
use crate::{EncodingConversionError, NativeEncodingError, UnknownEncodingError};
thread_local! {
static NWNRS_ENCODING: Cell<&'static Encoding> = Cell::new(WINDOWS_1252);
static NATIVE_ENCODING: Cell<Option<&'static Encoding>> = const { Cell::new(None) };
}
pub fn get_nwnrs_encoding() -> &'static Encoding {
NWNRS_ENCODING.with(Cell::get)
}
#[must_use]
pub fn get_nwnrs_encoding_name() -> &'static str {
get_nwnrs_encoding().name()
}
#[instrument(level = "debug", skip_all, err, fields(label = %label))]
pub fn set_nwnrs_encoding(label: &str) -> Result<(), UnknownEncodingError> {
let encoding =
Encoding::for_label(label.as_bytes()).ok_or_else(|| UnknownEncodingError::new(label))?;
NWNRS_ENCODING.with(|slot| slot.set(encoding));
Ok(())
}
#[instrument(level = "debug", err)]
pub fn get_native_encoding() -> Result<&'static Encoding, NativeEncodingError> {
if let Some(encoding) = NATIVE_ENCODING.with(Cell::get) {
return Ok(encoding);
}
let encoding = detect_system_native_encoding()?;
NATIVE_ENCODING.with(|slot| slot.set(Some(encoding)));
Ok(encoding)
}
#[instrument(level = "debug", err)]
pub fn get_native_encoding_name() -> Result<&'static str, NativeEncodingError> {
Ok(get_native_encoding()?.name())
}
#[instrument(level = "debug", skip_all, err, fields(label = %label))]
pub fn set_native_encoding(label: &str) -> Result<(), UnknownEncodingError> {
let encoding =
Encoding::for_label(label.as_bytes()).ok_or_else(|| UnknownEncodingError::new(label))?;
NATIVE_ENCODING.with(|slot| slot.set(Some(encoding)));
Ok(())
}
pub fn clear_native_encoding() {
NATIVE_ENCODING.with(|slot| slot.set(None));
}
#[instrument(level = "debug", err)]
pub fn detect_system_native_encoding() -> Result<&'static Encoding, NativeEncodingError> {
#[cfg(windows)]
{
detect_windows_native_encoding()
}
#[cfg(not(windows))]
{
detect_unix_native_encoding()
}
}
#[instrument(level = "debug", skip_all, err, fields(input_len = value.len()))]
pub fn to_nwnrs_encoding(value: &str) -> Result<Vec<u8>, EncodingConversionError> {
encode_with(get_nwnrs_encoding(), value, "encode text for NWN")
}
#[instrument(level = "debug", skip_all, err, fields(input_len = bytes.len()))]
pub fn from_nwnrs_encoding(bytes: &[u8]) -> Result<String, EncodingConversionError> {
decode_with(get_nwnrs_encoding(), bytes, "decode text from NWN")
}
#[instrument(level = "debug", skip_all, err, fields(input_len = value.len()))]
pub fn to_native_encoding(value: &str) -> Result<Vec<u8>, EncodingConversionError> {
let encoding = get_native_encoding().map_err(|error| {
EncodingConversionError::new(error.to_string(), "encode text for native output")
})?;
encode_with(encoding, value, "encode text for native output")
}
#[instrument(level = "debug", skip_all, err, fields(input_len = bytes.len()))]
pub fn from_native_encoding(bytes: &[u8]) -> Result<String, EncodingConversionError> {
let encoding = get_native_encoding().map_err(|error| {
EncodingConversionError::new(error.to_string(), "decode text from native input")
})?;
decode_with(encoding, bytes, "decode text from native input")
}
pub(crate) fn encode_with(
encoding: &'static Encoding,
value: &str,
operation: &'static str,
) -> Result<Vec<u8>, EncodingConversionError> {
let (encoded, _, had_errors) = encoding.encode(value);
if had_errors {
Err(EncodingConversionError::new(encoding.name(), operation))
} else {
Ok(encoded.into_owned())
}
}
pub(crate) fn decode_with(
encoding: &'static Encoding,
bytes: &[u8],
operation: &'static str,
) -> Result<String, EncodingConversionError> {
let (decoded, _, had_errors) = encoding.decode(bytes);
if had_errors {
Err(EncodingConversionError::new(encoding.name(), operation))
} else {
Ok(decoded.into_owned())
}
}
#[cfg(not(windows))]
fn detect_unix_native_encoding() -> Result<&'static Encoding, NativeEncodingError> {
for key in ["LC_ALL", "LC_CTYPE", "LANG"] {
if let Ok(value) = env::var(key)
&& let Some(encoding) = parse_locale_encoding(&value)
{
return Ok(encoding);
}
}
Err(NativeEncodingError::new(
"unable to determine native encoding from LC_ALL, LC_CTYPE, or LANG",
))
}
#[cfg(windows)]
fn detect_windows_native_encoding() -> Result<&'static Encoding, NativeEncodingError> {
if let Ok(chcp_output) = std::process::Command::new("chcp").output() {
if let Ok(output_str) = String::from_utf8(chcp_output.stdout) {
if let Some(code_page_str) = output_str
.lines()
.find(|line| line.contains("Active code page:"))
.and_then(|line| line.split(':').nth(1))
.map(|s| s.trim())
{
if let Ok(code_page) = code_page_str.parse::<u16>() {
if let Some(encoding) = codepage::to_encoding(code_page) {
return Ok(encoding);
}
}
}
}
}
if let Some(encoding) = codepage::to_encoding(1252) {
Ok(encoding)
} else {
Err(NativeEncodingError::new(
"unable to determine Windows native encoding",
))
}
}
pub(crate) fn parse_locale_encoding(locale: &str) -> Option<&'static Encoding> {
let trimmed = locale.trim();
if trimmed.is_empty() {
return None;
}
let without_modifier = trimmed.split('@').next().unwrap_or(trimmed);
let candidate = without_modifier
.split_once('.')
.map_or(without_modifier, |(_, encoding)| encoding);
Encoding::for_label(candidate.trim().as_bytes())
}