use super::{FileReader, MatchResult, WorkUnit, Worker, WorkerStats};
use crossbeam_channel::{bounded, Sender};
use std::fs;
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use std::thread;
use std::time::{Duration, Instant};
const LARGE_FILE: u64 = 1024 * 1024 * 1024; const HUGE_FILE: u64 = 10 * 1024 * 1024 * 1024;
const MAX_QUEUE_PER_WORKER: usize = 2;
#[derive(Debug)]
struct SystemState {
worker_queue_depth: AtomicUsize,
reader_queue_depth: AtomicUsize,
files_completed_recent: AtomicUsize,
chunks_processed_recent: AtomicUsize,
last_completion_ns: AtomicU64,
num_workers: usize,
}
impl SystemState {
fn new(num_workers: usize, _num_readers: usize) -> Self {
Self {
worker_queue_depth: AtomicUsize::new(0),
reader_queue_depth: AtomicUsize::new(0),
files_completed_recent: AtomicUsize::new(0),
chunks_processed_recent: AtomicUsize::new(0),
last_completion_ns: AtomicU64::new(0),
num_workers,
}
}
fn has_routing_capacity(&self) -> bool {
let worker_depth = self.worker_queue_depth.load(Ordering::Relaxed);
worker_depth < (MAX_QUEUE_PER_WORKER * self.num_workers)
}
fn record_file_completion(&self) {
self.files_completed_recent.fetch_add(1, Ordering::Relaxed);
self.last_completion_ns.store(
u64::try_from(Instant::now().elapsed().as_nanos()).unwrap_or(u64::MAX),
Ordering::Relaxed,
);
}
fn record_chunk_processed(&self) {
self.chunks_processed_recent.fetch_add(1, Ordering::Relaxed);
}
fn inc_worker_queue(&self) {
self.worker_queue_depth.fetch_add(1, Ordering::Relaxed);
}
fn dec_worker_queue(&self) {
self.worker_queue_depth.fetch_sub(1, Ordering::Relaxed);
}
fn inc_reader_queue(&self) {
self.reader_queue_depth.fetch_add(1, Ordering::Relaxed);
}
fn dec_reader_queue(&self) {
self.reader_queue_depth.fetch_sub(1, Ordering::Relaxed);
}
}
fn chunk_size_for(file_size: u64, is_compressed: bool) -> usize {
if is_compressed {
match file_size {
s if s < LARGE_FILE => 4 * 1024 * 1024, s if s < HUGE_FILE => 16 * 1024 * 1024, _ => 32 * 1024 * 1024, }
} else {
match file_size {
s if s < LARGE_FILE => 256 * 1024, s if s < HUGE_FILE => 1024 * 1024, _ => 4 * 1024 * 1024, }
}
}
fn reader_thread_chunker(file_path: &Path, work_sender: &Sender<WorkUnit>) -> Result<(), String> {
let is_stdin = file_path.to_str() == Some("-");
let chunk_size = if is_stdin {
256 * 1024 } else {
let metadata = fs::metadata(file_path)
.map_err(|e| format!("Failed to stat {}: {}", file_path.display(), e))?;
let file_size = metadata.len();
let is_compressed = is_file_compressed(file_path);
chunk_size_for(file_size, is_compressed)
};
let mut reader = FileReader::new(file_path, chunk_size)
.map_err(|e| format!("Failed to open {}: {}", file_path.display(), e))?;
while let Some(batch) = reader
.next_batch()
.map_err(|e| format!("Read error in {}: {}", file_path.display(), e))?
{
work_sender
.send(WorkUnit::Chunk { batch })
.map_err(|_| "Worker channel closed")?;
}
Ok(())
}
#[derive(Debug, Clone)]
struct FileInfo {
path: PathBuf,
size: u64,
is_stdin: bool,
is_compressed: bool,
}
#[derive(Debug, Clone)]
struct WorkloadStats {
median_size: u64,
p95_size: u64,
#[allow(dead_code)] total_bytes: u64,
}
#[derive(Debug, Clone, Default)]
pub struct RoutingStats {
pub files_to_workers: usize,
pub files_to_readers: usize,
pub bytes_to_workers: u64,
pub bytes_to_readers: u64,
}
impl RoutingStats {
#[must_use]
pub fn total_files(&self) -> usize {
self.files_to_workers + self.files_to_readers
}
#[must_use]
pub fn total_bytes(&self) -> u64 {
self.bytes_to_workers + self.bytes_to_readers
}
}
pub struct ParallelProcessingResult {
pub matches: Vec<MatchResult>,
pub routing_stats: RoutingStats,
pub worker_stats: WorkerStats,
pub actual_readers: usize,
pub actual_workers: usize,
}
fn is_file_compressed(path: &Path) -> bool {
if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
matches!(
ext.to_lowercase().as_str(),
"gz" | "bz2" | "xz" | "zst" | "lz4" | "lzma" | "z"
)
} else {
path.to_str()
.map(|s| {
s.ends_with(".tar.gz")
|| s.ends_with(".tar.bz2")
|| s.ends_with(".tar.xz")
|| s.ends_with(".tar.zst")
})
.unwrap_or(false)
}
}
fn collect_file_metadata(files: &[PathBuf]) -> Result<Vec<FileInfo>, String> {
let mut file_infos = Vec::with_capacity(files.len());
for path in files {
let is_stdin = path.to_str() == Some("-");
let size = if is_stdin {
0 } else {
fs::metadata(path)
.map_err(|e| format!("Failed to stat {}: {}", path.display(), e))?
.len()
};
let is_compressed = !is_stdin && is_file_compressed(path);
file_infos.push(FileInfo {
path: path.clone(),
size,
is_stdin,
is_compressed,
});
}
Ok(file_infos)
}
fn compute_workload_stats(file_infos: &[FileInfo]) -> WorkloadStats {
let mut sizes: Vec<u64> = file_infos
.iter()
.filter(|f| !f.is_stdin) .map(|f| f.size)
.collect();
if sizes.is_empty() {
return WorkloadStats {
median_size: 0,
p95_size: 0,
total_bytes: 0,
};
}
sizes.sort_unstable();
let median_size = sizes[sizes.len() / 2];
let p95_idx = sizes.len() * 95 / 100;
let p95_size = sizes[p95_idx.min(sizes.len() - 1)];
let total_bytes: u64 = sizes.iter().sum();
WorkloadStats {
median_size,
p95_size,
total_bytes,
}
}
fn decide_routing(
files_remaining: usize,
num_workers: usize,
file_size: u64,
is_compressed: bool,
stats: &WorkloadStats,
) -> bool {
if files_remaining > num_workers * 2 {
if is_compressed && file_size > 200 * 1024 * 1024 {
return true;
}
return false; }
if files_remaining > num_workers {
if is_compressed && file_size > 100 * 1024 * 1024 {
return true;
}
let is_huge_outlier =
file_size > stats.median_size.saturating_mul(10) && file_size > 500 * 1024 * 1024;
return is_huge_outlier;
}
if files_remaining > 3 {
if is_compressed && file_size > 50 * 1024 * 1024 {
return true;
}
let is_large =
file_size > stats.p95_size || file_size > stats.median_size.saturating_mul(5);
let is_worth_chunking = file_size >= 100 * 1024 * 1024; return is_large && is_worth_chunking;
}
if is_compressed && file_size > 50 * 1024 * 1024 {
return true;
}
let is_straggler =
file_size > stats.median_size.saturating_mul(2) && file_size > 300 * 1024 * 1024;
let is_huge_with_small_median = file_size > 1024 * 1024 * 1024 && stats.median_size < 1024 * 1024 * 1024;
is_straggler || is_huge_with_small_median
}
fn count_files_to_chunk(
file_infos: &[FileInfo],
workload_stats: &WorkloadStats,
num_workers: usize,
) -> usize {
let file_count = file_infos.len();
let mut count = 0;
for (idx, file_info) in file_infos.iter().enumerate() {
if file_info.is_stdin {
count += 1; continue;
}
let files_remaining = file_count - idx;
if decide_routing(
files_remaining,
num_workers,
file_info.size,
file_info.is_compressed,
workload_stats,
) {
count += 1;
}
}
count
}
fn process_work_unit_with_worker(
unit: &WorkUnit,
worker: &mut Worker,
) -> Result<Vec<MatchResult>, String> {
match unit {
WorkUnit::WholeFile { path } => {
let mut reader = FileReader::new(path, 128 * 1024)
.map_err(|e| format!("Failed to open {}: {}", path.display(), e))?;
let mut all_matches = Vec::new();
while let Some(batch) = reader
.next_batch()
.map_err(|e| format!("Read error in {}: {}", path.display(), e))?
{
let matches = worker.process_batch(&batch)?;
all_matches.extend(matches);
}
Ok(all_matches)
}
WorkUnit::Chunk { batch } => {
worker.process_batch(batch)
}
}
}
pub fn process_files_parallel<F, P>(
files: &[PathBuf],
num_readers: Option<usize>,
num_workers: Option<usize>,
create_worker: F,
progress_callback: Option<P>,
debug_routing: bool,
) -> Result<ParallelProcessingResult, String>
where
F: Fn() -> Result<Worker, String> + Sync + Send + 'static,
P: Fn(&WorkerStats) + Sync + Send + 'static,
{
let num_cpus = thread::available_parallelism()
.map(std::num::NonZero::get)
.unwrap_or(4);
let num_workers = num_workers.unwrap_or(num_cpus);
let file_infos = collect_file_metadata(files)?;
let workload_stats = compute_workload_stats(&file_infos);
let file_count = file_infos.len();
let files_to_chunk = count_files_to_chunk(&file_infos, &workload_stats, num_workers);
if debug_routing {
eprintln!("\n[DEBUG] === Routing Analysis ===");
eprintln!("[DEBUG] Workload statistics:");
eprintln!("[DEBUG] Total files: {file_count}");
eprintln!(
"[DEBUG] Median size: {} bytes ({:.2} MB)",
workload_stats.median_size,
workload_stats.median_size as f64 / (1024.0 * 1024.0)
);
eprintln!(
"[DEBUG] P95 size: {} bytes ({:.2} MB)",
workload_stats.p95_size,
workload_stats.p95_size as f64 / (1024.0 * 1024.0)
);
eprintln!(
"[DEBUG] Total bytes: {} ({:.2} GB)",
workload_stats.total_bytes,
workload_stats.total_bytes as f64 / (1024.0 * 1024.0 * 1024.0)
);
eprintln!("[DEBUG] Workers: {num_workers}");
eprintln!("[DEBUG] Predicted files to chunk: {files_to_chunk}");
eprintln!();
}
let num_readers = num_readers.unwrap_or_else(|| {
if files_to_chunk == 0 {
0 } else if files_to_chunk <= 3 {
1 } else if files_to_chunk <= 10 {
2 } else {
(files_to_chunk / 10).max(2).min(num_workers / 3)
}
});
const MAX_WORK_QUEUE_SIZE: usize = 32;
const MAX_FILE_QUEUE_SIZE: usize = 16;
let file_queue_size = (num_readers.max(1) * 2).min(MAX_FILE_QUEUE_SIZE);
let work_queue_size = (num_workers * MAX_QUEUE_PER_WORKER).min(MAX_WORK_QUEUE_SIZE);
let (file_sender, file_receiver) = bounded::<PathBuf>(file_queue_size);
let (work_sender, work_receiver) = bounded::<WorkUnit>(work_queue_size);
let worker_factory = Arc::new(create_worker);
let progress_callback = progress_callback.map(Arc::new);
let worker_stats_map = Arc::new(Mutex::new(
std::collections::HashMap::<usize, WorkerStats>::new(),
));
let system_state = Arc::new(SystemState::new(num_workers, num_readers));
let mut reader_handles = Vec::new();
if num_readers > 0 {
for _reader_id in 0..num_readers {
let file_rx = file_receiver.clone();
let work_tx = work_sender.clone();
let state = Arc::clone(&system_state);
let handle = thread::spawn(move || {
while let Ok(file_path) = file_rx.recv() {
state.dec_reader_queue();
if let Err(e) = reader_thread_chunker(&file_path, &work_tx) {
eprintln!("Reader error: {e}");
}
state.record_chunk_processed();
}
});
reader_handles.push(handle);
}
}
let mut worker_handles = Vec::new();
for worker_id in 0..num_workers {
let receiver = work_receiver.clone();
let factory = Arc::clone(&worker_factory);
let state = Arc::clone(&system_state);
let progress_cb = progress_callback.clone();
let stats_map = Arc::clone(&worker_stats_map);
let handle = thread::spawn(move || -> (Vec<MatchResult>, WorkerStats) {
let mut worker = match factory() {
Ok(w) => w,
Err(e) => {
eprintln!("Worker creation failed: {e}");
return (Vec::new(), WorkerStats::default());
}
};
let mut local_matches = Vec::new();
let mut last_progress = std::time::Instant::now();
let progress_interval = std::time::Duration::from_millis(100);
while let Ok(unit) = receiver.recv() {
state.dec_worker_queue();
match process_work_unit_with_worker(&unit, &mut worker) {
Ok(matches) => {
local_matches.extend(matches);
}
Err(e) => {
eprintln!("Processing error: {e}");
}
}
if matches!(unit, WorkUnit::WholeFile { .. }) {
state.record_file_completion();
}
if let Some(ref cb) = progress_cb {
let now = std::time::Instant::now();
if now.duration_since(last_progress) >= progress_interval {
stats_map
.lock()
.unwrap()
.insert(worker_id, worker.stats().clone());
let aggregated = {
let map = stats_map.lock().unwrap();
let mut agg = WorkerStats::default();
for stats in map.values() {
agg.lines_processed += stats.lines_processed;
agg.candidates_tested += stats.candidates_tested;
agg.matches_found += stats.matches_found;
agg.total_bytes += stats.total_bytes;
agg.extraction_time += stats.extraction_time;
agg.extraction_samples += stats.extraction_samples;
agg.lookup_time += stats.lookup_time;
agg.lookup_samples += stats.lookup_samples;
agg.ipv4_count += stats.ipv4_count;
agg.ipv6_count += stats.ipv6_count;
agg.domain_count += stats.domain_count;
agg.email_count += stats.email_count;
}
agg
};
cb(&aggregated);
last_progress = now;
}
}
}
let stats = worker.stats().clone();
(local_matches, stats)
});
worker_handles.push(handle);
}
let mut routing_stats = RoutingStats::default();
if debug_routing {
eprintln!("[DEBUG] === Per-File Routing Decisions ===");
}
for (idx, file_info) in file_infos.iter().enumerate() {
let files_remaining = file_count - idx;
while !system_state.has_routing_capacity() {
thread::sleep(Duration::from_millis(10));
}
if file_info.is_stdin {
routing_stats.files_to_readers += 1;
routing_stats.bytes_to_readers += 0;
if debug_routing {
eprintln!(
"[DEBUG] File {}: {} (stdin) → READER (unknown size, always chunk)",
idx,
file_info.path.display()
);
}
system_state.inc_reader_queue(); file_sender
.send(file_info.path.clone())
.map_err(|_| "File queue closed unexpectedly")?;
} else {
let should_chunk = decide_routing(
files_remaining,
num_workers,
file_info.size,
file_info.is_compressed,
&workload_stats,
);
let scenario = if files_remaining > num_workers * 2 {
"Scenario 1: many files"
} else if files_remaining > num_workers {
"Scenario 2: moderate files"
} else if files_remaining > 3 {
"Scenario 3: few files"
} else {
"Scenario 4: last few files (straggler detection)"
};
if should_chunk && num_readers > 0 {
routing_stats.files_to_readers += 1;
routing_stats.bytes_to_readers += file_info.size;
if debug_routing {
let size_mb = file_info.size as f64 / (1024.0 * 1024.0);
let vs_median =
file_info.size as f64 / workload_stats.median_size.max(1) as f64;
eprintln!(
"[DEBUG] File {}: {} ({:.1} MB, {:.1}x median) → READER ({})",
idx,
file_info.path.display(),
size_mb,
vs_median,
scenario
);
}
system_state.inc_reader_queue(); file_sender
.send(file_info.path.clone())
.map_err(|_| "File queue closed unexpectedly")?;
} else {
routing_stats.files_to_workers += 1;
routing_stats.bytes_to_workers += file_info.size;
if debug_routing {
let size_mb = file_info.size as f64 / (1024.0 * 1024.0);
let vs_median =
file_info.size as f64 / workload_stats.median_size.max(1) as f64;
eprintln!(
"[DEBUG] File {}: {} ({:.1} MB, {:.1}x median) → WORKER ({})",
idx,
file_info.path.display(),
size_mb,
vs_median,
scenario
);
}
system_state.inc_worker_queue(); work_sender
.send(WorkUnit::WholeFile {
path: file_info.path.clone(),
})
.map_err(|_| "Work queue closed unexpectedly")?;
}
}
}
if debug_routing {
eprintln!("\n[DEBUG] === Routing Summary ===");
eprintln!("[DEBUG] Readers spawned: {num_readers}");
eprintln!(
"[DEBUG] Files to workers: {}",
routing_stats.files_to_workers
);
eprintln!(
"[DEBUG] Files to readers: {}",
routing_stats.files_to_readers
);
eprintln!();
}
drop(file_sender);
for handle in reader_handles {
if let Err(e) = handle.join() {
eprintln!("Reader thread panicked: {e:?}");
}
}
drop(work_sender);
let mut all_matches = Vec::new();
let mut aggregate_stats = WorkerStats::default();
for handle in worker_handles {
match handle.join() {
Ok((matches, stats)) => {
all_matches.extend(matches);
aggregate_stats.lines_processed += stats.lines_processed;
aggregate_stats.candidates_tested += stats.candidates_tested;
aggregate_stats.matches_found += stats.matches_found;
aggregate_stats.total_bytes += stats.total_bytes;
aggregate_stats.extraction_time += stats.extraction_time;
aggregate_stats.extraction_samples += stats.extraction_samples;
aggregate_stats.lookup_time += stats.lookup_time;
aggregate_stats.lookup_samples += stats.lookup_samples;
aggregate_stats.ipv4_count += stats.ipv4_count;
aggregate_stats.ipv6_count += stats.ipv6_count;
aggregate_stats.domain_count += stats.domain_count;
aggregate_stats.email_count += stats.email_count;
}
Err(e) => {
eprintln!("Worker thread panicked: {e:?}");
}
}
}
Ok(ParallelProcessingResult {
matches: all_matches,
routing_stats,
worker_stats: aggregate_stats,
actual_readers: num_readers,
actual_workers: num_workers,
})
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
#[test]
fn test_chunk_size_selection() {
assert_eq!(chunk_size_for(500 * 1024 * 1024, false), 256 * 1024);
assert_eq!(chunk_size_for(5 * 1024 * 1024 * 1024, false), 1024 * 1024);
assert_eq!(
chunk_size_for(50 * 1024 * 1024 * 1024, false),
4 * 1024 * 1024
);
assert_eq!(chunk_size_for(500 * 1024 * 1024, true), 4 * 1024 * 1024);
assert_eq!(
chunk_size_for(5 * 1024 * 1024 * 1024, true),
16 * 1024 * 1024
);
assert_eq!(
chunk_size_for(50 * 1024 * 1024 * 1024, true),
32 * 1024 * 1024
);
}
#[test]
fn test_routing_scenario_many_huge_files() {
let num_workers = 8;
let stats = WorkloadStats {
median_size: 5 * 1024 * 1024 * 1024, p95_size: 8 * 1024 * 1024 * 1024, total_bytes: 5000 * 1024 * 1024 * 1024, };
for i in 0..950 {
let files_remaining = 1000 - i;
let should_chunk = decide_routing(
files_remaining,
num_workers,
5 * 1024 * 1024 * 1024, false, &stats,
);
assert!(
!should_chunk,
"File {i} (remaining={files_remaining}) should NOT chunk with many files"
);
}
for i in 950..997 {
let files_remaining = 1000 - i;
let should_chunk = decide_routing(
files_remaining,
num_workers,
5 * 1024 * 1024 * 1024, false, &stats,
);
assert!(
!should_chunk,
"File {i} (remaining={files_remaining}) should NOT chunk (not an outlier)"
);
}
for i in 997..1000 {
let files_remaining = 1000 - i;
let should_chunk_normal = decide_routing(
files_remaining,
num_workers,
5 * 1024 * 1024 * 1024, false, &stats,
);
let should_chunk_large = decide_routing(
files_remaining,
num_workers,
12 * 1024 * 1024 * 1024, false, &stats,
);
assert!(
!should_chunk_normal,
"File {i} (remaining={files_remaining}, 5GB) should NOT chunk (< 2x median)"
);
assert!(
should_chunk_large,
"File {i} (remaining={files_remaining}, 12GB) SHOULD chunk (> 2x median)"
);
}
}
#[test]
fn test_routing_scenario_journal_logs_with_tarball() {
let num_workers = 8;
let stats = WorkloadStats {
median_size: 209715200, p95_size: 209715200, total_bytes: 3766210481, };
for i in 0..15 {
let files_remaining = 16 - i;
let should_chunk = decide_routing(
files_remaining,
num_workers,
209715200, false, &stats,
);
assert!(
!should_chunk,
"File {i} (remaining={files_remaining}) should NOT chunk (many files)"
);
}
let should_chunk_tarball = decide_routing(
1, num_workers,
616354689, false, &stats,
);
assert!(
should_chunk_tarball,
"Last file (600MB, 3x median) SHOULD chunk to avoid straggler"
);
}
#[test]
fn test_routing_scenario_five_large_files_with_outlier() {
let num_workers = 16;
let stats = WorkloadStats {
median_size: 120 * 1024 * 1024, p95_size: 130 * 1024 * 1024, total_bytes: 1800 * 1024 * 1024, };
for i in 0..4 {
let files_remaining = 5 - i;
let should_chunk = decide_routing(
files_remaining,
num_workers,
120 * 1024 * 1024, false, &stats,
);
assert!(
!should_chunk,
"File {i} (remaining={files_remaining}, 120MB) should NOT chunk"
);
}
let should_chunk_outlier = decide_routing(
1,
num_workers,
1346 * 1024 * 1024, false, &stats,
);
assert!(
should_chunk_outlier,
"Last file (1.3GB outlier) SHOULD chunk (> 2x median)"
);
}
#[test]
fn test_routing_scenario_single_massive_file() {
let num_workers = 16;
let stats = WorkloadStats {
median_size: 100 * 1024 * 1024 * 1024, p95_size: 100 * 1024 * 1024 * 1024,
total_bytes: 100 * 1024 * 1024 * 1024,
};
let should_chunk = decide_routing(
1,
num_workers,
100 * 1024 * 1024 * 1024, false, &stats,
);
assert!(
!should_chunk,
"Single file where median=size doesn't chunk (use --readers to override)"
);
}
#[test]
fn test_routing_scenario_many_small_files() {
let num_workers = 16;
let stats = WorkloadStats {
median_size: 1024 * 1024, p95_size: 1024 * 1024, total_bytes: 10000 * 1024 * 1024, };
for i in 0..10000 {
let files_remaining = 10000 - i;
let should_chunk = decide_routing(
files_remaining,
num_workers,
1024 * 1024, false, &stats,
);
assert!(
!should_chunk,
"File {i} should NOT chunk (many small files)"
);
}
}
#[test]
fn test_routing_scenario_moderate_outlier_in_middle() {
let num_workers = 8;
let stats = WorkloadStats {
median_size: 100 * 1024 * 1024, p95_size: 100 * 1024 * 1024,
total_bytes: 5000 * 1024 * 1024, };
for i in 0..33 {
let files_remaining = 50 - i;
let should_chunk = decide_routing(
files_remaining,
num_workers,
100 * 1024 * 1024, false, &stats,
);
assert!(
!should_chunk,
"File {i} (remaining={files_remaining}) should NOT chunk (many remaining)"
);
}
let should_chunk_outlier = decide_routing(
25,
num_workers,
5 * 1024 * 1024 * 1024, false, &stats,
);
assert!(
!should_chunk_outlier,
"Outlier with many files remaining (25 > 16) should NOT chunk (Scenario 1)"
);
}
#[test]
fn test_routing_count_files_to_chunk() {
let num_workers = 8;
let mut file_infos = Vec::new();
for _ in 0..50 {
file_infos.push(FileInfo {
path: PathBuf::from("file.log"),
size: 200 * 1024 * 1024, is_stdin: false,
is_compressed: false,
});
}
file_infos.push(FileInfo {
path: PathBuf::from("huge.tar.gz"),
size: 2 * 1024 * 1024 * 1024, is_stdin: false,
is_compressed: true, });
let workload_stats = compute_workload_stats(&file_infos);
let files_to_chunk = count_files_to_chunk(&file_infos, &workload_stats, num_workers);
assert_eq!(
files_to_chunk, 1,
"Should chunk exactly 1 file (the outlier)"
);
}
#[test]
fn test_process_files_parallel_basic() {
use crate::extractor::Extractor;
use crate::{DatabaseBuilder, MatchMode};
use std::collections::HashMap;
let mut builder = DatabaseBuilder::new(MatchMode::CaseSensitive);
let data = HashMap::new();
builder.add_ip("192.168.1.1", data).unwrap();
let db_bytes = builder.build().unwrap();
let mut db_file = NamedTempFile::new().unwrap();
db_file.write_all(&db_bytes).unwrap();
db_file.flush().unwrap();
let db_path = db_file.path().to_path_buf();
let mut test_file = NamedTempFile::new().unwrap();
writeln!(test_file, "Connection from 192.168.1.1").unwrap();
writeln!(test_file, "Another line").unwrap();
test_file.flush().unwrap();
let files = vec![test_file.path().to_path_buf()];
let result = process_files_parallel(
&files,
Some(1), Some(2), move || {
let db = crate::Database::from(db_path.to_str().unwrap())
.open()
.map_err(|e| e.to_string())?;
let extractor = Extractor::new().map_err(|e| e.to_string())?;
Ok(Worker::builder()
.extractor(extractor)
.add_database("test", Arc::new(db))
.build())
},
None::<fn(&WorkerStats)>,
false,
)
.unwrap();
assert_eq!(result.matches.len(), 1);
assert_eq!(result.matches[0].matched_text, "192.168.1.1");
}
#[test]
fn test_process_files_parallel_multiple_files() {
use crate::extractor::Extractor;
use crate::{DatabaseBuilder, MatchMode};
use std::collections::HashMap;
let mut builder = DatabaseBuilder::new(MatchMode::CaseSensitive);
let data = HashMap::new();
builder.add_ip("10.0.0.1", data.clone()).unwrap();
builder.add_ip("10.0.0.2", data).unwrap();
let db_bytes = builder.build().unwrap();
let mut db_file = NamedTempFile::new().unwrap();
db_file.write_all(&db_bytes).unwrap();
db_file.flush().unwrap();
let db_path = db_file.path().to_path_buf();
let mut file1 = NamedTempFile::new().unwrap();
writeln!(file1, "IP: 10.0.0.1").unwrap();
file1.flush().unwrap();
let mut file2 = NamedTempFile::new().unwrap();
writeln!(file2, "IP: 10.0.0.2").unwrap();
file2.flush().unwrap();
let files = vec![file1.path().to_path_buf(), file2.path().to_path_buf()];
let result = process_files_parallel(
&files,
Some(0), Some(4), move || {
let db = crate::Database::from(db_path.to_str().unwrap())
.open()
.map_err(|e| e.to_string())?;
let extractor = Extractor::new().map_err(|e| e.to_string())?;
Ok(Worker::builder()
.extractor(extractor)
.add_database("test", Arc::new(db))
.build())
},
None::<fn(&WorkerStats)>,
false,
)
.unwrap();
assert_eq!(result.matches.len(), 2);
let matched_texts: Vec<&str> = result
.matches
.iter()
.map(|m| m.matched_text.as_str())
.collect();
assert!(matched_texts.contains(&"10.0.0.1"));
assert!(matched_texts.contains(&"10.0.0.2"));
assert_eq!(result.routing_stats.total_files(), 2);
}
#[test]
fn test_bounded_channel_backpressure() {
use crossbeam_channel::bounded;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::thread;
use std::time::Duration;
let channel_capacity = 8;
let (sender, receiver) = bounded::<Vec<u8>>(channel_capacity);
let max_channel_len = Arc::new(AtomicUsize::new(0));
let total_items = 200;
let num_consumers = 2;
let mut consumer_handles = Vec::new();
for _ in 0..num_consumers {
let rx = receiver.clone();
let max_len = Arc::clone(&max_channel_len);
let handle = thread::spawn(move || {
let mut count = 0;
while let Ok(_item) = rx.recv() {
let len = rx.len();
let mut max = max_len.load(Ordering::Relaxed);
while len > max {
match max_len.compare_exchange_weak(
max,
len,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(m) => max = m,
}
}
thread::sleep(Duration::from_millis(2));
count += 1;
}
count
});
consumer_handles.push(handle);
}
drop(receiver);
let num_producers = 4;
let items_per_producer = total_items / num_producers;
let mut producer_handles = Vec::new();
for _ in 0..num_producers {
let tx = sender.clone();
let max_len = Arc::clone(&max_channel_len);
let handle = thread::spawn(move || {
for _ in 0..items_per_producer {
let chunk = vec![0u8; 1024];
let len = tx.len();
let mut max = max_len.load(Ordering::Relaxed);
while len > max {
match max_len.compare_exchange_weak(
max,
len,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(m) => max = m,
}
}
if tx.send(chunk).is_err() {
break;
}
}
});
producer_handles.push(handle);
}
drop(sender);
for handle in producer_handles {
handle.join().unwrap();
}
let total_consumed: usize = consumer_handles
.into_iter()
.map(|h| h.join().unwrap())
.sum();
assert_eq!(total_consumed, total_items);
let observed_max_len = max_channel_len.load(Ordering::Relaxed);
assert!(
observed_max_len <= channel_capacity,
"Channel exceeded capacity: max observed length ({observed_max_len}) > capacity ({channel_capacity}). \
With unbounded channels this would grow to {total_items}."
);
assert!(
observed_max_len >= channel_capacity / 2,
"Test may not have applied enough pressure: max channel length ({observed_max_len}) \
was less than half capacity ({channel_capacity}). Increase total_items or slow down consumers."
);
}
}