use crate::format::streams::ResourceLimits;
use crate::progress::ProgressReporter;
#[cfg(feature = "aes")]
use crate::Password;
pub use crate::safety::PathSafety;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum OverwritePolicy {
#[default]
Error,
Skip,
Overwrite,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum LinkPolicy {
#[default]
Forbid,
ValidateTargets,
Allow,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum FilterPolicy {
#[default]
Include,
Exclude,
}
impl FilterPolicy {
pub fn is_include(&self) -> bool {
matches!(self, Self::Include)
}
pub fn is_exclude(&self) -> bool {
matches!(self, Self::Exclude)
}
pub fn apply(&self, matched: bool) -> bool {
match self {
Self::Include => matched,
Self::Exclude => !matched,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum Threads {
#[default]
Auto,
Count(std::num::NonZeroUsize),
Single,
}
impl Threads {
pub fn count_or_single(n: usize) -> Self {
match std::num::NonZeroUsize::new(n) {
Some(count) => Self::Count(count),
None => Self::Single,
}
}
pub fn count(&self) -> usize {
match self {
Self::Auto => std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1),
Self::Count(n) => n.get(),
Self::Single => 1,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct PreserveMetadata {
pub modification_time: bool,
pub creation_time: bool,
pub attributes: bool,
}
impl PreserveMetadata {
pub fn all() -> Self {
Self {
modification_time: true,
creation_time: true,
attributes: true,
}
}
pub fn none() -> Self {
Self::default()
}
pub fn times() -> Self {
Self {
modification_time: true,
creation_time: true,
attributes: false,
}
}
pub fn modification_time_only() -> Self {
Self {
modification_time: true,
creation_time: false,
attributes: false,
}
}
}
#[derive(Default)]
pub struct ExtractOptions {
pub overwrite: OverwritePolicy,
pub path_safety: PathSafety,
pub link_policy: LinkPolicy,
pub limits: ResourceLimits,
pub threads: Threads,
pub preserve_metadata: PreserveMetadata,
#[cfg(feature = "aes")]
pub password: Option<Password>,
pub progress: Option<Box<dyn ProgressReporter>>,
}
impl std::fmt::Debug for ExtractOptions {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ExtractOptions")
.field("overwrite", &self.overwrite)
.field("path_safety", &self.path_safety)
.field("link_policy", &self.link_policy)
.field("threads", &self.threads)
.field("preserve_metadata", &self.preserve_metadata)
.finish_non_exhaustive()
}
}
impl ExtractOptions {
pub fn new() -> Self {
Self::default()
}
pub fn overwrite(mut self, policy: OverwritePolicy) -> Self {
self.overwrite = policy;
self
}
pub fn path_safety(mut self, policy: PathSafety) -> Self {
self.path_safety = policy;
self
}
pub fn link_policy(mut self, policy: LinkPolicy) -> Self {
self.link_policy = policy;
self
}
pub fn limits(mut self, limits: ResourceLimits) -> Self {
self.limits = limits;
self
}
pub fn threads(mut self, threads: Threads) -> Self {
self.threads = threads;
self
}
pub fn preserve_metadata(mut self, preserve: PreserveMetadata) -> Self {
self.preserve_metadata = preserve;
self
}
#[cfg(feature = "aes")]
pub fn password(mut self, password: impl Into<Password>) -> Self {
self.password = Some(password.into());
self
}
pub fn progress(mut self, reporter: impl ProgressReporter + 'static) -> Self {
self.progress = Some(Box::new(reporter));
self
}
pub fn clone_settings(&self) -> Self {
Self {
overwrite: self.overwrite,
path_safety: self.path_safety,
link_policy: self.link_policy,
limits: self.limits.clone(),
threads: self.threads,
preserve_metadata: self.preserve_metadata.clone(),
#[cfg(feature = "aes")]
password: self.password.clone(),
progress: None, }
}
}
#[derive(Default)]
pub struct TestOptions {
pub limits: ResourceLimits,
pub threads: Threads,
#[cfg(feature = "aes")]
pub password: Option<Password>,
pub progress: Option<Box<dyn ProgressReporter>>,
}
impl std::fmt::Debug for TestOptions {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TestOptions")
.field("threads", &self.threads)
.finish_non_exhaustive()
}
}
impl TestOptions {
pub fn new() -> Self {
Self::default()
}
pub fn limits(mut self, limits: ResourceLimits) -> Self {
self.limits = limits;
self
}
pub fn threads(mut self, threads: Threads) -> Self {
self.threads = threads;
self
}
#[cfg(feature = "aes")]
pub fn password(mut self, password: impl Into<Password>) -> Self {
self.password = Some(password.into());
self
}
pub fn progress(mut self, reporter: impl ProgressReporter + 'static) -> Self {
self.progress = Some(Box::new(reporter));
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_overwrite_policy_default() {
assert_eq!(OverwritePolicy::default(), OverwritePolicy::Error);
}
#[test]
fn test_path_safety_default() {
assert_eq!(PathSafety::default(), PathSafety::Strict);
}
#[test]
fn test_link_policy_default() {
assert_eq!(LinkPolicy::default(), LinkPolicy::Forbid);
}
#[test]
fn test_threads_count() {
use std::num::NonZeroUsize;
assert_eq!(Threads::Single.count(), 1);
assert_eq!(Threads::Count(NonZeroUsize::new(4).unwrap()).count(), 4);
assert!(Threads::Auto.count() >= 1);
}
#[test]
fn test_threads_count_or_single() {
assert_eq!(Threads::count_or_single(0), Threads::Single);
assert_eq!(Threads::count_or_single(4).count(), 4);
assert_eq!(Threads::count_or_single(1).count(), 1);
}
#[test]
fn test_preserve_metadata() {
let all = PreserveMetadata::all();
assert!(all.modification_time);
assert!(all.creation_time);
assert!(all.attributes);
let none = PreserveMetadata::none();
assert!(!none.modification_time);
assert!(!none.creation_time);
assert!(!none.attributes);
}
#[test]
fn test_preserve_metadata_times() {
let times = PreserveMetadata::times();
assert!(times.modification_time);
assert!(times.creation_time);
assert!(!times.attributes);
}
#[test]
fn test_preserve_metadata_modification_time_only() {
let mtime = PreserveMetadata::modification_time_only();
assert!(mtime.modification_time);
assert!(!mtime.creation_time);
assert!(!mtime.attributes);
}
#[test]
fn test_extract_options_builder() {
let opts = ExtractOptions::new()
.overwrite(OverwritePolicy::Skip)
.path_safety(PathSafety::Relaxed)
.threads(Threads::count_or_single(2));
assert_eq!(opts.overwrite, OverwritePolicy::Skip);
assert_eq!(opts.path_safety, PathSafety::Relaxed);
assert_eq!(opts.threads.count(), 2);
}
#[test]
fn test_filter_policy_default() {
assert_eq!(FilterPolicy::default(), FilterPolicy::Include);
}
#[test]
fn test_filter_policy_is_include_exclude() {
assert!(FilterPolicy::Include.is_include());
assert!(!FilterPolicy::Include.is_exclude());
assert!(!FilterPolicy::Exclude.is_include());
assert!(FilterPolicy::Exclude.is_exclude());
}
#[test]
fn test_filter_policy_apply() {
assert!(FilterPolicy::Include.apply(true));
assert!(!FilterPolicy::Include.apply(false));
assert!(!FilterPolicy::Exclude.apply(true));
assert!(FilterPolicy::Exclude.apply(false));
}
#[test]
fn test_threads_count_always_positive() {
use std::num::NonZeroUsize;
assert!(Threads::Single.count() >= 1);
let count = Threads::Count(NonZeroUsize::new(1).unwrap());
assert!(count.count() >= 1);
for n in [1, 2, 4, 8, 16, 100] {
let threads = Threads::Count(NonZeroUsize::new(n).unwrap());
assert!(threads.count() >= 1, "count() should always be >= 1");
assert_eq!(threads.count(), n);
}
assert!(Threads::Auto.count() >= 1);
}
#[test]
fn test_threads_count_or_single_invariants() {
for n in 0..=100 {
let threads = Threads::count_or_single(n);
assert!(
threads.count() >= 1,
"count_or_single({}) produced count() = {}",
n,
threads.count()
);
}
}
#[test]
fn test_extract_options_clone_settings() {
let original = ExtractOptions::new()
.overwrite(OverwritePolicy::Skip)
.path_safety(PathSafety::Relaxed)
.link_policy(LinkPolicy::Allow)
.threads(Threads::count_or_single(4));
let cloned = original.clone_settings();
assert_eq!(cloned.overwrite, OverwritePolicy::Skip);
assert_eq!(cloned.path_safety, PathSafety::Relaxed);
assert_eq!(cloned.link_policy, LinkPolicy::Allow);
assert_eq!(cloned.threads.count(), 4);
assert!(cloned.progress.is_none());
}
}