jujutsu_lib/
lock.rs

1// Copyright 2020 The Jujutsu Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fs::{File, OpenOptions};
16use std::path::PathBuf;
17use std::time::Duration;
18
19use backoff::{retry, ExponentialBackoff};
20
21pub struct FileLock {
22    path: PathBuf,
23    _file: File,
24}
25
26impl FileLock {
27    pub fn lock(path: PathBuf) -> FileLock {
28        let mut options = OpenOptions::new();
29        options.create_new(true);
30        options.write(true);
31        let try_write_lock_file = || match options.open(&path) {
32            Ok(file) => Ok(FileLock {
33                path: path.clone(),
34                _file: file,
35            }),
36            Err(err) if err.kind() == std::io::ErrorKind::AlreadyExists => {
37                Err(backoff::Error::Transient {
38                    err,
39                    retry_after: None,
40                })
41            }
42            Err(err) if cfg!(windows) && err.kind() == std::io::ErrorKind::PermissionDenied => {
43                Err(backoff::Error::Transient {
44                    err,
45                    retry_after: None,
46                })
47            }
48            Err(err) => Err(backoff::Error::Permanent(err)),
49        };
50        let backoff = ExponentialBackoff {
51            initial_interval: Duration::from_millis(1),
52            max_elapsed_time: Some(Duration::from_secs(10)),
53            ..Default::default()
54        };
55        match retry(backoff, try_write_lock_file) {
56            Err(err) => panic!(
57                "failed to create lock file {}: {}",
58                path.to_string_lossy(),
59                err
60            ),
61            Ok(file_lock) => file_lock,
62        }
63    }
64}
65
66impl Drop for FileLock {
67    fn drop(&mut self) {
68        std::fs::remove_file(&self.path).expect("failed to delete lock file");
69    }
70}
71
72#[cfg(test)]
73mod tests {
74    use std::cmp::max;
75    use std::thread;
76
77    use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
78
79    use super::*;
80
81    #[test]
82    fn lock_basic() {
83        let temp_dir = testutils::new_temp_dir();
84        let lock_path = temp_dir.path().join("test.lock");
85        assert!(!lock_path.exists());
86        {
87            let _lock = FileLock::lock(lock_path.clone());
88            assert!(lock_path.exists());
89        }
90        assert!(!lock_path.exists());
91    }
92
93    #[test]
94    fn lock_concurrent() {
95        let temp_dir = testutils::new_temp_dir();
96        let data_path = temp_dir.path().join("test");
97        let lock_path = temp_dir.path().join("test.lock");
98        let mut data_file = OpenOptions::new()
99            .create(true)
100            .write(true)
101            .open(data_path.clone())
102            .unwrap();
103        data_file.write_u32::<LittleEndian>(0).unwrap();
104        let num_threads = max(num_cpus::get(), 4);
105        let mut threads = vec![];
106        for _ in 0..num_threads {
107            let data_path = data_path.clone();
108            let lock_path = lock_path.clone();
109            let handle = thread::spawn(move || {
110                let _lock = FileLock::lock(lock_path);
111                let mut data_file = OpenOptions::new()
112                    .read(true)
113                    .open(data_path.clone())
114                    .unwrap();
115                let value = data_file.read_u32::<LittleEndian>().unwrap();
116                thread::sleep(Duration::from_millis(1));
117                let mut data_file = OpenOptions::new().write(true).open(data_path).unwrap();
118                data_file.write_u32::<LittleEndian>(value + 1).unwrap();
119            });
120            threads.push(handle);
121        }
122        for thread in threads {
123            thread.join().ok().unwrap();
124        }
125        let mut data_file = OpenOptions::new().read(true).open(data_path).unwrap();
126        let value = data_file.read_u32::<LittleEndian>().unwrap();
127        assert_eq!(value, num_threads as u32);
128    }
129}