use std::fs::{self, File, OpenOptions};
#[cfg(not(target_os = "wasi"))]
use std::io::Write;
use std::path::{Path, PathBuf};
use crate::core::{Error, Result};
#[derive(Debug)]
pub struct FileLock {
#[allow(dead_code)]
file: File,
path: PathBuf,
}
impl FileLock {
pub fn acquire(db_path: impl AsRef<Path>) -> Result<Self> {
let db_path = db_path.as_ref();
fs::create_dir_all(db_path)
.map_err(|e| Error::internal(format!("failed to create database directory: {}", e)))?;
let lock_file_path = db_path.join("db.lock");
#[allow(unused_mut)]
let mut file = OpenOptions::new()
.create(true)
.truncate(false)
.read(true)
.write(true)
.open(&lock_file_path)
.map_err(|e| Error::internal(format!("failed to open lock file: {}", e)))?;
acquire_lock(&file)?;
#[cfg(not(target_os = "wasi"))]
{
file.set_len(0)
.map_err(|e| Error::internal(format!("failed to truncate lock file: {}", e)))?;
let pid = std::process::id();
write!(file, "{}", pid).ok();
file.sync_all().ok();
}
Ok(Self {
file,
path: lock_file_path,
})
}
pub fn path(&self) -> &Path {
&self.path
}
}
impl Drop for FileLock {
fn drop(&mut self) {
}
}
#[cfg(unix)]
fn acquire_lock(file: &File) -> Result<()> {
use std::os::unix::io::AsRawFd;
let fd = file.as_raw_fd();
let result = unsafe { libc::flock(fd, libc::LOCK_EX | libc::LOCK_NB) };
if result != 0 {
let errno = std::io::Error::last_os_error();
if errno.raw_os_error() == Some(libc::EWOULDBLOCK) {
return Err(Error::DatabaseLocked);
}
return Err(Error::internal(format!(
"failed to acquire lock: {}",
errno
)));
}
Ok(())
}
#[cfg(windows)]
fn acquire_lock(file: &File) -> Result<()> {
use std::os::windows::io::AsRawHandle;
use windows_sys::Win32::Foundation::{ERROR_LOCK_VIOLATION, HANDLE};
use windows_sys::Win32::Storage::FileSystem::{
LockFileEx, LOCKFILE_EXCLUSIVE_LOCK, LOCKFILE_FAIL_IMMEDIATELY,
};
use windows_sys::Win32::System::IO::OVERLAPPED;
let handle = file.as_raw_handle() as HANDLE;
let mut overlapped: OVERLAPPED = unsafe { std::mem::zeroed() };
let result = unsafe {
LockFileEx(
handle,
LOCKFILE_EXCLUSIVE_LOCK | LOCKFILE_FAIL_IMMEDIATELY,
0,
1, 0,
&mut overlapped,
)
};
if result == 0 {
let error = std::io::Error::last_os_error();
if error.raw_os_error() == Some(ERROR_LOCK_VIOLATION as i32) {
return Err(Error::DatabaseLocked);
}
return Err(Error::internal(format!(
"failed to acquire lock: {}",
error
)));
}
Ok(())
}
#[cfg(not(any(unix, windows)))]
fn acquire_lock(_file: &File) -> Result<()> {
eprintln!("Warning: File locking not supported on this platform");
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn test_acquire_lock() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_db");
let lock = FileLock::acquire(&db_path).unwrap();
assert!(db_path.join("db.lock").exists());
#[cfg(unix)]
{
let contents = fs::read_to_string(db_path.join("db.lock")).unwrap();
assert_eq!(contents, std::process::id().to_string());
}
drop(lock);
}
#[test]
fn test_lock_prevents_second_acquisition() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_db");
let _lock1 = FileLock::acquire(&db_path).unwrap();
let result = FileLock::acquire(&db_path);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("locked by another process"));
}
#[test]
fn test_lock_released_on_drop() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_db");
{
let _lock = FileLock::acquire(&db_path).unwrap();
}
let _lock2 = FileLock::acquire(&db_path).unwrap();
}
}