use crate::{Error, RefCounter, Result};
use std::ffi::OsStr;
use std::os::windows::ffi::OsStrExt;
use std::sync::atomic::{AtomicBool, Ordering};
use windows_sys::Win32::Foundation::{CloseHandle, GetLastError, HANDLE, WAIT_OBJECT_0};
use windows_sys::Win32::System::Threading::{
CreateMutexW, CreateSemaphoreW, OpenMutexW, OpenSemaphoreW, ReleaseMutex, ReleaseSemaphore,
WaitForSingleObject, MUTEX_ALL_ACCESS, SEMAPHORE_ALL_ACCESS,
};
pub struct WindowsRefCounter {
sem_handle: HANDLE,
mutex_handle: HANDLE,
has_reference: AtomicBool,
has_lock: AtomicBool,
name: String,
}
unsafe impl Send for WindowsRefCounter {}
unsafe impl Sync for WindowsRefCounter {}
impl WindowsRefCounter {
pub fn new(name: &str) -> Result<Self> {
let sem_name = format!("Global\\procref-sem-{}", name);
let mutex_name = format!("Global\\procref-mutex-{}", name);
let sem_name_wide: Vec<u16> = OsStr::new(&sem_name)
.encode_wide()
.chain(std::iter::once(0))
.collect();
let mutex_name_wide: Vec<u16> = OsStr::new(&mutex_name)
.encode_wide()
.chain(std::iter::once(0))
.collect();
let sem_handle = unsafe {
CreateSemaphoreW(
std::ptr::null(),
0, i32::MAX, sem_name_wide.as_ptr(),
)
};
if sem_handle == 0 {
let sem_handle = unsafe {
OpenSemaphoreW(SEMAPHORE_ALL_ACCESS, 0, sem_name_wide.as_ptr())
};
if sem_handle == 0 {
return Err(Error::RefCounterInit(format!(
"Failed to create/open semaphore: {}",
unsafe { GetLastError() }
)));
}
}
let mutex_handle = unsafe {
CreateMutexW(std::ptr::null(), 0, mutex_name_wide.as_ptr())
};
if mutex_handle == 0 {
let mutex_handle = unsafe {
OpenMutexW(MUTEX_ALL_ACCESS, 0, mutex_name_wide.as_ptr())
};
if mutex_handle == 0 {
unsafe { CloseHandle(sem_handle) };
return Err(Error::RefCounterInit(format!(
"Failed to create/open mutex: {}",
unsafe { GetLastError() }
)));
}
}
Ok(Self {
sem_handle,
mutex_handle,
has_reference: AtomicBool::new(false),
has_lock: AtomicBool::new(false),
name: name.to_string(),
})
}
}
impl RefCounter for WindowsRefCounter {
fn acquire(&self) -> Result<u32> {
if self.has_reference.load(Ordering::SeqCst) {
return self.count();
}
let mut prev_count: i32 = 0;
let result = unsafe { ReleaseSemaphore(self.sem_handle, 1, &mut prev_count) };
if result == 0 {
return Err(Error::Acquire(format!(
"ReleaseSemaphore failed: {}",
unsafe { GetLastError() }
)));
}
self.has_reference.store(true, Ordering::SeqCst);
Ok((prev_count + 1) as u32)
}
fn release(&self) -> Result<u32> {
if !self.has_reference.load(Ordering::SeqCst) {
return self.count();
}
let result = unsafe { WaitForSingleObject(self.sem_handle, 0) };
if result != WAIT_OBJECT_0 {
return Err(Error::Release(format!(
"WaitForSingleObject failed: {}",
unsafe { GetLastError() }
)));
}
self.has_reference.store(false, Ordering::SeqCst);
self.count()
}
fn count(&self) -> Result<u32> {
let mut prev_count: i32 = 0;
let result = unsafe { ReleaseSemaphore(self.sem_handle, 0, &mut prev_count) };
if result == 0 {
return Ok(if self.has_reference.load(Ordering::SeqCst) {
1
} else {
0
});
}
Ok(prev_count as u32)
}
fn try_lock(&self) -> Result<bool> {
if self.has_lock.load(Ordering::SeqCst) {
return Ok(true);
}
let result = unsafe { WaitForSingleObject(self.mutex_handle, 0) };
if result == WAIT_OBJECT_0 {
self.has_lock.store(true, Ordering::SeqCst);
Ok(true)
} else {
Ok(false)
}
}
fn unlock(&self) -> Result<()> {
if !self.has_lock.load(Ordering::SeqCst) {
return Ok(());
}
let result = unsafe { ReleaseMutex(self.mutex_handle) };
if result == 0 {
return Err(Error::Lock(format!(
"ReleaseMutex failed: {}",
unsafe { GetLastError() }
)));
}
self.has_lock.store(false, Ordering::SeqCst);
Ok(())
}
}
impl Drop for WindowsRefCounter {
fn drop(&mut self) {
if self.has_lock.load(Ordering::SeqCst) {
let _ = self.unlock();
}
unsafe {
CloseHandle(self.sem_handle);
CloseHandle(self.mutex_handle);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_windows_ref_counter_basic() {
let counter = WindowsRefCounter::new("test-procref-windows").unwrap();
let count = counter.acquire().unwrap();
assert!(count >= 1);
let _ = counter.release().unwrap();
}
}