use crate::error::{Error, Result};
use crate::string::{from_wide, to_wide, WideString};
use windows::Win32::Foundation::{
ERROR_MORE_DATA, ERROR_NO_MORE_ITEMS, ERROR_SUCCESS, WIN32_ERROR,
};
use windows::Win32::System::Registry::{
RegCloseKey, RegCreateKeyExW, RegDeleteKeyW, RegDeleteValueW, RegEnumKeyExW, RegEnumValueW,
RegOpenKeyExW, RegQueryValueExW, RegSetValueExW, HKEY, HKEY_CLASSES_ROOT, HKEY_CURRENT_CONFIG,
HKEY_CURRENT_USER, HKEY_LOCAL_MACHINE, HKEY_USERS, KEY_ALL_ACCESS, KEY_CREATE_SUB_KEY,
KEY_ENUMERATE_SUB_KEYS, KEY_QUERY_VALUE, KEY_READ, KEY_SET_VALUE, KEY_WOW64_32KEY,
KEY_WOW64_64KEY, KEY_WRITE, REG_BINARY, REG_DWORD, REG_EXPAND_SZ, REG_MULTI_SZ,
REG_OPTION_NON_VOLATILE, REG_QWORD, REG_SAM_FLAGS, REG_SZ, REG_VALUE_TYPE,
};
fn check_error(err: WIN32_ERROR) -> Result<()> {
if err == ERROR_SUCCESS {
Ok(())
} else {
Err(Error::Windows(windows::core::Error::from(err)))
}
}
#[derive(Clone, Copy, Debug)]
pub struct RootKey(pub HKEY);
impl RootKey {
pub const CLASSES_ROOT: Self = Self(HKEY_CLASSES_ROOT);
pub const CURRENT_USER: Self = Self(HKEY_CURRENT_USER);
pub const LOCAL_MACHINE: Self = Self(HKEY_LOCAL_MACHINE);
pub const USERS: Self = Self(HKEY_USERS);
pub const CURRENT_CONFIG: Self = Self(HKEY_CURRENT_CONFIG);
}
#[derive(Clone, Copy, Debug)]
pub struct Access(pub REG_SAM_FLAGS);
impl Access {
pub const READ: Self = Self(KEY_READ);
pub const WRITE: Self = Self(KEY_WRITE);
pub const ALL: Self = Self(KEY_ALL_ACCESS);
pub const QUERY_VALUE: Self = Self(KEY_QUERY_VALUE);
pub const SET_VALUE: Self = Self(KEY_SET_VALUE);
pub const CREATE_SUB_KEY: Self = Self(KEY_CREATE_SUB_KEY);
pub const ENUMERATE_SUB_KEYS: Self = Self(KEY_ENUMERATE_SUB_KEYS);
pub const WOW64_32: Self = Self(KEY_WOW64_32KEY);
pub const WOW64_64: Self = Self(KEY_WOW64_64KEY);
pub fn with(self, other: Self) -> Self {
Self(REG_SAM_FLAGS(self.0 .0 | other.0 .0))
}
}
#[derive(Clone, Debug)]
pub enum Value {
String(String),
ExpandString(String),
MultiString(Vec<String>),
Dword(u32),
Qword(u64),
Binary(Vec<u8>),
}
impl Value {
pub fn string(s: impl Into<String>) -> Self {
Value::String(s.into())
}
pub fn dword(v: u32) -> Self {
Value::Dword(v)
}
pub fn qword(v: u64) -> Self {
Value::Qword(v)
}
pub fn binary(data: impl Into<Vec<u8>>) -> Self {
Value::Binary(data.into())
}
pub fn as_string(&self) -> Option<&str> {
match self {
Value::String(s) | Value::ExpandString(s) => Some(s),
_ => None,
}
}
pub fn as_dword(&self) -> Option<u32> {
match self {
Value::Dword(v) => Some(*v),
_ => None,
}
}
pub fn as_qword(&self) -> Option<u64> {
match self {
Value::Qword(v) => Some(*v),
_ => None,
}
}
pub fn as_binary(&self) -> Option<&[u8]> {
match self {
Value::Binary(v) => Some(v),
_ => None,
}
}
}
pub struct Key {
hkey: HKEY,
owned: bool,
}
impl Key {
pub fn open(root: RootKey, path: &str, access: Access) -> Result<Self> {
let path_wide = WideString::new(path);
let mut hkey = HKEY::default();
let err = unsafe { RegOpenKeyExW(root.0, path_wide.as_pcwstr(), 0, access.0, &mut hkey) };
check_error(err)?;
Ok(Self { hkey, owned: true })
}
pub fn create(root: RootKey, path: &str, access: Access) -> Result<Self> {
let path_wide = WideString::new(path);
let mut hkey = HKEY::default();
let err = unsafe {
RegCreateKeyExW(
root.0,
path_wide.as_pcwstr(),
0,
None,
REG_OPTION_NON_VOLATILE,
access.0,
None,
&mut hkey,
None,
)
};
check_error(err)?;
Ok(Self { hkey, owned: true })
}
pub fn open_subkey(&self, path: &str, access: Access) -> Result<Self> {
let path_wide = WideString::new(path);
let mut hkey = HKEY::default();
let err =
unsafe { RegOpenKeyExW(self.hkey, path_wide.as_pcwstr(), 0, access.0, &mut hkey) };
check_error(err)?;
Ok(Self { hkey, owned: true })
}
pub fn create_subkey(&self, path: &str, access: Access) -> Result<Self> {
let path_wide = WideString::new(path);
let mut hkey = HKEY::default();
let err = unsafe {
RegCreateKeyExW(
self.hkey,
path_wide.as_pcwstr(),
0,
None,
REG_OPTION_NON_VOLATILE,
access.0,
None,
&mut hkey,
None,
)
};
check_error(err)?;
Ok(Self { hkey, owned: true })
}
pub fn delete_subkey(&self, name: &str) -> Result<()> {
let name_wide = WideString::new(name);
let err = unsafe { RegDeleteKeyW(self.hkey, name_wide.as_pcwstr()) };
check_error(err)
}
pub fn get_value(&self, name: &str) -> Result<Value> {
let name_wide = WideString::new(name);
let mut value_type = REG_VALUE_TYPE::default();
let mut size = 0u32;
let err = unsafe {
RegQueryValueExW(
self.hkey,
name_wide.as_pcwstr(),
None,
Some(&mut value_type),
None,
Some(&mut size),
)
};
if err != ERROR_SUCCESS && err != ERROR_MORE_DATA {
return Err(Error::Windows(windows::core::Error::from(err)));
}
let mut buffer = vec![0u8; size as usize];
let err = unsafe {
RegQueryValueExW(
self.hkey,
name_wide.as_pcwstr(),
None,
Some(&mut value_type),
Some(buffer.as_mut_ptr()),
Some(&mut size),
)
};
check_error(err)?;
buffer.truncate(size as usize);
buffer.shrink_to_fit();
match value_type {
REG_SZ | REG_EXPAND_SZ => {
let wide: Vec<u16> = buffer
.chunks_exact(2)
.map(|c| u16::from_le_bytes([c[0], c[1]]))
.collect();
let s = from_wide(&wide)?;
if value_type == REG_SZ {
Ok(Value::String(s))
} else {
Ok(Value::ExpandString(s))
}
}
REG_MULTI_SZ => {
let wide: Vec<u16> = buffer
.chunks_exact(2)
.map(|c| u16::from_le_bytes([c[0], c[1]]))
.collect();
let mut strings = Vec::new();
let mut start = 0;
for (i, &c) in wide.iter().enumerate() {
if c == 0 {
if i > start {
strings.push(from_wide(&wide[start..i])?);
}
start = i + 1;
}
}
Ok(Value::MultiString(strings))
}
REG_DWORD => {
if buffer.len() >= 4 {
let value = u32::from_le_bytes([buffer[0], buffer[1], buffer[2], buffer[3]]);
Ok(Value::Dword(value))
} else {
Err(Error::custom("Invalid DWORD size"))
}
}
REG_QWORD => {
if buffer.len() >= 8 {
let value = u64::from_le_bytes([
buffer[0], buffer[1], buffer[2], buffer[3], buffer[4], buffer[5],
buffer[6], buffer[7],
]);
Ok(Value::Qword(value))
} else {
Err(Error::custom("Invalid QWORD size"))
}
}
REG_BINARY => Ok(Value::Binary(buffer)),
_ => Err(Error::custom(format!(
"Unsupported registry type: {:?}",
value_type
))),
}
}
pub fn set_value(&self, name: &str, value: &Value) -> Result<()> {
let name_wide = WideString::new(name);
let (value_type, data) = match value {
Value::String(s) => {
let wide = to_wide(s);
let bytes: Vec<u8> = wide.iter().flat_map(|&w| w.to_le_bytes()).collect();
(REG_SZ, bytes)
}
Value::ExpandString(s) => {
let wide = to_wide(s);
let bytes: Vec<u8> = wide.iter().flat_map(|&w| w.to_le_bytes()).collect();
(REG_EXPAND_SZ, bytes)
}
Value::MultiString(strings) => {
let mut wide = Vec::new();
for s in strings {
wide.extend(s.encode_utf16());
wide.push(0);
}
wide.push(0); let bytes: Vec<u8> = wide.iter().flat_map(|&w| w.to_le_bytes()).collect();
(REG_MULTI_SZ, bytes)
}
Value::Dword(v) => (REG_DWORD, v.to_le_bytes().to_vec()),
Value::Qword(v) => (REG_QWORD, v.to_le_bytes().to_vec()),
Value::Binary(data) => (REG_BINARY, data.clone()),
};
let err =
unsafe { RegSetValueExW(self.hkey, name_wide.as_pcwstr(), 0, value_type, Some(&data)) };
check_error(err)
}
pub fn delete_value(&self, name: &str) -> Result<()> {
let name_wide = WideString::new(name);
let err = unsafe { RegDeleteValueW(self.hkey, name_wide.as_pcwstr()) };
check_error(err)
}
pub fn subkeys(&self) -> Result<Vec<String>> {
let mut result = Vec::new();
let mut index = 0u32;
let mut name_buffer = vec![0u16; 256];
loop {
let mut name_len = name_buffer.len() as u32;
let err = unsafe {
RegEnumKeyExW(
self.hkey,
index,
windows::core::PWSTR(name_buffer.as_mut_ptr()),
&mut name_len,
None,
windows::core::PWSTR::null(),
None,
None,
)
};
if err == ERROR_SUCCESS {
let name = from_wide(&name_buffer[..name_len as usize])?;
result.push(name);
index += 1;
} else if err == ERROR_NO_MORE_ITEMS {
break;
} else {
return Err(Error::Windows(windows::core::Error::from(err)));
}
}
Ok(result)
}
pub fn values(&self) -> Result<Vec<String>> {
let mut result = Vec::new();
let mut index = 0u32;
let mut name_buffer = vec![0u16; 256];
loop {
let mut name_len = name_buffer.len() as u32;
let err = unsafe {
RegEnumValueW(
self.hkey,
index,
windows::core::PWSTR(name_buffer.as_mut_ptr()),
&mut name_len,
None,
None,
None,
None,
)
};
if err == ERROR_SUCCESS {
let name = from_wide(&name_buffer[..name_len as usize])?;
result.push(name);
index += 1;
} else if err == ERROR_NO_MORE_ITEMS {
break;
} else {
return Err(Error::Windows(windows::core::Error::from(err)));
}
}
Ok(result)
}
pub fn as_raw(&self) -> HKEY {
self.hkey
}
}
impl Drop for Key {
fn drop(&mut self) {
if self.owned {
unsafe {
let _ = RegCloseKey(self.hkey);
}
}
}
}
pub fn get_string(root: RootKey, path: &str, name: &str) -> Result<String> {
let key = Key::open(root, path, Access::READ)?;
match key.get_value(name)? {
Value::String(s) | Value::ExpandString(s) => Ok(s),
_ => Err(Error::custom("Value is not a string")),
}
}
pub fn get_dword(root: RootKey, path: &str, name: &str) -> Result<u32> {
let key = Key::open(root, path, Access::READ)?;
match key.get_value(name)? {
Value::Dword(v) => Ok(v),
_ => Err(Error::custom("Value is not a DWORD")),
}
}
pub fn set_string(root: RootKey, path: &str, name: &str, value: &str) -> Result<()> {
let key = Key::create(root, path, Access::WRITE)?;
key.set_value(name, &Value::String(value.to_string()))
}
pub fn set_dword(root: RootKey, path: &str, name: &str, value: u32) -> Result<()> {
let key = Key::create(root, path, Access::WRITE)?;
key.set_value(name, &Value::Dword(value))
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
static TEST_COUNTER: AtomicU32 = AtomicU32::new(0);
const TEST_KEY_BASE: &str = "Software\\ErgonomicWindowsTest";
fn get_unique_test_key() -> String {
let id = TEST_COUNTER.fetch_add(1, Ordering::SeqCst);
let thread_id = std::thread::current().id();
format!("{}_{:?}_{}", TEST_KEY_BASE, thread_id, id)
}
fn cleanup_test_key_path(path: &str) {
if let Some(subkey) = path.strip_prefix("Software\\") {
if let Ok(key) = Key::open(RootKey::CURRENT_USER, "Software", Access::WRITE) {
let _ = key.delete_subkey(subkey);
}
}
}
#[allow(dead_code)]
fn cleanup_test_key() {
if let Ok(key) = Key::open(RootKey::CURRENT_USER, "Software", Access::WRITE) {
let _ = key.delete_subkey("ErgonomicWindowsTest");
}
}
#[test]
fn test_empty_string_value() {
let test_key = get_unique_test_key();
cleanup_test_key_path(&test_key);
if let Ok(key) = Key::create(RootKey::CURRENT_USER, &test_key, Access::ALL) {
let result = key.set_value("empty_string", &Value::String(String::new()));
assert!(result.is_ok());
let value = key.get_value("empty_string");
assert!(value.is_ok());
match value.unwrap() {
Value::String(s) => assert!(s.is_empty(), "Expected empty string, got: {:?}", s),
other => panic!("Expected String, got: {:?}", other),
}
}
cleanup_test_key_path(&test_key);
}
#[test]
fn test_empty_expand_string_value() {
let test_key = get_unique_test_key();
cleanup_test_key_path(&test_key);
if let Ok(key) = Key::create(RootKey::CURRENT_USER, &test_key, Access::ALL) {
let result = key.set_value("empty_expand", &Value::ExpandString(String::new()));
assert!(result.is_ok());
let value = key.get_value("empty_expand");
assert!(value.is_ok());
match value.unwrap() {
Value::ExpandString(s) => assert!(s.is_empty()),
other => panic!("Expected ExpandString, got: {:?}", other),
}
}
cleanup_test_key_path(&test_key);
}
#[test]
fn test_empty_binary_value() {
let test_key = get_unique_test_key();
cleanup_test_key_path(&test_key);
let key = Key::create(RootKey::CURRENT_USER, &test_key, Access::ALL)
.expect("Failed to create test registry key");
let result = key.set_value("single_byte_binary", &Value::Binary(vec![0]));
assert!(
result.is_ok(),
"Failed to set single_byte_binary: {:?}",
result
);
let value = key.get_value("single_byte_binary");
assert!(
value.is_ok(),
"Failed to get single_byte_binary: {:?}",
value
);
match value.unwrap() {
Value::Binary(b) => assert_eq!(b, vec![0]),
other => panic!("Expected Binary, got: {:?}", other),
}
let _ = key.set_value("empty_binary", &Value::Binary(vec![]));
cleanup_test_key_path(&test_key);
}
#[test]
fn test_empty_multi_string_value() {
let test_key = get_unique_test_key();
cleanup_test_key_path(&test_key);
if let Ok(key) = Key::create(RootKey::CURRENT_USER, &test_key, Access::ALL) {
let result = key.set_value("empty_multi", &Value::MultiString(vec![]));
assert!(result.is_ok());
let value = key.get_value("empty_multi");
assert!(value.is_ok());
match value.unwrap() {
Value::MultiString(v) => assert!(v.is_empty()),
other => panic!("Expected MultiString, got: {:?}", other),
}
}
cleanup_test_key_path(&test_key);
}
#[test]
fn test_dword_value() {
let test_key = get_unique_test_key();
cleanup_test_key_path(&test_key);
if let Ok(key) = Key::create(RootKey::CURRENT_USER, &test_key, Access::ALL) {
let test_values = [0u32, 1, 42, u32::MAX];
for &val in &test_values {
let result = key.set_value("dword_test", &Value::Dword(val));
assert!(
result.is_ok(),
"Failed to set DWORD value {}: {:?}",
val,
result
);
let read = key.get_value("dword_test");
assert!(read.is_ok(), "Failed to read DWORD value: {:?}", read);
assert_eq!(read.unwrap().as_dword(), Some(val));
}
}
cleanup_test_key_path(&test_key);
}
#[test]
fn test_qword_value() {
let test_key = get_unique_test_key();
cleanup_test_key_path(&test_key);
if let Ok(key) = Key::create(RootKey::CURRENT_USER, &test_key, Access::ALL) {
let test_values = [0u64, 1, 42, u64::MAX];
for &val in &test_values {
let result = key.set_value("qword_test", &Value::Qword(val));
assert!(result.is_ok());
let read = key.get_value("qword_test");
assert!(read.is_ok());
assert_eq!(read.unwrap().as_qword(), Some(val));
}
}
cleanup_test_key_path(&test_key);
}
#[test]
fn test_string_with_special_characters() {
let test_key = get_unique_test_key();
cleanup_test_key_path(&test_key);
if let Ok(key) = Key::create(RootKey::CURRENT_USER, &test_key, Access::ALL) {
let test_strings = [
"Simple ASCII",
"With\ttab",
"With\nnewline",
"Unicode: 日本語",
"Emoji: 🎉",
"Path: C:\\Windows\\System32",
];
for (i, &s) in test_strings.iter().enumerate() {
let name = format!("string_test_{}", i);
let result = key.set_value(&name, &Value::String(s.to_string()));
assert!(result.is_ok(), "Failed to set: {}", s);
let read = key.get_value(&name);
assert!(read.is_ok());
assert_eq!(read.unwrap().as_string(), Some(s));
}
}
cleanup_test_key_path(&test_key);
}
#[test]
fn test_multi_string_value() {
let test_key = get_unique_test_key();
cleanup_test_key_path(&test_key);
if let Ok(key) = Key::create(RootKey::CURRENT_USER, &test_key, Access::ALL) {
let strings = vec![
"First".to_string(),
"Second".to_string(),
"Third with space".to_string(),
];
let result = key.set_value("multi_test", &Value::MultiString(strings.clone()));
assert!(result.is_ok());
let read = key.get_value("multi_test");
assert!(read.is_ok());
match read.unwrap() {
Value::MultiString(v) => assert_eq!(v, strings),
other => panic!("Expected MultiString, got: {:?}", other),
}
}
cleanup_test_key_path(&test_key);
}
#[test]
fn test_binary_value() {
let test_key = get_unique_test_key();
cleanup_test_key_path(&test_key);
if let Ok(key) = Key::create(RootKey::CURRENT_USER, &test_key, Access::ALL) {
let data: Vec<u8> = (0..=255).collect();
let result = key.set_value("binary_test", &Value::Binary(data.clone()));
assert!(result.is_ok());
let read = key.get_value("binary_test");
assert!(read.is_ok());
match read.unwrap() {
Value::Binary(v) => assert_eq!(v, data),
other => panic!("Expected Binary, got: {:?}", other),
}
}
cleanup_test_key_path(&test_key);
}
#[test]
fn test_create_and_delete_subkey() {
let test_key = get_unique_test_key();
cleanup_test_key_path(&test_key);
if let Ok(key) = Key::create(RootKey::CURRENT_USER, &test_key, Access::ALL) {
let result = key.create_subkey("SubKey", Access::ALL);
assert!(result.is_ok());
let result = key.delete_subkey("SubKey");
assert!(result.is_ok());
}
cleanup_test_key_path(&test_key);
}
#[test]
fn test_enumerate_subkeys() {
let test_key = get_unique_test_key();
cleanup_test_key_path(&test_key);
if let Ok(key) = Key::create(RootKey::CURRENT_USER, &test_key, Access::ALL) {
let _ = key.create_subkey("SubA", Access::ALL);
let _ = key.create_subkey("SubB", Access::ALL);
let _ = key.create_subkey("SubC", Access::ALL);
let subkeys = key.subkeys();
assert!(subkeys.is_ok());
let subkeys = subkeys.unwrap();
assert!(subkeys.contains(&"SubA".to_string()));
assert!(subkeys.contains(&"SubB".to_string()));
assert!(subkeys.contains(&"SubC".to_string()));
}
cleanup_test_key_path(&test_key);
}
#[test]
fn test_enumerate_values() {
let test_key = get_unique_test_key();
cleanup_test_key_path(&test_key);
if let Ok(key) = Key::create(RootKey::CURRENT_USER, &test_key, Access::ALL) {
let _ = key.set_value("ValA", &Value::Dword(1));
let _ = key.set_value("ValB", &Value::String("test".to_string()));
let _ = key.set_value("ValC", &Value::Binary(vec![1, 2, 3]));
let values = key.values();
assert!(values.is_ok());
let values = values.unwrap();
assert!(values.contains(&"ValA".to_string()));
assert!(values.contains(&"ValB".to_string()));
assert!(values.contains(&"ValC".to_string()));
}
cleanup_test_key_path(&test_key);
}
#[test]
fn test_delete_value() {
let test_key = get_unique_test_key();
cleanup_test_key_path(&test_key);
if let Ok(key) = Key::create(RootKey::CURRENT_USER, &test_key, Access::ALL) {
let _ = key.set_value("ToDelete", &Value::Dword(42));
assert!(key.get_value("ToDelete").is_ok());
let result = key.delete_value("ToDelete");
assert!(result.is_ok());
assert!(key.get_value("ToDelete").is_err());
}
cleanup_test_key_path(&test_key);
}
#[test]
fn test_nonexistent_value() {
let test_key = get_unique_test_key();
cleanup_test_key_path(&test_key);
if let Ok(key) = Key::create(RootKey::CURRENT_USER, &test_key, Access::ALL) {
let result = key.get_value("NonExistent");
assert!(result.is_err());
}
cleanup_test_key_path(&test_key);
}
#[test]
fn test_convenience_functions() {
let test_key = get_unique_test_key();
cleanup_test_key_path(&test_key);
let result = set_string(RootKey::CURRENT_USER, &test_key, "conv_string", "hello");
assert!(result.is_ok());
let value = get_string(RootKey::CURRENT_USER, &test_key, "conv_string");
assert!(value.is_ok());
assert_eq!(value.unwrap(), "hello");
let result = set_dword(RootKey::CURRENT_USER, &test_key, "conv_dword", 12345);
assert!(result.is_ok());
let value = get_dword(RootKey::CURRENT_USER, &test_key, "conv_dword");
assert!(value.is_ok());
assert_eq!(value.unwrap(), 12345);
cleanup_test_key_path(&test_key);
}
#[test]
fn test_access_flags_combination() {
let combined = Access::READ.with(Access::WRITE);
assert!((combined.0 .0 & KEY_READ.0) != 0);
assert!((combined.0 .0 & KEY_WRITE.0) != 0);
let with_32bit = Access::READ.with(Access::WOW64_32);
assert!((with_32bit.0 .0 & KEY_WOW64_32KEY.0) != 0);
}
#[test]
fn test_value_constructors() {
let s = Value::string("test");
assert_eq!(s.as_string(), Some("test"));
let d = Value::dword(42);
assert_eq!(d.as_dword(), Some(42));
let q = Value::qword(1234567890);
assert_eq!(q.as_qword(), Some(1234567890));
let b = Value::binary(vec![1, 2, 3]);
assert_eq!(b.as_binary(), Some(&[1u8, 2, 3][..]));
}
}