ddup_bak/chunks/
lock.rs

1use atomicwrites::{AllowOverwrite, AtomicFile};
2use std::{
3    fs::File,
4    io::{Read, Seek, SeekFrom, Write},
5    path::Path,
6    sync::{
7        Arc, Mutex,
8        atomic::{AtomicU64, Ordering},
9    },
10    thread::{self, JoinHandle},
11    time::Duration,
12};
13
14#[repr(u8)]
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum LockMode {
17    None = 0,
18    Destructive = 1,
19    NonDestructive = 2,
20}
21
22impl LockMode {
23    fn from_u8(value: u8) -> Self {
24        match value {
25            1 => LockMode::Destructive,
26            2 => LockMode::NonDestructive,
27            _ => LockMode::None,
28        }
29    }
30
31    fn as_u8(self) -> u8 {
32        self as u8
33    }
34}
35
36#[derive(Debug, Clone)]
37pub struct RwLock {
38    path: Arc<String>,
39    writer_mode: Arc<AtomicU64>,
40    writer_present: Arc<AtomicU64>,
41    writer_pid: Arc<AtomicU64>,
42    reader_counts: Arc<Vec<AtomicU64>>,
43    refresh: Arc<Mutex<Option<JoinHandle<()>>>>,
44    running: Arc<AtomicU64>,
45    process_reader_counts: Arc<Vec<AtomicU64>>,
46    process_has_writer: Arc<AtomicU64>,
47}
48
49#[derive(Debug, Clone)]
50struct LockState {
51    writer_mode: u8,
52    writer_present: u8,
53    writer_pid: u64,
54    reader_counts: [u64; 3],
55}
56
57impl RwLock {
58    pub fn new<P: AsRef<Path>>(path: P) -> std::io::Result<Self> {
59        let path_str = path.as_ref().to_string_lossy().to_string();
60        let path_arc = Arc::new(path_str.clone());
61
62        let state = if !path.as_ref().exists() {
63            let initial_state = LockState {
64                writer_mode: LockMode::None.as_u8(),
65                writer_present: 0,
66                writer_pid: 0,
67                reader_counts: [0; 3],
68            };
69            Self::write_state(&path_str, &initial_state)?;
70            initial_state
71        } else {
72            Self::read_state(&path_str)?
73        };
74
75        let reader_counts = Arc::new(
76            (0..3)
77                .map(|i| AtomicU64::new(state.reader_counts[i]))
78                .collect::<Vec<_>>(),
79        );
80
81        let writer_mode = Arc::new(AtomicU64::new(state.writer_mode as u64));
82        let writer_present = Arc::new(AtomicU64::new(state.writer_present as u64));
83        let writer_pid = Arc::new(AtomicU64::new(state.writer_pid));
84
85        let process_reader_counts = Arc::new((0..3).map(|_| AtomicU64::new(0)).collect::<Vec<_>>());
86        let process_has_writer = Arc::new(AtomicU64::new(0));
87
88        let running = Arc::new(AtomicU64::new(1));
89        let running_clone = Arc::clone(&running);
90        let path_clone = Arc::clone(&path_arc);
91        let writer_mode_clone = Arc::clone(&writer_mode);
92        let writer_present_clone = Arc::clone(&writer_present);
93        let writer_pid_clone = Arc::clone(&writer_pid);
94        let reader_counts_clone = Arc::clone(&reader_counts);
95
96        let refresh = thread::spawn(move || {
97            while running_clone.load(Ordering::SeqCst) == 1 {
98                thread::sleep(Duration::from_millis(100));
99
100                match Self::read_state(&path_clone) {
101                    Ok(state) => {
102                        writer_mode_clone.store(state.writer_mode as u64, Ordering::SeqCst);
103                        writer_present_clone.store(state.writer_present as u64, Ordering::SeqCst);
104                        writer_pid_clone.store(state.writer_pid, Ordering::SeqCst);
105
106                        for (i, count) in state.reader_counts.iter().enumerate() {
107                            if i < reader_counts_clone.len() {
108                                reader_counts_clone[i].store(*count, Ordering::SeqCst);
109                            }
110                        }
111                    }
112                    Err(e) => {
113                        eprintln!("Error in refresh thread: {e}");
114                    }
115                }
116            }
117        });
118
119        Ok(Self {
120            path: path_arc,
121            writer_mode,
122            writer_present,
123            writer_pid,
124            reader_counts,
125            refresh: Arc::new(Mutex::new(Some(refresh))),
126            running,
127            process_reader_counts,
128            process_has_writer,
129        })
130    }
131
132    fn read_state(path: &str) -> std::io::Result<LockState> {
133        let mut file = File::open(path)?;
134        let mut reader_counts = [0u64; 3];
135
136        file.seek(SeekFrom::Start(0))?;
137        let mut writer_mode_buf = [0; 1];
138        file.read_exact(&mut writer_mode_buf)?;
139        let writer_mode = writer_mode_buf[0];
140
141        file.seek(SeekFrom::Current(7))?;
142
143        let mut writer_present_buf = [0; 1];
144        file.read_exact(&mut writer_present_buf)?;
145        let writer_present = writer_present_buf[0];
146
147        file.seek(SeekFrom::Current(7))?;
148
149        let mut writer_pid_buf = [0; 8];
150        file.read_exact(&mut writer_pid_buf)?;
151        let writer_pid = u64::from_le_bytes(writer_pid_buf);
152
153        for reader_count in reader_counts.iter_mut() {
154            let mut count_buf = [0; 8];
155            if file.read_exact(&mut count_buf).is_ok() {
156                *reader_count = u64::from_le_bytes(count_buf);
157            } else {
158                break;
159            }
160        }
161
162        Ok(LockState {
163            writer_mode,
164            writer_present,
165            writer_pid,
166            reader_counts,
167        })
168    }
169
170    fn write_state(path: &str, state: &LockState) -> std::io::Result<()> {
171        let atomic_file = AtomicFile::new(path, AllowOverwrite);
172
173        atomic_file.write(|f| {
174            f.seek(SeekFrom::Start(0))?;
175
176            f.write_all(&[state.writer_mode])?;
177            f.write_all(&[0; 7])?; // Padding
178
179            f.write_all(&[state.writer_present])?;
180            f.write_all(&[0; 7])?; // Padding
181
182            f.write_all(&state.writer_pid.to_le_bytes())?;
183
184            for count in &state.reader_counts {
185                f.write_all(&count.to_le_bytes())?;
186            }
187
188            Ok(())
189        })?;
190
191        Ok(())
192    }
193
194    fn update_state<F>(&self, update_fn: F) -> std::io::Result<()>
195    where
196        F: FnOnce(LockState) -> LockState,
197    {
198        let atomic_file = AtomicFile::new(&*self.path, AllowOverwrite);
199
200        let current_state = Self::read_state(&self.path)?;
201        let new_state = update_fn(current_state);
202
203        self.writer_mode
204            .store(new_state.writer_mode as u64, Ordering::SeqCst);
205        self.writer_present
206            .store(new_state.writer_present as u64, Ordering::SeqCst);
207        self.writer_pid
208            .store(new_state.writer_pid, Ordering::SeqCst);
209
210        for (i, count) in new_state.reader_counts.iter().enumerate() {
211            if i < self.reader_counts.len() {
212                self.reader_counts[i].store(*count, Ordering::SeqCst);
213            }
214        }
215
216        atomic_file.write(|f| {
217            f.seek(SeekFrom::Start(0))?;
218
219            f.write_all(&[new_state.writer_mode])?;
220            f.write_all(&[0; 7])?;
221
222            f.write_all(&[new_state.writer_present])?;
223            f.write_all(&[0; 7])?;
224
225            f.write_all(&new_state.writer_pid.to_le_bytes())?;
226
227            for count in &new_state.reader_counts {
228                f.write_all(&count.to_le_bytes())?;
229            }
230
231            Ok(())
232        })?;
233
234        Ok(())
235    }
236
237    fn current_pid() -> u64 {
238        std::process::id() as u64
239    }
240
241    fn process_owns_writer(&self) -> bool {
242        self.process_has_writer.load(Ordering::SeqCst) > 0
243    }
244
245    pub fn read_lock(&self, mode: LockMode) -> std::io::Result<ReadGuard> {
246        if mode == LockMode::None {
247            return Err(std::io::Error::new(
248                std::io::ErrorKind::InvalidInput,
249                "Cannot acquire read lock with None mode",
250            ));
251        }
252
253        if self.process_owns_writer() {
254            self.process_reader_counts[mode as usize].fetch_add(1, Ordering::SeqCst);
255
256            return Ok(ReadGuard {
257                lock: self.clone(),
258                mode,
259                active: true,
260            });
261        }
262
263        let mut backoff = Duration::from_millis(1);
264        let max_backoff = Duration::from_secs(1);
265
266        loop {
267            let current_writer_mode =
268                LockMode::from_u8(self.writer_mode.load(Ordering::SeqCst) as u8);
269            let writer_present = self.writer_present.load(Ordering::SeqCst) != 0;
270            let writer_pid = self.writer_pid.load(Ordering::SeqCst);
271            let current_pid = Self::current_pid();
272
273            if !writer_present || current_writer_mode == mode || writer_pid == current_pid {
274                match self.update_state(|mut state| {
275                    if state.writer_present != 0
276                        && LockMode::from_u8(state.writer_mode) != mode
277                        && state.writer_pid != current_pid
278                    {
279                        return state;
280                    }
281
282                    state.reader_counts[mode as usize] += 1;
283                    state
284                }) {
285                    Ok(()) => {
286                        self.process_reader_counts[mode as usize].fetch_add(1, Ordering::SeqCst);
287
288                        return Ok(ReadGuard {
289                            lock: self.clone(),
290                            mode,
291                            active: true,
292                        });
293                    }
294                    Err(e) => {
295                        if e.kind() == std::io::ErrorKind::WouldBlock {
296                            thread::sleep(backoff);
297                            backoff = std::cmp::min(backoff * 2, max_backoff);
298                            continue;
299                        }
300                        return Err(e);
301                    }
302                }
303            }
304
305            thread::sleep(backoff);
306            backoff = std::cmp::min(backoff * 2, max_backoff);
307        }
308    }
309
310    pub fn write_lock(&self, mode: LockMode) -> std::io::Result<WriteGuard> {
311        if mode == LockMode::None {
312            return Err(std::io::Error::new(
313                std::io::ErrorKind::InvalidInput,
314                "Cannot acquire write lock with None mode",
315            ));
316        }
317
318        if self.process_owns_writer() {
319            self.process_has_writer.fetch_add(1, Ordering::SeqCst);
320
321            return Ok(WriteGuard {
322                lock: self.clone(),
323                mode,
324                active: true,
325            });
326        }
327
328        let mut backoff = Duration::from_millis(1);
329        let max_backoff = Duration::from_secs(1);
330        let current_pid = Self::current_pid();
331
332        loop {
333            let writer_present = self.writer_present.load(Ordering::SeqCst) != 0;
334            let writer_pid = self.writer_pid.load(Ordering::SeqCst);
335
336            let incompatible_readers = (0..3).any(|i| {
337                if i == mode as usize {
338                    false
339                } else {
340                    self.reader_counts[i].load(Ordering::SeqCst) > 0
341                }
342            });
343
344            if (writer_present && writer_pid != current_pid) || incompatible_readers {
345                thread::sleep(backoff);
346                backoff = std::cmp::min(backoff * 2, max_backoff);
347                continue;
348            }
349
350            match self.update_state(|mut state| {
351                let incompatible_readers = (0..3).any(|i| {
352                    if i == mode as usize {
353                        false
354                    } else {
355                        state.reader_counts[i] > 0
356                    }
357                });
358
359                if (state.writer_present != 0 && state.writer_pid != current_pid)
360                    || incompatible_readers
361                {
362                    return state;
363                }
364
365                state.writer_mode = mode.as_u8();
366                state.writer_present = 1;
367                state.writer_pid = current_pid;
368                state
369            }) {
370                Ok(()) => {
371                    self.process_has_writer.store(1, Ordering::SeqCst);
372
373                    return Ok(WriteGuard {
374                        lock: self.clone(),
375                        mode,
376                        active: true,
377                    });
378                }
379                Err(e) => {
380                    if e.kind() == std::io::ErrorKind::WouldBlock {
381                        thread::sleep(backoff);
382                        backoff = std::cmp::min(backoff * 2, max_backoff);
383                        continue;
384                    }
385
386                    return Err(e);
387                }
388            }
389        }
390    }
391
392    pub fn try_read_lock(&self, mode: LockMode) -> std::io::Result<Option<ReadGuard>> {
393        if mode == LockMode::None {
394            return Err(std::io::Error::new(
395                std::io::ErrorKind::InvalidInput,
396                "Cannot acquire read lock with None mode",
397            ));
398        }
399
400        if self.process_owns_writer() {
401            self.process_reader_counts[mode as usize].fetch_add(1, Ordering::SeqCst);
402
403            return Ok(Some(ReadGuard {
404                lock: self.clone(),
405                mode,
406                active: true,
407            }));
408        }
409
410        let current_writer_mode = LockMode::from_u8(self.writer_mode.load(Ordering::SeqCst) as u8);
411        let writer_present = self.writer_present.load(Ordering::SeqCst) != 0;
412        let writer_pid = self.writer_pid.load(Ordering::SeqCst);
413        let current_pid = Self::current_pid();
414
415        if !writer_present || current_writer_mode == mode || writer_pid == current_pid {
416            match self.update_state(|mut state| {
417                if state.writer_present != 0
418                    && LockMode::from_u8(state.writer_mode) != mode
419                    && state.writer_pid != current_pid
420                {
421                    return state;
422                }
423
424                state.reader_counts[mode as usize] += 1;
425                state
426            }) {
427                Ok(()) => {
428                    self.process_reader_counts[mode as usize].fetch_add(1, Ordering::SeqCst);
429
430                    return Ok(Some(ReadGuard {
431                        lock: self.clone(),
432                        mode,
433                        active: true,
434                    }));
435                }
436                Err(e) => return Err(e),
437            }
438        }
439
440        Ok(None)
441    }
442
443    pub fn try_write_lock(&self, mode: LockMode) -> std::io::Result<Option<WriteGuard>> {
444        if mode == LockMode::None {
445            return Err(std::io::Error::new(
446                std::io::ErrorKind::InvalidInput,
447                "Cannot acquire write lock with None mode",
448            ));
449        }
450
451        if self.process_owns_writer() {
452            self.process_has_writer.fetch_add(1, Ordering::SeqCst);
453
454            return Ok(Some(WriteGuard {
455                lock: self.clone(),
456                mode,
457                active: true,
458            }));
459        }
460
461        let writer_present = self.writer_present.load(Ordering::SeqCst) != 0;
462        let writer_pid = self.writer_pid.load(Ordering::SeqCst);
463        let current_pid = Self::current_pid();
464
465        let incompatible_readers = (0..3).any(|i| {
466            if i == mode as usize {
467                false
468            } else {
469                self.reader_counts[i].load(Ordering::SeqCst) > 0
470            }
471        });
472
473        if (writer_present && writer_pid != current_pid) || incompatible_readers {
474            return Ok(None);
475        }
476
477        match self.update_state(|mut state| {
478            let incompatible_readers = (0..3).any(|i| {
479                if i == mode as usize {
480                    false
481                } else {
482                    state.reader_counts[i] > 0
483                }
484            });
485
486            if (state.writer_present != 0 && state.writer_pid != current_pid)
487                || incompatible_readers
488            {
489                return state;
490            }
491
492            state.writer_mode = mode.as_u8();
493            state.writer_present = 1;
494            state.writer_pid = current_pid;
495            state
496        }) {
497            Ok(()) => {
498                self.process_has_writer.store(1, Ordering::SeqCst);
499
500                Ok(Some(WriteGuard {
501                    lock: self.clone(),
502                    mode,
503                    active: true,
504                }))
505            }
506            Err(e) => Err(e),
507        }
508    }
509
510    pub fn reader_count(&self, mode: LockMode) -> u64 {
511        self.reader_counts[mode as usize].load(Ordering::SeqCst)
512    }
513
514    pub fn total_reader_count(&self) -> u64 {
515        (0..3)
516            .map(|i| self.reader_counts[i].load(Ordering::SeqCst))
517            .sum()
518    }
519
520    pub fn has_writer(&self) -> bool {
521        self.writer_present.load(Ordering::SeqCst) != 0
522    }
523
524    pub fn writer_mode(&self) -> Option<LockMode> {
525        if self.has_writer() {
526            Some(LockMode::from_u8(
527                self.writer_mode.load(Ordering::SeqCst) as u8
528            ))
529        } else {
530            None
531        }
532    }
533
534    pub fn writer_pid(&self) -> Option<u64> {
535        if self.has_writer() {
536            Some(self.writer_pid.load(Ordering::SeqCst))
537        } else {
538            None
539        }
540    }
541}
542
543pub struct ReadGuard {
544    lock: RwLock,
545    mode: LockMode,
546    active: bool,
547}
548
549impl ReadGuard {
550    pub fn mode(&self) -> LockMode {
551        self.mode
552    }
553
554    pub fn unlock(&mut self) -> std::io::Result<()> {
555        if self.active {
556            let prev_count =
557                self.lock.process_reader_counts[self.mode as usize].fetch_sub(1, Ordering::SeqCst);
558
559            if prev_count == 1 && !self.lock.process_owns_writer() {
560                self.lock.update_state(|mut state| {
561                    if state.reader_counts[self.mode as usize] > 0 {
562                        state.reader_counts[self.mode as usize] -= 1;
563                    }
564                    state
565                })?;
566            }
567
568            self.active = false;
569        }
570        Ok(())
571    }
572}
573
574impl Drop for ReadGuard {
575    fn drop(&mut self) {
576        if self.active {
577            if let Err(e) = self.unlock() {
578                eprintln!("Error releasing read lock in drop: {e}");
579            }
580        }
581    }
582}
583
584pub struct WriteGuard {
585    lock: RwLock,
586    mode: LockMode,
587    active: bool,
588}
589
590impl WriteGuard {
591    pub fn mode(&self) -> LockMode {
592        self.mode
593    }
594
595    pub fn unlock(&mut self) -> std::io::Result<()> {
596        if self.active {
597            let prev_count = self.lock.process_has_writer.fetch_sub(1, Ordering::SeqCst);
598
599            if prev_count == 1 {
600                self.lock.update_state(|mut state| {
601                    let current_pid = RwLock::current_pid();
602                    if state.writer_present != 0 && state.writer_pid == current_pid {
603                        state.writer_present = 0;
604                        state.writer_mode = LockMode::None.as_u8();
605                        state.writer_pid = 0;
606                    }
607                    state
608                })?;
609            }
610
611            self.active = false;
612        }
613        Ok(())
614    }
615}
616
617impl Drop for WriteGuard {
618    fn drop(&mut self) {
619        if self.active {
620            if let Err(e) = self.unlock() {
621                eprintln!("Error releasing write lock in drop: {e}");
622            }
623        }
624    }
625}
626
627impl Drop for RwLock {
628    fn drop(&mut self) {
629        self.running.store(0, Ordering::SeqCst);
630
631        if let Ok(mut refresh_guard) = self.refresh.lock() {
632            if let Some(handle) = refresh_guard.take() {
633                let _ = handle.join();
634            }
635        }
636    }
637}