use std::path::Path;
use std::ptr;
use windows_sys::Win32::Foundation::{CloseHandle, HANDLE, LPARAM, WAIT_OBJECT_0};
use windows_sys::Win32::System::Threading::{CreateMutexW, ReleaseMutex, WaitForSingleObject};
use windows_sys::Win32::UI::WindowsAndMessaging::{
SendMessageTimeoutA, HWND_BROADCAST, SMTO_ABORTIFHUNG, WM_SETTINGCHANGE,
};
use winreg::enums::{RegType, HKEY_CURRENT_USER, KEY_READ, KEY_WRITE};
use winreg::{RegKey, RegValue};
use crate::config::Position;
use crate::error::{RegistryError, Result};
use crate::report::Action;
const LOCK_TIMEOUT_MS: u32 = 10_000;
struct RegistryLock {
handle: HANDLE,
}
impl RegistryLock {
#[allow(unsafe_code)]
fn acquire() -> Result<Self> {
let name: Vec<u16> = "Local\\onpath-registry-lock\0".encode_utf16().collect();
let handle = unsafe { CreateMutexW(ptr::null(), 0, name.as_ptr()) };
if handle.is_null() {
return Err(RegistryError::LockFailed(std::io::Error::last_os_error()).into());
}
let wait_result = unsafe { WaitForSingleObject(handle, LOCK_TIMEOUT_MS) };
if wait_result != WAIT_OBJECT_0 {
unsafe { CloseHandle(handle) };
if wait_result == windows_sys::Win32::Foundation::WAIT_TIMEOUT {
return Err(RegistryError::LockTimeout.into());
}
return Err(RegistryError::LockFailed(std::io::Error::last_os_error()).into());
}
Ok(Self { handle })
}
}
impl Drop for RegistryLock {
#[allow(unsafe_code)]
fn drop(&mut self) {
unsafe {
ReleaseMutex(self.handle);
CloseHandle(self.handle);
}
}
}
pub fn add_to_path(dir: &Path, position: Position) -> Result<Action> {
let _lock = RegistryLock::acquire()?;
let (current, vtype) = read_user_path()?;
let dir_str = dir.to_string_lossy();
let dir_normalized = crate::normalize::normalize_windows_path_str(&dir_str);
if current.split(';').any(|entry| {
crate::normalize::normalize_windows_path_str(entry).eq_ignore_ascii_case(&dir_normalized)
}) {
return Ok(Action::RegistryAlreadyContains);
}
let new_path = match position {
Position::Prepend => {
if current.is_empty() {
dir_str.to_string()
} else {
format!("{dir_str};{current}")
}
}
Position::Append => {
if current.is_empty() {
dir_str.to_string()
} else {
format!("{current};{dir_str}")
}
}
};
write_user_path(&new_path, vtype)?;
broadcast_settings_change();
Ok(Action::RegistryModified {
old_value: current,
new_value: new_path,
})
}
pub fn remove_from_path(dir: &Path) -> Result<Action> {
let _lock = RegistryLock::acquire()?;
let (current, vtype) = read_user_path()?;
let dir_str = dir.to_string_lossy();
let dir_normalized = crate::normalize::normalize_windows_path_str(&dir_str);
let entries: Vec<&str> = current.split(';').collect();
let filtered: Vec<&str> = entries
.iter()
.filter(|entry| {
!crate::normalize::normalize_windows_path_str(entry)
.eq_ignore_ascii_case(&dir_normalized)
})
.copied()
.collect();
if entries.len() == filtered.len() {
return Ok(Action::RegistryAlreadyContains);
}
let new_path = filtered.join(";");
write_user_path(&new_path, vtype)?;
broadcast_settings_change();
Ok(Action::RegistryEntryRemoved {
old_value: current,
new_value: new_path,
})
}
fn read_user_path() -> Result<(String, RegType)> {
let hkcu = RegKey::predef(HKEY_CURRENT_USER);
let env = hkcu
.open_subkey_with_flags("Environment", KEY_READ)
.map_err(RegistryError::OpenKey)?;
match env.get_raw_value("PATH") {
Ok(raw) => {
let vtype = raw.vtype;
let utf16: Vec<u16> = raw
.bytes
.chunks_exact(2)
.map(|c| u16::from_le_bytes([c[0], c[1]]))
.collect();
let value = String::from_utf16_lossy(&utf16)
.trim_end_matches('\0')
.to_owned();
Ok((value, vtype))
}
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
Ok((String::new(), RegType::REG_EXPAND_SZ))
}
Err(e) => Err(RegistryError::ReadPath(e).into()),
}
}
fn write_user_path(value: &str, vtype: RegType) -> Result<()> {
let hkcu = RegKey::predef(HKEY_CURRENT_USER);
let (env, _) = hkcu
.create_subkey_with_flags("Environment", KEY_WRITE)
.map_err(RegistryError::OpenKey)?;
let reg_value = RegValue {
vtype,
bytes: value
.encode_utf16()
.chain(std::iter::once(0))
.flat_map(u16::to_le_bytes)
.collect(),
};
env.set_raw_value("PATH", ®_value)
.map_err(RegistryError::WritePath)?;
Ok(())
}
#[allow(unsafe_code)]
fn broadcast_settings_change() {
unsafe {
SendMessageTimeoutA(
HWND_BROADCAST,
WM_SETTINGCHANGE,
0,
b"Environment\0".as_ptr() as LPARAM,
SMTO_ABORTIFHUNG,
5000,
ptr::null_mut(),
);
}
}
#[cfg(test)]
mod tests {
#[test]
fn path_split_and_case_insensitive_check() {
let current = r"C:\Windows\system32;C:\Users\test\.myapp\bin";
let dir = r"c:\users\test\.myapp\bin";
let found = current
.split(';')
.any(|entry| entry.eq_ignore_ascii_case(dir));
assert!(found);
}
#[test]
fn path_prepend_format() {
let current = r"C:\Windows\system32";
let dir = r"C:\Users\test\bin";
let new_path = format!("{dir};{current}");
assert_eq!(new_path, r"C:\Users\test\bin;C:\Windows\system32");
}
#[test]
fn path_filter_removes_entry() {
let current = r"C:\a;C:\b;C:\c";
let filtered: Vec<&str> = current
.split(';')
.filter(|e| !e.eq_ignore_ascii_case(r"C:\b"))
.collect();
assert_eq!(filtered.join(";"), r"C:\a;C:\c");
}
#[test]
fn path_with_trailing_semicolons() {
let current = r"C:\a;C:\b;";
let entries: Vec<&str> = current.split(';').collect();
assert_eq!(entries, vec![r"C:\a", r"C:\b", ""]);
let filtered: Vec<&str> = entries
.iter()
.filter(|e| !e.eq_ignore_ascii_case(r"C:\a"))
.copied()
.collect();
assert_eq!(filtered.join(";"), r"C:\b;");
}
#[test]
fn path_with_empty_segments() {
let current = r"C:\a;;C:\b";
let entries: Vec<&str> = current.split(';').collect();
assert_eq!(entries, vec![r"C:\a", "", r"C:\b"]);
let filtered: Vec<&str> = entries
.iter()
.filter(|e| !e.eq_ignore_ascii_case(r"C:\a"))
.copied()
.collect();
assert_eq!(filtered.join(";"), r";C:\b");
}
#[test]
fn path_entry_with_spaces() {
let current = r"C:\Program Files\bin;C:\a";
let dir = r"C:\Program Files\bin";
let found = current
.split(';')
.any(|entry| entry.eq_ignore_ascii_case(dir));
assert!(found);
}
#[test]
fn path_with_trailing_backslash_now_matches() {
let current = r"C:\foo\";
let dir = r"C:\foo";
let norm_current = crate::normalize::normalize_windows_path_str(current);
let norm_dir = crate::normalize::normalize_windows_path_str(dir);
assert_eq!(
norm_current, norm_dir,
"trailing backslash should normalize away"
);
}
#[test]
fn path_with_forward_slashes_now_matches() {
let current = r"C:\foo";
let dir = "C:/foo";
let norm_current = crate::normalize::normalize_windows_path_str(current);
let norm_dir = crate::normalize::normalize_windows_path_str(dir);
assert_eq!(
norm_current, norm_dir,
"forward vs backslash should normalize to same"
);
}
#[test]
fn path_dedup_exact_match() {
let current = r"C:\a;C:\b;C:\c";
let dir = r"C:\b";
let found = current
.split(';')
.any(|entry| entry.eq_ignore_ascii_case(dir));
assert!(found);
}
#[test]
fn path_dedup_case_only_differs() {
let current = r"C:\Users\Test\.myapp\bin";
let dir = r"c:\users\test\.myapp\bin";
let found = current
.split(';')
.any(|entry| entry.eq_ignore_ascii_case(dir));
assert!(found);
}
#[test]
fn path_no_substring_match() {
let current = r"C:\foobar;C:\baz";
let dir = r"C:\foo";
let found = current
.split(';')
.any(|entry| entry.eq_ignore_ascii_case(dir));
assert!(!found);
}
#[test]
fn path_single_entry_add_and_remove() {
let current = r"C:\existing";
let dir = r"C:\new";
let new_path = format!("{dir};{current}");
assert_eq!(new_path, r"C:\new;C:\existing");
let filtered: Vec<&str> = new_path
.split(';')
.filter(|e| !e.eq_ignore_ascii_case(dir))
.collect();
assert_eq!(filtered.join(";"), r"C:\existing");
}
#[test]
fn path_empty_string_entries_not_matched() {
let current = r"C:\a;;C:\b";
let entries: Vec<&str> = current.split(';').collect();
assert_eq!(entries, vec![r"C:\a", "", r"C:\b"]);
let dir2 = r"C:\nonexistent";
let found = current
.split(';')
.any(|entry| entry.eq_ignore_ascii_case(dir2));
assert!(!found);
}
#[test]
fn remove_nonexistent_is_noop() {
let current = r"C:\a;C:\b;C:\c";
let dir = r"C:\nonexistent";
let entries: Vec<&str> = current.split(';').collect();
let filtered: Vec<&str> = entries
.iter()
.filter(|e| !e.eq_ignore_ascii_case(dir))
.copied()
.collect();
assert_eq!(entries.len(), filtered.len());
assert_eq!(filtered.join(";"), current);
}
}