use std::fs;
use std::io::Write;
use std::path::{Path, PathBuf};
use crate::error::WorktreeError;
#[derive(Debug, serde::Serialize, serde::Deserialize)]
struct LockPayload {
pid: u32,
start_time: u64,
uuid: String,
hostname: String,
acquired_at: String,
}
pub struct StateLock {
_file: fs::File,
lock_path: PathBuf,
uuid: String,
}
impl StateLock {
const MAX_ATTEMPTS: u32 = 15;
pub fn acquire(
lock_path: &Path,
timeout_ms: u64,
) -> Result<Self, WorktreeError> {
if let Some(parent) = lock_path.parent() {
fs::create_dir_all(parent)?;
}
if let Some(dir) = lock_path.parent() {
if is_network_filesystem(dir) {
eprintln!(
"[iso-code] WARNING: state directory appears to be on a network filesystem; \
advisory locking may be unreliable. Consider ISO_CODE_HOME on local storage."
);
}
}
let uuid = uuid::Uuid::new_v4().to_string();
let start = std::time::Instant::now();
for attempt in 0..Self::MAX_ATTEMPTS {
match Self::try_acquire(lock_path, &uuid) {
Ok(lock) => return Ok(lock),
Err(_) => {
if start.elapsed().as_millis() as u64 >= timeout_ms {
return Err(WorktreeError::StateLockContention { timeout_ms });
}
let cap_ms: u64 = 2000;
let base_ms = 10u64.saturating_mul(1u64 << attempt);
let max_sleep = cap_ms.min(base_ms);
let sleep_ms = rand::random::<u64>() % (max_sleep + 1);
std::thread::sleep(std::time::Duration::from_millis(sleep_ms));
}
}
}
Err(WorktreeError::StateLockContention { timeout_ms })
}
fn try_acquire(lock_path: &Path, uuid: &str) -> Result<Self, WorktreeError> {
let mut file = fs::OpenOptions::new()
.create(true)
.truncate(false)
.write(true)
.read(true)
.open(lock_path)?;
#[cfg(unix)]
{
use std::os::unix::io::AsRawFd;
let ret = unsafe { libc::flock(file.as_raw_fd(), libc::LOCK_EX | libc::LOCK_NB) };
if ret != 0 {
return Err(WorktreeError::StateLockContention { timeout_ms: 0 });
}
}
#[cfg(windows)]
{
let rw = fd_lock::RwLock::new(file);
let write_guard = rw.try_write().map_err(|e| {
let _ = e.into_inner().into_inner();
WorktreeError::StateLockContention { timeout_ms: 0 }
})?;
file = write_guard.into_inner().into_inner();
}
use std::io::Seek;
file.set_len(0)?;
file.seek(std::io::SeekFrom::Start(0))?;
let payload = LockPayload {
pid: std::process::id(),
start_time: process_start_time(),
uuid: uuid.to_string(),
hostname: hostname(),
acquired_at: chrono::Utc::now().to_rfc3339(),
};
let json = serde_json::to_string(&payload).unwrap_or_default();
file.write_all(json.as_bytes())?;
file.flush()?;
Ok(StateLock {
_file: file,
lock_path: lock_path.to_path_buf(),
uuid: uuid.to_string(),
})
}
#[allow(dead_code)]
fn inspect_holder(lock_path: &Path) -> Option<LockPayload> {
let content = fs::read_to_string(lock_path).ok()?;
if content.is_empty() {
return None;
}
serde_json::from_str(&content).ok()
}
pub fn path(&self) -> &Path {
&self.lock_path
}
pub fn uuid(&self) -> &str {
&self.uuid
}
}
fn process_start_time() -> u64 {
#[cfg(target_os = "macos")]
{
use std::mem;
let mut info: libc::proc_bsdinfo = unsafe { mem::zeroed() };
let size = mem::size_of::<libc::proc_bsdinfo>() as libc::c_int;
let ret = unsafe {
libc::proc_pidinfo(
libc::getpid(),
libc::PROC_PIDTBSDINFO,
0,
&mut info as *mut _ as *mut libc::c_void,
size,
)
};
if ret > 0 {
return info.pbi_start_tvsec;
}
0
}
#[cfg(target_os = "linux")]
{
if let Ok(stat) = std::fs::read_to_string("/proc/self/stat") {
if let Some(pos) = stat.rfind(')') {
let rest = &stat[pos + 2..];
let fields: Vec<&str> = rest.split_whitespace().collect();
if fields.len() > 19 {
return fields[19].parse().unwrap_or(0);
}
}
}
0
}
#[cfg(not(any(target_os = "macos", target_os = "linux")))]
{
0
}
}
fn is_network_filesystem(path: &Path) -> bool {
#[cfg(target_os = "macos")]
{
let path_cstr = match path.to_str() {
Some(s) => match std::ffi::CString::new(s) {
Ok(c) => c,
Err(_) => return false,
},
None => return false,
};
unsafe {
let mut stat: libc::statfs = std::mem::zeroed();
if libc::statfs(path_cstr.as_ptr(), &mut stat) == 0 {
let fstype = std::ffi::CStr::from_ptr(stat.f_fstypename.as_ptr())
.to_string_lossy();
let network_types = ["nfs", "smbfs", "afpfs", "cifs", "webdav"];
return network_types.iter().any(|t| fstype.eq_ignore_ascii_case(t));
}
}
}
#[cfg(target_os = "linux")]
{
if let Ok(mounts) = std::fs::read_to_string("/proc/mounts") {
let path_str = path.to_string_lossy();
let network_types = ["nfs", "nfs4", "cifs", "smbfs", "fuse.sshfs", "9p"];
let mut best: Option<(&str, &str)> = None;
for line in mounts.lines() {
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() >= 3 {
let mp = parts[1];
let fs = parts[2];
if path_str.starts_with(mp)
&& best.map_or(true, |(cur_mp, _)| mp.len() > cur_mp.len())
{
best = Some((mp, fs));
}
}
}
if let Some((_, fs)) = best {
return network_types.contains(&fs);
}
}
}
let _ = path;
false
}
fn hostname() -> String {
#[cfg(unix)]
{
let mut buf = [0u8; 256];
let ret = unsafe { libc::gethostname(buf.as_mut_ptr() as *mut libc::c_char, buf.len()) };
if ret == 0 {
let end = buf.iter().position(|&b| b == 0).unwrap_or(buf.len());
return String::from_utf8_lossy(&buf[..end]).to_string();
}
}
"unknown".to_string()
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn setup_dir() -> TempDir {
TempDir::new().unwrap()
}
#[test]
fn test_acquire_and_drop() {
let dir = setup_dir();
let lock_path = dir.path().join("state.lock");
let lock = StateLock::acquire(&lock_path, 5000).unwrap();
assert!(lock_path.exists());
assert!(!lock.uuid().is_empty());
let content = fs::read_to_string(&lock_path).unwrap();
let payload: LockPayload = serde_json::from_str(&content).unwrap();
assert_eq!(payload.pid, std::process::id());
assert!(!payload.uuid.is_empty());
assert!(!payload.hostname.is_empty());
drop(lock);
assert!(lock_path.exists());
}
#[test]
fn test_sequential_acquire() {
let dir = setup_dir();
let lock_path = dir.path().join("state.lock");
let lock1 = StateLock::acquire(&lock_path, 5000).unwrap();
drop(lock1);
let lock2 = StateLock::acquire(&lock_path, 5000).unwrap();
drop(lock2);
}
#[test]
fn test_dead_pid_holder_yields_to_new_acquirer() {
let dir = setup_dir();
let lock_path = dir.path().join("state.lock");
let payload = LockPayload {
pid: 99_999_999,
start_time: 1,
uuid: "stale-uuid".to_string(),
hostname: "test".to_string(),
acquired_at: "2020-01-01T00:00:00Z".to_string(),
};
fs::write(&lock_path, serde_json::to_string(&payload).unwrap()).unwrap();
let lock = StateLock::acquire(&lock_path, 5_000).unwrap();
assert_eq!(lock.path(), lock_path);
}
#[test]
fn test_inspect_holder_returns_payload() {
let dir = setup_dir();
let lock_path = dir.path().join("state.lock");
let _lock = StateLock::acquire(&lock_path, 5_000).unwrap();
let holder = StateLock::inspect_holder(&lock_path).unwrap();
assert_eq!(holder.pid, std::process::id());
}
#[test]
fn test_hostname_nonempty() {
let h = hostname();
assert!(!h.is_empty());
}
}