use std::{
io,
path::{Path, PathBuf},
};
use thiserror::Error;
use tokio::{sync::mpsc, task::JoinHandle};
use walkdir::{DirEntry, IntoIter, WalkDir};
use super::SymlinkPolicy;
use crate::name::NameValidation;
pub(crate) const DIRECTORY_TRAVERSAL_BATCH_ENTRIES: usize = 256;
const DIRECTORY_TRAVERSAL_BUFFER_BATCHES: usize = 1;
#[derive(Debug)]
pub(crate) struct TraversalEntry {
pub(crate) source: PathBuf,
pub(crate) archive_path: String,
pub(crate) kind: TraversalKind,
}
#[derive(Debug)]
pub(crate) enum TraversalKind {
Directory,
Regular,
SymbolicLink { target: String },
}
pub(crate) struct TraversalStream {
entries: mpsc::Receiver<Vec<TraversalEntry>>,
task: JoinHandle<Result<(), TraversalError>>,
}
impl TraversalStream {
pub(crate) async fn recv(&mut self) -> Option<Vec<TraversalEntry>> {
self.entries.recv().await
}
pub(crate) async fn finish(self) -> Result<(), TraversalError> {
drop(self.entries);
self.task.await?
}
}
#[derive(Debug, Error)]
pub enum TraversalError {
#[error("invalid archive path {path:?}: {reason}")]
InvalidArchivePath {
path: PathBuf,
reason: &'static str,
},
#[error("archive {context} rejected by builder policy: {value:?}")]
NameRejected {
context: &'static str,
value: String,
},
#[error("source path is not valid UTF-8: {path}")]
NonUtf8SourcePath {
path: PathBuf,
},
#[error("symbolic-link target is not valid UTF-8: {path}")]
NonUtf8LinkTarget {
path: PathBuf,
},
#[error("source directory is not a real directory: {path}")]
SourceNotDirectory {
path: PathBuf,
},
#[error("symbolic link rejected by builder policy: {path}")]
SymbolicLinkRejected {
path: PathBuf,
},
#[error("unsupported filesystem entry type: {path}")]
UnsupportedFilesystemType {
path: PathBuf,
},
#[error("failed to {operation} {path}: {source}")]
Filesystem {
operation: &'static str,
path: PathBuf,
#[source]
source: io::Error,
},
#[error("failed to complete blocking directory traversal: {0}")]
BlockingTask(#[from] tokio::task::JoinError),
}
pub(crate) fn stream_directory_entries(
source: PathBuf,
validation: NameValidation,
symlink_policy: SymlinkPolicy,
) -> Result<TraversalStream, TraversalError> {
let Some(archive_path) = source
.file_name()
.and_then(|name| name.to_str())
.map(str::to_owned)
else {
return Err(TraversalError::NonUtf8SourcePath {
path: source.to_path_buf(),
});
};
validate_name(&archive_path, validation, "member path")?;
let (sender, receiver) = mpsc::channel(DIRECTORY_TRAVERSAL_BUFFER_BATCHES);
let task = tokio::spawn(async move {
let mut producer = TraversalProducer::new(source, archive_path, validation, symlink_policy);
loop {
let (next_producer, entries) =
tokio::task::spawn_blocking(move || producer.next_batch()).await??;
producer = next_producer;
let Some(entries) = entries else {
return Ok(());
};
if sender.send(entries).await.is_err() {
return Ok(());
}
}
});
Ok(TraversalStream {
entries: receiver,
task,
})
}
struct TraversalProducer {
source: PathBuf,
archive_path: String,
validation: NameValidation,
symlink_policy: SymlinkPolicy,
entries: IntoIter,
}
impl TraversalProducer {
fn new(
source: PathBuf,
archive_path: String,
validation: NameValidation,
symlink_policy: SymlinkPolicy,
) -> Self {
let entries = WalkDir::new(&source)
.follow_links(false)
.follow_root_links(false)
.sort_by_file_name()
.into_iter();
Self {
source,
archive_path,
validation,
symlink_policy,
entries,
}
}
fn next_batch(mut self) -> Result<(Self, Option<Vec<TraversalEntry>>), TraversalError> {
let mut entries = Vec::with_capacity(DIRECTORY_TRAVERSAL_BATCH_ENTRIES);
while entries.len() < DIRECTORY_TRAVERSAL_BATCH_ENTRIES {
let Some(entry) = self.entries.next() else {
break;
};
let entry = entry.map_err(|error| {
let path = error.path().unwrap_or(&self.source).to_path_buf();
filesystem_error("traverse source directory", &path, error.into())
})?;
entries.push(traversal_entry(
&self.source,
&self.archive_path,
self.validation,
self.symlink_policy,
entry,
)?);
}
let entries = if entries.is_empty() {
None
} else {
Some(entries)
};
Ok((self, entries))
}
}
fn traversal_entry(
source: &Path,
archive_path: &str,
validation: NameValidation,
symlink_policy: SymlinkPolicy,
entry: DirEntry,
) -> Result<TraversalEntry, TraversalError> {
let path = entry.path();
let file_type = entry.file_type();
let kind = if file_type.is_dir() {
TraversalKind::Directory
} else if file_type.is_file() {
TraversalKind::Regular
} else if file_type.is_symlink() {
if symlink_policy == SymlinkPolicy::Reject {
return Err(TraversalError::SymbolicLinkRejected {
path: path.to_path_buf(),
});
}
let target = std::fs::read_link(path)
.map_err(|error| filesystem_error("read symbolic link", path, error))?;
let Some(target) = target.to_str().map(str::to_owned) else {
return Err(TraversalError::NonUtf8LinkTarget {
path: path.to_path_buf(),
});
};
validate_name(&target, validation, "symbolic-link target")?;
TraversalKind::SymbolicLink { target }
} else {
return Err(TraversalError::UnsupportedFilesystemType {
path: path.to_path_buf(),
});
};
if entry.depth() == 0 && !matches!(&kind, TraversalKind::Directory) {
return Err(TraversalError::SourceNotDirectory {
path: source.to_path_buf(),
});
}
let relative = path
.strip_prefix(source)
.map_err(|_| TraversalError::InvalidArchivePath {
path: path.to_path_buf(),
reason: "source entry is outside recursive root",
})?;
let archive_path = if relative.as_os_str().is_empty() {
archive_path.to_owned()
} else {
join_archive_path(archive_path, relative, path, validation)?
};
Ok(TraversalEntry {
source: entry.into_path(),
archive_path,
kind,
})
}
fn join_archive_path(
archive_path: &str,
relative: &Path,
source_path: &Path,
validation: NameValidation,
) -> Result<String, TraversalError> {
let mut joined = archive_path.to_owned();
for component in relative {
let Some(component) = component.to_str() else {
return Err(TraversalError::NonUtf8SourcePath {
path: source_path.to_path_buf(),
});
};
joined.push('/');
joined.push_str(component);
}
validate_name(&joined, validation, "member path")?;
Ok(joined)
}
fn validate_name(
name: &str,
validation: NameValidation,
context: &'static str,
) -> Result<(), TraversalError> {
if validation.accepts(name) {
Ok(())
} else {
Err(TraversalError::NameRejected {
context,
value: name.to_owned(),
})
}
}
fn filesystem_error(operation: &'static str, path: &Path, source: io::Error) -> TraversalError {
TraversalError::Filesystem {
operation,
path: path.to_path_buf(),
source,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn joins_native_relative_paths_with_archive_separators() {
let relative = Path::new("nested").join("file");
assert!(matches!(
join_archive_path("tree", &relative, &relative, NameValidation::Default),
Ok(path) if path == "tree/nested/file"
));
}
#[cfg(unix)]
#[test]
fn preserves_backslashes_in_source_path_components() {
let relative = Path::new("nested\\file");
assert!(matches!(
join_archive_path(
"tree",
relative,
relative,
NameValidation::Default,
),
Ok(path) if path == r"tree/nested\file"
));
}
}