1use std::{
19 io,
20 path::{Path, PathBuf},
21};
22
23use thiserror::Error;
24use tokio::{sync::mpsc, task::JoinHandle};
25use walkdir::{DirEntry, IntoIter, WalkDir};
26
27use super::SymlinkPolicy;
28use crate::name::NameValidation;
29
30pub(crate) const DIRECTORY_TRAVERSAL_BATCH_ENTRIES: usize = 256;
35const DIRECTORY_TRAVERSAL_BUFFER_BATCHES: usize = 1;
39
40#[derive(Debug)]
42pub(crate) struct TraversalEntry {
43 pub(crate) source: PathBuf,
45 pub(crate) archive_path: String,
47 pub(crate) kind: TraversalKind,
49}
50
51#[derive(Debug)]
53pub(crate) enum TraversalKind {
54 Directory,
56 Regular,
58 SymbolicLink { target: String },
60}
61
62pub(crate) struct TraversalStream {
67 entries: mpsc::Receiver<Vec<TraversalEntry>>,
68 task: JoinHandle<Result<(), TraversalError>>,
69}
70
71impl TraversalStream {
72 pub(crate) async fn recv(&mut self) -> Option<Vec<TraversalEntry>> {
74 self.entries.recv().await
75 }
76
77 pub(crate) async fn finish(self) -> Result<(), TraversalError> {
82 drop(self.entries);
83 self.task.await?
84 }
85}
86
87#[derive(Debug, Error)]
89pub enum TraversalError {
90 #[error("invalid archive path {path:?}: {reason}")]
92 InvalidArchivePath {
93 path: PathBuf,
95 reason: &'static str,
97 },
98 #[error("archive {context} rejected by builder policy: {value:?}")]
100 NameRejected {
101 context: &'static str,
103 value: String,
105 },
106 #[error("source path is not valid UTF-8: {path}")]
108 NonUtf8SourcePath {
109 path: PathBuf,
111 },
112 #[error("symbolic-link target is not valid UTF-8: {path}")]
114 NonUtf8LinkTarget {
115 path: PathBuf,
117 },
118 #[error("source directory is not a real directory: {path}")]
120 SourceNotDirectory {
121 path: PathBuf,
123 },
124 #[error("symbolic link rejected by builder policy: {path}")]
126 SymbolicLinkRejected {
127 path: PathBuf,
129 },
130 #[error("unsupported filesystem entry type: {path}")]
132 UnsupportedFilesystemType {
133 path: PathBuf,
135 },
136 #[error("failed to {operation} {path}: {source}")]
138 Filesystem {
139 operation: &'static str,
141 path: PathBuf,
143 #[source]
145 source: io::Error,
146 },
147 #[error("failed to complete blocking directory traversal: {0}")]
149 BlockingTask(#[from] tokio::task::JoinError),
150}
151
152pub(crate) fn stream_directory_entries(
157 source: PathBuf,
158 validation: NameValidation,
159 symlink_policy: SymlinkPolicy,
160) -> Result<TraversalStream, TraversalError> {
161 let Some(archive_path) = source
162 .file_name()
163 .and_then(|name| name.to_str())
164 .map(str::to_owned)
165 else {
166 return Err(TraversalError::NonUtf8SourcePath {
167 path: source.to_path_buf(),
168 });
169 };
170 validate_name(&archive_path, validation, "member path")?;
171 let (sender, receiver) = mpsc::channel(DIRECTORY_TRAVERSAL_BUFFER_BATCHES);
172 let task = tokio::spawn(async move {
175 let mut producer = TraversalProducer::new(source, archive_path, validation, symlink_policy);
176 loop {
177 let (next_producer, entries) =
178 tokio::task::spawn_blocking(move || producer.next_batch()).await??;
179 producer = next_producer;
180 let Some(entries) = entries else {
181 return Ok(());
182 };
183 if sender.send(entries).await.is_err() {
184 return Ok(());
185 }
186 }
187 });
188 Ok(TraversalStream {
189 entries: receiver,
190 task,
191 })
192}
193
194struct TraversalProducer {
196 source: PathBuf,
197 archive_path: String,
198 validation: NameValidation,
199 symlink_policy: SymlinkPolicy,
200 entries: IntoIter,
201}
202
203impl TraversalProducer {
204 fn new(
205 source: PathBuf,
206 archive_path: String,
207 validation: NameValidation,
208 symlink_policy: SymlinkPolicy,
209 ) -> Self {
210 let entries = WalkDir::new(&source)
211 .follow_links(false)
212 .follow_root_links(false)
213 .sort_by_file_name()
214 .into_iter();
215 Self {
216 source,
217 archive_path,
218 validation,
219 symlink_policy,
220 entries,
221 }
222 }
223
224 fn next_batch(mut self) -> Result<(Self, Option<Vec<TraversalEntry>>), TraversalError> {
225 let mut entries = Vec::with_capacity(DIRECTORY_TRAVERSAL_BATCH_ENTRIES);
226 while entries.len() < DIRECTORY_TRAVERSAL_BATCH_ENTRIES {
227 let Some(entry) = self.entries.next() else {
228 break;
229 };
230 let entry = entry.map_err(|error| {
231 let path = error.path().unwrap_or(&self.source).to_path_buf();
232 TraversalError::Filesystem {
233 operation: "traverse source directory",
234 path,
235 source: error.into(),
236 }
237 })?;
238 entries.push(traversal_entry(
239 &self.source,
240 &self.archive_path,
241 self.validation,
242 self.symlink_policy,
243 entry,
244 )?);
245 }
246 let entries = if entries.is_empty() {
247 None
248 } else {
249 Some(entries)
250 };
251 Ok((self, entries))
252 }
253}
254
255fn traversal_entry(
260 source: &Path,
261 archive_path: &str,
262 validation: NameValidation,
263 symlink_policy: SymlinkPolicy,
264 entry: DirEntry,
265) -> Result<TraversalEntry, TraversalError> {
266 let path = entry.path();
267 let file_type = entry.file_type();
268 let kind = if file_type.is_dir() {
269 TraversalKind::Directory
270 } else if file_type.is_file() {
271 TraversalKind::Regular
272 } else if file_type.is_symlink() {
273 if symlink_policy == SymlinkPolicy::Reject {
274 return Err(TraversalError::SymbolicLinkRejected {
275 path: path.to_path_buf(),
276 });
277 }
278 let target = std::fs::read_link(path).map_err(|source| TraversalError::Filesystem {
279 operation: "read symbolic link",
280 path: path.to_path_buf(),
281 source,
282 })?;
283 let Some(target) = target.to_str().map(str::to_owned) else {
284 return Err(TraversalError::NonUtf8LinkTarget {
285 path: path.to_path_buf(),
286 });
287 };
288 validate_name(&target, validation, "symbolic-link target")?;
289 TraversalKind::SymbolicLink { target }
290 } else {
291 return Err(TraversalError::UnsupportedFilesystemType {
292 path: path.to_path_buf(),
293 });
294 };
295 if entry.depth() == 0 && !matches!(&kind, TraversalKind::Directory) {
296 return Err(TraversalError::SourceNotDirectory {
297 path: source.to_path_buf(),
298 });
299 }
300 let relative = path
301 .strip_prefix(source)
302 .map_err(|_| TraversalError::InvalidArchivePath {
303 path: path.to_path_buf(),
304 reason: "source entry is outside recursive root",
305 })?;
306 let archive_path = if relative.as_os_str().is_empty() {
307 archive_path.to_owned()
308 } else {
309 join_archive_path(archive_path, relative, path, validation)?
310 };
311 Ok(TraversalEntry {
312 source: entry.into_path(),
313 archive_path,
314 kind,
315 })
316}
317
318fn join_archive_path(
319 archive_path: &str,
320 relative: &Path,
321 source_path: &Path,
322 validation: NameValidation,
323) -> Result<String, TraversalError> {
324 let mut joined = archive_path.to_owned();
325 for component in relative {
326 let Some(component) = component.to_str() else {
327 return Err(TraversalError::NonUtf8SourcePath {
328 path: source_path.to_path_buf(),
329 });
330 };
331 joined.push('/');
332 joined.push_str(component);
333 }
334 validate_name(&joined, validation, "member path")?;
335 Ok(joined)
336}
337
338fn validate_name(
339 name: &str,
340 validation: NameValidation,
341 context: &'static str,
342) -> Result<(), TraversalError> {
343 if validation.accepts(name) {
344 Ok(())
345 } else {
346 Err(TraversalError::NameRejected {
347 context,
348 value: name.to_owned(),
349 })
350 }
351}
352
353#[cfg(test)]
354mod tests {
355 use super::*;
356
357 #[test]
358 fn joins_native_relative_paths_with_archive_separators() {
359 let relative = Path::new("nested").join("file");
360 assert!(matches!(
361 join_archive_path("tree", &relative, &relative, NameValidation::Default),
362 Ok(path) if path == "tree/nested/file"
363 ));
364 }
365
366 #[cfg(unix)]
367 #[test]
368 fn preserves_backslashes_in_source_path_components() {
369 let relative = Path::new("nested\\file");
370 assert!(matches!(
371 join_archive_path(
372 "tree",
373 relative,
374 relative,
375 NameValidation::Default,
376 ),
377 Ok(path) if path == r"tree/nested\file"
378 ));
379 }
380}