1use std::fs;
9use std::io::Write;
10use std::path::{Path, PathBuf};
11
12use crate::error::WorktreeError;
13
14#[derive(Debug, serde::Serialize, serde::Deserialize)]
16struct LockPayload {
17 pid: u32,
18 start_time: u64,
19 uuid: String,
20 hostname: String,
21 acquired_at: String,
22}
23
24pub struct StateLock {
30 _file: fs::File,
32 lock_path: PathBuf,
33 uuid: String,
34}
35
36impl StateLock {
37 const MAX_ATTEMPTS: u32 = 15;
39
40 pub fn acquire(
52 lock_path: &Path,
53 timeout_ms: u64,
54 ) -> Result<Self, WorktreeError> {
55 if let Some(parent) = lock_path.parent() {
57 fs::create_dir_all(parent)?;
58 }
59
60 if let Some(dir) = lock_path.parent() {
63 if is_network_filesystem(dir) {
64 eprintln!(
65 "[iso-code] WARNING: state directory appears to be on a network filesystem; \
66 advisory locking may be unreliable. Consider ISO_CODE_HOME on local storage."
67 );
68 }
69 }
70
71 let uuid = uuid::Uuid::new_v4().to_string();
72 let start = std::time::Instant::now();
73
74 for attempt in 0..Self::MAX_ATTEMPTS {
75 match Self::try_acquire(lock_path, &uuid) {
76 Ok(lock) => return Ok(lock),
77 Err(_) => {
78 if start.elapsed().as_millis() as u64 >= timeout_ms {
80 return Err(WorktreeError::StateLockContention { timeout_ms });
81 }
82
83 let cap_ms: u64 = 2000;
85 let base_ms = 10u64.saturating_mul(1u64 << attempt);
86 let max_sleep = cap_ms.min(base_ms);
87 let sleep_ms = rand::random::<u64>() % (max_sleep + 1);
88 std::thread::sleep(std::time::Duration::from_millis(sleep_ms));
89 }
90 }
91 }
92
93 Err(WorktreeError::StateLockContention { timeout_ms })
94 }
95
96 fn try_acquire(lock_path: &Path, uuid: &str) -> Result<Self, WorktreeError> {
106 let mut file = fs::OpenOptions::new()
107 .create(true)
108 .truncate(false)
109 .write(true)
110 .read(true)
111 .open(lock_path)?;
112
113 #[cfg(unix)]
115 {
116 use std::os::unix::io::AsRawFd;
117 let ret = unsafe { libc::flock(file.as_raw_fd(), libc::LOCK_EX | libc::LOCK_NB) };
118 if ret != 0 {
119 return Err(WorktreeError::StateLockContention { timeout_ms: 0 });
120 }
121 }
122
123 #[cfg(windows)]
124 {
125 let rw = fd_lock::RwLock::new(file);
130 let write_guard = rw.try_write().map_err(|e| {
131 let _ = e.into_inner().into_inner();
133 WorktreeError::StateLockContention { timeout_ms: 0 }
134 })?;
135 file = write_guard.into_inner().into_inner();
137 }
138
139 use std::io::Seek;
141 file.set_len(0)?;
142 file.seek(std::io::SeekFrom::Start(0))?;
143
144 let payload = LockPayload {
145 pid: std::process::id(),
146 start_time: process_start_time(),
147 uuid: uuid.to_string(),
148 hostname: hostname(),
149 acquired_at: chrono::Utc::now().to_rfc3339(),
150 };
151 let json = serde_json::to_string(&payload).unwrap_or_default();
152 file.write_all(json.as_bytes())?;
153 file.flush()?;
154
155 Ok(StateLock {
156 _file: file,
157 lock_path: lock_path.to_path_buf(),
158 uuid: uuid.to_string(),
159 })
160 }
161
162 #[allow(dead_code)]
170 fn inspect_holder(lock_path: &Path) -> Option<LockPayload> {
171 let content = fs::read_to_string(lock_path).ok()?;
172 if content.is_empty() {
173 return None;
174 }
175 serde_json::from_str(&content).ok()
176 }
177
178 pub fn path(&self) -> &Path {
180 &self.lock_path
181 }
182
183 pub fn uuid(&self) -> &str {
185 &self.uuid
186 }
187}
188
189fn process_start_time() -> u64 {
192 #[cfg(target_os = "macos")]
193 {
194 use std::mem;
195 let mut info: libc::proc_bsdinfo = unsafe { mem::zeroed() };
196 let size = mem::size_of::<libc::proc_bsdinfo>() as libc::c_int;
197 let ret = unsafe {
198 libc::proc_pidinfo(
199 libc::getpid(),
200 libc::PROC_PIDTBSDINFO,
201 0,
202 &mut info as *mut _ as *mut libc::c_void,
203 size,
204 )
205 };
206 if ret > 0 {
207 return info.pbi_start_tvsec;
208 }
209 0
210 }
211
212 #[cfg(target_os = "linux")]
213 {
214 if let Ok(stat) = std::fs::read_to_string("/proc/self/stat") {
215 if let Some(pos) = stat.rfind(')') {
216 let rest = &stat[pos + 2..];
217 let fields: Vec<&str> = rest.split_whitespace().collect();
218 if fields.len() > 19 {
219 return fields[19].parse().unwrap_or(0);
220 }
221 }
222 }
223 0
224 }
225
226 #[cfg(not(any(target_os = "macos", target_os = "linux")))]
227 {
228 0
229 }
230}
231
232fn is_network_filesystem(path: &Path) -> bool {
235 #[cfg(target_os = "macos")]
236 {
237 let path_cstr = match path.to_str() {
238 Some(s) => match std::ffi::CString::new(s) {
239 Ok(c) => c,
240 Err(_) => return false,
241 },
242 None => return false,
243 };
244 unsafe {
245 let mut stat: libc::statfs = std::mem::zeroed();
246 if libc::statfs(path_cstr.as_ptr(), &mut stat) == 0 {
247 let fstype = std::ffi::CStr::from_ptr(stat.f_fstypename.as_ptr())
248 .to_string_lossy();
249 let network_types = ["nfs", "smbfs", "afpfs", "cifs", "webdav"];
250 return network_types.iter().any(|t| fstype.eq_ignore_ascii_case(t));
251 }
252 }
253 }
254
255 #[cfg(target_os = "linux")]
256 {
257 if let Ok(mounts) = std::fs::read_to_string("/proc/mounts") {
258 let path_str = path.to_string_lossy();
259 let network_types = ["nfs", "nfs4", "cifs", "smbfs", "fuse.sshfs", "9p"];
260 let mut best: Option<(&str, &str)> = None;
262 for line in mounts.lines() {
263 let parts: Vec<&str> = line.split_whitespace().collect();
264 if parts.len() >= 3 {
265 let mp = parts[1];
266 let fs = parts[2];
267 if path_str.starts_with(mp)
268 && best.map_or(true, |(cur_mp, _)| mp.len() > cur_mp.len())
269 {
270 best = Some((mp, fs));
271 }
272 }
273 }
274 if let Some((_, fs)) = best {
275 return network_types.contains(&fs);
276 }
277 }
278 }
279
280 let _ = path;
281 false
282}
283
284fn hostname() -> String {
286 #[cfg(unix)]
287 {
288 let mut buf = [0u8; 256];
289 let ret = unsafe { libc::gethostname(buf.as_mut_ptr() as *mut libc::c_char, buf.len()) };
290 if ret == 0 {
291 let end = buf.iter().position(|&b| b == 0).unwrap_or(buf.len());
292 return String::from_utf8_lossy(&buf[..end]).to_string();
293 }
294 }
295 "unknown".to_string()
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301 use tempfile::TempDir;
302
303 fn setup_dir() -> TempDir {
304 TempDir::new().unwrap()
305 }
306
307 #[test]
308 fn test_acquire_and_drop() {
309 let dir = setup_dir();
310 let lock_path = dir.path().join("state.lock");
311
312 let lock = StateLock::acquire(&lock_path, 5000).unwrap();
313 assert!(lock_path.exists());
314 assert!(!lock.uuid().is_empty());
315
316 let content = fs::read_to_string(&lock_path).unwrap();
318 let payload: LockPayload = serde_json::from_str(&content).unwrap();
319 assert_eq!(payload.pid, std::process::id());
320 assert!(!payload.uuid.is_empty());
321 assert!(!payload.hostname.is_empty());
322
323 drop(lock);
324 assert!(lock_path.exists());
326 }
327
328 #[test]
329 fn test_sequential_acquire() {
330 let dir = setup_dir();
331 let lock_path = dir.path().join("state.lock");
332
333 let lock1 = StateLock::acquire(&lock_path, 5000).unwrap();
334 drop(lock1);
335
336 let lock2 = StateLock::acquire(&lock_path, 5000).unwrap();
337 drop(lock2);
338 }
339
340 #[test]
341 fn test_dead_pid_holder_yields_to_new_acquirer() {
342 let dir = setup_dir();
345 let lock_path = dir.path().join("state.lock");
346
347 let payload = LockPayload {
349 pid: 99_999_999,
350 start_time: 1,
351 uuid: "stale-uuid".to_string(),
352 hostname: "test".to_string(),
353 acquired_at: "2020-01-01T00:00:00Z".to_string(),
354 };
355 fs::write(&lock_path, serde_json::to_string(&payload).unwrap()).unwrap();
356
357 let lock = StateLock::acquire(&lock_path, 5_000).unwrap();
358 assert_eq!(lock.path(), lock_path);
359 }
360
361 #[test]
362 fn test_inspect_holder_returns_payload() {
363 let dir = setup_dir();
364 let lock_path = dir.path().join("state.lock");
365 let _lock = StateLock::acquire(&lock_path, 5_000).unwrap();
366 let holder = StateLock::inspect_holder(&lock_path).unwrap();
367 assert_eq!(holder.pid, std::process::id());
368 }
369
370 #[test]
371 fn test_hostname_nonempty() {
372 let h = hostname();
373 assert!(!h.is_empty());
374 }
375}