1use atomicwrites::{AllowOverwrite, AtomicFile};
2use std::{
3 fs::File,
4 io::{Read, Seek, SeekFrom, Write},
5 path::Path,
6 sync::{
7 Arc, Mutex,
8 atomic::{AtomicU64, Ordering},
9 },
10 thread::{self, JoinHandle},
11 time::Duration,
12};
13
14#[repr(u8)]
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum LockMode {
17 None = 0,
18 Destructive = 1,
19 NonDestructive = 2,
20}
21
22impl LockMode {
23 fn from_u8(value: u8) -> Self {
24 match value {
25 1 => LockMode::Destructive,
26 2 => LockMode::NonDestructive,
27 _ => LockMode::None,
28 }
29 }
30
31 fn as_u8(self) -> u8 {
32 self as u8
33 }
34}
35
36#[derive(Debug, Clone)]
37pub struct RwLock {
38 path: Arc<String>,
39 writer_mode: Arc<AtomicU64>,
40 writer_present: Arc<AtomicU64>,
41 writer_pid: Arc<AtomicU64>,
42 reader_counts: Arc<Vec<AtomicU64>>,
43 refresh: Arc<Mutex<Option<JoinHandle<()>>>>,
44 running: Arc<AtomicU64>,
45 process_reader_counts: Arc<Vec<AtomicU64>>,
46 process_has_writer: Arc<AtomicU64>,
47}
48
49#[derive(Debug, Clone)]
50struct LockState {
51 writer_mode: u8,
52 writer_present: u8,
53 writer_pid: u64,
54 reader_counts: [u64; 3],
55}
56
57impl RwLock {
58 pub fn new<P: AsRef<Path>>(path: P) -> std::io::Result<Self> {
59 let path_str = path.as_ref().to_string_lossy().to_string();
60 let path_arc = Arc::new(path_str.clone());
61
62 let state = if !path.as_ref().exists() {
63 let initial_state = LockState {
64 writer_mode: LockMode::None.as_u8(),
65 writer_present: 0,
66 writer_pid: 0,
67 reader_counts: [0; 3],
68 };
69 Self::write_state(&path_str, &initial_state)?;
70 initial_state
71 } else {
72 Self::read_state(&path_str)?
73 };
74
75 let reader_counts = Arc::new(
76 (0..3)
77 .map(|i| AtomicU64::new(state.reader_counts[i]))
78 .collect::<Vec<_>>(),
79 );
80
81 let writer_mode = Arc::new(AtomicU64::new(state.writer_mode as u64));
82 let writer_present = Arc::new(AtomicU64::new(state.writer_present as u64));
83 let writer_pid = Arc::new(AtomicU64::new(state.writer_pid));
84
85 let process_reader_counts = Arc::new((0..3).map(|_| AtomicU64::new(0)).collect::<Vec<_>>());
86 let process_has_writer = Arc::new(AtomicU64::new(0));
87
88 let running = Arc::new(AtomicU64::new(1));
89 let running_clone = Arc::clone(&running);
90 let path_clone = Arc::clone(&path_arc);
91 let writer_mode_clone = Arc::clone(&writer_mode);
92 let writer_present_clone = Arc::clone(&writer_present);
93 let writer_pid_clone = Arc::clone(&writer_pid);
94 let reader_counts_clone = Arc::clone(&reader_counts);
95
96 let refresh = thread::spawn(move || {
97 while running_clone.load(Ordering::SeqCst) == 1 {
98 thread::sleep(Duration::from_millis(100));
99
100 match Self::read_state(&path_clone) {
101 Ok(state) => {
102 writer_mode_clone.store(state.writer_mode as u64, Ordering::SeqCst);
103 writer_present_clone.store(state.writer_present as u64, Ordering::SeqCst);
104 writer_pid_clone.store(state.writer_pid, Ordering::SeqCst);
105
106 for (i, count) in state.reader_counts.iter().enumerate() {
107 if i < reader_counts_clone.len() {
108 reader_counts_clone[i].store(*count, Ordering::SeqCst);
109 }
110 }
111 }
112 Err(e) => {
113 eprintln!("Error in refresh thread: {e}");
114 }
115 }
116 }
117 });
118
119 Ok(Self {
120 path: path_arc,
121 writer_mode,
122 writer_present,
123 writer_pid,
124 reader_counts,
125 refresh: Arc::new(Mutex::new(Some(refresh))),
126 running,
127 process_reader_counts,
128 process_has_writer,
129 })
130 }
131
132 fn read_state(path: &str) -> std::io::Result<LockState> {
133 let mut file = File::open(path)?;
134 let mut reader_counts = [0u64; 3];
135
136 file.seek(SeekFrom::Start(0))?;
137 let mut writer_mode_buf = [0; 1];
138 file.read_exact(&mut writer_mode_buf)?;
139 let writer_mode = writer_mode_buf[0];
140
141 file.seek(SeekFrom::Current(7))?;
142
143 let mut writer_present_buf = [0; 1];
144 file.read_exact(&mut writer_present_buf)?;
145 let writer_present = writer_present_buf[0];
146
147 file.seek(SeekFrom::Current(7))?;
148
149 let mut writer_pid_buf = [0; 8];
150 file.read_exact(&mut writer_pid_buf)?;
151 let writer_pid = u64::from_le_bytes(writer_pid_buf);
152
153 for reader_count in reader_counts.iter_mut() {
154 let mut count_buf = [0; 8];
155 if file.read_exact(&mut count_buf).is_ok() {
156 *reader_count = u64::from_le_bytes(count_buf);
157 } else {
158 break;
159 }
160 }
161
162 Ok(LockState {
163 writer_mode,
164 writer_present,
165 writer_pid,
166 reader_counts,
167 })
168 }
169
170 fn write_state(path: &str, state: &LockState) -> std::io::Result<()> {
171 let atomic_file = AtomicFile::new(path, AllowOverwrite);
172
173 atomic_file.write(|f| {
174 f.seek(SeekFrom::Start(0))?;
175
176 f.write_all(&[state.writer_mode])?;
177 f.write_all(&[0; 7])?; f.write_all(&[state.writer_present])?;
180 f.write_all(&[0; 7])?; f.write_all(&state.writer_pid.to_le_bytes())?;
183
184 for count in &state.reader_counts {
185 f.write_all(&count.to_le_bytes())?;
186 }
187
188 Ok(())
189 })?;
190
191 Ok(())
192 }
193
194 fn update_state<F>(&self, update_fn: F) -> std::io::Result<()>
195 where
196 F: FnOnce(LockState) -> LockState,
197 {
198 let atomic_file = AtomicFile::new(&*self.path, AllowOverwrite);
199
200 let current_state = Self::read_state(&self.path)?;
201 let new_state = update_fn(current_state);
202
203 self.writer_mode
204 .store(new_state.writer_mode as u64, Ordering::SeqCst);
205 self.writer_present
206 .store(new_state.writer_present as u64, Ordering::SeqCst);
207 self.writer_pid
208 .store(new_state.writer_pid, Ordering::SeqCst);
209
210 for (i, count) in new_state.reader_counts.iter().enumerate() {
211 if i < self.reader_counts.len() {
212 self.reader_counts[i].store(*count, Ordering::SeqCst);
213 }
214 }
215
216 atomic_file.write(|f| {
217 f.seek(SeekFrom::Start(0))?;
218
219 f.write_all(&[new_state.writer_mode])?;
220 f.write_all(&[0; 7])?;
221
222 f.write_all(&[new_state.writer_present])?;
223 f.write_all(&[0; 7])?;
224
225 f.write_all(&new_state.writer_pid.to_le_bytes())?;
226
227 for count in &new_state.reader_counts {
228 f.write_all(&count.to_le_bytes())?;
229 }
230
231 Ok(())
232 })?;
233
234 Ok(())
235 }
236
237 fn current_pid() -> u64 {
238 std::process::id() as u64
239 }
240
241 fn process_owns_writer(&self) -> bool {
242 self.process_has_writer.load(Ordering::SeqCst) > 0
243 }
244
245 pub fn read_lock(&self, mode: LockMode) -> std::io::Result<ReadGuard> {
246 if mode == LockMode::None {
247 return Err(std::io::Error::new(
248 std::io::ErrorKind::InvalidInput,
249 "Cannot acquire read lock with None mode",
250 ));
251 }
252
253 if self.process_owns_writer() {
254 self.process_reader_counts[mode as usize].fetch_add(1, Ordering::SeqCst);
255
256 return Ok(ReadGuard {
257 lock: self.clone(),
258 mode,
259 active: true,
260 });
261 }
262
263 let mut backoff = Duration::from_millis(1);
264 let max_backoff = Duration::from_secs(1);
265
266 loop {
267 let current_writer_mode =
268 LockMode::from_u8(self.writer_mode.load(Ordering::SeqCst) as u8);
269 let writer_present = self.writer_present.load(Ordering::SeqCst) != 0;
270 let writer_pid = self.writer_pid.load(Ordering::SeqCst);
271 let current_pid = Self::current_pid();
272
273 if !writer_present || current_writer_mode == mode || writer_pid == current_pid {
274 match self.update_state(|mut state| {
275 if state.writer_present != 0
276 && LockMode::from_u8(state.writer_mode) != mode
277 && state.writer_pid != current_pid
278 {
279 return state;
280 }
281
282 state.reader_counts[mode as usize] += 1;
283 state
284 }) {
285 Ok(()) => {
286 self.process_reader_counts[mode as usize].fetch_add(1, Ordering::SeqCst);
287
288 return Ok(ReadGuard {
289 lock: self.clone(),
290 mode,
291 active: true,
292 });
293 }
294 Err(e) => {
295 if e.kind() == std::io::ErrorKind::WouldBlock {
296 thread::sleep(backoff);
297 backoff = std::cmp::min(backoff * 2, max_backoff);
298 continue;
299 }
300 return Err(e);
301 }
302 }
303 }
304
305 thread::sleep(backoff);
306 backoff = std::cmp::min(backoff * 2, max_backoff);
307 }
308 }
309
310 pub fn write_lock(&self, mode: LockMode) -> std::io::Result<WriteGuard> {
311 if mode == LockMode::None {
312 return Err(std::io::Error::new(
313 std::io::ErrorKind::InvalidInput,
314 "Cannot acquire write lock with None mode",
315 ));
316 }
317
318 if self.process_owns_writer() {
319 self.process_has_writer.fetch_add(1, Ordering::SeqCst);
320
321 return Ok(WriteGuard {
322 lock: self.clone(),
323 mode,
324 active: true,
325 });
326 }
327
328 let mut backoff = Duration::from_millis(1);
329 let max_backoff = Duration::from_secs(1);
330 let current_pid = Self::current_pid();
331
332 loop {
333 let writer_present = self.writer_present.load(Ordering::SeqCst) != 0;
334 let writer_pid = self.writer_pid.load(Ordering::SeqCst);
335
336 let incompatible_readers = (0..3).any(|i| {
337 if i == mode as usize {
338 false
339 } else {
340 self.reader_counts[i].load(Ordering::SeqCst) > 0
341 }
342 });
343
344 if (writer_present && writer_pid != current_pid) || incompatible_readers {
345 thread::sleep(backoff);
346 backoff = std::cmp::min(backoff * 2, max_backoff);
347 continue;
348 }
349
350 match self.update_state(|mut state| {
351 let incompatible_readers = (0..3).any(|i| {
352 if i == mode as usize {
353 false
354 } else {
355 state.reader_counts[i] > 0
356 }
357 });
358
359 if (state.writer_present != 0 && state.writer_pid != current_pid)
360 || incompatible_readers
361 {
362 return state;
363 }
364
365 state.writer_mode = mode.as_u8();
366 state.writer_present = 1;
367 state.writer_pid = current_pid;
368 state
369 }) {
370 Ok(()) => {
371 self.process_has_writer.store(1, Ordering::SeqCst);
372
373 return Ok(WriteGuard {
374 lock: self.clone(),
375 mode,
376 active: true,
377 });
378 }
379 Err(e) => {
380 if e.kind() == std::io::ErrorKind::WouldBlock {
381 thread::sleep(backoff);
382 backoff = std::cmp::min(backoff * 2, max_backoff);
383 continue;
384 }
385
386 return Err(e);
387 }
388 }
389 }
390 }
391
392 pub fn try_read_lock(&self, mode: LockMode) -> std::io::Result<Option<ReadGuard>> {
393 if mode == LockMode::None {
394 return Err(std::io::Error::new(
395 std::io::ErrorKind::InvalidInput,
396 "Cannot acquire read lock with None mode",
397 ));
398 }
399
400 if self.process_owns_writer() {
401 self.process_reader_counts[mode as usize].fetch_add(1, Ordering::SeqCst);
402
403 return Ok(Some(ReadGuard {
404 lock: self.clone(),
405 mode,
406 active: true,
407 }));
408 }
409
410 let current_writer_mode = LockMode::from_u8(self.writer_mode.load(Ordering::SeqCst) as u8);
411 let writer_present = self.writer_present.load(Ordering::SeqCst) != 0;
412 let writer_pid = self.writer_pid.load(Ordering::SeqCst);
413 let current_pid = Self::current_pid();
414
415 if !writer_present || current_writer_mode == mode || writer_pid == current_pid {
416 match self.update_state(|mut state| {
417 if state.writer_present != 0
418 && LockMode::from_u8(state.writer_mode) != mode
419 && state.writer_pid != current_pid
420 {
421 return state;
422 }
423
424 state.reader_counts[mode as usize] += 1;
425 state
426 }) {
427 Ok(()) => {
428 self.process_reader_counts[mode as usize].fetch_add(1, Ordering::SeqCst);
429
430 return Ok(Some(ReadGuard {
431 lock: self.clone(),
432 mode,
433 active: true,
434 }));
435 }
436 Err(e) => return Err(e),
437 }
438 }
439
440 Ok(None)
441 }
442
443 pub fn try_write_lock(&self, mode: LockMode) -> std::io::Result<Option<WriteGuard>> {
444 if mode == LockMode::None {
445 return Err(std::io::Error::new(
446 std::io::ErrorKind::InvalidInput,
447 "Cannot acquire write lock with None mode",
448 ));
449 }
450
451 if self.process_owns_writer() {
452 self.process_has_writer.fetch_add(1, Ordering::SeqCst);
453
454 return Ok(Some(WriteGuard {
455 lock: self.clone(),
456 mode,
457 active: true,
458 }));
459 }
460
461 let writer_present = self.writer_present.load(Ordering::SeqCst) != 0;
462 let writer_pid = self.writer_pid.load(Ordering::SeqCst);
463 let current_pid = Self::current_pid();
464
465 let incompatible_readers = (0..3).any(|i| {
466 if i == mode as usize {
467 false
468 } else {
469 self.reader_counts[i].load(Ordering::SeqCst) > 0
470 }
471 });
472
473 if (writer_present && writer_pid != current_pid) || incompatible_readers {
474 return Ok(None);
475 }
476
477 match self.update_state(|mut state| {
478 let incompatible_readers = (0..3).any(|i| {
479 if i == mode as usize {
480 false
481 } else {
482 state.reader_counts[i] > 0
483 }
484 });
485
486 if (state.writer_present != 0 && state.writer_pid != current_pid)
487 || incompatible_readers
488 {
489 return state;
490 }
491
492 state.writer_mode = mode.as_u8();
493 state.writer_present = 1;
494 state.writer_pid = current_pid;
495 state
496 }) {
497 Ok(()) => {
498 self.process_has_writer.store(1, Ordering::SeqCst);
499
500 Ok(Some(WriteGuard {
501 lock: self.clone(),
502 mode,
503 active: true,
504 }))
505 }
506 Err(e) => Err(e),
507 }
508 }
509
510 pub fn reader_count(&self, mode: LockMode) -> u64 {
511 self.reader_counts[mode as usize].load(Ordering::SeqCst)
512 }
513
514 pub fn total_reader_count(&self) -> u64 {
515 (0..3)
516 .map(|i| self.reader_counts[i].load(Ordering::SeqCst))
517 .sum()
518 }
519
520 pub fn has_writer(&self) -> bool {
521 self.writer_present.load(Ordering::SeqCst) != 0
522 }
523
524 pub fn writer_mode(&self) -> Option<LockMode> {
525 if self.has_writer() {
526 Some(LockMode::from_u8(
527 self.writer_mode.load(Ordering::SeqCst) as u8
528 ))
529 } else {
530 None
531 }
532 }
533
534 pub fn writer_pid(&self) -> Option<u64> {
535 if self.has_writer() {
536 Some(self.writer_pid.load(Ordering::SeqCst))
537 } else {
538 None
539 }
540 }
541}
542
543pub struct ReadGuard {
544 lock: RwLock,
545 mode: LockMode,
546 active: bool,
547}
548
549impl ReadGuard {
550 pub fn mode(&self) -> LockMode {
551 self.mode
552 }
553
554 pub fn unlock(&mut self) -> std::io::Result<()> {
555 if self.active {
556 let prev_count =
557 self.lock.process_reader_counts[self.mode as usize].fetch_sub(1, Ordering::SeqCst);
558
559 if prev_count == 1 && !self.lock.process_owns_writer() {
560 self.lock.update_state(|mut state| {
561 if state.reader_counts[self.mode as usize] > 0 {
562 state.reader_counts[self.mode as usize] -= 1;
563 }
564 state
565 })?;
566 }
567
568 self.active = false;
569 }
570 Ok(())
571 }
572}
573
574impl Drop for ReadGuard {
575 fn drop(&mut self) {
576 if self.active {
577 if let Err(e) = self.unlock() {
578 eprintln!("Error releasing read lock in drop: {e}");
579 }
580 }
581 }
582}
583
584pub struct WriteGuard {
585 lock: RwLock,
586 mode: LockMode,
587 active: bool,
588}
589
590impl WriteGuard {
591 pub fn mode(&self) -> LockMode {
592 self.mode
593 }
594
595 pub fn unlock(&mut self) -> std::io::Result<()> {
596 if self.active {
597 let prev_count = self.lock.process_has_writer.fetch_sub(1, Ordering::SeqCst);
598
599 if prev_count == 1 {
600 self.lock.update_state(|mut state| {
601 let current_pid = RwLock::current_pid();
602 if state.writer_present != 0 && state.writer_pid == current_pid {
603 state.writer_present = 0;
604 state.writer_mode = LockMode::None.as_u8();
605 state.writer_pid = 0;
606 }
607 state
608 })?;
609 }
610
611 self.active = false;
612 }
613 Ok(())
614 }
615}
616
617impl Drop for WriteGuard {
618 fn drop(&mut self) {
619 if self.active {
620 if let Err(e) = self.unlock() {
621 eprintln!("Error releasing write lock in drop: {e}");
622 }
623 }
624 }
625}
626
627impl Drop for RwLock {
628 fn drop(&mut self) {
629 self.running.store(0, Ordering::SeqCst);
630
631 if let Ok(mut refresh_guard) = self.refresh.lock() {
632 if let Some(handle) = refresh_guard.take() {
633 let _ = handle.join();
634 }
635 }
636 }
637}