1use std::{
19 io,
20 path::{Path, PathBuf},
21};
22
23use thiserror::Error;
24use tokio::{sync::mpsc, task::JoinHandle};
25use walkdir::{DirEntry, IntoIter, WalkDir};
26
27use super::SymlinkPolicy;
28use crate::name::NameValidation;
29
30pub(crate) const DIRECTORY_TRAVERSAL_BATCH_ENTRIES: usize = 256;
35const DIRECTORY_TRAVERSAL_BUFFER_BATCHES: usize = 1;
39
40#[derive(Debug)]
42pub(crate) struct TraversalEntry {
43 pub(crate) source: PathBuf,
45 pub(crate) archive_path: String,
47 pub(crate) kind: TraversalKind,
49}
50
51#[derive(Debug)]
53pub(crate) enum TraversalKind {
54 Directory,
56 Regular,
58 SymbolicLink { target: String },
60}
61
62pub(crate) struct TraversalStream {
67 entries: mpsc::Receiver<Vec<TraversalEntry>>,
68 task: JoinHandle<Result<(), TraversalError>>,
69}
70
71impl TraversalStream {
72 pub(crate) async fn recv(&mut self) -> Option<Vec<TraversalEntry>> {
74 self.entries.recv().await
75 }
76
77 pub(crate) async fn finish(self) -> Result<(), TraversalError> {
82 drop(self.entries);
83 self.task.await?
84 }
85}
86
87#[derive(Debug, Error)]
89pub enum TraversalError {
90 #[error("invalid archive path {path:?}: {reason}")]
92 InvalidArchivePath {
93 path: PathBuf,
95 reason: &'static str,
97 },
98 #[error("archive {context} rejected by builder policy: {value:?}")]
100 NameRejected {
101 context: &'static str,
103 value: String,
105 },
106 #[error("source path is not valid UTF-8: {path}")]
108 NonUtf8SourcePath {
109 path: PathBuf,
111 },
112 #[error("symbolic-link target is not valid UTF-8: {path}")]
114 NonUtf8LinkTarget {
115 path: PathBuf,
117 },
118 #[error("source directory is not a real directory: {path}")]
120 SourceNotDirectory {
121 path: PathBuf,
123 },
124 #[error("symbolic link rejected by builder policy: {path}")]
126 SymbolicLinkRejected {
127 path: PathBuf,
129 },
130 #[error("unsupported filesystem entry type: {path}")]
132 UnsupportedFilesystemType {
133 path: PathBuf,
135 },
136 #[error("failed to {operation} {path}: {source}")]
138 Filesystem {
139 operation: &'static str,
141 path: PathBuf,
143 #[source]
145 source: io::Error,
146 },
147 #[error("failed to complete blocking directory traversal: {0}")]
149 BlockingTask(#[from] tokio::task::JoinError),
150}
151
152pub(crate) fn stream_directory_entries(
157 source: PathBuf,
158 validation: NameValidation,
159 symlink_policy: SymlinkPolicy,
160) -> Result<TraversalStream, TraversalError> {
161 let Some(archive_path) = source
162 .file_name()
163 .and_then(|name| name.to_str())
164 .map(str::to_owned)
165 else {
166 return Err(TraversalError::NonUtf8SourcePath {
167 path: source.to_path_buf(),
168 });
169 };
170 validate_name(&archive_path, validation, "member path")?;
171 let (sender, receiver) = mpsc::channel(DIRECTORY_TRAVERSAL_BUFFER_BATCHES);
172 let task = tokio::spawn(async move {
175 let mut producer = TraversalProducer::new(source, archive_path, validation, symlink_policy);
176 loop {
177 let (next_producer, entries) =
178 tokio::task::spawn_blocking(move || producer.next_batch()).await??;
179 producer = next_producer;
180 let Some(entries) = entries else {
181 return Ok(());
182 };
183 if sender.send(entries).await.is_err() {
184 return Ok(());
185 }
186 }
187 });
188 Ok(TraversalStream {
189 entries: receiver,
190 task,
191 })
192}
193
194struct TraversalProducer {
196 source: PathBuf,
197 archive_path: String,
198 validation: NameValidation,
199 symlink_policy: SymlinkPolicy,
200 entries: IntoIter,
201}
202
203impl TraversalProducer {
204 fn new(
205 source: PathBuf,
206 archive_path: String,
207 validation: NameValidation,
208 symlink_policy: SymlinkPolicy,
209 ) -> Self {
210 let entries = WalkDir::new(&source)
211 .follow_links(false)
212 .follow_root_links(false)
213 .sort_by_file_name()
214 .into_iter();
215 Self {
216 source,
217 archive_path,
218 validation,
219 symlink_policy,
220 entries,
221 }
222 }
223
224 fn next_batch(mut self) -> Result<(Self, Option<Vec<TraversalEntry>>), TraversalError> {
225 let mut entries = Vec::with_capacity(DIRECTORY_TRAVERSAL_BATCH_ENTRIES);
226 while entries.len() < DIRECTORY_TRAVERSAL_BATCH_ENTRIES {
227 let Some(entry) = self.entries.next() else {
228 break;
229 };
230 let entry = entry.map_err(|error| {
231 let path = error.path().unwrap_or(&self.source).to_path_buf();
232 filesystem_error("traverse source directory", &path, error.into())
233 })?;
234 entries.push(traversal_entry(
235 &self.source,
236 &self.archive_path,
237 self.validation,
238 self.symlink_policy,
239 entry,
240 )?);
241 }
242 let entries = if entries.is_empty() {
243 None
244 } else {
245 Some(entries)
246 };
247 Ok((self, entries))
248 }
249}
250
251fn traversal_entry(
256 source: &Path,
257 archive_path: &str,
258 validation: NameValidation,
259 symlink_policy: SymlinkPolicy,
260 entry: DirEntry,
261) -> Result<TraversalEntry, TraversalError> {
262 let path = entry.path();
263 let file_type = entry.file_type();
264 let kind = if file_type.is_dir() {
265 TraversalKind::Directory
266 } else if file_type.is_file() {
267 TraversalKind::Regular
268 } else if file_type.is_symlink() {
269 if symlink_policy == SymlinkPolicy::Reject {
270 return Err(TraversalError::SymbolicLinkRejected {
271 path: path.to_path_buf(),
272 });
273 }
274 let target = std::fs::read_link(path)
275 .map_err(|error| filesystem_error("read symbolic link", path, error))?;
276 let Some(target) = target.to_str().map(str::to_owned) else {
277 return Err(TraversalError::NonUtf8LinkTarget {
278 path: path.to_path_buf(),
279 });
280 };
281 validate_name(&target, validation, "symbolic-link target")?;
282 TraversalKind::SymbolicLink { target }
283 } else {
284 return Err(TraversalError::UnsupportedFilesystemType {
285 path: path.to_path_buf(),
286 });
287 };
288 if entry.depth() == 0 && !matches!(&kind, TraversalKind::Directory) {
289 return Err(TraversalError::SourceNotDirectory {
290 path: source.to_path_buf(),
291 });
292 }
293 let relative = path
294 .strip_prefix(source)
295 .map_err(|_| TraversalError::InvalidArchivePath {
296 path: path.to_path_buf(),
297 reason: "source entry is outside recursive root",
298 })?;
299 let archive_path = if relative.as_os_str().is_empty() {
300 archive_path.to_owned()
301 } else {
302 join_archive_path(archive_path, relative, path, validation)?
303 };
304 Ok(TraversalEntry {
305 source: entry.into_path(),
306 archive_path,
307 kind,
308 })
309}
310
311fn join_archive_path(
312 archive_path: &str,
313 relative: &Path,
314 source_path: &Path,
315 validation: NameValidation,
316) -> Result<String, TraversalError> {
317 let mut joined = archive_path.to_owned();
318 for component in relative {
319 let Some(component) = component.to_str() else {
320 return Err(TraversalError::NonUtf8SourcePath {
321 path: source_path.to_path_buf(),
322 });
323 };
324 joined.push('/');
325 joined.push_str(component);
326 }
327 validate_name(&joined, validation, "member path")?;
328 Ok(joined)
329}
330
331fn validate_name(
332 name: &str,
333 validation: NameValidation,
334 context: &'static str,
335) -> Result<(), TraversalError> {
336 if validation.accepts(name) {
337 Ok(())
338 } else {
339 Err(TraversalError::NameRejected {
340 context,
341 value: name.to_owned(),
342 })
343 }
344}
345
346fn filesystem_error(operation: &'static str, path: &Path, source: io::Error) -> TraversalError {
347 TraversalError::Filesystem {
348 operation,
349 path: path.to_path_buf(),
350 source,
351 }
352}
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357
358 #[test]
359 fn joins_native_relative_paths_with_archive_separators() {
360 let relative = Path::new("nested").join("file");
361 assert!(matches!(
362 join_archive_path("tree", &relative, &relative, NameValidation::Default),
363 Ok(path) if path == "tree/nested/file"
364 ));
365 }
366
367 #[cfg(unix)]
368 #[test]
369 fn preserves_backslashes_in_source_path_components() {
370 let relative = Path::new("nested\\file");
371 assert!(matches!(
372 join_archive_path(
373 "tree",
374 relative,
375 relative,
376 NameValidation::Default,
377 ),
378 Ok(path) if path == r"tree/nested\file"
379 ));
380 }
381}