1#![allow(missing_docs)]
16
17use std::borrow::Cow;
18use std::ffi::OsString;
19use std::fs;
20use std::fs::File;
21use std::io;
22use std::io::Read;
23use std::io::Write;
24use std::path::Component;
25use std::path::Path;
26use std::path::PathBuf;
27use std::pin::Pin;
28use std::task::Poll;
29
30use tempfile::NamedTempFile;
31use tempfile::PersistError;
32use thiserror::Error;
33use tokio::io::AsyncRead;
34use tokio::io::AsyncReadExt as _;
35use tokio::io::ReadBuf;
36
37pub use self::platform::check_symlink_support;
38pub use self::platform::try_symlink;
39
40#[derive(Debug, Error)]
41#[error("Cannot access {path}")]
42pub struct PathError {
43 pub path: PathBuf,
44 #[source]
45 pub error: io::Error,
46}
47
48pub trait IoResultExt<T> {
49 fn context(self, path: impl AsRef<Path>) -> Result<T, PathError>;
50}
51
52impl<T> IoResultExt<T> for io::Result<T> {
53 fn context(self, path: impl AsRef<Path>) -> Result<T, PathError> {
54 self.map_err(|error| PathError {
55 path: path.as_ref().to_path_buf(),
56 error,
57 })
58 }
59}
60
61pub fn create_or_reuse_dir(dirname: &Path) -> io::Result<()> {
67 match fs::create_dir(dirname) {
68 Ok(()) => Ok(()),
69 Err(_) if dirname.is_dir() => Ok(()),
70 Err(e) => Err(e),
71 }
72}
73
74pub fn remove_dir_contents(dirname: &Path) -> Result<(), PathError> {
78 for entry in dirname.read_dir().context(dirname)? {
79 let entry = entry.context(dirname)?;
80 let path = entry.path();
81 fs::remove_file(&path).context(&path)?;
82 }
83 Ok(())
84}
85
86#[derive(Debug, Error)]
87#[error(transparent)]
88pub struct BadPathEncoding(platform::BadOsStrEncoding);
89
90pub fn path_from_bytes(bytes: &[u8]) -> Result<&Path, BadPathEncoding> {
95 let s = platform::os_str_from_bytes(bytes).map_err(BadPathEncoding)?;
96 Ok(Path::new(s))
97}
98
99pub fn path_to_bytes(path: &Path) -> Result<&[u8], BadPathEncoding> {
107 platform::os_str_to_bytes(path.as_ref()).map_err(BadPathEncoding)
108}
109
110pub fn expand_home_path(path_str: &str) -> PathBuf {
112 if let Some(remainder) = path_str.strip_prefix("~/") {
113 if let Ok(home_dir_str) = std::env::var("HOME") {
114 return PathBuf::from(home_dir_str).join(remainder);
115 }
116 }
117 PathBuf::from(path_str)
118}
119
120pub fn relative_path(from: &Path, to: &Path) -> PathBuf {
125 for (i, base) in from.ancestors().enumerate() {
127 if let Ok(suffix) = to.strip_prefix(base) {
128 if i == 0 && suffix.as_os_str().is_empty() {
129 return ".".into();
130 } else {
131 let mut result = PathBuf::from_iter(std::iter::repeat_n("..", i));
132 result.push(suffix);
133 return result;
134 }
135 }
136 }
137
138 to.to_owned()
140}
141
142pub fn normalize_path(path: &Path) -> PathBuf {
144 let mut result = PathBuf::new();
145 for c in path.components() {
146 match c {
147 Component::CurDir => {}
148 Component::ParentDir
149 if matches!(result.components().next_back(), Some(Component::Normal(_))) =>
150 {
151 let popped = result.pop();
153 assert!(popped);
154 }
155 _ => {
156 result.push(c);
157 }
158 }
159 }
160
161 if result.as_os_str().is_empty() {
162 ".".into()
163 } else {
164 result
165 }
166}
167
168pub fn slash_path(path: &Path) -> Cow<'_, Path> {
173 if cfg!(windows) {
174 Cow::Owned(to_slash_separated(path).into())
175 } else {
176 Cow::Borrowed(path)
177 }
178}
179
180fn to_slash_separated(path: &Path) -> OsString {
181 let mut buf = OsString::with_capacity(path.as_os_str().len());
182 let mut components = path.components();
183 match components.next() {
184 Some(c) => buf.push(c),
185 None => return buf,
186 }
187 for c in components {
188 buf.push("/");
189 buf.push(c);
190 }
191 buf
192}
193
194pub fn persist_content_addressed_temp_file<P: AsRef<Path>>(
197 temp_file: NamedTempFile,
198 new_path: P,
199) -> io::Result<File> {
200 if cfg!(windows) {
201 match temp_file.persist_noclobber(&new_path) {
205 Ok(file) => Ok(file),
206 Err(PersistError { error, file: _ }) => {
207 if let Ok(existing_file) = File::open(new_path) {
208 Ok(existing_file)
210 } else {
211 Err(error)
212 }
213 }
214 }
215 } else {
216 temp_file
220 .persist(new_path)
221 .map_err(|PersistError { error, file: _ }| error)
222 }
223}
224
225pub async fn copy_async_to_sync<R: AsyncRead, W: Write + ?Sized>(
228 reader: R,
229 writer: &mut W,
230) -> io::Result<usize> {
231 let mut buf = vec![0; 16 << 10];
232 let mut total_written_bytes = 0;
233
234 let mut reader = std::pin::pin!(reader);
235 loop {
236 let written_bytes = reader.read(&mut buf).await?;
237 if written_bytes == 0 {
238 return Ok(total_written_bytes);
239 }
240 writer.write_all(&buf[0..written_bytes])?;
241 total_written_bytes += written_bytes;
242 }
243}
244
245pub struct BlockingAsyncReader<R> {
249 reader: R,
250}
251
252impl<R: Read + Unpin> BlockingAsyncReader<R> {
253 pub fn new(reader: R) -> Self {
255 Self { reader }
256 }
257}
258
259impl<R: Read + Unpin> AsyncRead for BlockingAsyncReader<R> {
260 fn poll_read(
261 mut self: Pin<&mut Self>,
262 _cx: &mut std::task::Context<'_>,
263 buf: &mut ReadBuf<'_>,
264 ) -> Poll<io::Result<()>> {
265 let num_bytes_read = self.reader.read(buf.initialize_unfilled())?;
266 buf.advance(num_bytes_read);
267 Poll::Ready(Ok(()))
268 }
269}
270
271#[cfg(unix)]
272mod platform {
273 use std::convert::Infallible;
274 use std::ffi::OsStr;
275 use std::io;
276 use std::os::unix::ffi::OsStrExt as _;
277 use std::os::unix::fs::symlink;
278 use std::path::Path;
279
280 pub type BadOsStrEncoding = Infallible;
281
282 pub fn os_str_from_bytes(data: &[u8]) -> Result<&OsStr, BadOsStrEncoding> {
283 Ok(OsStr::from_bytes(data))
284 }
285
286 pub fn os_str_to_bytes(data: &OsStr) -> Result<&[u8], BadOsStrEncoding> {
287 Ok(data.as_bytes())
288 }
289
290 pub fn check_symlink_support() -> io::Result<bool> {
292 Ok(true)
293 }
294
295 pub fn try_symlink<P: AsRef<Path>, Q: AsRef<Path>>(original: P, link: Q) -> io::Result<()> {
296 symlink(original, link)
297 }
298}
299
300#[cfg(windows)]
301mod platform {
302 use std::io;
303 use std::os::windows::fs::symlink_file;
304 use std::path::Path;
305
306 use winreg::enums::HKEY_LOCAL_MACHINE;
307 use winreg::RegKey;
308
309 pub use super::fallback::os_str_from_bytes;
310 pub use super::fallback::os_str_to_bytes;
311 pub use super::fallback::BadOsStrEncoding;
312
313 pub fn check_symlink_support() -> io::Result<bool> {
316 let hklm = RegKey::predef(HKEY_LOCAL_MACHINE);
317 let sideloading =
318 hklm.open_subkey("SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\AppModelUnlock")?;
319 let developer_mode: u32 = sideloading.get_value("AllowDevelopmentWithoutDevLicense")?;
320 Ok(developer_mode == 1)
321 }
322
323 pub fn try_symlink<P: AsRef<Path>, Q: AsRef<Path>>(original: P, link: Q) -> io::Result<()> {
324 symlink_file(original, link)
331 }
332}
333
334#[cfg_attr(unix, allow(dead_code))]
335mod fallback {
336 use std::ffi::OsStr;
337 use std::str;
338
339 use thiserror::Error;
340
341 #[derive(Debug, Error)]
343 #[error("Invalid UTF-8 sequence")]
344 pub struct BadOsStrEncoding;
345
346 pub fn os_str_from_bytes(data: &[u8]) -> Result<&OsStr, BadOsStrEncoding> {
347 Ok(str::from_utf8(data).map_err(|_| BadOsStrEncoding)?.as_ref())
348 }
349
350 pub fn os_str_to_bytes(data: &OsStr) -> Result<&[u8], BadOsStrEncoding> {
351 Ok(data.to_str().ok_or(BadOsStrEncoding)?.as_ref())
352 }
353}
354
355#[cfg(test)]
356mod tests {
357 use std::io::Cursor;
358 use std::io::Write as _;
359
360 use itertools::Itertools as _;
361 use pollster::FutureExt as _;
362 use test_case::test_case;
363
364 use super::*;
365 use crate::tests::new_temp_dir;
366
367 #[test]
368 fn test_path_bytes_roundtrip() {
369 let bytes = b"ascii";
370 let path = path_from_bytes(bytes).unwrap();
371 assert_eq!(path_to_bytes(path).unwrap(), bytes);
372
373 let bytes = b"utf-8.\xc3\xa0";
374 let path = path_from_bytes(bytes).unwrap();
375 assert_eq!(path_to_bytes(path).unwrap(), bytes);
376
377 let bytes = b"latin1.\xe0";
378 if cfg!(unix) {
379 let path = path_from_bytes(bytes).unwrap();
380 assert_eq!(path_to_bytes(path).unwrap(), bytes);
381 } else {
382 assert!(path_from_bytes(bytes).is_err());
383 }
384 }
385
386 #[test]
387 fn normalize_too_many_dot_dot() {
388 assert_eq!(normalize_path(Path::new("foo/..")), Path::new("."));
389 assert_eq!(normalize_path(Path::new("foo/../..")), Path::new(".."));
390 assert_eq!(
391 normalize_path(Path::new("foo/../../..")),
392 Path::new("../..")
393 );
394 assert_eq!(
395 normalize_path(Path::new("foo/../../../bar/baz/..")),
396 Path::new("../../bar")
397 );
398 }
399
400 #[test]
401 fn test_slash_path() {
402 assert_eq!(slash_path(Path::new("")), Path::new(""));
403 assert_eq!(slash_path(Path::new("foo")), Path::new("foo"));
404 assert_eq!(slash_path(Path::new("foo/bar")), Path::new("foo/bar"));
405 assert_eq!(slash_path(Path::new("foo/bar/..")), Path::new("foo/bar/.."));
406 assert_eq!(
407 slash_path(Path::new(r"foo\bar")),
408 if cfg!(windows) {
409 Path::new("foo/bar")
410 } else {
411 Path::new(r"foo\bar")
412 }
413 );
414 assert_eq!(
415 slash_path(Path::new(r"..\foo\bar")),
416 if cfg!(windows) {
417 Path::new("../foo/bar")
418 } else {
419 Path::new(r"..\foo\bar")
420 }
421 );
422 }
423
424 #[test]
425 fn test_persist_no_existing_file() {
426 let temp_dir = new_temp_dir();
427 let target = temp_dir.path().join("file");
428 let mut temp_file = NamedTempFile::new_in(&temp_dir).unwrap();
429 temp_file.write_all(b"contents").unwrap();
430 assert!(persist_content_addressed_temp_file(temp_file, target).is_ok());
431 }
432
433 #[test_case(false ; "existing file open")]
434 #[test_case(true ; "existing file closed")]
435 fn test_persist_target_exists(existing_file_closed: bool) {
436 let temp_dir = new_temp_dir();
437 let target = temp_dir.path().join("file");
438 let mut temp_file = NamedTempFile::new_in(&temp_dir).unwrap();
439 temp_file.write_all(b"contents").unwrap();
440
441 let mut file = File::create(&target).unwrap();
442 file.write_all(b"contents").unwrap();
443 if existing_file_closed {
444 drop(file);
445 }
446
447 assert!(persist_content_addressed_temp_file(temp_file, &target).is_ok());
448 }
449
450 #[test]
451 fn test_copy_async_to_sync_small() {
452 let input = b"hello";
453 let mut output = vec![];
454
455 let result = copy_async_to_sync(Cursor::new(&input), &mut output).block_on();
456 assert!(result.is_ok());
457 assert_eq!(result.unwrap(), 5);
458 assert_eq!(output, input);
459 }
460
461 #[test]
462 fn test_copy_async_to_sync_large() {
463 let input = (0..100u8).cycle().take(40000).collect_vec();
465 let mut output = vec![];
466
467 let result = copy_async_to_sync(Cursor::new(&input), &mut output).block_on();
468 assert!(result.is_ok());
469 assert_eq!(result.unwrap(), 40000);
470 assert_eq!(output, input);
471 }
472
473 #[test]
474 fn test_blocking_async_reader() {
475 let input = b"hello";
476 let sync_reader = Cursor::new(&input);
477 let mut async_reader = BlockingAsyncReader::new(sync_reader);
478
479 let mut buf = [0u8; 3];
480 let num_bytes_read = async_reader.read(&mut buf).block_on().unwrap();
481 assert_eq!(num_bytes_read, 3);
482 assert_eq!(&buf, &input[0..3]);
483
484 let num_bytes_read = async_reader.read(&mut buf).block_on().unwrap();
485 assert_eq!(num_bytes_read, 2);
486 assert_eq!(&buf[0..2], &input[3..5]);
487 }
488
489 #[test]
490 fn test_blocking_async_reader_read_to_end() {
491 let input = b"hello";
492 let sync_reader = Cursor::new(&input);
493 let mut async_reader = BlockingAsyncReader::new(sync_reader);
494
495 let mut buf = vec![];
496 let num_bytes_read = async_reader.read_to_end(&mut buf).block_on().unwrap();
497 assert_eq!(num_bytes_read, input.len());
498 assert_eq!(&buf, &input);
499 }
500}