use std::path::Path;
use super::error::{StorageError, StorageResult};
#[cfg(target_arch = "wasm32")]
mod imp {
use super::*;
#[derive(Debug, Clone)]
pub struct StorageLock;
#[derive(Debug)]
pub struct StorageLockGuard;
impl StorageLock {
pub fn open(_path: &Path) -> StorageResult<Self> {
Ok(Self)
}
pub fn lock(&self) -> StorageResult<StorageLockGuard> {
Ok(StorageLockGuard)
}
pub fn try_lock(&self) -> StorageResult<Option<StorageLockGuard>> {
Ok(Some(StorageLockGuard))
}
}
}
#[cfg(not(target_arch = "wasm32"))]
mod imp {
use super::{Path, StorageError, StorageResult};
use std::fs::{self, File, OpenOptions};
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct StorageLock {
file: Arc<File>,
}
#[derive(Debug)]
pub struct StorageLockGuard {
file: Arc<File>,
}
impl StorageLock {
pub fn open(path: &Path) -> StorageResult<Self> {
if let Some(parent) = path.parent() {
fs::create_dir_all(parent).map_err(|err| map_io_err(&err))?;
}
let file = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(false)
.open(path)
.map_err(|err| map_io_err(&err))?;
Ok(Self {
file: Arc::new(file),
})
}
pub fn lock(&self) -> StorageResult<StorageLockGuard> {
lock_exclusive(&self.file).map_err(|err| map_io_err(&err))?;
Ok(StorageLockGuard {
file: Arc::clone(&self.file),
})
}
pub fn try_lock(&self) -> StorageResult<Option<StorageLockGuard>> {
if try_lock_exclusive(&self.file).map_err(|err| map_io_err(&err))? {
Ok(Some(StorageLockGuard {
file: Arc::clone(&self.file),
}))
} else {
Ok(None)
}
}
}
impl Drop for StorageLockGuard {
fn drop(&mut self) {
let _ = unlock(&self.file);
}
}
fn map_io_err(err: &std::io::Error) -> StorageError {
StorageError::Lock(err.to_string())
}
#[cfg(unix)]
fn lock_exclusive(file: &File) -> std::io::Result<()> {
let fd = std::os::unix::io::AsRawFd::as_raw_fd(file);
let result = unsafe { flock(fd, LOCK_EX) };
if result == 0 {
Ok(())
} else {
Err(std::io::Error::last_os_error())
}
}
#[cfg(unix)]
fn try_lock_exclusive(file: &File) -> std::io::Result<bool> {
let fd = std::os::unix::io::AsRawFd::as_raw_fd(file);
let result = unsafe { flock(fd, LOCK_EX | LOCK_NB) };
if result == 0 {
Ok(true)
} else {
let err = std::io::Error::last_os_error();
if err.kind() == std::io::ErrorKind::WouldBlock {
Ok(false)
} else {
Err(err)
}
}
}
#[cfg(unix)]
fn unlock(file: &File) -> std::io::Result<()> {
let fd = std::os::unix::io::AsRawFd::as_raw_fd(file);
let result = unsafe { flock(fd, LOCK_UN) };
if result == 0 {
Ok(())
} else {
Err(std::io::Error::last_os_error())
}
}
#[cfg(unix)]
use std::os::raw::c_int;
#[cfg(unix)]
const LOCK_EX: c_int = 2;
#[cfg(unix)]
const LOCK_NB: c_int = 4;
#[cfg(unix)]
const LOCK_UN: c_int = 8;
#[cfg(unix)]
extern "C" {
fn flock(fd: c_int, operation: c_int) -> c_int;
}
#[cfg(windows)]
fn lock_exclusive(file: &File) -> std::io::Result<()> {
lock_file(file, 0)
}
#[cfg(windows)]
fn try_lock_exclusive(file: &File) -> std::io::Result<bool> {
match lock_file(file, LOCKFILE_FAIL_IMMEDIATELY) {
Ok(()) => Ok(true),
Err(err) => {
if err.raw_os_error() == Some(ERROR_LOCK_VIOLATION) {
Ok(false)
} else {
Err(err)
}
}
}
}
#[cfg(windows)]
fn unlock(file: &File) -> std::io::Result<()> {
let handle = std::os::windows::io::AsRawHandle::as_raw_handle(file) as HANDLE;
let mut overlapped: OVERLAPPED = unsafe { std::mem::zeroed() };
let result = unsafe { UnlockFileEx(handle, 0, 1, 0, &mut overlapped) };
if result != 0 {
Ok(())
} else {
Err(std::io::Error::last_os_error())
}
}
#[cfg(windows)]
fn lock_file(file: &File, flags: u32) -> std::io::Result<()> {
let handle = std::os::windows::io::AsRawHandle::as_raw_handle(file) as HANDLE;
let mut overlapped: OVERLAPPED = unsafe { std::mem::zeroed() };
let result = unsafe {
LockFileEx(
handle,
LOCKFILE_EXCLUSIVE_LOCK | flags,
0,
1,
0,
&mut overlapped,
)
};
if result != 0 {
Ok(())
} else {
Err(std::io::Error::last_os_error())
}
}
#[cfg(windows)]
type HANDLE = *mut std::ffi::c_void;
#[cfg(windows)]
#[repr(C)]
struct OVERLAPPED {
internal: usize,
internal_high: usize,
offset: u32,
offset_high: u32,
h_event: HANDLE,
}
#[cfg(windows)]
const LOCKFILE_EXCLUSIVE_LOCK: u32 = 0x2;
#[cfg(windows)]
const LOCKFILE_FAIL_IMMEDIATELY: u32 = 0x1;
#[cfg(windows)]
const ERROR_LOCK_VIOLATION: i32 = 33;
#[cfg(windows)]
extern "system" {
fn LockFileEx(
h_file: HANDLE,
flags: u32,
reserved: u32,
bytes_to_lock_low: u32,
bytes_to_lock_high: u32,
overlapped: *mut OVERLAPPED,
) -> i32;
fn UnlockFileEx(
h_file: HANDLE,
reserved: u32,
bytes_to_unlock_low: u32,
bytes_to_unlock_high: u32,
overlapped: *mut OVERLAPPED,
) -> i32;
}
}
pub use imp::{StorageLock, StorageLockGuard};
#[cfg(test)]
mod tests {
use super::*;
use uuid::Uuid;
fn temp_lock_path() -> std::path::PathBuf {
let mut path = std::env::temp_dir();
path.push(format!("walletkit-lock-{}.lock", Uuid::new_v4()));
path
}
#[test]
fn test_lock_is_exclusive() {
let path = temp_lock_path();
let lock_a = StorageLock::open(&path).expect("open lock");
let guard = lock_a.lock().expect("acquire lock");
let lock_b = StorageLock::open(&path).expect("open lock");
let blocked = lock_b.try_lock().expect("try lock");
assert!(blocked.is_none());
drop(guard);
let guard = lock_b.try_lock().expect("try lock");
assert!(guard.is_some());
let _ = std::fs::remove_file(path);
}
#[test]
fn test_lock_serializes_across_threads() {
let path = temp_lock_path();
let lock = StorageLock::open(&path).expect("open lock");
let (locked_tx, locked_rx) = std::sync::mpsc::channel();
let (release_tx, release_rx) = std::sync::mpsc::channel();
let (released_tx, released_rx) = std::sync::mpsc::channel();
let path_clone = path.clone();
let thread_a = std::thread::spawn(move || {
let guard = lock.lock().expect("lock in thread");
locked_tx.send(()).expect("signal locked");
release_rx.recv().expect("wait release");
drop(guard);
released_tx.send(()).expect("signal released");
let _ = std::fs::remove_file(path_clone);
});
locked_rx.recv().expect("wait locked");
let lock_b = StorageLock::open(&path).expect("open lock");
let blocked = lock_b.try_lock().expect("try lock");
assert!(blocked.is_none());
release_tx.send(()).expect("release");
released_rx.recv().expect("wait released");
let guard = lock_b.try_lock().expect("try lock");
assert!(guard.is_some());
thread_a.join().expect("thread join");
}
}