jj_lib/
file_util.rs

1// Copyright 2021 The Jujutsu Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#![allow(missing_docs)]
16
17use std::borrow::Cow;
18use std::ffi::OsString;
19use std::fs;
20use std::fs::File;
21use std::io;
22use std::io::Read;
23use std::io::Write;
24use std::path::Component;
25use std::path::Path;
26use std::path::PathBuf;
27use std::pin::Pin;
28use std::task::Poll;
29
30use tempfile::NamedTempFile;
31use tempfile::PersistError;
32use thiserror::Error;
33use tokio::io::AsyncRead;
34use tokio::io::AsyncReadExt as _;
35use tokio::io::ReadBuf;
36
37pub use self::platform::check_symlink_support;
38pub use self::platform::try_symlink;
39
40#[derive(Debug, Error)]
41#[error("Cannot access {path}")]
42pub struct PathError {
43    pub path: PathBuf,
44    pub source: io::Error,
45}
46
47pub trait IoResultExt<T> {
48    fn context(self, path: impl AsRef<Path>) -> Result<T, PathError>;
49}
50
51impl<T> IoResultExt<T> for io::Result<T> {
52    fn context(self, path: impl AsRef<Path>) -> Result<T, PathError> {
53        self.map_err(|error| PathError {
54            path: path.as_ref().to_path_buf(),
55            source: error,
56        })
57    }
58}
59
60/// Creates a directory or does nothing if the directory already exists.
61///
62/// Returns the underlying error if the directory can't be created.
63/// The function will also fail if intermediate directories on the path do not
64/// already exist.
65pub fn create_or_reuse_dir(dirname: &Path) -> io::Result<()> {
66    match fs::create_dir(dirname) {
67        Ok(()) => Ok(()),
68        Err(_) if dirname.is_dir() => Ok(()),
69        Err(e) => Err(e),
70    }
71}
72
73/// Removes all files in the directory, but not the directory itself.
74///
75/// The directory must exist, and there should be no sub directories.
76pub fn remove_dir_contents(dirname: &Path) -> Result<(), PathError> {
77    for entry in dirname.read_dir().context(dirname)? {
78        let entry = entry.context(dirname)?;
79        let path = entry.path();
80        fs::remove_file(&path).context(&path)?;
81    }
82    Ok(())
83}
84
85#[derive(Debug, Error)]
86#[error(transparent)]
87pub struct BadPathEncoding(platform::BadOsStrEncoding);
88
89/// Constructs [`Path`] from `bytes` in platform-specific manner.
90///
91/// On Unix, this function never fails because paths are just bytes. On Windows,
92/// this may return error if the input wasn't well-formed UTF-8.
93pub fn path_from_bytes(bytes: &[u8]) -> Result<&Path, BadPathEncoding> {
94    let s = platform::os_str_from_bytes(bytes).map_err(BadPathEncoding)?;
95    Ok(Path::new(s))
96}
97
98/// Converts `path` to bytes in platform-specific manner.
99///
100/// On Unix, this function never fails because paths are just bytes. On Windows,
101/// this may return error if the input wasn't well-formed UTF-8.
102///
103/// The returned byte sequence can be considered a superset of ASCII (such as
104/// UTF-8 bytes.)
105pub fn path_to_bytes(path: &Path) -> Result<&[u8], BadPathEncoding> {
106    platform::os_str_to_bytes(path.as_ref()).map_err(BadPathEncoding)
107}
108
109/// Expands "~/" to "$HOME/".
110pub fn expand_home_path(path_str: &str) -> PathBuf {
111    if let Some(remainder) = path_str.strip_prefix("~/") {
112        if let Ok(home_dir_str) = std::env::var("HOME") {
113            return PathBuf::from(home_dir_str).join(remainder);
114        }
115    }
116    PathBuf::from(path_str)
117}
118
119/// Turns the given `to` path into relative path starting from the `from` path.
120///
121/// Both `from` and `to` paths are supposed to be absolute and normalized in the
122/// same manner.
123pub fn relative_path(from: &Path, to: &Path) -> PathBuf {
124    // Find common prefix.
125    for (i, base) in from.ancestors().enumerate() {
126        if let Ok(suffix) = to.strip_prefix(base) {
127            if i == 0 && suffix.as_os_str().is_empty() {
128                return ".".into();
129            } else {
130                let mut result = PathBuf::from_iter(std::iter::repeat_n("..", i));
131                result.push(suffix);
132                return result;
133            }
134        }
135    }
136
137    // No common prefix found. Return the original (absolute) path.
138    to.to_owned()
139}
140
141/// Consumes as much `..` and `.` as possible without considering symlinks.
142pub fn normalize_path(path: &Path) -> PathBuf {
143    let mut result = PathBuf::new();
144    for c in path.components() {
145        match c {
146            Component::CurDir => {}
147            Component::ParentDir
148                if matches!(result.components().next_back(), Some(Component::Normal(_))) =>
149            {
150                // Do not pop ".."
151                let popped = result.pop();
152                assert!(popped);
153            }
154            _ => {
155                result.push(c);
156            }
157        }
158    }
159
160    if result.as_os_str().is_empty() {
161        ".".into()
162    } else {
163        result
164    }
165}
166
167/// Converts the given `path` to Unix-like path separated by "/".
168///
169/// The returned path might not work on Windows if it was canonicalized. On
170/// Unix, this function is noop.
171pub fn slash_path(path: &Path) -> Cow<'_, Path> {
172    if cfg!(windows) {
173        Cow::Owned(to_slash_separated(path).into())
174    } else {
175        Cow::Borrowed(path)
176    }
177}
178
179fn to_slash_separated(path: &Path) -> OsString {
180    let mut buf = OsString::with_capacity(path.as_os_str().len());
181    let mut components = path.components();
182    match components.next() {
183        Some(c) => buf.push(c),
184        None => return buf,
185    }
186    for c in components {
187        buf.push("/");
188        buf.push(c);
189    }
190    buf
191}
192
193/// Persists the temporary file after synchronizing the content.
194///
195/// After system crash, the persisted file should have a valid content if
196/// existed. However, the persisted file name (or directory entry) could be
197/// lost. It's up to caller to synchronize the directory entries.
198///
199/// See also <https://lwn.net/Articles/457667/> for the behavior on Linux.
200pub fn persist_temp_file<P: AsRef<Path>>(
201    temp_file: NamedTempFile,
202    new_path: P,
203) -> io::Result<File> {
204    // Ensure persisted file content is flushed to disk.
205    temp_file.as_file().sync_data()?;
206    temp_file
207        .persist(new_path)
208        .map_err(|PersistError { error, file: _ }| error)
209}
210
211/// Like [`persist_temp_file()`], but doesn't try to overwrite the existing
212/// target on Windows.
213pub fn persist_content_addressed_temp_file<P: AsRef<Path>>(
214    temp_file: NamedTempFile,
215    new_path: P,
216) -> io::Result<File> {
217    // Ensure new file content is flushed to disk, so the old file content
218    // wouldn't be lost if existed at the same location.
219    temp_file.as_file().sync_data()?;
220    if cfg!(windows) {
221        // On Windows, overwriting file can fail if the file is opened without
222        // FILE_SHARE_DELETE for example. We don't need to take a risk if the
223        // file already exists.
224        match temp_file.persist_noclobber(&new_path) {
225            Ok(file) => Ok(file),
226            Err(PersistError { error, file: _ }) => {
227                if let Ok(existing_file) = File::open(new_path) {
228                    // TODO: Update mtime to help GC keep this file
229                    Ok(existing_file)
230                } else {
231                    Err(error)
232                }
233            }
234        }
235    } else {
236        // On Unix, rename() is atomic and should succeed even if the
237        // destination file exists. Checking if the target exists might involve
238        // non-atomic operation, so don't use persist_noclobber().
239        temp_file
240            .persist(new_path)
241            .map_err(|PersistError { error, file: _ }| error)
242    }
243}
244
245/// Reads from an async source and writes to a sync destination. Does not spawn
246/// a task, so writes will block.
247pub async fn copy_async_to_sync<R: AsyncRead, W: Write + ?Sized>(
248    reader: R,
249    writer: &mut W,
250) -> io::Result<usize> {
251    let mut buf = vec![0; 16 << 10];
252    let mut total_written_bytes = 0;
253
254    let mut reader = std::pin::pin!(reader);
255    loop {
256        let written_bytes = reader.read(&mut buf).await?;
257        if written_bytes == 0 {
258            return Ok(total_written_bytes);
259        }
260        writer.write_all(&buf[0..written_bytes])?;
261        total_written_bytes += written_bytes;
262    }
263}
264
265/// `AsyncRead`` implementation backed by a `Read`. It is not actually async;
266/// the goal is simply to avoid reading the full contents from the `Read` into
267/// memory.
268pub struct BlockingAsyncReader<R> {
269    reader: R,
270}
271
272impl<R: Read + Unpin> BlockingAsyncReader<R> {
273    /// Creates a new `BlockingAsyncReader`
274    pub fn new(reader: R) -> Self {
275        Self { reader }
276    }
277}
278
279impl<R: Read + Unpin> AsyncRead for BlockingAsyncReader<R> {
280    fn poll_read(
281        mut self: Pin<&mut Self>,
282        _cx: &mut std::task::Context<'_>,
283        buf: &mut ReadBuf<'_>,
284    ) -> Poll<io::Result<()>> {
285        let num_bytes_read = self.reader.read(buf.initialize_unfilled())?;
286        buf.advance(num_bytes_read);
287        Poll::Ready(Ok(()))
288    }
289}
290
291#[cfg(unix)]
292mod platform {
293    use std::convert::Infallible;
294    use std::ffi::OsStr;
295    use std::io;
296    use std::os::unix::ffi::OsStrExt as _;
297    use std::os::unix::fs::symlink;
298    use std::path::Path;
299
300    pub type BadOsStrEncoding = Infallible;
301
302    pub fn os_str_from_bytes(data: &[u8]) -> Result<&OsStr, BadOsStrEncoding> {
303        Ok(OsStr::from_bytes(data))
304    }
305
306    pub fn os_str_to_bytes(data: &OsStr) -> Result<&[u8], BadOsStrEncoding> {
307        Ok(data.as_bytes())
308    }
309
310    /// Symlinks are always available on UNIX
311    pub fn check_symlink_support() -> io::Result<bool> {
312        Ok(true)
313    }
314
315    pub fn try_symlink<P: AsRef<Path>, Q: AsRef<Path>>(original: P, link: Q) -> io::Result<()> {
316        symlink(original, link)
317    }
318}
319
320#[cfg(windows)]
321mod platform {
322    use std::io;
323    use std::os::windows::fs::symlink_file;
324    use std::path::Path;
325
326    use winreg::RegKey;
327    use winreg::enums::HKEY_LOCAL_MACHINE;
328
329    pub use super::fallback::BadOsStrEncoding;
330    pub use super::fallback::os_str_from_bytes;
331    pub use super::fallback::os_str_to_bytes;
332
333    /// Symlinks may or may not be enabled on Windows. They require the
334    /// Developer Mode setting, which is stored in the registry key below.
335    pub fn check_symlink_support() -> io::Result<bool> {
336        let hklm = RegKey::predef(HKEY_LOCAL_MACHINE);
337        let sideloading =
338            hklm.open_subkey("SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\AppModelUnlock")?;
339        let developer_mode: u32 = sideloading.get_value("AllowDevelopmentWithoutDevLicense")?;
340        Ok(developer_mode == 1)
341    }
342
343    pub fn try_symlink<P: AsRef<Path>, Q: AsRef<Path>>(original: P, link: Q) -> io::Result<()> {
344        // this will create a nonfunctional link for directories, but at the moment
345        // we don't have enough information in the tree to determine whether the
346        // symlink target is a file or a directory
347        // note: if developer mode is not enabled the error code will be 1314,
348        // ERROR_PRIVILEGE_NOT_HELD
349
350        symlink_file(original, link)
351    }
352}
353
354#[cfg_attr(unix, allow(dead_code))]
355mod fallback {
356    use std::ffi::OsStr;
357    use std::str;
358
359    use thiserror::Error;
360
361    // Define error per platform so we can explicitly say UTF-8 is expected.
362    #[derive(Debug, Error)]
363    #[error("Invalid UTF-8 sequence")]
364    pub struct BadOsStrEncoding;
365
366    pub fn os_str_from_bytes(data: &[u8]) -> Result<&OsStr, BadOsStrEncoding> {
367        Ok(str::from_utf8(data).map_err(|_| BadOsStrEncoding)?.as_ref())
368    }
369
370    pub fn os_str_to_bytes(data: &OsStr) -> Result<&[u8], BadOsStrEncoding> {
371        Ok(data.to_str().ok_or(BadOsStrEncoding)?.as_ref())
372    }
373}
374
375#[cfg(test)]
376mod tests {
377    use std::io::Cursor;
378    use std::io::Write as _;
379
380    use itertools::Itertools as _;
381    use pollster::FutureExt as _;
382    use test_case::test_case;
383
384    use super::*;
385    use crate::tests::new_temp_dir;
386
387    #[test]
388    fn test_path_bytes_roundtrip() {
389        let bytes = b"ascii";
390        let path = path_from_bytes(bytes).unwrap();
391        assert_eq!(path_to_bytes(path).unwrap(), bytes);
392
393        let bytes = b"utf-8.\xc3\xa0";
394        let path = path_from_bytes(bytes).unwrap();
395        assert_eq!(path_to_bytes(path).unwrap(), bytes);
396
397        let bytes = b"latin1.\xe0";
398        if cfg!(unix) {
399            let path = path_from_bytes(bytes).unwrap();
400            assert_eq!(path_to_bytes(path).unwrap(), bytes);
401        } else {
402            assert!(path_from_bytes(bytes).is_err());
403        }
404    }
405
406    #[test]
407    fn normalize_too_many_dot_dot() {
408        assert_eq!(normalize_path(Path::new("foo/..")), Path::new("."));
409        assert_eq!(normalize_path(Path::new("foo/../..")), Path::new(".."));
410        assert_eq!(
411            normalize_path(Path::new("foo/../../..")),
412            Path::new("../..")
413        );
414        assert_eq!(
415            normalize_path(Path::new("foo/../../../bar/baz/..")),
416            Path::new("../../bar")
417        );
418    }
419
420    #[test]
421    fn test_slash_path() {
422        assert_eq!(slash_path(Path::new("")), Path::new(""));
423        assert_eq!(slash_path(Path::new("foo")), Path::new("foo"));
424        assert_eq!(slash_path(Path::new("foo/bar")), Path::new("foo/bar"));
425        assert_eq!(slash_path(Path::new("foo/bar/..")), Path::new("foo/bar/.."));
426        assert_eq!(
427            slash_path(Path::new(r"foo\bar")),
428            if cfg!(windows) {
429                Path::new("foo/bar")
430            } else {
431                Path::new(r"foo\bar")
432            }
433        );
434        assert_eq!(
435            slash_path(Path::new(r"..\foo\bar")),
436            if cfg!(windows) {
437                Path::new("../foo/bar")
438            } else {
439                Path::new(r"..\foo\bar")
440            }
441        );
442    }
443
444    #[test]
445    fn test_persist_no_existing_file() {
446        let temp_dir = new_temp_dir();
447        let target = temp_dir.path().join("file");
448        let mut temp_file = NamedTempFile::new_in(&temp_dir).unwrap();
449        temp_file.write_all(b"contents").unwrap();
450        assert!(persist_content_addressed_temp_file(temp_file, target).is_ok());
451    }
452
453    #[test_case(false ; "existing file open")]
454    #[test_case(true ; "existing file closed")]
455    fn test_persist_target_exists(existing_file_closed: bool) {
456        let temp_dir = new_temp_dir();
457        let target = temp_dir.path().join("file");
458        let mut temp_file = NamedTempFile::new_in(&temp_dir).unwrap();
459        temp_file.write_all(b"contents").unwrap();
460
461        let mut file = File::create(&target).unwrap();
462        file.write_all(b"contents").unwrap();
463        if existing_file_closed {
464            drop(file);
465        }
466
467        assert!(persist_content_addressed_temp_file(temp_file, &target).is_ok());
468    }
469
470    #[test]
471    fn test_copy_async_to_sync_small() {
472        let input = b"hello";
473        let mut output = vec![];
474
475        let result = copy_async_to_sync(Cursor::new(&input), &mut output).block_on();
476        assert!(result.is_ok());
477        assert_eq!(result.unwrap(), 5);
478        assert_eq!(output, input);
479    }
480
481    #[test]
482    fn test_copy_async_to_sync_large() {
483        // More than 1 buffer worth of data
484        let input = (0..100u8).cycle().take(40000).collect_vec();
485        let mut output = vec![];
486
487        let result = copy_async_to_sync(Cursor::new(&input), &mut output).block_on();
488        assert!(result.is_ok());
489        assert_eq!(result.unwrap(), 40000);
490        assert_eq!(output, input);
491    }
492
493    #[test]
494    fn test_blocking_async_reader() {
495        let input = b"hello";
496        let sync_reader = Cursor::new(&input);
497        let mut async_reader = BlockingAsyncReader::new(sync_reader);
498
499        let mut buf = [0u8; 3];
500        let num_bytes_read = async_reader.read(&mut buf).block_on().unwrap();
501        assert_eq!(num_bytes_read, 3);
502        assert_eq!(&buf, &input[0..3]);
503
504        let num_bytes_read = async_reader.read(&mut buf).block_on().unwrap();
505        assert_eq!(num_bytes_read, 2);
506        assert_eq!(&buf[0..2], &input[3..5]);
507    }
508
509    #[test]
510    fn test_blocking_async_reader_read_to_end() {
511        let input = b"hello";
512        let sync_reader = Cursor::new(&input);
513        let mut async_reader = BlockingAsyncReader::new(sync_reader);
514
515        let mut buf = vec![];
516        let num_bytes_read = async_reader.read_to_end(&mut buf).block_on().unwrap();
517        assert_eq!(num_bytes_read, input.len());
518        assert_eq!(&buf, &input);
519    }
520}