use crate::{Archive, Error, Result};
use rayon::prelude::*;
use std::path::{Path, PathBuf};
use std::sync::Arc;
#[derive(Debug)]
pub struct ParallelArchive {
path: PathBuf,
file_list: Arc<Vec<String>>,
}
impl ParallelArchive {
pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
let path = path.as_ref().to_path_buf();
let mut archive = Archive::open(&path)?;
let entries = archive.list()?;
let file_list = Arc::new(entries.into_iter().map(|e| e.name).collect());
Ok(Self { path, file_list })
}
pub fn extract_files_parallel(&self, filenames: &[&str]) -> Result<Vec<(String, Vec<u8>)>> {
filenames
.par_iter()
.map(|&filename| {
let data = self.read_file_with_new_handle(filename)?;
Ok((filename.to_string(), data))
})
.collect()
}
pub fn extract_matching_parallel<F>(&self, predicate: F) -> Result<Vec<(String, Vec<u8>)>>
where
F: Fn(&str) -> bool + Sync,
{
let files = self.list_files();
files
.par_iter()
.filter(|name| predicate(name))
.map(|filename| {
let data = self.read_file_with_new_handle(filename)?;
Ok((filename.clone(), data))
})
.collect()
}
pub fn process_files_parallel<F, T>(&self, filenames: &[&str], processor: F) -> Result<Vec<T>>
where
F: Fn(&str, Vec<u8>) -> Result<T> + Sync,
T: Send,
{
filenames
.par_iter()
.map(|&filename| {
let data = self.read_file_with_new_handle(filename)?;
processor(filename, data)
})
.collect()
}
pub fn read_file_with_new_handle(&self, filename: &str) -> Result<Vec<u8>> {
let mut archive = Archive::open(&self.path)?;
archive.read_file(filename)
}
pub fn list_files(&self) -> &[String] {
&self.file_list
}
pub fn thread_count(&self) -> usize {
rayon::current_num_threads()
}
pub fn extract_files_batched(
&self,
filenames: &[&str],
batch_size: usize,
) -> Result<Vec<(String, Vec<u8>)>> {
let chunks: Vec<_> = filenames.chunks(batch_size).collect();
let results: Result<Vec<_>> = chunks
.par_iter()
.map(|chunk| {
let mut archive = Archive::open(&self.path)?;
let mut batch_results = Vec::new();
for &filename in chunk.iter() {
let data = archive.read_file(filename)?;
batch_results.push((filename.to_string(), data));
}
Ok(batch_results)
})
.collect();
Ok(results?.into_iter().flatten().collect())
}
}
#[derive(Debug, Clone)]
pub struct ParallelConfig {
pub num_threads: Option<usize>,
pub batch_size: usize,
pub skip_errors: bool,
}
impl Default for ParallelConfig {
fn default() -> Self {
Self {
num_threads: None,
batch_size: 10,
skip_errors: false,
}
}
}
impl ParallelConfig {
pub fn new() -> Self {
Self::default()
}
pub fn threads(mut self, num: usize) -> Self {
self.num_threads = Some(num);
self
}
pub fn batch_size(mut self, size: usize) -> Self {
self.batch_size = size;
self
}
pub fn skip_errors(mut self, skip: bool) -> Self {
self.skip_errors = skip;
self
}
}
pub fn extract_with_config<P: AsRef<Path>>(
archive_path: P,
filenames: &[&str],
config: ParallelConfig,
) -> Result<Vec<(String, Result<Vec<u8>>)>> {
let use_batched = filenames.len() > 1000;
if use_batched {
extract_with_config_batched(archive_path, filenames, config)
} else {
extract_with_config_unbatched(archive_path, filenames, config)
}
}
fn extract_with_config_batched<P: AsRef<Path>>(
archive_path: P,
filenames: &[&str],
config: ParallelConfig,
) -> Result<Vec<(String, Result<Vec<u8>>)>> {
let archive = ParallelArchive::open(archive_path)?;
let num_threads = config
.num_threads
.unwrap_or_else(rayon::current_num_threads);
let effective_batch_size = if filenames.len() > 5000 {
std::cmp::max(config.batch_size, filenames.len() / (num_threads * 2))
} else {
config.batch_size
};
let pool = if let Some(threads) = config.num_threads {
rayon::ThreadPoolBuilder::new()
.num_threads(threads)
.build()
.map_err(|e| {
Error::Io(std::io::Error::other(format!(
"Failed to create thread pool: {e}"
)))
})?
} else {
rayon::ThreadPoolBuilder::new().build().unwrap()
};
pool.install(|| {
let chunks: Vec<_> = filenames.chunks(effective_batch_size).collect();
let batch_results: Result<Vec<_>> = chunks
.par_iter()
.map(|chunk| {
let mut archive_handle = Archive::open(archive.path.as_path())?;
let mut batch_results = Vec::with_capacity(chunk.len());
for &filename in chunk.iter() {
let result = if config.skip_errors {
archive_handle.read_file(filename)
} else {
let data = archive_handle.read_file(filename)?;
Ok(data)
};
batch_results.push((filename.to_string(), result));
}
Ok(batch_results)
})
.collect();
match batch_results {
Ok(batches) => Ok(batches.into_iter().flatten().collect()),
Err(e) => Err(e),
}
})
}
fn extract_with_config_unbatched<P: AsRef<Path>>(
archive_path: P,
filenames: &[&str],
config: ParallelConfig,
) -> Result<Vec<(String, Result<Vec<u8>>)>> {
let archive = ParallelArchive::open(archive_path)?;
let pool = if let Some(threads) = config.num_threads {
rayon::ThreadPoolBuilder::new()
.num_threads(threads)
.build()
.map_err(|e| {
Error::Io(std::io::Error::other(format!(
"Failed to create thread pool: {e}"
)))
})?
} else {
rayon::ThreadPoolBuilder::new().build().unwrap()
};
pool.install(|| {
if config.skip_errors {
Ok(filenames
.par_iter()
.map(|&filename| {
let result = archive.read_file_with_new_handle(filename);
(filename.to_string(), result)
})
.collect())
} else {
let results: Result<Vec<_>> = filenames
.par_iter()
.map(|&filename| {
let data = archive.read_file_with_new_handle(filename)?;
Ok((filename.to_string(), Ok(data)))
})
.collect();
results
}
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ArchiveBuilder;
use tempfile::TempDir;
fn create_test_archive() -> (TempDir, PathBuf) {
let temp = TempDir::new().unwrap();
let path = temp.path().join("test.mpq");
let mut builder = ArchiveBuilder::new();
for i in 0..20 {
let content = format!("File {i} content with some data to make it larger").repeat(100);
builder = builder.add_file_data(content.into_bytes(), &format!("file_{i:02}.txt"));
}
builder.build(&path).unwrap();
(temp, path)
}
#[test]
fn test_parallel_extraction() {
let (_temp, archive_path) = create_test_archive();
let archive = ParallelArchive::open(&archive_path).unwrap();
let files = vec!["file_00.txt", "file_05.txt", "file_10.txt", "file_15.txt"];
let results = archive.extract_files_parallel(&files).unwrap();
assert_eq!(results.len(), 4);
for (filename, data) in results {
assert!(!data.is_empty());
assert!(files.contains(&filename.as_str()));
}
}
#[test]
fn test_extract_matching() {
let (_temp, archive_path) = create_test_archive();
let archive = ParallelArchive::open(&archive_path).unwrap();
let results = archive
.extract_matching_parallel(|name| name.ends_with("5.txt"))
.unwrap();
assert_eq!(results.len(), 2); }
#[test]
fn test_batched_extraction() {
let (_temp, archive_path) = create_test_archive();
let archive = ParallelArchive::open(&archive_path).unwrap();
let files: Vec<&str> = (0..10)
.map(|i| Box::leak(format!("file_{i:02}.txt").into_boxed_str()) as &str)
.collect();
let results = archive.extract_files_batched(&files, 3).unwrap();
assert_eq!(results.len(), 10);
}
#[test]
fn test_custom_processing() {
let (_temp, archive_path) = create_test_archive();
let archive = ParallelArchive::open(&archive_path).unwrap();
let files = vec!["file_00.txt", "file_01.txt"];
let sizes = archive
.process_files_parallel(&files, |_name, data| Ok(data.len()))
.unwrap();
assert_eq!(sizes.len(), 2);
for size in sizes {
assert!(size > 0);
}
}
#[test]
fn test_with_config() {
let (_temp, archive_path) = create_test_archive();
let config = ParallelConfig::new()
.threads(2)
.batch_size(5)
.skip_errors(true);
let files = vec!["file_00.txt", "nonexistent.txt", "file_01.txt"];
let results = extract_with_config(&archive_path, &files, config).unwrap();
assert_eq!(results.len(), 3);
assert!(results[0].1.is_ok());
assert!(results[1].1.is_err());
assert!(results[2].1.is_ok());
}
}