jj_lib/
file_util.rs

1// Copyright 2021 The Jujutsu Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#![allow(missing_docs)]
16
17use std::borrow::Cow;
18use std::ffi::OsString;
19use std::fs;
20use std::fs::File;
21use std::io;
22use std::io::Read;
23use std::io::Write;
24use std::path::Component;
25use std::path::Path;
26use std::path::PathBuf;
27use std::pin::Pin;
28use std::task::Poll;
29
30use tempfile::NamedTempFile;
31use tempfile::PersistError;
32use thiserror::Error;
33use tokio::io::AsyncRead;
34use tokio::io::AsyncReadExt as _;
35use tokio::io::ReadBuf;
36
37pub use self::platform::check_symlink_support;
38pub use self::platform::try_symlink;
39
40#[derive(Debug, Error)]
41#[error("Cannot access {path}")]
42pub struct PathError {
43    pub path: PathBuf,
44    #[source]
45    pub error: io::Error,
46}
47
48pub trait IoResultExt<T> {
49    fn context(self, path: impl AsRef<Path>) -> Result<T, PathError>;
50}
51
52impl<T> IoResultExt<T> for io::Result<T> {
53    fn context(self, path: impl AsRef<Path>) -> Result<T, PathError> {
54        self.map_err(|error| PathError {
55            path: path.as_ref().to_path_buf(),
56            error,
57        })
58    }
59}
60
61/// Creates a directory or does nothing if the directory already exists.
62///
63/// Returns the underlying error if the directory can't be created.
64/// The function will also fail if intermediate directories on the path do not
65/// already exist.
66pub fn create_or_reuse_dir(dirname: &Path) -> io::Result<()> {
67    match fs::create_dir(dirname) {
68        Ok(()) => Ok(()),
69        Err(_) if dirname.is_dir() => Ok(()),
70        Err(e) => Err(e),
71    }
72}
73
74/// Removes all files in the directory, but not the directory itself.
75///
76/// The directory must exist, and there should be no sub directories.
77pub fn remove_dir_contents(dirname: &Path) -> Result<(), PathError> {
78    for entry in dirname.read_dir().context(dirname)? {
79        let entry = entry.context(dirname)?;
80        let path = entry.path();
81        fs::remove_file(&path).context(&path)?;
82    }
83    Ok(())
84}
85
86#[derive(Debug, Error)]
87#[error(transparent)]
88pub struct BadPathEncoding(platform::BadOsStrEncoding);
89
90/// Constructs [`Path`] from `bytes` in platform-specific manner.
91///
92/// On Unix, this function never fails because paths are just bytes. On Windows,
93/// this may return error if the input wasn't well-formed UTF-8.
94pub fn path_from_bytes(bytes: &[u8]) -> Result<&Path, BadPathEncoding> {
95    let s = platform::os_str_from_bytes(bytes).map_err(BadPathEncoding)?;
96    Ok(Path::new(s))
97}
98
99/// Converts `path` to bytes in platform-specific manner.
100///
101/// On Unix, this function never fails because paths are just bytes. On Windows,
102/// this may return error if the input wasn't well-formed UTF-8.
103///
104/// The returned byte sequence can be considered a superset of ASCII (such as
105/// UTF-8 bytes.)
106pub fn path_to_bytes(path: &Path) -> Result<&[u8], BadPathEncoding> {
107    platform::os_str_to_bytes(path.as_ref()).map_err(BadPathEncoding)
108}
109
110/// Expands "~/" to "$HOME/".
111pub fn expand_home_path(path_str: &str) -> PathBuf {
112    if let Some(remainder) = path_str.strip_prefix("~/") {
113        if let Ok(home_dir_str) = std::env::var("HOME") {
114            return PathBuf::from(home_dir_str).join(remainder);
115        }
116    }
117    PathBuf::from(path_str)
118}
119
120/// Turns the given `to` path into relative path starting from the `from` path.
121///
122/// Both `from` and `to` paths are supposed to be absolute and normalized in the
123/// same manner.
124pub fn relative_path(from: &Path, to: &Path) -> PathBuf {
125    // Find common prefix.
126    for (i, base) in from.ancestors().enumerate() {
127        if let Ok(suffix) = to.strip_prefix(base) {
128            if i == 0 && suffix.as_os_str().is_empty() {
129                return ".".into();
130            } else {
131                let mut result = PathBuf::from_iter(std::iter::repeat_n("..", i));
132                result.push(suffix);
133                return result;
134            }
135        }
136    }
137
138    // No common prefix found. Return the original (absolute) path.
139    to.to_owned()
140}
141
142/// Consumes as much `..` and `.` as possible without considering symlinks.
143pub fn normalize_path(path: &Path) -> PathBuf {
144    let mut result = PathBuf::new();
145    for c in path.components() {
146        match c {
147            Component::CurDir => {}
148            Component::ParentDir
149                if matches!(result.components().next_back(), Some(Component::Normal(_))) =>
150            {
151                // Do not pop ".."
152                let popped = result.pop();
153                assert!(popped);
154            }
155            _ => {
156                result.push(c);
157            }
158        }
159    }
160
161    if result.as_os_str().is_empty() {
162        ".".into()
163    } else {
164        result
165    }
166}
167
168/// Converts the given `path` to Unix-like path separated by "/".
169///
170/// The returned path might not work on Windows if it was canonicalized. On
171/// Unix, this function is noop.
172pub fn slash_path(path: &Path) -> Cow<'_, Path> {
173    if cfg!(windows) {
174        Cow::Owned(to_slash_separated(path).into())
175    } else {
176        Cow::Borrowed(path)
177    }
178}
179
180fn to_slash_separated(path: &Path) -> OsString {
181    let mut buf = OsString::with_capacity(path.as_os_str().len());
182    let mut components = path.components();
183    match components.next() {
184        Some(c) => buf.push(c),
185        None => return buf,
186    }
187    for c in components {
188        buf.push("/");
189        buf.push(c);
190    }
191    buf
192}
193
194/// Like `NamedTempFile::persist()`, but doesn't try to overwrite the existing
195/// target on Windows.
196pub fn persist_content_addressed_temp_file<P: AsRef<Path>>(
197    temp_file: NamedTempFile,
198    new_path: P,
199) -> io::Result<File> {
200    if cfg!(windows) {
201        // On Windows, overwriting file can fail if the file is opened without
202        // FILE_SHARE_DELETE for example. We don't need to take a risk if the
203        // file already exists.
204        match temp_file.persist_noclobber(&new_path) {
205            Ok(file) => Ok(file),
206            Err(PersistError { error, file: _ }) => {
207                if let Ok(existing_file) = File::open(new_path) {
208                    // TODO: Update mtime to help GC keep this file
209                    Ok(existing_file)
210                } else {
211                    Err(error)
212                }
213            }
214        }
215    } else {
216        // On Unix, rename() is atomic and should succeed even if the
217        // destination file exists. Checking if the target exists might involve
218        // non-atomic operation, so don't use persist_noclobber().
219        temp_file
220            .persist(new_path)
221            .map_err(|PersistError { error, file: _ }| error)
222    }
223}
224
225/// Reads from an async source and writes to a sync destination. Does not spawn
226/// a task, so writes will block.
227pub async fn copy_async_to_sync<R: AsyncRead, W: Write + ?Sized>(
228    reader: R,
229    writer: &mut W,
230) -> io::Result<usize> {
231    let mut buf = vec![0; 16 << 10];
232    let mut total_written_bytes = 0;
233
234    let mut reader = std::pin::pin!(reader);
235    loop {
236        let written_bytes = reader.read(&mut buf).await?;
237        if written_bytes == 0 {
238            return Ok(total_written_bytes);
239        }
240        writer.write_all(&buf[0..written_bytes])?;
241        total_written_bytes += written_bytes;
242    }
243}
244
245/// `AsyncRead`` implementation backed by a `Read`. It is not actually async;
246/// the goal is simply to avoid reading the full contents from the `Read` into
247/// memory.
248pub struct BlockingAsyncReader<R> {
249    reader: R,
250}
251
252impl<R: Read + Unpin> BlockingAsyncReader<R> {
253    /// Creates a new `BlockingAsyncReader`
254    pub fn new(reader: R) -> Self {
255        Self { reader }
256    }
257}
258
259impl<R: Read + Unpin> AsyncRead for BlockingAsyncReader<R> {
260    fn poll_read(
261        mut self: Pin<&mut Self>,
262        _cx: &mut std::task::Context<'_>,
263        buf: &mut ReadBuf<'_>,
264    ) -> Poll<io::Result<()>> {
265        let num_bytes_read = self.reader.read(buf.initialize_unfilled())?;
266        buf.advance(num_bytes_read);
267        Poll::Ready(Ok(()))
268    }
269}
270
271#[cfg(unix)]
272mod platform {
273    use std::convert::Infallible;
274    use std::ffi::OsStr;
275    use std::io;
276    use std::os::unix::ffi::OsStrExt as _;
277    use std::os::unix::fs::symlink;
278    use std::path::Path;
279
280    pub type BadOsStrEncoding = Infallible;
281
282    pub fn os_str_from_bytes(data: &[u8]) -> Result<&OsStr, BadOsStrEncoding> {
283        Ok(OsStr::from_bytes(data))
284    }
285
286    pub fn os_str_to_bytes(data: &OsStr) -> Result<&[u8], BadOsStrEncoding> {
287        Ok(data.as_bytes())
288    }
289
290    /// Symlinks are always available on UNIX
291    pub fn check_symlink_support() -> io::Result<bool> {
292        Ok(true)
293    }
294
295    pub fn try_symlink<P: AsRef<Path>, Q: AsRef<Path>>(original: P, link: Q) -> io::Result<()> {
296        symlink(original, link)
297    }
298}
299
300#[cfg(windows)]
301mod platform {
302    use std::io;
303    use std::os::windows::fs::symlink_file;
304    use std::path::Path;
305
306    use winreg::enums::HKEY_LOCAL_MACHINE;
307    use winreg::RegKey;
308
309    pub use super::fallback::os_str_from_bytes;
310    pub use super::fallback::os_str_to_bytes;
311    pub use super::fallback::BadOsStrEncoding;
312
313    /// Symlinks may or may not be enabled on Windows. They require the
314    /// Developer Mode setting, which is stored in the registry key below.
315    pub fn check_symlink_support() -> io::Result<bool> {
316        let hklm = RegKey::predef(HKEY_LOCAL_MACHINE);
317        let sideloading =
318            hklm.open_subkey("SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\AppModelUnlock")?;
319        let developer_mode: u32 = sideloading.get_value("AllowDevelopmentWithoutDevLicense")?;
320        Ok(developer_mode == 1)
321    }
322
323    pub fn try_symlink<P: AsRef<Path>, Q: AsRef<Path>>(original: P, link: Q) -> io::Result<()> {
324        // this will create a nonfunctional link for directories, but at the moment
325        // we don't have enough information in the tree to determine whether the
326        // symlink target is a file or a directory
327        // note: if developer mode is not enabled the error code will be 1314,
328        // ERROR_PRIVILEGE_NOT_HELD
329
330        symlink_file(original, link)
331    }
332}
333
334#[cfg_attr(unix, allow(dead_code))]
335mod fallback {
336    use std::ffi::OsStr;
337    use std::str;
338
339    use thiserror::Error;
340
341    // Define error per platform so we can explicitly say UTF-8 is expected.
342    #[derive(Debug, Error)]
343    #[error("Invalid UTF-8 sequence")]
344    pub struct BadOsStrEncoding;
345
346    pub fn os_str_from_bytes(data: &[u8]) -> Result<&OsStr, BadOsStrEncoding> {
347        Ok(str::from_utf8(data).map_err(|_| BadOsStrEncoding)?.as_ref())
348    }
349
350    pub fn os_str_to_bytes(data: &OsStr) -> Result<&[u8], BadOsStrEncoding> {
351        Ok(data.to_str().ok_or(BadOsStrEncoding)?.as_ref())
352    }
353}
354
355#[cfg(test)]
356mod tests {
357    use std::io::Cursor;
358    use std::io::Write as _;
359
360    use itertools::Itertools as _;
361    use pollster::FutureExt as _;
362    use test_case::test_case;
363
364    use super::*;
365    use crate::tests::new_temp_dir;
366
367    #[test]
368    fn test_path_bytes_roundtrip() {
369        let bytes = b"ascii";
370        let path = path_from_bytes(bytes).unwrap();
371        assert_eq!(path_to_bytes(path).unwrap(), bytes);
372
373        let bytes = b"utf-8.\xc3\xa0";
374        let path = path_from_bytes(bytes).unwrap();
375        assert_eq!(path_to_bytes(path).unwrap(), bytes);
376
377        let bytes = b"latin1.\xe0";
378        if cfg!(unix) {
379            let path = path_from_bytes(bytes).unwrap();
380            assert_eq!(path_to_bytes(path).unwrap(), bytes);
381        } else {
382            assert!(path_from_bytes(bytes).is_err());
383        }
384    }
385
386    #[test]
387    fn normalize_too_many_dot_dot() {
388        assert_eq!(normalize_path(Path::new("foo/..")), Path::new("."));
389        assert_eq!(normalize_path(Path::new("foo/../..")), Path::new(".."));
390        assert_eq!(
391            normalize_path(Path::new("foo/../../..")),
392            Path::new("../..")
393        );
394        assert_eq!(
395            normalize_path(Path::new("foo/../../../bar/baz/..")),
396            Path::new("../../bar")
397        );
398    }
399
400    #[test]
401    fn test_slash_path() {
402        assert_eq!(slash_path(Path::new("")), Path::new(""));
403        assert_eq!(slash_path(Path::new("foo")), Path::new("foo"));
404        assert_eq!(slash_path(Path::new("foo/bar")), Path::new("foo/bar"));
405        assert_eq!(slash_path(Path::new("foo/bar/..")), Path::new("foo/bar/.."));
406        assert_eq!(
407            slash_path(Path::new(r"foo\bar")),
408            if cfg!(windows) {
409                Path::new("foo/bar")
410            } else {
411                Path::new(r"foo\bar")
412            }
413        );
414        assert_eq!(
415            slash_path(Path::new(r"..\foo\bar")),
416            if cfg!(windows) {
417                Path::new("../foo/bar")
418            } else {
419                Path::new(r"..\foo\bar")
420            }
421        );
422    }
423
424    #[test]
425    fn test_persist_no_existing_file() {
426        let temp_dir = new_temp_dir();
427        let target = temp_dir.path().join("file");
428        let mut temp_file = NamedTempFile::new_in(&temp_dir).unwrap();
429        temp_file.write_all(b"contents").unwrap();
430        assert!(persist_content_addressed_temp_file(temp_file, target).is_ok());
431    }
432
433    #[test_case(false ; "existing file open")]
434    #[test_case(true ; "existing file closed")]
435    fn test_persist_target_exists(existing_file_closed: bool) {
436        let temp_dir = new_temp_dir();
437        let target = temp_dir.path().join("file");
438        let mut temp_file = NamedTempFile::new_in(&temp_dir).unwrap();
439        temp_file.write_all(b"contents").unwrap();
440
441        let mut file = File::create(&target).unwrap();
442        file.write_all(b"contents").unwrap();
443        if existing_file_closed {
444            drop(file);
445        }
446
447        assert!(persist_content_addressed_temp_file(temp_file, &target).is_ok());
448    }
449
450    #[test]
451    fn test_copy_async_to_sync_small() {
452        let input = b"hello";
453        let mut output = vec![];
454
455        let result = copy_async_to_sync(Cursor::new(&input), &mut output).block_on();
456        assert!(result.is_ok());
457        assert_eq!(result.unwrap(), 5);
458        assert_eq!(output, input);
459    }
460
461    #[test]
462    fn test_copy_async_to_sync_large() {
463        // More than 1 buffer worth of data
464        let input = (0..100u8).cycle().take(40000).collect_vec();
465        let mut output = vec![];
466
467        let result = copy_async_to_sync(Cursor::new(&input), &mut output).block_on();
468        assert!(result.is_ok());
469        assert_eq!(result.unwrap(), 40000);
470        assert_eq!(output, input);
471    }
472
473    #[test]
474    fn test_blocking_async_reader() {
475        let input = b"hello";
476        let sync_reader = Cursor::new(&input);
477        let mut async_reader = BlockingAsyncReader::new(sync_reader);
478
479        let mut buf = [0u8; 3];
480        let num_bytes_read = async_reader.read(&mut buf).block_on().unwrap();
481        assert_eq!(num_bytes_read, 3);
482        assert_eq!(&buf, &input[0..3]);
483
484        let num_bytes_read = async_reader.read(&mut buf).block_on().unwrap();
485        assert_eq!(num_bytes_read, 2);
486        assert_eq!(&buf[0..2], &input[3..5]);
487    }
488
489    #[test]
490    fn test_blocking_async_reader_read_to_end() {
491        let input = b"hello";
492        let sync_reader = Cursor::new(&input);
493        let mut async_reader = BlockingAsyncReader::new(sync_reader);
494
495        let mut buf = vec![];
496        let num_bytes_read = async_reader.read_to_end(&mut buf).block_on().unwrap();
497        assert_eq!(num_bytes_read, input.len());
498        assert_eq!(&buf, &input);
499    }
500}