use std::{
convert::{Infallible, TryFrom, TryInto},
fmt::{Debug, Display},
io,
};
use utfx::U16CString;
use windows::{core::PCWSTR, Win32::System::Registry::{RegDeleteValueW, RegQueryValueExW, RegSetValueExW, HKEY, REG_BINARY, REG_DWORD, REG_DWORD_BIG_ENDIAN, REG_EXPAND_SZ, REG_FULL_RESOURCE_DESCRIPTOR, REG_LINK, REG_MULTI_SZ, REG_NONE, REG_QWORD, REG_RESOURCE_LIST, REG_RESOURCE_REQUIREMENTS_LIST, REG_SZ, REG_VALUE_TYPE}};
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum Error {
#[error("Data not found for value with name '{0}'")]
NotFound(String, #[source] io::Error),
#[error("Permission denied for given value name: '{0}'")]
PermissionDenied(String, #[source] io::Error),
#[error("Unhandled type: 0x{0:x}")]
UnhandledType(u32),
#[error("Invalid null found in string")]
InvalidNul(#[from] utfx::NulError<u16>),
#[error("Missing null terminator in string")]
MissingNul(#[from] utfx::MissingNulError<u16>),
#[error("Missing null terminator in multi string")]
MissingMultiNul,
#[error("Invalid UTF-16")]
InvalidUtf16(#[from] std::string::FromUtf16Error),
#[error("An unknown IO error occurred for given value name: '{0}'")]
Unknown(String, #[source] io::Error),
#[deprecated(note = "not used")]
#[error("Error determining required buffer size for value '{0}'")]
BufferSize(String, #[source] io::Error),
#[deprecated(note = "not used")]
#[error("Invalid buffer size for UTF-16 string: {0}")]
InvalidBufferSize(usize),
}
impl Error {
#[cfg(test)]
pub(crate) fn is_not_found(&self) -> bool {
match self {
Error::NotFound(_, _) => true,
_ => false,
}
}
fn from_code(code: i32, value_name: String) -> Self {
let err = std::io::Error::from_raw_os_error(code);
return match err.kind() {
io::ErrorKind::NotFound => Error::NotFound(value_name, err),
io::ErrorKind::PermissionDenied => Error::PermissionDenied(value_name, err),
_ => Error::Unknown(value_name, err),
};
}
}
impl From<Infallible> for Error {
fn from(_: Infallible) -> Self {
unsafe { std::hint::unreachable_unchecked() }
}
}
#[repr(u32)]
#[derive(Debug, Copy, Clone)]
pub(crate) enum Type {
None = 0,
String = 1,
ExpandString = 2,
Binary = 3,
U32 = 4,
U32BE = 5,
Link = 6,
MultiString = 7,
ResourceList = 8,
FullResourceDescriptor = 9,
ResourceRequirementsList = 10,
U64 = 11,
}
impl Type {
const MAX: u32 = 11;
}
impl From<Type> for REG_VALUE_TYPE {
fn from(ty: Type) -> Self {
match ty {
Type::None => REG_NONE,
Type::String => REG_SZ,
Type::ExpandString => REG_EXPAND_SZ,
Type::Binary => REG_BINARY,
Type::U32 => REG_DWORD,
Type::U32BE => REG_DWORD_BIG_ENDIAN,
Type::Link => REG_LINK,
Type::MultiString => REG_MULTI_SZ,
Type::ResourceList => REG_RESOURCE_LIST,
Type::FullResourceDescriptor => REG_FULL_RESOURCE_DESCRIPTOR,
Type::ResourceRequirementsList => REG_RESOURCE_REQUIREMENTS_LIST,
Type::U64 => REG_QWORD,
}
}
}
impl TryFrom<REG_VALUE_TYPE> for Type {
type Error = Error;
fn try_from(value: REG_VALUE_TYPE) -> Result<Self, Self::Error> {
match value {
REG_NONE => Ok(Type::None),
REG_SZ => Ok(Type::String),
REG_EXPAND_SZ => Ok(Type::ExpandString),
REG_BINARY => Ok(Type::Binary),
REG_DWORD => Ok(Type::U32),
REG_DWORD_BIG_ENDIAN => Ok(Type::U32BE),
REG_LINK => Ok(Type::Link),
REG_MULTI_SZ => Ok(Type::MultiString),
REG_RESOURCE_LIST => Ok(Type::ResourceList),
REG_FULL_RESOURCE_DESCRIPTOR => Ok(Type::FullResourceDescriptor),
REG_RESOURCE_REQUIREMENTS_LIST => Ok(Type::ResourceRequirementsList),
REG_QWORD => Ok(Type::U64),
ty => Err(Error::UnhandledType(ty.0)),
}
}
}
#[derive(Clone)]
pub enum Data {
None,
String(U16CString),
ExpandString(U16CString),
Binary(Vec<u8>),
U32(u32),
U32BE(u32),
Link,
MultiString(Vec<U16CString>),
ResourceList,
FullResourceDescriptor,
ResourceRequirementsList,
U64(u64),
}
impl Debug for Data {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Data::None => f.write_str("None"),
Data::String(s) => {
write!(f, "String({:?})", s.to_string_lossy())
}
Data::ExpandString(s) => {
write!(f, "ExpandString({:?})", s.to_string_lossy())
}
Data::Binary(s) => write!(f, "Binary({:?})", s),
Data::U32(x) => write!(f, "U32({})", x),
Data::U32BE(x) => write!(f, "U32BE({})", x),
Data::Link => f.write_str("Link"),
x @ Data::MultiString(_) => {
write!(f, "MultiString({})", x.to_string())
}
Data::ResourceList => f.write_str("ResourceList"),
Data::FullResourceDescriptor => f.write_str("FullResourceDescriptor"),
Data::ResourceRequirementsList => f.write_str("ResourceRequirementsList"),
Data::U64(x) => write!(f, "U64({})", x),
}
}
}
impl Display for Data {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Data::None => f.write_str("<None>"),
Data::String(s) => f.write_str(&s.to_string_lossy()),
Data::ExpandString(s) => f.write_str(&s.to_string_lossy()),
Data::Binary(s) => write!(
f,
"<{}>",
s.iter()
.map(|x| format!("{:02x}", x))
.collect::<Vec<_>>()
.join(" ")
),
Data::U32(x) => write!(f, "0x{:016x}", x),
Data::U32BE(x) => write!(f, "0x{:016x}", x),
Data::Link => f.write_str("<Link>"),
Data::MultiString(x) => f
.debug_list()
.entries(x.iter().map(|x| x.to_string_lossy()))
.finish(),
Data::ResourceList => f.write_str("<Resource List>"),
Data::FullResourceDescriptor => f.write_str("<Full Resource Descriptor>"),
Data::ResourceRequirementsList => f.write_str("<Resource Requirements List>"),
Data::U64(x) => write!(f, "0x{:032x}", x),
}
}
}
impl Data {
fn as_type(&self) -> Type {
match self {
Data::None => Type::None,
Data::String(_) => Type::String,
Data::ExpandString(_) => Type::ExpandString,
Data::Binary(_) => Type::Binary,
Data::U32(_) => Type::U32,
Data::U32BE(_) => Type::U32BE,
Data::Link => Type::Link,
Data::MultiString(_) => Type::MultiString,
Data::ResourceList => Type::ResourceList,
Data::FullResourceDescriptor => Type::FullResourceDescriptor,
Data::ResourceRequirementsList => Type::ResourceRequirementsList,
Data::U64(_) => Type::U64,
}
}
fn to_bytes(&self) -> Vec<u8> {
match self {
Data::None => vec![],
Data::String(s) => string_to_utf16_byte_vec(s),
Data::ExpandString(s) => string_to_utf16_byte_vec(s),
Data::Binary(x) => x.to_vec(),
Data::U32(x) => x.to_le_bytes().to_vec(),
Data::U32BE(x) => x.to_be_bytes().to_vec(),
Data::Link => vec![],
Data::MultiString(x) => multi_string_bytes(x),
Data::ResourceList => vec![],
Data::FullResourceDescriptor => vec![],
Data::ResourceRequirementsList => vec![],
Data::U64(x) => x.to_le_bytes().to_vec(),
}
}
}
#[inline(always)]
fn multi_string_bytes(s: &[U16CString]) -> Vec<u8> {
let mut vec = s
.iter()
.flat_map(|x| string_to_utf16_byte_vec(&*x))
.collect::<Vec<u8>>();
vec.push(0);
vec.push(0);
vec
}
#[inline(always)]
fn string_to_utf16_byte_vec(s: &U16CString) -> Vec<u8> {
s.to_owned()
.into_vec_with_nul()
.into_iter()
.flat_map(|x| x.to_le_bytes().to_vec())
.collect()
}
fn parse_wide_string_nul(vec: Vec<u16>) -> Result<U16CString, Error> {
Ok(U16CString::from_vec_with_nul(vec)?)
}
fn parse_wide_multi_string(vec: Vec<u16>) -> Result<Vec<U16CString>, Error> {
let len = vec.len();
if len < 2 || vec[len - 1] != 0 || vec[len - 2] != 0 {
return Err(Error::MissingMultiNul);
}
(&vec[0..vec.len() - 2])
.split(|x| *x == 0)
.map(U16CString::new)
.collect::<Result<Vec<_>, _>>()
.map_err(Error::InvalidNul)
}
#[inline]
pub(crate) fn set_value<S>(base: HKEY, value_name: S, data: &Data) -> Result<(), Error>
where
S: TryInto<U16CString>,
S::Error: Into<Error>,
{
let value_name = value_name.try_into().map_err(Into::into)?;
let raw_ty = data.as_type();
let vec = data.to_bytes();
let result = unsafe {
RegSetValueExW(
base,
PCWSTR(value_name.as_ptr()),
0,
raw_ty.into(),
Some(&vec),
)
};
if result.is_err() {
return Err(Error::from_code(result.0 as i32, value_name.to_string_lossy()));
}
Ok(())
}
#[inline]
pub(crate) fn delete_value<S>(base: HKEY, value_name: S) -> Result<(), Error>
where
S: TryInto<U16CString>,
S::Error: Into<Error>,
{
let value_name = value_name.try_into().map_err(Into::into)?;
let result = unsafe { RegDeleteValueW(base, PCWSTR(value_name.as_ptr())) };
if result.is_err() {
return Err(Error::from_code(result.0 as i32, value_name.to_string_lossy()));
}
Ok(())
}
#[inline]
pub(crate) fn query_value<S>(base: HKEY, value_name: S) -> Result<Data, Error>
where
S: TryInto<U16CString>,
S::Error: Into<Error>,
{
let value_name = value_name.try_into().map_err(Into::into)?;
let mut sz: u32 = 0;
let result = unsafe {
RegQueryValueExW(
base,
PCWSTR(value_name.as_ptr()),
None,
None,
None,
Some(&mut sz),
)
};
if result.is_err() {
return Err(Error::from_code(result.0 as i32, value_name.to_string_lossy()));
}
let mut buf: Vec<u16> = vec![0u16; (sz / 2 + sz % 2) as usize];
let mut ty = REG_VALUE_TYPE::default();
let result = unsafe {
RegQueryValueExW(
base,
PCWSTR(value_name.as_ptr()),
None,
Some(&mut ty),
Some(buf.as_mut_ptr() as *mut u8),
Some(&mut sz),
)
};
if result.is_err() {
return Err(Error::from_code(result.0 as i32, value_name.to_string_lossy()));
}
parse_value_type_data(ty, buf)
}
pub fn u16_to_u8_vec(mut vec: Vec<u16>) -> Vec<u8> {
unsafe {
let capacity = vec.capacity();
let len = vec.len();
let ptr = vec.as_mut_ptr();
std::mem::forget(vec);
Vec::from_raw_parts(ptr as *mut u8, 2 * len, 2 * capacity)
}
}
#[inline(always)]
pub(crate) fn parse_value_type_data(ty: REG_VALUE_TYPE, buf: Vec<u16>) -> Result<Data, Error> {
let ty = Type::try_from(ty)?;
match ty {
Type::None => return Ok(Data::None),
Type::String => return parse_wide_string_nul(buf).map(Data::String),
Type::ExpandString => return parse_wide_string_nul(buf).map(Data::ExpandString),
Type::Link => return Ok(Data::Link),
Type::MultiString => return parse_wide_multi_string(buf).map(Data::MultiString),
Type::ResourceList => return Ok(Data::ResourceList),
Type::FullResourceDescriptor => return Ok(Data::FullResourceDescriptor),
Type::ResourceRequirementsList => return Ok(Data::ResourceRequirementsList),
_ => {}
}
let buf = u16_to_u8_vec(buf);
match ty {
Type::Binary => Ok(Data::Binary(buf)),
Type::U32 => Ok(Data::U32(u32::from_le_bytes([
buf[0], buf[1], buf[2], buf[3],
]))),
Type::U32BE => Ok(Data::U32BE(u32::from_be_bytes([
buf[0], buf[1], buf[2], buf[3],
]))),
Type::U64 => Ok(Data::U64(u64::from_le_bytes([
buf[0], buf[1], buf[2], buf[3], buf[4], buf[5], buf[6], buf[7],
]))),
_ => unreachable!(),
}
}
#[derive(Debug, thiserror::Error)]
#[error("Invalid or unknown type value: {0:#x}")]
pub struct TryIntoTypeError(u32);
impl TryFrom<u32> for Type {
type Error = TryIntoTypeError;
fn try_from(ty: u32) -> Result<Self, Self::Error> {
if ty > Type::MAX {
return Err(TryIntoTypeError(ty));
}
Ok(unsafe { std::mem::transmute::<u32, Type>(ty) })
}
}