#![expect(missing_docs)]
use std::borrow::Cow;
use std::ffi::OsString;
use std::fs;
use std::fs::File;
use std::io;
use std::io::ErrorKind;
use std::io::Read;
use std::io::Write;
use std::path::Component;
use std::path::Path;
use std::path::PathBuf;
use std::pin::Pin;
use std::task::Poll;
use tempfile::NamedTempFile;
use tempfile::PersistError;
use thiserror::Error;
use tokio::io::AsyncRead;
use tokio::io::AsyncReadExt as _;
use tokio::io::ReadBuf;
#[cfg(unix)]
pub use self::platform::check_executable_bit_support;
pub use self::platform::check_symlink_support;
pub use self::platform::symlink_dir;
pub use self::platform::symlink_file;
#[derive(Debug, Error)]
#[error("Cannot access {path}")]
pub struct PathError {
pub path: PathBuf,
pub source: io::Error,
}
pub trait IoResultExt<T> {
fn context(self, path: impl AsRef<Path>) -> Result<T, PathError>;
}
impl<T> IoResultExt<T> for io::Result<T> {
fn context(self, path: impl AsRef<Path>) -> Result<T, PathError> {
self.map_err(|error| PathError {
path: path.as_ref().to_path_buf(),
source: error,
})
}
}
pub fn create_or_reuse_dir(dirname: &Path) -> io::Result<()> {
match fs::create_dir(dirname) {
Ok(()) => Ok(()),
Err(_) if dirname.is_dir() => Ok(()),
Err(e) => Err(e),
}
}
pub fn remove_dir_contents(dirname: &Path) -> Result<(), PathError> {
for entry in dirname.read_dir().context(dirname)? {
let entry = entry.context(dirname)?;
let path = entry.path();
fs::remove_file(&path).context(&path)?;
}
Ok(())
}
pub fn is_empty_dir(path: &Path) -> Result<bool, PathError> {
match path.read_dir() {
Ok(mut entries) => Ok(entries.next().is_none()),
Err(error) => match error.kind() {
ErrorKind::NotADirectory => Ok(false),
ErrorKind::NotFound => Ok(false),
_ => Err(error).context(path)?,
},
}
}
#[derive(Debug, Error)]
#[error(transparent)]
pub struct BadPathEncoding(platform::BadOsStrEncoding);
pub fn path_from_bytes(bytes: &[u8]) -> Result<&Path, BadPathEncoding> {
let s = platform::os_str_from_bytes(bytes).map_err(BadPathEncoding)?;
Ok(Path::new(s))
}
pub fn path_to_bytes(path: &Path) -> Result<&[u8], BadPathEncoding> {
platform::os_str_to_bytes(path.as_ref()).map_err(BadPathEncoding)
}
pub fn expand_home_path(path_str: &str) -> PathBuf {
if let Some(remainder) = path_str.strip_prefix("~/")
&& let Ok(home_dir) = etcetera::home_dir()
{
return home_dir.join(remainder);
}
PathBuf::from(path_str)
}
pub fn relative_path(from: &Path, to: &Path) -> PathBuf {
for (i, base) in from.ancestors().enumerate() {
if let Ok(suffix) = to.strip_prefix(base) {
if i == 0 && suffix.as_os_str().is_empty() {
return ".".into();
} else {
return std::iter::repeat_n(Path::new(".."), i)
.chain(std::iter::once(suffix))
.collect();
}
}
}
to.to_owned()
}
pub fn normalize_path(path: &Path) -> PathBuf {
let mut result = PathBuf::new();
for c in path.components() {
match c {
Component::CurDir => {}
Component::ParentDir
if matches!(result.components().next_back(), Some(Component::Normal(_))) =>
{
let popped = result.pop();
assert!(popped);
}
_ => {
result.push(c);
}
}
}
if result.as_os_str().is_empty() {
".".into()
} else {
result
}
}
pub fn slash_path(path: &Path) -> Cow<'_, Path> {
if cfg!(windows) {
Cow::Owned(to_slash_separated(path).into())
} else {
Cow::Borrowed(path)
}
}
fn to_slash_separated(path: &Path) -> OsString {
let mut buf = OsString::with_capacity(path.as_os_str().len());
let mut components = path.components();
match components.next() {
Some(c) => buf.push(c),
None => return buf,
}
for c in components {
buf.push("/");
buf.push(c);
}
buf
}
pub fn persist_temp_file<P: AsRef<Path>>(
temp_file: NamedTempFile,
new_path: P,
) -> io::Result<File> {
temp_file.as_file().sync_data()?;
temp_file
.persist(new_path)
.map_err(|PersistError { error, file: _ }| error)
}
pub fn persist_content_addressed_temp_file<P: AsRef<Path>>(
temp_file: NamedTempFile,
new_path: P,
) -> io::Result<File> {
temp_file.as_file().sync_data()?;
if cfg!(windows) {
match temp_file.persist_noclobber(&new_path) {
Ok(file) => Ok(file),
Err(PersistError { error, file: _ }) => {
if let Ok(existing_file) = File::open(new_path) {
Ok(existing_file)
} else {
Err(error)
}
}
}
} else {
temp_file
.persist(new_path)
.map_err(|PersistError { error, file: _ }| error)
}
}
#[derive(Debug, Eq, Hash, PartialEq)]
pub struct FileIdentity(platform::FileIdentity);
impl FileIdentity {
pub fn from_symlink_path(path: impl AsRef<Path>) -> io::Result<Self> {
platform::file_identity_from_symlink_path(path.as_ref()).map(Self)
}
pub fn from_file(file: File) -> io::Result<Self> {
platform::file_identity_from_file(file).map(Self)
}
}
pub async fn copy_async_to_sync<R: AsyncRead, W: Write + ?Sized>(
reader: R,
writer: &mut W,
) -> io::Result<usize> {
let mut buf = vec![0; 16 << 10];
let mut total_written_bytes = 0;
let mut reader = std::pin::pin!(reader);
loop {
let written_bytes = reader.read(&mut buf).await?;
if written_bytes == 0 {
return Ok(total_written_bytes);
}
writer.write_all(&buf[0..written_bytes])?;
total_written_bytes += written_bytes;
}
}
pub struct BlockingAsyncReader<R> {
reader: R,
}
impl<R: Read + Unpin> BlockingAsyncReader<R> {
pub fn new(reader: R) -> Self {
Self { reader }
}
}
impl<R: Read + Unpin> AsyncRead for BlockingAsyncReader<R> {
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let num_bytes_read = self.reader.read(buf.initialize_unfilled())?;
buf.advance(num_bytes_read);
Poll::Ready(Ok(()))
}
}
#[cfg(unix)]
mod platform {
use std::convert::Infallible;
use std::ffi::OsStr;
use std::fs;
use std::fs::File;
use std::io;
use std::os::unix::ffi::OsStrExt as _;
use std::os::unix::fs::MetadataExt as _;
use std::os::unix::fs::PermissionsExt;
use std::os::unix::fs::symlink;
use std::path::Path;
pub type BadOsStrEncoding = Infallible;
pub fn os_str_from_bytes(data: &[u8]) -> Result<&OsStr, BadOsStrEncoding> {
Ok(OsStr::from_bytes(data))
}
pub fn os_str_to_bytes(data: &OsStr) -> Result<&[u8], BadOsStrEncoding> {
Ok(data.as_bytes())
}
pub fn check_executable_bit_support(path: impl AsRef<Path>) -> io::Result<bool> {
let temp_file = tempfile::tempfile_in(path)?;
let old_mode = temp_file.metadata()?.permissions().mode();
let new_mode = old_mode ^ 0o100;
let result = temp_file.set_permissions(PermissionsExt::from_mode(new_mode));
match result {
Err(err) if err.kind() == io::ErrorKind::PermissionDenied => Ok(false),
Err(err) => Err(err),
Ok(()) => {
let mode = temp_file.metadata()?.permissions().mode();
Ok(mode == new_mode)
}
}
}
pub fn check_symlink_support() -> io::Result<bool> {
Ok(true)
}
pub fn symlink_dir<P: AsRef<Path>, Q: AsRef<Path>>(original: P, link: Q) -> io::Result<()> {
symlink(original, link)
}
pub fn symlink_file<P: AsRef<Path>, Q: AsRef<Path>>(original: P, link: Q) -> io::Result<()> {
symlink(original, link)
}
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub struct FileIdentity {
dev: u64,
ino: u64,
}
impl FileIdentity {
fn from_metadata(metadata: fs::Metadata) -> Self {
Self {
dev: metadata.dev(),
ino: metadata.ino(),
}
}
}
pub fn file_identity_from_symlink_path(path: &Path) -> io::Result<FileIdentity> {
path.symlink_metadata().map(FileIdentity::from_metadata)
}
pub fn file_identity_from_file(file: File) -> io::Result<FileIdentity> {
file.metadata().map(FileIdentity::from_metadata)
}
}
#[cfg(windows)]
mod platform {
use std::fs::File;
use std::io;
pub use std::os::windows::fs::symlink_dir;
pub use std::os::windows::fs::symlink_file;
use std::path::Path;
use winreg::RegKey;
use winreg::enums::HKEY_LOCAL_MACHINE;
pub use super::fallback::BadOsStrEncoding;
pub use super::fallback::os_str_from_bytes;
pub use super::fallback::os_str_to_bytes;
pub fn check_symlink_support() -> io::Result<bool> {
let hklm = RegKey::predef(HKEY_LOCAL_MACHINE);
let sideloading =
hklm.open_subkey("SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\AppModelUnlock")?;
let developer_mode: u32 = sideloading.get_value("AllowDevelopmentWithoutDevLicense")?;
Ok(developer_mode == 1)
}
pub type FileIdentity = same_file::Handle;
pub fn file_identity_from_symlink_path(path: &Path) -> io::Result<FileIdentity> {
same_file::Handle::from_path(path)
}
pub fn file_identity_from_file(file: File) -> io::Result<FileIdentity> {
same_file::Handle::from_file(file)
}
}
#[cfg_attr(unix, expect(dead_code))]
mod fallback {
use std::ffi::OsStr;
use thiserror::Error;
#[derive(Debug, Error)]
#[error("Invalid UTF-8 sequence")]
pub struct BadOsStrEncoding;
pub fn os_str_from_bytes(data: &[u8]) -> Result<&OsStr, BadOsStrEncoding> {
Ok(str::from_utf8(data).map_err(|_| BadOsStrEncoding)?.as_ref())
}
pub fn os_str_to_bytes(data: &OsStr) -> Result<&[u8], BadOsStrEncoding> {
Ok(data.to_str().ok_or(BadOsStrEncoding)?.as_ref())
}
}
#[cfg(test)]
mod tests {
use std::io::Cursor;
use std::io::Write as _;
use itertools::Itertools as _;
use pollster::FutureExt as _;
use test_case::test_case;
use super::*;
use crate::tests::TestResult;
use crate::tests::new_temp_dir;
#[test]
#[cfg(unix)]
fn exec_bit_support_in_temp_dir() -> TestResult {
let dir = new_temp_dir();
let supported = check_executable_bit_support(dir.path())?;
assert!(supported);
Ok(())
}
#[test]
fn test_path_bytes_roundtrip() -> TestResult {
let bytes = b"ascii";
let path = path_from_bytes(bytes)?;
assert_eq!(path_to_bytes(path)?, bytes);
let bytes = b"utf-8.\xc3\xa0";
let path = path_from_bytes(bytes)?;
assert_eq!(path_to_bytes(path)?, bytes);
let bytes = b"latin1.\xe0";
if cfg!(unix) {
let path = path_from_bytes(bytes)?;
assert_eq!(path_to_bytes(path)?, bytes);
} else {
assert!(path_from_bytes(bytes).is_err());
}
Ok(())
}
#[test]
fn normalize_too_many_dot_dot() {
assert_eq!(normalize_path(Path::new("foo/..")), Path::new("."));
assert_eq!(normalize_path(Path::new("foo/../..")), Path::new(".."));
assert_eq!(
normalize_path(Path::new("foo/../../..")),
Path::new("../..")
);
assert_eq!(
normalize_path(Path::new("foo/../../../bar/baz/..")),
Path::new("../../bar")
);
}
#[test]
fn test_slash_path() {
assert_eq!(slash_path(Path::new("")), Path::new(""));
assert_eq!(slash_path(Path::new("foo")), Path::new("foo"));
assert_eq!(slash_path(Path::new("foo/bar")), Path::new("foo/bar"));
assert_eq!(slash_path(Path::new("foo/bar/..")), Path::new("foo/bar/.."));
assert_eq!(
slash_path(Path::new(r"foo\bar")),
if cfg!(windows) {
Path::new("foo/bar")
} else {
Path::new(r"foo\bar")
}
);
assert_eq!(
slash_path(Path::new(r"..\foo\bar")),
if cfg!(windows) {
Path::new("../foo/bar")
} else {
Path::new(r"..\foo\bar")
}
);
}
#[test]
fn test_persist_no_existing_file() -> TestResult {
let temp_dir = new_temp_dir();
let target = temp_dir.path().join("file");
let mut temp_file = NamedTempFile::new_in(&temp_dir)?;
temp_file.write_all(b"contents")?;
assert!(persist_content_addressed_temp_file(temp_file, target).is_ok());
Ok(())
}
#[test_case(false ; "existing file open")]
#[test_case(true ; "existing file closed")]
fn test_persist_target_exists(existing_file_closed: bool) -> TestResult {
let temp_dir = new_temp_dir();
let target = temp_dir.path().join("file");
let mut temp_file = NamedTempFile::new_in(&temp_dir)?;
temp_file.write_all(b"contents")?;
let mut file = File::create(&target)?;
file.write_all(b"contents")?;
if existing_file_closed {
drop(file);
}
assert!(persist_content_addressed_temp_file(temp_file, &target).is_ok());
Ok(())
}
#[test]
fn test_file_identity_hard_link() -> TestResult {
let temp_dir = new_temp_dir();
let file_path = temp_dir.path().join("file");
let other_file_path = temp_dir.path().join("other_file");
let link_path = temp_dir.path().join("link");
fs::write(&file_path, "")?;
fs::write(&other_file_path, "")?;
fs::hard_link(&file_path, &link_path)?;
assert_eq!(
FileIdentity::from_symlink_path(&file_path)?,
FileIdentity::from_symlink_path(&link_path)?
);
assert_ne!(
FileIdentity::from_symlink_path(&other_file_path)?,
FileIdentity::from_symlink_path(&link_path)?
);
assert_eq!(
FileIdentity::from_symlink_path(&file_path)?,
FileIdentity::from_file(File::open(&link_path)?)?
);
Ok(())
}
#[cfg(unix)]
#[test]
fn test_file_identity_unix_symlink_dir() -> TestResult {
let temp_dir = new_temp_dir();
let dir_path = temp_dir.path().join("dir");
let symlink_path = temp_dir.path().join("symlink");
fs::create_dir(&dir_path)?;
std::os::unix::fs::symlink("dir", &symlink_path)?;
assert_eq!(
FileIdentity::from_symlink_path(&symlink_path)?,
FileIdentity::from_symlink_path(&symlink_path)?
);
assert_ne!(
FileIdentity::from_symlink_path(&dir_path)?,
FileIdentity::from_symlink_path(&symlink_path)?
);
assert_eq!(
FileIdentity::from_symlink_path(&dir_path)?,
FileIdentity::from_file(File::open(&symlink_path)?)?
);
assert_ne!(
FileIdentity::from_symlink_path(&symlink_path)?,
FileIdentity::from_file(File::open(&symlink_path)?)?
);
Ok(())
}
#[cfg(unix)]
#[test]
fn test_file_identity_unix_symlink_loop() -> TestResult {
let temp_dir = new_temp_dir();
let lower_file_path = temp_dir.path().join("file");
let upper_file_path = temp_dir.path().join("FILE");
let lower_symlink_path = temp_dir.path().join("symlink");
let upper_symlink_path = temp_dir.path().join("SYMLINK");
fs::write(&lower_file_path, "")?;
std::os::unix::fs::symlink("symlink", &lower_symlink_path)?;
let is_icase_fs = upper_file_path.try_exists()?;
assert_eq!(
FileIdentity::from_symlink_path(&lower_symlink_path)?,
FileIdentity::from_symlink_path(&lower_symlink_path)?
);
assert_ne!(
FileIdentity::from_symlink_path(&lower_symlink_path)?,
FileIdentity::from_symlink_path(&lower_file_path)?
);
if is_icase_fs {
assert_eq!(
FileIdentity::from_symlink_path(&lower_symlink_path)?,
FileIdentity::from_symlink_path(&upper_symlink_path)?
);
} else {
assert!(FileIdentity::from_symlink_path(&upper_symlink_path).is_err());
}
Ok(())
}
#[test]
fn test_copy_async_to_sync_small() -> TestResult {
let input = b"hello";
let mut output = vec![];
let result = copy_async_to_sync(Cursor::new(&input), &mut output).block_on();
assert!(result.is_ok());
assert_eq!(result?, 5);
assert_eq!(output, input);
Ok(())
}
#[test]
fn test_copy_async_to_sync_large() -> TestResult {
let input = (0..100u8).cycle().take(40000).collect_vec();
let mut output = vec![];
let result = copy_async_to_sync(Cursor::new(&input), &mut output).block_on();
assert!(result.is_ok());
assert_eq!(result?, 40000);
assert_eq!(output, input);
Ok(())
}
#[test]
fn test_blocking_async_reader() -> TestResult {
let input = b"hello";
let sync_reader = Cursor::new(&input);
let mut async_reader = BlockingAsyncReader::new(sync_reader);
let mut buf = [0u8; 3];
let num_bytes_read = async_reader.read(&mut buf).block_on()?;
assert_eq!(num_bytes_read, 3);
assert_eq!(&buf, &input[0..3]);
let num_bytes_read = async_reader.read(&mut buf).block_on()?;
assert_eq!(num_bytes_read, 2);
assert_eq!(&buf[0..2], &input[3..5]);
Ok(())
}
#[test]
fn test_blocking_async_reader_read_to_end() -> TestResult {
let input = b"hello";
let sync_reader = Cursor::new(&input);
let mut async_reader = BlockingAsyncReader::new(sync_reader);
let mut buf = vec![];
let num_bytes_read = async_reader.read_to_end(&mut buf).block_on()?;
assert_eq!(num_bytes_read, input.len());
assert_eq!(&buf, &input);
Ok(())
}
}