1#![allow(missing_docs)]
16
17use std::borrow::Cow;
18use std::ffi::OsString;
19use std::fs;
20use std::fs::File;
21use std::io;
22use std::io::Read;
23use std::io::Write;
24use std::path::Component;
25use std::path::Path;
26use std::path::PathBuf;
27use std::pin::Pin;
28use std::task::Poll;
29
30use tempfile::NamedTempFile;
31use tempfile::PersistError;
32use thiserror::Error;
33use tokio::io::AsyncRead;
34use tokio::io::AsyncReadExt as _;
35use tokio::io::ReadBuf;
36
37pub use self::platform::check_symlink_support;
38pub use self::platform::try_symlink;
39
40#[derive(Debug, Error)]
41#[error("Cannot access {path}")]
42pub struct PathError {
43 pub path: PathBuf,
44 pub source: io::Error,
45}
46
47pub trait IoResultExt<T> {
48 fn context(self, path: impl AsRef<Path>) -> Result<T, PathError>;
49}
50
51impl<T> IoResultExt<T> for io::Result<T> {
52 fn context(self, path: impl AsRef<Path>) -> Result<T, PathError> {
53 self.map_err(|error| PathError {
54 path: path.as_ref().to_path_buf(),
55 source: error,
56 })
57 }
58}
59
60pub fn create_or_reuse_dir(dirname: &Path) -> io::Result<()> {
66 match fs::create_dir(dirname) {
67 Ok(()) => Ok(()),
68 Err(_) if dirname.is_dir() => Ok(()),
69 Err(e) => Err(e),
70 }
71}
72
73pub fn remove_dir_contents(dirname: &Path) -> Result<(), PathError> {
77 for entry in dirname.read_dir().context(dirname)? {
78 let entry = entry.context(dirname)?;
79 let path = entry.path();
80 fs::remove_file(&path).context(&path)?;
81 }
82 Ok(())
83}
84
85#[derive(Debug, Error)]
86#[error(transparent)]
87pub struct BadPathEncoding(platform::BadOsStrEncoding);
88
89pub fn path_from_bytes(bytes: &[u8]) -> Result<&Path, BadPathEncoding> {
94 let s = platform::os_str_from_bytes(bytes).map_err(BadPathEncoding)?;
95 Ok(Path::new(s))
96}
97
98pub fn path_to_bytes(path: &Path) -> Result<&[u8], BadPathEncoding> {
106 platform::os_str_to_bytes(path.as_ref()).map_err(BadPathEncoding)
107}
108
109pub fn expand_home_path(path_str: &str) -> PathBuf {
111 if let Some(remainder) = path_str.strip_prefix("~/") {
112 if let Ok(home_dir_str) = std::env::var("HOME") {
113 return PathBuf::from(home_dir_str).join(remainder);
114 }
115 }
116 PathBuf::from(path_str)
117}
118
119pub fn relative_path(from: &Path, to: &Path) -> PathBuf {
124 for (i, base) in from.ancestors().enumerate() {
126 if let Ok(suffix) = to.strip_prefix(base) {
127 if i == 0 && suffix.as_os_str().is_empty() {
128 return ".".into();
129 } else {
130 let mut result = PathBuf::from_iter(std::iter::repeat_n("..", i));
131 result.push(suffix);
132 return result;
133 }
134 }
135 }
136
137 to.to_owned()
139}
140
141pub fn normalize_path(path: &Path) -> PathBuf {
143 let mut result = PathBuf::new();
144 for c in path.components() {
145 match c {
146 Component::CurDir => {}
147 Component::ParentDir
148 if matches!(result.components().next_back(), Some(Component::Normal(_))) =>
149 {
150 let popped = result.pop();
152 assert!(popped);
153 }
154 _ => {
155 result.push(c);
156 }
157 }
158 }
159
160 if result.as_os_str().is_empty() {
161 ".".into()
162 } else {
163 result
164 }
165}
166
167pub fn slash_path(path: &Path) -> Cow<'_, Path> {
172 if cfg!(windows) {
173 Cow::Owned(to_slash_separated(path).into())
174 } else {
175 Cow::Borrowed(path)
176 }
177}
178
179fn to_slash_separated(path: &Path) -> OsString {
180 let mut buf = OsString::with_capacity(path.as_os_str().len());
181 let mut components = path.components();
182 match components.next() {
183 Some(c) => buf.push(c),
184 None => return buf,
185 }
186 for c in components {
187 buf.push("/");
188 buf.push(c);
189 }
190 buf
191}
192
193pub fn persist_temp_file<P: AsRef<Path>>(
201 temp_file: NamedTempFile,
202 new_path: P,
203) -> io::Result<File> {
204 temp_file.as_file().sync_data()?;
206 temp_file
207 .persist(new_path)
208 .map_err(|PersistError { error, file: _ }| error)
209}
210
211pub fn persist_content_addressed_temp_file<P: AsRef<Path>>(
214 temp_file: NamedTempFile,
215 new_path: P,
216) -> io::Result<File> {
217 temp_file.as_file().sync_data()?;
220 if cfg!(windows) {
221 match temp_file.persist_noclobber(&new_path) {
225 Ok(file) => Ok(file),
226 Err(PersistError { error, file: _ }) => {
227 if let Ok(existing_file) = File::open(new_path) {
228 Ok(existing_file)
230 } else {
231 Err(error)
232 }
233 }
234 }
235 } else {
236 temp_file
240 .persist(new_path)
241 .map_err(|PersistError { error, file: _ }| error)
242 }
243}
244
245pub async fn copy_async_to_sync<R: AsyncRead, W: Write + ?Sized>(
248 reader: R,
249 writer: &mut W,
250) -> io::Result<usize> {
251 let mut buf = vec![0; 16 << 10];
252 let mut total_written_bytes = 0;
253
254 let mut reader = std::pin::pin!(reader);
255 loop {
256 let written_bytes = reader.read(&mut buf).await?;
257 if written_bytes == 0 {
258 return Ok(total_written_bytes);
259 }
260 writer.write_all(&buf[0..written_bytes])?;
261 total_written_bytes += written_bytes;
262 }
263}
264
265pub struct BlockingAsyncReader<R> {
269 reader: R,
270}
271
272impl<R: Read + Unpin> BlockingAsyncReader<R> {
273 pub fn new(reader: R) -> Self {
275 Self { reader }
276 }
277}
278
279impl<R: Read + Unpin> AsyncRead for BlockingAsyncReader<R> {
280 fn poll_read(
281 mut self: Pin<&mut Self>,
282 _cx: &mut std::task::Context<'_>,
283 buf: &mut ReadBuf<'_>,
284 ) -> Poll<io::Result<()>> {
285 let num_bytes_read = self.reader.read(buf.initialize_unfilled())?;
286 buf.advance(num_bytes_read);
287 Poll::Ready(Ok(()))
288 }
289}
290
291#[cfg(unix)]
292mod platform {
293 use std::convert::Infallible;
294 use std::ffi::OsStr;
295 use std::io;
296 use std::os::unix::ffi::OsStrExt as _;
297 use std::os::unix::fs::symlink;
298 use std::path::Path;
299
300 pub type BadOsStrEncoding = Infallible;
301
302 pub fn os_str_from_bytes(data: &[u8]) -> Result<&OsStr, BadOsStrEncoding> {
303 Ok(OsStr::from_bytes(data))
304 }
305
306 pub fn os_str_to_bytes(data: &OsStr) -> Result<&[u8], BadOsStrEncoding> {
307 Ok(data.as_bytes())
308 }
309
310 pub fn check_symlink_support() -> io::Result<bool> {
312 Ok(true)
313 }
314
315 pub fn try_symlink<P: AsRef<Path>, Q: AsRef<Path>>(original: P, link: Q) -> io::Result<()> {
316 symlink(original, link)
317 }
318}
319
320#[cfg(windows)]
321mod platform {
322 use std::io;
323 use std::os::windows::fs::symlink_file;
324 use std::path::Path;
325
326 use winreg::RegKey;
327 use winreg::enums::HKEY_LOCAL_MACHINE;
328
329 pub use super::fallback::BadOsStrEncoding;
330 pub use super::fallback::os_str_from_bytes;
331 pub use super::fallback::os_str_to_bytes;
332
333 pub fn check_symlink_support() -> io::Result<bool> {
336 let hklm = RegKey::predef(HKEY_LOCAL_MACHINE);
337 let sideloading =
338 hklm.open_subkey("SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\AppModelUnlock")?;
339 let developer_mode: u32 = sideloading.get_value("AllowDevelopmentWithoutDevLicense")?;
340 Ok(developer_mode == 1)
341 }
342
343 pub fn try_symlink<P: AsRef<Path>, Q: AsRef<Path>>(original: P, link: Q) -> io::Result<()> {
344 symlink_file(original, link)
351 }
352}
353
354#[cfg_attr(unix, allow(dead_code))]
355mod fallback {
356 use std::ffi::OsStr;
357 use std::str;
358
359 use thiserror::Error;
360
361 #[derive(Debug, Error)]
363 #[error("Invalid UTF-8 sequence")]
364 pub struct BadOsStrEncoding;
365
366 pub fn os_str_from_bytes(data: &[u8]) -> Result<&OsStr, BadOsStrEncoding> {
367 Ok(str::from_utf8(data).map_err(|_| BadOsStrEncoding)?.as_ref())
368 }
369
370 pub fn os_str_to_bytes(data: &OsStr) -> Result<&[u8], BadOsStrEncoding> {
371 Ok(data.to_str().ok_or(BadOsStrEncoding)?.as_ref())
372 }
373}
374
375#[cfg(test)]
376mod tests {
377 use std::io::Cursor;
378 use std::io::Write as _;
379
380 use itertools::Itertools as _;
381 use pollster::FutureExt as _;
382 use test_case::test_case;
383
384 use super::*;
385 use crate::tests::new_temp_dir;
386
387 #[test]
388 fn test_path_bytes_roundtrip() {
389 let bytes = b"ascii";
390 let path = path_from_bytes(bytes).unwrap();
391 assert_eq!(path_to_bytes(path).unwrap(), bytes);
392
393 let bytes = b"utf-8.\xc3\xa0";
394 let path = path_from_bytes(bytes).unwrap();
395 assert_eq!(path_to_bytes(path).unwrap(), bytes);
396
397 let bytes = b"latin1.\xe0";
398 if cfg!(unix) {
399 let path = path_from_bytes(bytes).unwrap();
400 assert_eq!(path_to_bytes(path).unwrap(), bytes);
401 } else {
402 assert!(path_from_bytes(bytes).is_err());
403 }
404 }
405
406 #[test]
407 fn normalize_too_many_dot_dot() {
408 assert_eq!(normalize_path(Path::new("foo/..")), Path::new("."));
409 assert_eq!(normalize_path(Path::new("foo/../..")), Path::new(".."));
410 assert_eq!(
411 normalize_path(Path::new("foo/../../..")),
412 Path::new("../..")
413 );
414 assert_eq!(
415 normalize_path(Path::new("foo/../../../bar/baz/..")),
416 Path::new("../../bar")
417 );
418 }
419
420 #[test]
421 fn test_slash_path() {
422 assert_eq!(slash_path(Path::new("")), Path::new(""));
423 assert_eq!(slash_path(Path::new("foo")), Path::new("foo"));
424 assert_eq!(slash_path(Path::new("foo/bar")), Path::new("foo/bar"));
425 assert_eq!(slash_path(Path::new("foo/bar/..")), Path::new("foo/bar/.."));
426 assert_eq!(
427 slash_path(Path::new(r"foo\bar")),
428 if cfg!(windows) {
429 Path::new("foo/bar")
430 } else {
431 Path::new(r"foo\bar")
432 }
433 );
434 assert_eq!(
435 slash_path(Path::new(r"..\foo\bar")),
436 if cfg!(windows) {
437 Path::new("../foo/bar")
438 } else {
439 Path::new(r"..\foo\bar")
440 }
441 );
442 }
443
444 #[test]
445 fn test_persist_no_existing_file() {
446 let temp_dir = new_temp_dir();
447 let target = temp_dir.path().join("file");
448 let mut temp_file = NamedTempFile::new_in(&temp_dir).unwrap();
449 temp_file.write_all(b"contents").unwrap();
450 assert!(persist_content_addressed_temp_file(temp_file, target).is_ok());
451 }
452
453 #[test_case(false ; "existing file open")]
454 #[test_case(true ; "existing file closed")]
455 fn test_persist_target_exists(existing_file_closed: bool) {
456 let temp_dir = new_temp_dir();
457 let target = temp_dir.path().join("file");
458 let mut temp_file = NamedTempFile::new_in(&temp_dir).unwrap();
459 temp_file.write_all(b"contents").unwrap();
460
461 let mut file = File::create(&target).unwrap();
462 file.write_all(b"contents").unwrap();
463 if existing_file_closed {
464 drop(file);
465 }
466
467 assert!(persist_content_addressed_temp_file(temp_file, &target).is_ok());
468 }
469
470 #[test]
471 fn test_copy_async_to_sync_small() {
472 let input = b"hello";
473 let mut output = vec![];
474
475 let result = copy_async_to_sync(Cursor::new(&input), &mut output).block_on();
476 assert!(result.is_ok());
477 assert_eq!(result.unwrap(), 5);
478 assert_eq!(output, input);
479 }
480
481 #[test]
482 fn test_copy_async_to_sync_large() {
483 let input = (0..100u8).cycle().take(40000).collect_vec();
485 let mut output = vec![];
486
487 let result = copy_async_to_sync(Cursor::new(&input), &mut output).block_on();
488 assert!(result.is_ok());
489 assert_eq!(result.unwrap(), 40000);
490 assert_eq!(output, input);
491 }
492
493 #[test]
494 fn test_blocking_async_reader() {
495 let input = b"hello";
496 let sync_reader = Cursor::new(&input);
497 let mut async_reader = BlockingAsyncReader::new(sync_reader);
498
499 let mut buf = [0u8; 3];
500 let num_bytes_read = async_reader.read(&mut buf).block_on().unwrap();
501 assert_eq!(num_bytes_read, 3);
502 assert_eq!(&buf, &input[0..3]);
503
504 let num_bytes_read = async_reader.read(&mut buf).block_on().unwrap();
505 assert_eq!(num_bytes_read, 2);
506 assert_eq!(&buf[0..2], &input[3..5]);
507 }
508
509 #[test]
510 fn test_blocking_async_reader_read_to_end() {
511 let input = b"hello";
512 let sync_reader = Cursor::new(&input);
513 let mut async_reader = BlockingAsyncReader::new(sync_reader);
514
515 let mut buf = vec![];
516 let num_bytes_read = async_reader.read_to_end(&mut buf).block_on().unwrap();
517 assert_eq!(num_bytes_read, input.len());
518 assert_eq!(&buf, &input);
519 }
520}