use crate::extractor::{ExtractedItem, Extractor, HashType};
use crate::{Database, QueryResult};
use std::io::{self, BufRead, Read};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Instant;
#[cfg(not(target_family = "wasm"))]
mod parallel;
#[cfg(not(target_family = "wasm"))]
pub use parallel::{process_files_parallel, ParallelProcessingResult, RoutingStats};
#[derive(Clone)]
pub enum WorkUnit {
WholeFile {
path: PathBuf,
},
Chunk {
batch: DataBatch,
},
}
#[derive(Clone)]
pub struct DataBatch {
pub source: PathBuf,
pub data: Arc<Vec<u8>>,
}
#[derive(Default, Clone, Debug)]
pub struct WorkerStats {
pub lines_processed: usize,
pub candidates_tested: usize,
pub matches_found: usize,
pub total_bytes: usize,
pub extraction_time: std::time::Duration,
pub extraction_samples: usize,
pub lookup_time: std::time::Duration,
pub lookup_samples: usize,
pub ipv4_count: usize,
pub ipv6_count: usize,
pub domain_count: usize,
pub email_count: usize,
pub md5_count: usize,
pub sha1_count: usize,
pub sha256_count: usize,
pub sha384_count: usize,
pub sha512_count: usize,
pub bitcoin_count: usize,
pub ethereum_count: usize,
pub monero_count: usize,
}
#[derive(Clone, Debug)]
pub struct MatchResult {
pub matched_text: String,
pub match_type: String,
pub result: QueryResult,
pub database_id: String,
pub source: PathBuf,
pub byte_offset: usize,
}
pub struct FileReader {
source_path: PathBuf,
reader: Box<dyn BufRead + Send>,
read_buffer: Vec<u8>,
eof: bool,
leftover: Vec<u8>, }
impl FileReader {
pub fn new<P: AsRef<Path>>(path: P, chunk_size: usize) -> io::Result<Self> {
let path = path.as_ref();
let reader = crate::file_reader::open(path)?;
Ok(Self {
source_path: path.to_path_buf(),
reader,
read_buffer: vec![0u8; chunk_size],
eof: false,
leftover: Vec::with_capacity(chunk_size),
})
}
pub fn next_batch(&mut self) -> io::Result<Option<DataBatch>> {
if self.eof {
return Ok(None);
}
let bytes_read = self.reader.read(&mut self.read_buffer)?;
if bytes_read == 0 {
self.eof = true;
if !self.leftover.is_empty() {
let chunk = std::mem::take(&mut self.leftover);
return Ok(Some(DataBatch {
source: self.source_path.clone(),
data: Arc::new(chunk),
}));
}
return Ok(None);
}
let mut combined = std::mem::take(&mut self.leftover);
combined.extend_from_slice(&self.read_buffer[..bytes_read]);
let chunk_end = if let Some(pos) = memchr::memrchr(b'\n', &combined) {
pos + 1 } else {
self.leftover = combined;
return self.next_batch(); };
let mut chunk = combined;
if chunk_end < chunk.len() {
self.leftover = chunk.split_off(chunk_end);
}
chunk.truncate(chunk_end);
Ok(Some(DataBatch {
source: self.source_path.clone(),
data: Arc::new(chunk),
}))
}
#[must_use]
pub fn batches(self) -> DataBatchIter {
DataBatchIter { reader: self }
}
}
pub struct DataBatchIter {
reader: FileReader,
}
impl Iterator for DataBatchIter {
type Item = io::Result<DataBatch>;
fn next(&mut self) -> Option<Self::Item> {
match self.reader.next_batch() {
Ok(Some(batch)) => Some(Ok(batch)),
Ok(None) => None,
Err(e) => Some(Err(e)),
}
}
}
pub struct Worker {
extractor: Extractor,
databases: Vec<(String, Arc<Database>)>, stats: WorkerStats,
}
impl Worker {
#[must_use]
pub fn builder() -> WorkerBuilder {
WorkerBuilder::new()
}
pub fn process_bytes(&mut self, data: &[u8]) -> Result<Vec<MatchResult>, String> {
let mut results = Vec::new();
self.stats.lines_processed += memchr::memchr_iter(b'\n', data).count();
self.stats.total_bytes += data.len();
let should_sample_extraction = self.stats.extraction_samples < 100_000
&& self.stats.candidates_tested.is_multiple_of(1000);
let extraction_start = if should_sample_extraction {
Some(Instant::now())
} else {
None
};
let extracted = self.extractor.extract_from_chunk(data);
if let Some(start) = extraction_start {
self.stats.extraction_time += start.elapsed();
self.stats.extraction_samples += 1;
}
for item in extracted {
self.stats.candidates_tested += 1;
match &item.item {
ExtractedItem::Ipv4(_) => self.stats.ipv4_count += 1,
ExtractedItem::Ipv6(_) => self.stats.ipv6_count += 1,
ExtractedItem::Domain(_) => self.stats.domain_count += 1,
ExtractedItem::Email(_) => self.stats.email_count += 1,
ExtractedItem::Hash(hash_type, _) => match hash_type {
HashType::Md5 => self.stats.md5_count += 1,
HashType::Sha1 => self.stats.sha1_count += 1,
HashType::Sha256 => self.stats.sha256_count += 1,
HashType::Sha384 => self.stats.sha384_count += 1,
HashType::Sha512 => self.stats.sha512_count += 1,
},
ExtractedItem::Bitcoin(_) => self.stats.bitcoin_count += 1,
ExtractedItem::Ethereum(_) => self.stats.ethereum_count += 1,
ExtractedItem::Monero(_) => self.stats.monero_count += 1,
}
let should_sample_lookup = self.stats.lookup_samples < 100_000
&& self.stats.candidates_tested.is_multiple_of(100);
for (database_id, database) in &self.databases {
let lookup_start = if should_sample_lookup {
Some(Instant::now())
} else {
None
};
let result_opt = database
.lookup_extracted(&item, data)
.map_err(|e| e.to_string())?;
if let Some(start) = lookup_start {
self.stats.lookup_time += start.elapsed();
self.stats.lookup_samples += 1;
}
if let Some(query_result) = result_opt {
if matches!(query_result, crate::QueryResult::NotFound) {
continue;
}
self.stats.matches_found += 1;
let matched_text = item.as_str(data).to_string();
results.push(MatchResult {
matched_text,
match_type: item.item.type_name().to_string(),
result: query_result,
database_id: database_id.clone(),
source: PathBuf::from(""), byte_offset: item.span.0,
});
}
}
}
Ok(results)
}
pub fn process_batch(&mut self, batch: &DataBatch) -> Result<Vec<MatchResult>, String> {
let mut match_results = self.process_bytes(&batch.data)?;
for m in &mut match_results {
m.source = batch.source.clone();
}
Ok(match_results)
}
#[must_use]
pub fn stats(&self) -> &WorkerStats {
&self.stats
}
pub fn reset_stats(&mut self) {
self.stats = WorkerStats::default();
}
}
pub struct WorkerBuilder {
extractor: Option<Extractor>,
databases: Vec<(String, Arc<Database>)>,
}
impl WorkerBuilder {
#[must_use]
pub fn new() -> Self {
Self {
extractor: None,
databases: Vec::new(),
}
}
#[must_use]
pub fn extractor(mut self, extractor: Extractor) -> Self {
self.extractor = Some(extractor);
self
}
#[must_use]
pub fn add_database(mut self, id: impl Into<String>, database: Arc<Database>) -> Self {
self.databases.push((id.into(), database));
self
}
#[must_use]
pub fn build(self) -> Worker {
let extractor = self
.extractor
.expect("Extractor not set - call .extractor()");
assert!(
!self.databases.is_empty(),
"No databases added - call .add_database() at least once"
);
Worker {
extractor,
databases: self.databases,
stats: WorkerStats::default(),
}
}
}
impl Default for WorkerBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
#[test]
fn test_file_reader_basic() {
let mut file = NamedTempFile::new().unwrap();
writeln!(file, "line 1").unwrap();
writeln!(file, "line 2").unwrap();
writeln!(file, "line 3").unwrap();
file.flush().unwrap();
let mut reader = FileReader::new(file.path(), 1024).unwrap();
let batch = reader.next_batch().unwrap().unwrap();
assert!(!batch.data.is_empty());
assert_eq!(batch.source, file.path());
}
#[test]
fn test_batch_iter() {
let mut file = NamedTempFile::new().unwrap();
for i in 1..=10 {
writeln!(file, "line {i}").unwrap();
}
file.flush().unwrap();
let reader = FileReader::new(file.path(), 1024).unwrap();
let batches: Vec<_> = reader.batches().collect::<io::Result<Vec<_>>>().unwrap();
assert!(!batches.is_empty());
let total_bytes: usize = batches.iter().map(|b| b.data.len()).sum();
assert!(total_bytes > 0);
}
#[test]
fn test_worker_process_bytes() {
use crate::extractor::Extractor;
use crate::{DatabaseBuilder, MatchMode};
use std::collections::HashMap;
let mut builder = DatabaseBuilder::new(MatchMode::CaseSensitive);
let mut data = HashMap::new();
data.insert(
"type".to_string(),
crate::DataValue::String("threat".to_string()),
);
builder.add_ip("1.2.3.4", data).unwrap();
let db_bytes = builder.build().unwrap();
let mut tmpfile = NamedTempFile::new().unwrap();
tmpfile.write_all(&db_bytes).unwrap();
tmpfile.flush().unwrap();
let db = crate::Database::from(tmpfile.path().to_str().unwrap())
.open()
.unwrap();
let extractor = Extractor::new().unwrap();
let mut worker = Worker::builder()
.extractor(extractor)
.add_database("test", Arc::new(db))
.build();
let input = b"Connection from 1.2.3.4 detected";
let matches = worker.process_bytes(input).unwrap();
assert_eq!(matches.len(), 1);
assert_eq!(matches[0].matched_text, "1.2.3.4");
assert_eq!(matches[0].match_type, "IPv4");
let stats = worker.stats();
assert_eq!(stats.matches_found, 1);
assert!(stats.candidates_tested > 0);
}
#[test]
fn test_worker_process_batch() {
use crate::extractor::Extractor;
use crate::{DatabaseBuilder, MatchMode};
use std::collections::HashMap;
let mut builder = DatabaseBuilder::new(MatchMode::CaseSensitive);
let data = HashMap::new();
builder.add_ip("8.8.8.8", data.clone()).unwrap();
builder.add_literal("evil.com", data).unwrap();
let db_bytes = builder.build().unwrap();
let mut tmpfile = NamedTempFile::new().unwrap();
tmpfile.write_all(&db_bytes).unwrap();
tmpfile.flush().unwrap();
let db = crate::Database::from(tmpfile.path().to_str().unwrap())
.open()
.unwrap();
let extractor = Extractor::new().unwrap();
let mut worker = Worker::builder()
.extractor(extractor)
.add_database("test", Arc::new(db))
.build();
let batch = DataBatch {
source: PathBuf::from("test.log"),
data: Arc::new(b"DNS query to evil.com from 8.8.8.8".to_vec()),
};
let matches = worker.process_batch(&batch).unwrap();
assert_eq!(matches.len(), 2);
assert!(matches.iter().any(|m| m.matched_text == "8.8.8.8"));
assert!(matches.iter().any(|m| m.matched_text == "evil.com"));
for m in &matches {
assert_eq!(m.source, PathBuf::from("test.log"));
}
}
}