jj_lib/
file_util.rs

1// Copyright 2021 The Jujutsu Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#![allow(missing_docs)]
16
17use std::fs;
18use std::fs::File;
19use std::io;
20use std::io::Read;
21use std::io::Write;
22use std::path::Component;
23use std::path::Path;
24use std::path::PathBuf;
25use std::pin::Pin;
26use std::task::Poll;
27
28use tempfile::NamedTempFile;
29use tempfile::PersistError;
30use thiserror::Error;
31use tokio::io::AsyncRead;
32use tokio::io::AsyncReadExt as _;
33use tokio::io::ReadBuf;
34
35pub use self::platform::*;
36
37#[derive(Debug, Error)]
38#[error("Cannot access {path}")]
39pub struct PathError {
40    pub path: PathBuf,
41    #[source]
42    pub error: io::Error,
43}
44
45pub trait IoResultExt<T> {
46    fn context(self, path: impl AsRef<Path>) -> Result<T, PathError>;
47}
48
49impl<T> IoResultExt<T> for io::Result<T> {
50    fn context(self, path: impl AsRef<Path>) -> Result<T, PathError> {
51        self.map_err(|error| PathError {
52            path: path.as_ref().to_path_buf(),
53            error,
54        })
55    }
56}
57
58/// Creates a directory or does nothing if the directory already exists.
59///
60/// Returns the underlying error if the directory can't be created.
61/// The function will also fail if intermediate directories on the path do not
62/// already exist.
63pub fn create_or_reuse_dir(dirname: &Path) -> io::Result<()> {
64    match fs::create_dir(dirname) {
65        Ok(()) => Ok(()),
66        Err(_) if dirname.is_dir() => Ok(()),
67        Err(e) => Err(e),
68    }
69}
70
71/// Removes all files in the directory, but not the directory itself.
72///
73/// The directory must exist, and there should be no sub directories.
74pub fn remove_dir_contents(dirname: &Path) -> Result<(), PathError> {
75    for entry in dirname.read_dir().context(dirname)? {
76        let entry = entry.context(dirname)?;
77        let path = entry.path();
78        fs::remove_file(&path).context(&path)?;
79    }
80    Ok(())
81}
82
83/// Expands "~/" to "$HOME/".
84pub fn expand_home_path(path_str: &str) -> PathBuf {
85    if let Some(remainder) = path_str.strip_prefix("~/") {
86        if let Ok(home_dir_str) = std::env::var("HOME") {
87            return PathBuf::from(home_dir_str).join(remainder);
88        }
89    }
90    PathBuf::from(path_str)
91}
92
93/// Turns the given `to` path into relative path starting from the `from` path.
94///
95/// Both `from` and `to` paths are supposed to be absolute and normalized in the
96/// same manner.
97pub fn relative_path(from: &Path, to: &Path) -> PathBuf {
98    // Find common prefix.
99    for (i, base) in from.ancestors().enumerate() {
100        if let Ok(suffix) = to.strip_prefix(base) {
101            if i == 0 && suffix.as_os_str().is_empty() {
102                return ".".into();
103            } else {
104                let mut result = PathBuf::from_iter(std::iter::repeat_n("..", i));
105                result.push(suffix);
106                return result;
107            }
108        }
109    }
110
111    // No common prefix found. Return the original (absolute) path.
112    to.to_owned()
113}
114
115/// Consumes as much `..` and `.` as possible without considering symlinks.
116pub fn normalize_path(path: &Path) -> PathBuf {
117    let mut result = PathBuf::new();
118    for c in path.components() {
119        match c {
120            Component::CurDir => {}
121            Component::ParentDir
122                if matches!(result.components().next_back(), Some(Component::Normal(_))) =>
123            {
124                // Do not pop ".."
125                let popped = result.pop();
126                assert!(popped);
127            }
128            _ => {
129                result.push(c);
130            }
131        }
132    }
133
134    if result.as_os_str().is_empty() {
135        ".".into()
136    } else {
137        result
138    }
139}
140
141/// Like `NamedTempFile::persist()`, but doesn't try to overwrite the existing
142/// target on Windows.
143pub fn persist_content_addressed_temp_file<P: AsRef<Path>>(
144    temp_file: NamedTempFile,
145    new_path: P,
146) -> io::Result<File> {
147    if cfg!(windows) {
148        // On Windows, overwriting file can fail if the file is opened without
149        // FILE_SHARE_DELETE for example. We don't need to take a risk if the
150        // file already exists.
151        match temp_file.persist_noclobber(&new_path) {
152            Ok(file) => Ok(file),
153            Err(PersistError { error, file: _ }) => {
154                if let Ok(existing_file) = File::open(new_path) {
155                    // TODO: Update mtime to help GC keep this file
156                    Ok(existing_file)
157                } else {
158                    Err(error)
159                }
160            }
161        }
162    } else {
163        // On Unix, rename() is atomic and should succeed even if the
164        // destination file exists. Checking if the target exists might involve
165        // non-atomic operation, so don't use persist_noclobber().
166        temp_file
167            .persist(new_path)
168            .map_err(|PersistError { error, file: _ }| error)
169    }
170}
171
172/// Reads from an async source and writes to a sync destination. Does not spawn
173/// a task, so writes will block.
174pub async fn copy_async_to_sync<R: AsyncRead, W: Write + ?Sized>(
175    reader: R,
176    writer: &mut W,
177) -> io::Result<usize> {
178    let mut buf = vec![0; 16 << 10];
179    let mut total_written_bytes = 0;
180
181    let mut reader = std::pin::pin!(reader);
182    loop {
183        let written_bytes = reader.read(&mut buf).await?;
184        if written_bytes == 0 {
185            return Ok(total_written_bytes);
186        }
187        writer.write_all(&buf[0..written_bytes])?;
188        total_written_bytes += written_bytes;
189    }
190}
191
192/// `AsyncRead`` implementation backed by a `Read`. It is not actually async;
193/// the goal is simply to avoid reading the full contents from the `Read` into
194/// memory.
195pub struct BlockingAsyncReader<R> {
196    reader: R,
197}
198
199impl<R: Read + Unpin> BlockingAsyncReader<R> {
200    /// Creates a new `BlockingAsyncReader`
201    pub fn new(reader: R) -> Self {
202        Self { reader }
203    }
204}
205
206impl<R: Read + Unpin> AsyncRead for BlockingAsyncReader<R> {
207    fn poll_read(
208        mut self: Pin<&mut Self>,
209        _cx: &mut std::task::Context<'_>,
210        buf: &mut ReadBuf<'_>,
211    ) -> Poll<io::Result<()>> {
212        let num_bytes_read = self.reader.read(buf.initialize_unfilled())?;
213        buf.advance(num_bytes_read);
214        Poll::Ready(Ok(()))
215    }
216}
217
218#[cfg(unix)]
219mod platform {
220    use std::io;
221    use std::os::unix::fs::symlink;
222    use std::path::Path;
223
224    /// Symlinks are always available on UNIX
225    pub fn check_symlink_support() -> io::Result<bool> {
226        Ok(true)
227    }
228
229    pub fn try_symlink<P: AsRef<Path>, Q: AsRef<Path>>(original: P, link: Q) -> io::Result<()> {
230        symlink(original, link)
231    }
232}
233
234#[cfg(windows)]
235mod platform {
236    use std::io;
237    use std::os::windows::fs::symlink_file;
238    use std::path::Path;
239
240    use winreg::enums::HKEY_LOCAL_MACHINE;
241    use winreg::RegKey;
242
243    /// Symlinks may or may not be enabled on Windows. They require the
244    /// Developer Mode setting, which is stored in the registry key below.
245    pub fn check_symlink_support() -> io::Result<bool> {
246        let hklm = RegKey::predef(HKEY_LOCAL_MACHINE);
247        let sideloading =
248            hklm.open_subkey("SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\AppModelUnlock")?;
249        let developer_mode: u32 = sideloading.get_value("AllowDevelopmentWithoutDevLicense")?;
250        Ok(developer_mode == 1)
251    }
252
253    pub fn try_symlink<P: AsRef<Path>, Q: AsRef<Path>>(original: P, link: Q) -> io::Result<()> {
254        // this will create a nonfunctional link for directories, but at the moment
255        // we don't have enough information in the tree to determine whether the
256        // symlink target is a file or a directory
257        // note: if developer mode is not enabled the error code will be 1314,
258        // ERROR_PRIVILEGE_NOT_HELD
259
260        symlink_file(original, link)
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    use std::io::Cursor;
267    use std::io::Write as _;
268
269    use itertools::Itertools as _;
270    use pollster::FutureExt as _;
271    use test_case::test_case;
272
273    use super::*;
274    use crate::tests::new_temp_dir;
275
276    #[test]
277    fn normalize_too_many_dot_dot() {
278        assert_eq!(normalize_path(Path::new("foo/..")), Path::new("."));
279        assert_eq!(normalize_path(Path::new("foo/../..")), Path::new(".."));
280        assert_eq!(
281            normalize_path(Path::new("foo/../../..")),
282            Path::new("../..")
283        );
284        assert_eq!(
285            normalize_path(Path::new("foo/../../../bar/baz/..")),
286            Path::new("../../bar")
287        );
288    }
289
290    #[test]
291    fn test_persist_no_existing_file() {
292        let temp_dir = new_temp_dir();
293        let target = temp_dir.path().join("file");
294        let mut temp_file = NamedTempFile::new_in(&temp_dir).unwrap();
295        temp_file.write_all(b"contents").unwrap();
296        assert!(persist_content_addressed_temp_file(temp_file, target).is_ok());
297    }
298
299    #[test_case(false ; "existing file open")]
300    #[test_case(true ; "existing file closed")]
301    fn test_persist_target_exists(existing_file_closed: bool) {
302        let temp_dir = new_temp_dir();
303        let target = temp_dir.path().join("file");
304        let mut temp_file = NamedTempFile::new_in(&temp_dir).unwrap();
305        temp_file.write_all(b"contents").unwrap();
306
307        let mut file = File::create(&target).unwrap();
308        file.write_all(b"contents").unwrap();
309        if existing_file_closed {
310            drop(file);
311        }
312
313        assert!(persist_content_addressed_temp_file(temp_file, &target).is_ok());
314    }
315
316    #[test]
317    fn test_copy_async_to_sync_small() {
318        let input = b"hello";
319        let mut output = vec![];
320
321        let result = copy_async_to_sync(Cursor::new(&input), &mut output).block_on();
322        assert!(result.is_ok());
323        assert_eq!(result.unwrap(), 5);
324        assert_eq!(output, input);
325    }
326
327    #[test]
328    fn test_copy_async_to_sync_large() {
329        // More than 1 buffer worth of data
330        let input = (0..100u8).cycle().take(40000).collect_vec();
331        let mut output = vec![];
332
333        let result = copy_async_to_sync(Cursor::new(&input), &mut output).block_on();
334        assert!(result.is_ok());
335        assert_eq!(result.unwrap(), 40000);
336        assert_eq!(output, input);
337    }
338
339    #[test]
340    fn test_blocking_async_reader() {
341        let input = b"hello";
342        let sync_reader = Cursor::new(&input);
343        let mut async_reader = BlockingAsyncReader::new(sync_reader);
344
345        let mut buf = [0u8; 3];
346        let num_bytes_read = async_reader.read(&mut buf).block_on().unwrap();
347        assert_eq!(num_bytes_read, 3);
348        assert_eq!(&buf, &input[0..3]);
349
350        let num_bytes_read = async_reader.read(&mut buf).block_on().unwrap();
351        assert_eq!(num_bytes_read, 2);
352        assert_eq!(&buf[0..2], &input[3..5]);
353    }
354
355    #[test]
356    fn test_blocking_async_reader_read_to_end() {
357        let input = b"hello";
358        let sync_reader = Cursor::new(&input);
359        let mut async_reader = BlockingAsyncReader::new(sync_reader);
360
361        let mut buf = vec![];
362        let num_bytes_read = async_reader.read_to_end(&mut buf).block_on().unwrap();
363        assert_eq!(num_bytes_read, input.len());
364        assert_eq!(&buf, &input);
365    }
366}