use std::fs::{File, OpenOptions};
use std::io::{Read, Write};
use std::path::{Path, PathBuf};
use std::time::{Duration, Instant};
use sochdb_core::SochDBError;
#[derive(Debug)]
pub enum LockError {
DatabaseLocked {
holder_pid: Option<u32>,
lock_path: PathBuf,
},
Timeout {
elapsed: Duration,
timeout: Duration,
},
StaleLock {
stale_pid: u32,
},
Io(std::io::Error),
}
impl std::fmt::Display for LockError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
LockError::DatabaseLocked { holder_pid, lock_path } => {
if let Some(pid) = holder_pid {
write!(f, "Database is locked by process {} (lock file: {})",
pid, lock_path.display())
} else {
write!(f, "Database is locked (lock file: {})", lock_path.display())
}
}
LockError::Timeout { elapsed, timeout } => {
write!(f, "Lock acquisition timed out after {:?} (timeout: {:?})",
elapsed, timeout)
}
LockError::StaleLock { stale_pid } => {
write!(f, "Stale lock detected from crashed process {}", stale_pid)
}
LockError::Io(e) => write!(f, "Lock I/O error: {}", e),
}
}
}
impl std::error::Error for LockError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
LockError::Io(e) => Some(e),
_ => None,
}
}
}
impl From<std::io::Error> for LockError {
fn from(e: std::io::Error) -> Self {
LockError::Io(e)
}
}
impl From<LockError> for SochDBError {
fn from(e: LockError) -> Self {
match e {
LockError::DatabaseLocked { holder_pid, lock_path } => {
SochDBError::LockError(format!(
"Database locked by PID {:?} (lock: {})",
holder_pid, lock_path.display()
))
}
LockError::Timeout { elapsed, timeout } => {
SochDBError::LockError(format!(
"Lock timeout after {:?} (max: {:?})", elapsed, timeout
))
}
LockError::StaleLock { stale_pid } => {
SochDBError::LockError(format!(
"Stale lock from crashed process {}", stale_pid
))
}
LockError::Io(e) => SochDBError::Io(e),
}
}
}
#[derive(Debug, Clone)]
pub struct LockConfig {
pub timeout: Option<Duration>,
pub retry_interval: Duration,
pub detect_stale_locks: bool,
pub lock_file_name: String,
}
impl Default for LockConfig {
fn default() -> Self {
Self {
timeout: Some(Duration::from_secs(5)),
retry_interval: Duration::from_millis(100),
detect_stale_locks: true,
lock_file_name: ".lock".to_string(),
}
}
}
impl LockConfig {
pub fn no_wait() -> Self {
Self {
timeout: None,
..Default::default()
}
}
pub fn with_timeout(timeout: Duration) -> Self {
Self {
timeout: Some(timeout),
..Default::default()
}
}
}
pub struct DatabaseLock {
lock_file: File,
path: PathBuf,
our_pid: u32,
}
impl DatabaseLock {
pub fn acquire<P: AsRef<Path>>(db_path: P) -> std::result::Result<Self, LockError> {
Self::acquire_with_config(db_path, &LockConfig::default())
}
pub fn acquire_no_wait<P: AsRef<Path>>(db_path: P) -> std::result::Result<Self, LockError> {
Self::acquire_with_config(db_path, &LockConfig::no_wait())
}
pub fn acquire_with_timeout<P: AsRef<Path>>(
db_path: P,
timeout: Duration
) -> std::result::Result<Self, LockError> {
Self::acquire_with_config(db_path, &LockConfig::with_timeout(timeout))
}
pub fn acquire_with_config<P: AsRef<Path>>(
db_path: P,
config: &LockConfig,
) -> std::result::Result<Self, LockError> {
let db_path = db_path.as_ref();
let lock_path = db_path.join(&config.lock_file_name);
if !db_path.exists() {
std::fs::create_dir_all(db_path)?;
}
let deadline = config.timeout.map(|t| Instant::now() + t);
let our_pid = std::process::id();
loop {
let file = OpenOptions::new()
.create(true)
.read(true)
.write(true)
.open(&lock_path)?;
match Self::try_flock(&file, false) {
Ok(()) => {
Self::write_pid(&file, our_pid)?;
return Ok(Self {
lock_file: file,
path: lock_path,
our_pid,
});
}
Err(LockError::DatabaseLocked { .. }) => {
let mut should_retry = false;
if config.detect_stale_locks {
if let Some(holder_pid) = Self::read_pid(&file) {
if !Self::process_exists(holder_pid) {
drop(file);
if std::fs::remove_file(&lock_path).is_ok() {
should_retry = true;
}
}
}
}
if should_retry {
continue; }
if let Some(deadline) = deadline {
if Instant::now() >= deadline {
return Err(LockError::Timeout {
elapsed: config.timeout.unwrap_or_default(),
timeout: config.timeout.unwrap_or_default(),
});
}
std::thread::sleep(config.retry_interval);
continue;
} else {
return Err(LockError::DatabaseLocked {
holder_pid: None,
lock_path
});
}
}
Err(e) => return Err(e),
}
}
}
pub fn path(&self) -> &Path {
&self.path
}
pub fn pid(&self) -> u32 {
self.our_pid
}
pub fn get_lock_holder<P: AsRef<Path>>(db_path: P) -> Option<u32> {
let lock_path = db_path.as_ref().join(".lock");
let file = File::open(&lock_path).ok()?;
Self::read_pid(&file)
}
fn write_pid(file: &File, pid: u32) -> std::result::Result<(), LockError> {
use std::io::Seek;
let mut file = file;
file.seek(std::io::SeekFrom::Start(0))?;
file.set_len(0)?;
writeln!(file, "{}", pid)?;
file.sync_all()?;
Ok(())
}
fn read_pid(file: &File) -> Option<u32> {
use std::io::Seek;
let mut file = file;
let _ = file.seek(std::io::SeekFrom::Start(0));
let mut contents = String::new();
file.read_to_string(&mut contents).ok()?;
contents.trim().parse().ok()
}
#[cfg(unix)]
fn process_exists(pid: u32) -> bool {
let result = unsafe { libc::kill(pid as libc::pid_t, 0) };
if result == 0 {
true
} else {
let errno = std::io::Error::last_os_error().raw_os_error();
errno != Some(libc::ESRCH)
}
}
#[cfg(windows)]
fn process_exists(pid: u32) -> bool {
unsafe {
let handle = windows_sys::Win32::System::Threading::OpenProcess(
windows_sys::Win32::System::Threading::PROCESS_QUERY_LIMITED_INFORMATION,
0,
pid,
);
if handle == 0 || handle == -1 {
false
} else {
windows_sys::Win32::Foundation::CloseHandle(handle);
true
}
}
}
#[cfg(not(any(unix, windows)))]
fn process_exists(_pid: u32) -> bool {
true
}
#[cfg(unix)]
fn try_flock(file: &File, blocking: bool) -> std::result::Result<(), LockError> {
use std::os::unix::io::AsRawFd;
let fd = file.as_raw_fd();
let operation = if blocking {
libc::LOCK_EX
} else {
libc::LOCK_EX | libc::LOCK_NB
};
let result = unsafe { libc::flock(fd, operation) };
if result == 0 {
Ok(())
} else {
let err = std::io::Error::last_os_error();
if err.raw_os_error() == Some(libc::EWOULDBLOCK) {
Err(LockError::DatabaseLocked {
holder_pid: None,
lock_path: PathBuf::new(),
})
} else {
Err(LockError::Io(err))
}
}
}
#[cfg(windows)]
fn try_flock(file: &File, blocking: bool) -> std::result::Result<(), LockError> {
use std::os::windows::io::AsRawHandle;
let handle = file.as_raw_handle() as windows_sys::Win32::Foundation::HANDLE;
let flags = windows_sys::Win32::Storage::FileSystem::LOCKFILE_EXCLUSIVE_LOCK
| if blocking { 0 } else { windows_sys::Win32::Storage::FileSystem::LOCKFILE_FAIL_IMMEDIATELY };
let mut overlapped: windows_sys::Win32::System::IO::OVERLAPPED = unsafe { std::mem::zeroed() };
let result = unsafe {
windows_sys::Win32::Storage::FileSystem::LockFileEx(
handle,
flags,
0,
1,
0,
&mut overlapped,
)
};
if result != 0 {
Ok(())
} else {
let err = std::io::Error::last_os_error();
if err.raw_os_error() == Some(windows_sys::Win32::Foundation::ERROR_LOCK_VIOLATION as i32) {
Err(LockError::DatabaseLocked {
holder_pid: None,
lock_path: PathBuf::new(),
})
} else {
Err(LockError::Io(err))
}
}
}
#[cfg(not(any(unix, windows)))]
fn try_flock(_file: &File, _blocking: bool) -> std::result::Result<(), LockError> {
Ok(())
}
#[cfg(unix)]
fn release(&self) {
use std::os::unix::io::AsRawFd;
let fd = self.lock_file.as_raw_fd();
unsafe { libc::flock(fd, libc::LOCK_UN) };
}
#[cfg(windows)]
fn release(&self) {
use std::os::windows::io::AsRawHandle;
let handle = self.lock_file.as_raw_handle() as windows_sys::Win32::Foundation::HANDLE;
let mut overlapped: windows_sys::Win32::System::IO::OVERLAPPED = unsafe { std::mem::zeroed() };
unsafe {
windows_sys::Win32::Storage::FileSystem::UnlockFileEx(
handle,
0,
1,
0,
&mut overlapped,
);
}
}
#[cfg(not(any(unix, windows)))]
fn release(&self) {
}
}
impl Drop for DatabaseLock {
fn drop(&mut self) {
self.release();
let _ = std::fs::remove_file(&self.path);
}
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Default)]
pub struct RwLockState {
pub reader_count: u32,
pub writer_intent: u32,
pub writer_active: u32,
pub _padding: u32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConnectionMode {
ReadOnly,
ReadWrite,
}
pub struct RwDatabaseLock {
lock_file: File,
path: PathBuf,
mode: ConnectionMode,
our_pid: u32,
}
impl RwDatabaseLock {
pub fn acquire_shared<P: AsRef<Path>>(db_path: P) -> std::result::Result<Self, LockError> {
Self::acquire_with_mode(db_path, ConnectionMode::ReadOnly, &LockConfig::default())
}
pub fn acquire_exclusive<P: AsRef<Path>>(db_path: P) -> std::result::Result<Self, LockError> {
Self::acquire_with_mode(db_path, ConnectionMode::ReadWrite, &LockConfig::default())
}
pub fn acquire_with_mode<P: AsRef<Path>>(
db_path: P,
mode: ConnectionMode,
config: &LockConfig,
) -> std::result::Result<Self, LockError> {
let db_path = db_path.as_ref();
let lock_path = db_path.join(&config.lock_file_name);
if !db_path.exists() {
std::fs::create_dir_all(db_path)?;
}
let file = OpenOptions::new()
.create(true)
.read(true)
.write(true)
.open(&lock_path)?;
let our_pid = std::process::id();
let deadline = config.timeout.map(|t| Instant::now() + t);
loop {
match mode {
ConnectionMode::ReadOnly => {
if Self::try_shared_lock(&file)? {
return Ok(Self {
lock_file: file,
path: lock_path,
mode,
our_pid,
});
}
}
ConnectionMode::ReadWrite => {
if Self::try_exclusive_lock(&file)? {
return Ok(Self {
lock_file: file,
path: lock_path,
mode,
our_pid,
});
}
}
}
if let Some(deadline) = deadline {
if Instant::now() >= deadline {
return Err(LockError::Timeout {
elapsed: config.timeout.unwrap_or_default(),
timeout: config.timeout.unwrap_or_default(),
});
}
std::thread::sleep(config.retry_interval);
} else {
return Err(LockError::DatabaseLocked {
holder_pid: None,
lock_path,
});
}
}
}
pub fn mode(&self) -> ConnectionMode {
self.mode
}
pub fn is_readonly(&self) -> bool {
self.mode == ConnectionMode::ReadOnly
}
#[cfg(unix)]
fn try_shared_lock(file: &File) -> std::result::Result<bool, LockError> {
use std::os::unix::io::AsRawFd;
let fd = file.as_raw_fd();
let result = unsafe { libc::flock(fd, libc::LOCK_SH | libc::LOCK_NB) };
if result == 0 {
Ok(true)
} else {
let err = std::io::Error::last_os_error();
if err.raw_os_error() == Some(libc::EWOULDBLOCK) {
Ok(false)
} else {
Err(LockError::Io(err))
}
}
}
#[cfg(unix)]
fn try_exclusive_lock(file: &File) -> std::result::Result<bool, LockError> {
use std::os::unix::io::AsRawFd;
let fd = file.as_raw_fd();
let result = unsafe { libc::flock(fd, libc::LOCK_EX | libc::LOCK_NB) };
if result == 0 {
Ok(true)
} else {
let err = std::io::Error::last_os_error();
if err.raw_os_error() == Some(libc::EWOULDBLOCK) {
Ok(false)
} else {
Err(LockError::Io(err))
}
}
}
#[cfg(windows)]
fn try_shared_lock(file: &File) -> std::result::Result<bool, LockError> {
use std::os::windows::io::AsRawHandle;
let handle = file.as_raw_handle() as windows_sys::Win32::Foundation::HANDLE;
let mut overlapped: windows_sys::Win32::System::IO::OVERLAPPED = unsafe { std::mem::zeroed() };
let result = unsafe {
windows_sys::Win32::Storage::FileSystem::LockFileEx(
handle,
windows_sys::Win32::Storage::FileSystem::LOCKFILE_FAIL_IMMEDIATELY,
0, 1, 0,
&mut overlapped,
)
};
if result != 0 {
Ok(true)
} else {
let err = std::io::Error::last_os_error();
if err.raw_os_error() == Some(windows_sys::Win32::Foundation::ERROR_LOCK_VIOLATION as i32) {
Ok(false)
} else {
Err(LockError::Io(err))
}
}
}
#[cfg(windows)]
fn try_exclusive_lock(file: &File) -> std::result::Result<bool, LockError> {
use std::os::windows::io::AsRawHandle;
let handle = file.as_raw_handle() as windows_sys::Win32::Foundation::HANDLE;
let mut overlapped: windows_sys::Win32::System::IO::OVERLAPPED = unsafe { std::mem::zeroed() };
let result = unsafe {
windows_sys::Win32::Storage::FileSystem::LockFileEx(
handle,
windows_sys::Win32::Storage::FileSystem::LOCKFILE_EXCLUSIVE_LOCK
| windows_sys::Win32::Storage::FileSystem::LOCKFILE_FAIL_IMMEDIATELY,
0, 1, 0,
&mut overlapped,
)
};
if result != 0 {
Ok(true)
} else {
let err = std::io::Error::last_os_error();
if err.raw_os_error() == Some(windows_sys::Win32::Foundation::ERROR_LOCK_VIOLATION as i32) {
Ok(false)
} else {
Err(LockError::Io(err))
}
}
}
#[cfg(not(any(unix, windows)))]
fn try_shared_lock(_file: &File) -> std::result::Result<bool, LockError> {
Ok(true)
}
#[cfg(not(any(unix, windows)))]
fn try_exclusive_lock(_file: &File) -> std::result::Result<bool, LockError> {
Ok(true)
}
#[cfg(unix)]
fn release(&self) {
use std::os::unix::io::AsRawFd;
let fd = self.lock_file.as_raw_fd();
unsafe { libc::flock(fd, libc::LOCK_UN) };
}
#[cfg(windows)]
fn release(&self) {
use std::os::windows::io::AsRawHandle;
let handle = self.lock_file.as_raw_handle() as windows_sys::Win32::Foundation::HANDLE;
let mut overlapped: windows_sys::Win32::System::IO::OVERLAPPED = unsafe { std::mem::zeroed() };
unsafe {
windows_sys::Win32::Storage::FileSystem::UnlockFileEx(handle, 0, 1, 0, &mut overlapped);
}
}
#[cfg(not(any(unix, windows)))]
fn release(&self) {}
}
impl Drop for RwDatabaseLock {
fn drop(&mut self) {
self.release();
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
use tempfile::TempDir;
#[test]
fn test_exclusive_lock_basic() {
let dir = TempDir::new().unwrap();
let db_path = dir.path();
let lock1 = DatabaseLock::acquire(db_path);
assert!(lock1.is_ok());
let lock2 = DatabaseLock::acquire_no_wait(db_path);
assert!(matches!(lock2, Err(LockError::DatabaseLocked { .. })));
drop(lock1);
let lock3 = DatabaseLock::acquire(db_path);
assert!(lock3.is_ok());
}
#[test]
fn test_acquire_default_timeout() {
let dir = TempDir::new().unwrap();
let db_path = dir.path().to_path_buf();
let _lock = DatabaseLock::acquire(&db_path).unwrap();
let db_path2 = db_path.clone();
let lock_holder = _lock;
let handle = thread::spawn(move || {
thread::sleep(Duration::from_millis(200));
drop(lock_holder);
});
let start = Instant::now();
let result = DatabaseLock::acquire(&db_path2);
let elapsed = start.elapsed();
assert!(result.is_ok(), "acquire() should succeed after lock is released");
assert!(elapsed >= Duration::from_millis(100), "should have waited for lock");
assert!(elapsed < Duration::from_secs(2), "should not wait too long");
handle.join().unwrap();
}
#[test]
fn test_lock_with_timeout() {
let dir = TempDir::new().unwrap();
let db_path = dir.path().to_path_buf();
let _lock = DatabaseLock::acquire(&db_path).unwrap();
let start = Instant::now();
let result = DatabaseLock::acquire_with_timeout(&db_path, Duration::from_millis(100));
let elapsed = start.elapsed();
assert!(matches!(result, Err(LockError::Timeout { .. })));
assert!(elapsed >= Duration::from_millis(100));
assert!(elapsed < Duration::from_millis(500)); }
#[test]
fn test_lock_pid_recorded() {
let dir = TempDir::new().unwrap();
let db_path = dir.path();
let lock = DatabaseLock::acquire(db_path).unwrap();
let our_pid = std::process::id();
assert_eq!(lock.pid(), our_pid);
let holder = DatabaseLock::get_lock_holder(db_path);
assert_eq!(holder, Some(our_pid));
}
#[test]
fn test_shared_lock_multiple_readers() {
let dir = TempDir::new().unwrap();
let db_path = dir.path();
let lock1 = RwDatabaseLock::acquire_shared(db_path);
let lock2 = RwDatabaseLock::acquire_shared(db_path);
assert!(lock1.is_ok());
assert!(lock2.is_ok());
}
#[test]
fn test_exclusive_blocks_shared() {
let dir = TempDir::new().unwrap();
let db_path = dir.path();
let _exclusive = RwDatabaseLock::acquire_exclusive(db_path).unwrap();
let shared = RwDatabaseLock::acquire_with_mode(
db_path,
ConnectionMode::ReadOnly,
&LockConfig::no_wait(),
);
assert!(matches!(shared, Err(LockError::DatabaseLocked { .. })));
}
}