use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use tokio_util::sync::CancellationToken;
use crate::format::streams::ResourceLimits;
use crate::read::{LinkPolicy, OverwritePolicy, PathSafety, PreserveMetadata, Threads};
#[cfg(feature = "aes")]
use crate::async_password::AsyncPasswordProvider;
pub trait AsyncProgressCallback: Send + Sync {
fn on_entry_start(
&self,
entry_name: &str,
entry_size: u64,
) -> Pin<Box<dyn Future<Output = ()> + Send + '_>>;
fn on_progress(
&self,
bytes_extracted: u64,
total_bytes: u64,
) -> Pin<Box<dyn Future<Output = ()> + Send + '_>>;
fn on_entry_complete(
&self,
entry_name: &str,
success: bool,
) -> Pin<Box<dyn Future<Output = ()> + Send + '_>>;
}
pub struct AsyncExtractOptions {
pub overwrite: OverwritePolicy,
pub path_safety: PathSafety,
pub link_policy: LinkPolicy,
pub limits: ResourceLimits,
pub threads: Threads,
pub preserve_metadata: PreserveMetadata,
pub cancel_token: Option<CancellationToken>,
#[cfg(feature = "aes")]
pub password_provider: Option<Arc<dyn AsyncPasswordProvider>>,
pub progress: Option<Arc<dyn AsyncProgressCallback>>,
}
#[allow(clippy::derivable_impls)] impl Default for AsyncExtractOptions {
fn default() -> Self {
Self {
overwrite: OverwritePolicy::default(),
path_safety: PathSafety::default(),
link_policy: LinkPolicy::default(),
limits: ResourceLimits::default(),
threads: Threads::default(),
preserve_metadata: PreserveMetadata::default(),
cancel_token: None,
#[cfg(feature = "aes")]
password_provider: None,
progress: None,
}
}
}
impl std::fmt::Debug for AsyncExtractOptions {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AsyncExtractOptions")
.field("overwrite", &self.overwrite)
.field("path_safety", &self.path_safety)
.field("link_policy", &self.link_policy)
.field("threads", &self.threads)
.field("preserve_metadata", &self.preserve_metadata)
.field("has_cancel_token", &self.cancel_token.is_some())
.finish_non_exhaustive()
}
}
impl AsyncExtractOptions {
pub fn new() -> Self {
Self::default()
}
pub fn overwrite(mut self, policy: OverwritePolicy) -> Self {
self.overwrite = policy;
self
}
pub fn path_safety(mut self, policy: PathSafety) -> Self {
self.path_safety = policy;
self
}
pub fn link_policy(mut self, policy: LinkPolicy) -> Self {
self.link_policy = policy;
self
}
pub fn limits(mut self, limits: ResourceLimits) -> Self {
self.limits = limits;
self
}
pub fn threads(mut self, threads: Threads) -> Self {
self.threads = threads;
self
}
pub fn preserve_metadata(mut self, preserve: PreserveMetadata) -> Self {
self.preserve_metadata = preserve;
self
}
pub fn cancel_token(mut self, token: CancellationToken) -> Self {
self.cancel_token = Some(token);
self
}
#[cfg(feature = "aes")]
pub fn password_provider(mut self, provider: Arc<dyn AsyncPasswordProvider>) -> Self {
self.password_provider = Some(provider);
self
}
pub fn progress(mut self, callback: Arc<dyn AsyncProgressCallback>) -> Self {
self.progress = Some(callback);
self
}
pub fn is_cancelled(&self) -> bool {
self.cancel_token
.as_ref()
.map(|t| t.is_cancelled())
.unwrap_or(false)
}
}
pub struct AsyncTestOptions {
pub limits: ResourceLimits,
pub threads: Threads,
pub cancel_token: Option<CancellationToken>,
#[cfg(feature = "aes")]
pub password_provider: Option<Arc<dyn AsyncPasswordProvider>>,
pub progress: Option<Arc<dyn AsyncProgressCallback>>,
}
#[allow(clippy::derivable_impls)] impl Default for AsyncTestOptions {
fn default() -> Self {
Self {
limits: ResourceLimits::default(),
threads: Threads::default(),
cancel_token: None,
#[cfg(feature = "aes")]
password_provider: None,
progress: None,
}
}
}
impl std::fmt::Debug for AsyncTestOptions {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AsyncTestOptions")
.field("threads", &self.threads)
.field("has_cancel_token", &self.cancel_token.is_some())
.finish_non_exhaustive()
}
}
impl AsyncTestOptions {
pub fn new() -> Self {
Self::default()
}
pub fn limits(mut self, limits: ResourceLimits) -> Self {
self.limits = limits;
self
}
pub fn threads(mut self, threads: Threads) -> Self {
self.threads = threads;
self
}
pub fn cancel_token(mut self, token: CancellationToken) -> Self {
self.cancel_token = Some(token);
self
}
#[cfg(feature = "aes")]
pub fn password_provider(mut self, provider: Arc<dyn AsyncPasswordProvider>) -> Self {
self.password_provider = Some(provider);
self
}
pub fn progress(mut self, callback: Arc<dyn AsyncProgressCallback>) -> Self {
self.progress = Some(callback);
self
}
pub fn is_cancelled(&self) -> bool {
self.cancel_token
.as_ref()
.map(|t| t.is_cancelled())
.unwrap_or(false)
}
}
pub struct ChannelProgressReporter {
sender: tokio::sync::mpsc::Sender<ProgressEvent>,
}
#[derive(Debug, Clone)]
pub enum ProgressEvent {
EntryStart {
name: String,
size: u64,
},
Progress {
bytes_extracted: u64,
total_bytes: u64,
},
EntryComplete {
name: String,
success: bool,
},
}
impl ChannelProgressReporter {
pub fn new(buffer_size: usize) -> (Self, tokio::sync::mpsc::Receiver<ProgressEvent>) {
let (tx, rx) = tokio::sync::mpsc::channel(buffer_size);
(Self { sender: tx }, rx)
}
}
impl AsyncProgressCallback for ChannelProgressReporter {
fn on_entry_start(
&self,
entry_name: &str,
entry_size: u64,
) -> Pin<Box<dyn Future<Output = ()> + Send + '_>> {
let event = ProgressEvent::EntryStart {
name: entry_name.to_string(),
size: entry_size,
};
Box::pin(async move {
let _ = self.sender.send(event).await;
})
}
fn on_progress(
&self,
bytes_extracted: u64,
total_bytes: u64,
) -> Pin<Box<dyn Future<Output = ()> + Send + '_>> {
let event = ProgressEvent::Progress {
bytes_extracted,
total_bytes,
};
Box::pin(async move {
let _ = self.sender.send(event).await;
})
}
fn on_entry_complete(
&self,
entry_name: &str,
success: bool,
) -> Pin<Box<dyn Future<Output = ()> + Send + '_>> {
let event = ProgressEvent::EntryComplete {
name: entry_name.to_string(),
success,
};
Box::pin(async move {
let _ = self.sender.send(event).await;
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_async_extract_options_default() {
let opts = AsyncExtractOptions::default();
assert_eq!(opts.overwrite, OverwritePolicy::Error);
assert_eq!(opts.path_safety, PathSafety::Strict);
assert!(!opts.is_cancelled());
}
#[test]
fn test_async_extract_options_builder() {
let token = CancellationToken::new();
let opts = AsyncExtractOptions::new()
.overwrite(OverwritePolicy::Skip)
.path_safety(PathSafety::Relaxed)
.threads(Threads::count_or_single(4))
.cancel_token(token.clone());
assert_eq!(opts.overwrite, OverwritePolicy::Skip);
assert_eq!(opts.path_safety, PathSafety::Relaxed);
assert_eq!(opts.threads.count(), 4);
assert!(!opts.is_cancelled());
token.cancel();
assert!(opts.is_cancelled());
}
#[test]
fn test_async_test_options_default() {
let opts = AsyncTestOptions::default();
assert!(!opts.is_cancelled());
}
#[tokio::test]
async fn test_channel_progress_reporter() {
let (reporter, mut rx) = ChannelProgressReporter::new(10);
reporter.on_entry_start("test.txt", 100).await;
reporter.on_progress(50, 100).await;
reporter.on_entry_complete("test.txt", true).await;
let event1 = rx.recv().await.unwrap();
assert!(matches!(event1, ProgressEvent::EntryStart { .. }));
let event2 = rx.recv().await.unwrap();
assert!(matches!(event2, ProgressEvent::Progress { .. }));
let event3 = rx.recv().await.unwrap();
assert!(matches!(event3, ProgressEvent::EntryComplete { .. }));
}
}