1#![allow(missing_docs)]
16
17use std::fs;
18use std::fs::File;
19use std::io;
20use std::io::Read;
21use std::io::Write;
22use std::path::Component;
23use std::path::Path;
24use std::path::PathBuf;
25use std::pin::Pin;
26use std::task::Poll;
27
28use tempfile::NamedTempFile;
29use tempfile::PersistError;
30use thiserror::Error;
31use tokio::io::AsyncRead;
32use tokio::io::AsyncReadExt as _;
33use tokio::io::ReadBuf;
34
35pub use self::platform::*;
36
37#[derive(Debug, Error)]
38#[error("Cannot access {path}")]
39pub struct PathError {
40 pub path: PathBuf,
41 #[source]
42 pub error: io::Error,
43}
44
45pub trait IoResultExt<T> {
46 fn context(self, path: impl AsRef<Path>) -> Result<T, PathError>;
47}
48
49impl<T> IoResultExt<T> for io::Result<T> {
50 fn context(self, path: impl AsRef<Path>) -> Result<T, PathError> {
51 self.map_err(|error| PathError {
52 path: path.as_ref().to_path_buf(),
53 error,
54 })
55 }
56}
57
58pub fn create_or_reuse_dir(dirname: &Path) -> io::Result<()> {
64 match fs::create_dir(dirname) {
65 Ok(()) => Ok(()),
66 Err(_) if dirname.is_dir() => Ok(()),
67 Err(e) => Err(e),
68 }
69}
70
71pub fn remove_dir_contents(dirname: &Path) -> Result<(), PathError> {
75 for entry in dirname.read_dir().context(dirname)? {
76 let entry = entry.context(dirname)?;
77 let path = entry.path();
78 fs::remove_file(&path).context(&path)?;
79 }
80 Ok(())
81}
82
83pub fn expand_home_path(path_str: &str) -> PathBuf {
85 if let Some(remainder) = path_str.strip_prefix("~/") {
86 if let Ok(home_dir_str) = std::env::var("HOME") {
87 return PathBuf::from(home_dir_str).join(remainder);
88 }
89 }
90 PathBuf::from(path_str)
91}
92
93pub fn relative_path(from: &Path, to: &Path) -> PathBuf {
98 for (i, base) in from.ancestors().enumerate() {
100 if let Ok(suffix) = to.strip_prefix(base) {
101 if i == 0 && suffix.as_os_str().is_empty() {
102 return ".".into();
103 } else {
104 let mut result = PathBuf::from_iter(std::iter::repeat_n("..", i));
105 result.push(suffix);
106 return result;
107 }
108 }
109 }
110
111 to.to_owned()
113}
114
115pub fn normalize_path(path: &Path) -> PathBuf {
117 let mut result = PathBuf::new();
118 for c in path.components() {
119 match c {
120 Component::CurDir => {}
121 Component::ParentDir
122 if matches!(result.components().next_back(), Some(Component::Normal(_))) =>
123 {
124 let popped = result.pop();
126 assert!(popped);
127 }
128 _ => {
129 result.push(c);
130 }
131 }
132 }
133
134 if result.as_os_str().is_empty() {
135 ".".into()
136 } else {
137 result
138 }
139}
140
141pub fn persist_content_addressed_temp_file<P: AsRef<Path>>(
144 temp_file: NamedTempFile,
145 new_path: P,
146) -> io::Result<File> {
147 if cfg!(windows) {
148 match temp_file.persist_noclobber(&new_path) {
152 Ok(file) => Ok(file),
153 Err(PersistError { error, file: _ }) => {
154 if let Ok(existing_file) = File::open(new_path) {
155 Ok(existing_file)
157 } else {
158 Err(error)
159 }
160 }
161 }
162 } else {
163 temp_file
167 .persist(new_path)
168 .map_err(|PersistError { error, file: _ }| error)
169 }
170}
171
172pub async fn copy_async_to_sync<R: AsyncRead, W: Write + ?Sized>(
175 reader: R,
176 writer: &mut W,
177) -> io::Result<usize> {
178 let mut buf = vec![0; 16 << 10];
179 let mut total_written_bytes = 0;
180
181 let mut reader = std::pin::pin!(reader);
182 loop {
183 let written_bytes = reader.read(&mut buf).await?;
184 if written_bytes == 0 {
185 return Ok(total_written_bytes);
186 }
187 writer.write_all(&buf[0..written_bytes])?;
188 total_written_bytes += written_bytes;
189 }
190}
191
192pub struct BlockingAsyncReader<R> {
196 reader: R,
197}
198
199impl<R: Read + Unpin> BlockingAsyncReader<R> {
200 pub fn new(reader: R) -> Self {
202 Self { reader }
203 }
204}
205
206impl<R: Read + Unpin> AsyncRead for BlockingAsyncReader<R> {
207 fn poll_read(
208 mut self: Pin<&mut Self>,
209 _cx: &mut std::task::Context<'_>,
210 buf: &mut ReadBuf<'_>,
211 ) -> Poll<io::Result<()>> {
212 let num_bytes_read = self.reader.read(buf.initialize_unfilled())?;
213 buf.advance(num_bytes_read);
214 Poll::Ready(Ok(()))
215 }
216}
217
218#[cfg(unix)]
219mod platform {
220 use std::io;
221 use std::os::unix::fs::symlink;
222 use std::path::Path;
223
224 pub fn check_symlink_support() -> io::Result<bool> {
226 Ok(true)
227 }
228
229 pub fn try_symlink<P: AsRef<Path>, Q: AsRef<Path>>(original: P, link: Q) -> io::Result<()> {
230 symlink(original, link)
231 }
232}
233
234#[cfg(windows)]
235mod platform {
236 use std::io;
237 use std::os::windows::fs::symlink_file;
238 use std::path::Path;
239
240 use winreg::enums::HKEY_LOCAL_MACHINE;
241 use winreg::RegKey;
242
243 pub fn check_symlink_support() -> io::Result<bool> {
246 let hklm = RegKey::predef(HKEY_LOCAL_MACHINE);
247 let sideloading =
248 hklm.open_subkey("SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\AppModelUnlock")?;
249 let developer_mode: u32 = sideloading.get_value("AllowDevelopmentWithoutDevLicense")?;
250 Ok(developer_mode == 1)
251 }
252
253 pub fn try_symlink<P: AsRef<Path>, Q: AsRef<Path>>(original: P, link: Q) -> io::Result<()> {
254 symlink_file(original, link)
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use std::io::Cursor;
267 use std::io::Write as _;
268
269 use itertools::Itertools as _;
270 use pollster::FutureExt as _;
271 use test_case::test_case;
272
273 use super::*;
274 use crate::tests::new_temp_dir;
275
276 #[test]
277 fn normalize_too_many_dot_dot() {
278 assert_eq!(normalize_path(Path::new("foo/..")), Path::new("."));
279 assert_eq!(normalize_path(Path::new("foo/../..")), Path::new(".."));
280 assert_eq!(
281 normalize_path(Path::new("foo/../../..")),
282 Path::new("../..")
283 );
284 assert_eq!(
285 normalize_path(Path::new("foo/../../../bar/baz/..")),
286 Path::new("../../bar")
287 );
288 }
289
290 #[test]
291 fn test_persist_no_existing_file() {
292 let temp_dir = new_temp_dir();
293 let target = temp_dir.path().join("file");
294 let mut temp_file = NamedTempFile::new_in(&temp_dir).unwrap();
295 temp_file.write_all(b"contents").unwrap();
296 assert!(persist_content_addressed_temp_file(temp_file, target).is_ok());
297 }
298
299 #[test_case(false ; "existing file open")]
300 #[test_case(true ; "existing file closed")]
301 fn test_persist_target_exists(existing_file_closed: bool) {
302 let temp_dir = new_temp_dir();
303 let target = temp_dir.path().join("file");
304 let mut temp_file = NamedTempFile::new_in(&temp_dir).unwrap();
305 temp_file.write_all(b"contents").unwrap();
306
307 let mut file = File::create(&target).unwrap();
308 file.write_all(b"contents").unwrap();
309 if existing_file_closed {
310 drop(file);
311 }
312
313 assert!(persist_content_addressed_temp_file(temp_file, &target).is_ok());
314 }
315
316 #[test]
317 fn test_copy_async_to_sync_small() {
318 let input = b"hello";
319 let mut output = vec![];
320
321 let result = copy_async_to_sync(Cursor::new(&input), &mut output).block_on();
322 assert!(result.is_ok());
323 assert_eq!(result.unwrap(), 5);
324 assert_eq!(output, input);
325 }
326
327 #[test]
328 fn test_copy_async_to_sync_large() {
329 let input = (0..100u8).cycle().take(40000).collect_vec();
331 let mut output = vec![];
332
333 let result = copy_async_to_sync(Cursor::new(&input), &mut output).block_on();
334 assert!(result.is_ok());
335 assert_eq!(result.unwrap(), 40000);
336 assert_eq!(output, input);
337 }
338
339 #[test]
340 fn test_blocking_async_reader() {
341 let input = b"hello";
342 let sync_reader = Cursor::new(&input);
343 let mut async_reader = BlockingAsyncReader::new(sync_reader);
344
345 let mut buf = [0u8; 3];
346 let num_bytes_read = async_reader.read(&mut buf).block_on().unwrap();
347 assert_eq!(num_bytes_read, 3);
348 assert_eq!(&buf, &input[0..3]);
349
350 let num_bytes_read = async_reader.read(&mut buf).block_on().unwrap();
351 assert_eq!(num_bytes_read, 2);
352 assert_eq!(&buf[0..2], &input[3..5]);
353 }
354
355 #[test]
356 fn test_blocking_async_reader_read_to_end() {
357 let input = b"hello";
358 let sync_reader = Cursor::new(&input);
359 let mut async_reader = BlockingAsyncReader::new(sync_reader);
360
361 let mut buf = vec![];
362 let num_bytes_read = async_reader.read_to_end(&mut buf).block_on().unwrap();
363 assert_eq!(num_bytes_read, input.len());
364 assert_eq!(&buf, &input);
365 }
366}