use crate::error::Result;
use lazy_static::lazy_static;
use rayon::ThreadPool;
use std::collections::HashMap;
use std::io::{BufRead, Seek, SeekFrom};
use std::sync::{Arc, Mutex};
pub trait Mergeable: Sized + Send {
fn merge(&mut self, other: Self);
fn merge_all(stats: Vec<Self>) -> Option<Self> {
let mut iter = stats.into_iter();
let mut result = iter.next()?;
for stat in iter {
result.merge(stat);
}
Some(result)
}
}
#[derive(Debug, Clone)]
pub struct ParallelConfig {
pub num_threads: Option<usize>,
pub min_chunk_size: usize,
pub max_chunks: Option<usize>,
}
impl Default for ParallelConfig {
fn default() -> Self {
Self {
num_threads: None, min_chunk_size: 5 * 1024 * 1024, max_chunks: None,
}
}
}
impl ParallelConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_threads(mut self, threads: usize) -> Self {
self.num_threads = Some(threads);
self
}
pub fn with_chunk_size(mut self, size: usize) -> Self {
self.min_chunk_size = size;
self
}
pub fn with_max_chunks(mut self, max: usize) -> Self {
self.max_chunks = Some(max);
self
}
pub fn threads(&self) -> usize {
self.num_threads.unwrap_or_else(num_cpus::get)
}
pub fn max_chunks(&self) -> usize {
self.max_chunks.unwrap_or_else(|| self.threads() * 4)
}
}
lazy_static! {
static ref THREAD_POOL_CACHE: Mutex<HashMap<usize, Arc<ThreadPool>>> = Mutex::new(HashMap::new());
}
pub fn get_thread_pool(num_threads: usize) -> Result<Arc<ThreadPool>> {
let mut cache = THREAD_POOL_CACHE
.lock()
.map_err(|e| crate::error::Error::InvalidInput(format!("Thread pool cache poisoned: {}", e)))?;
if let Some(pool) = cache.get(&num_threads) {
return Ok(Arc::clone(pool));
}
let pool = rayon::ThreadPoolBuilder::new()
.num_threads(num_threads)
.build()
.map_err(|e| crate::error::Error::InvalidInput(format!("Failed to create thread pool: {}", e)))?;
let cached_pool = Arc::new(pool);
cache.insert(num_threads, Arc::clone(&cached_pool));
Ok(cached_pool)
}
#[derive(Debug, Clone, Copy)]
pub struct FileChunk {
pub start: u64,
pub end: u64,
pub index: usize,
}
impl FileChunk {
pub fn new(start: u64, end: u64, index: usize) -> Self {
Self { start, end, index }
}
pub fn size(&self) -> u64 {
self.end - self.start
}
}
pub fn calculate_chunks(file_size: u64, config: &ParallelConfig) -> Vec<FileChunk> {
if file_size == 0 {
return vec![];
}
let num_threads = config.threads();
let min_chunk_size = config.min_chunk_size as u64;
let max_chunks = config.max_chunks();
let ideal_chunks = num_threads;
let max_chunks_by_size = (file_size / min_chunk_size).max(1) as usize;
let num_chunks = ideal_chunks.min(max_chunks_by_size).min(max_chunks);
let chunk_size = file_size / num_chunks as u64;
(0..num_chunks)
.map(|i| {
let start = i as u64 * chunk_size;
let end = if i == num_chunks - 1 {
file_size } else {
(i + 1) as u64 * chunk_size
};
FileChunk::new(start, end, i)
})
.collect()
}
pub fn find_record_boundary<R: BufRead + Seek>(reader: &mut R, offset: u64) -> Result<Option<u64>> {
if offset == 0 {
return Ok(Some(0));
}
reader.seek(SeekFrom::Start(offset))?;
let mut discard = Vec::new();
let first_skip = reader.read_until(b'\n', &mut discard)?;
if first_skip == 0 {
return Ok(None); }
let mut current_pos = offset + first_skip as u64;
loop {
let mut line = Vec::new();
let bytes_read = reader.read_until(b'\n', &mut line)?;
if bytes_read == 0 {
return Ok(None);
}
if !line.is_empty() && line[0] == b'@' {
return Ok(Some(current_pos));
}
current_pos += bytes_read as u64;
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
#[test]
fn test_parallel_config_defaults() {
let config = ParallelConfig::default();
assert!(config.threads() > 0);
assert_eq!(config.min_chunk_size, 5 * 1024 * 1024);
}
#[test]
fn test_parallel_config_builder() {
let config = ParallelConfig::new()
.with_threads(4)
.with_chunk_size(5_000_000)
.with_max_chunks(16);
assert_eq!(config.threads(), 4);
assert_eq!(config.min_chunk_size, 5_000_000);
assert_eq!(config.max_chunks(), 16);
}
#[test]
fn test_calculate_chunks_small_file() {
let config = ParallelConfig::new().with_threads(4);
let chunks = calculate_chunks(1_000_000, &config);
assert_eq!(chunks.len(), 1);
assert_eq!(chunks[0].start, 0);
assert_eq!(chunks[0].end, 1_000_000);
}
#[test]
fn test_calculate_chunks_large_file() {
let config = ParallelConfig::new()
.with_threads(4)
.with_chunk_size(10_000_000);
let chunks = calculate_chunks(100_000_000, &config);
assert_eq!(chunks.len(), 4);
for (i, chunk) in chunks.iter().enumerate() {
assert_eq!(chunk.index, i);
if i < 3 {
assert_eq!(chunk.size(), 25_000_000);
} else {
assert_eq!(chunk.size(), 25_000_000);
}
}
}
#[test]
fn test_calculate_chunks_empty_file() {
let config = ParallelConfig::default();
let chunks = calculate_chunks(0, &config);
assert_eq!(chunks.len(), 0);
}
#[test]
fn test_find_record_boundary_at_start() {
let data = b"@READ1\nACGT\n+\nIIII\n@READ2\nGGGG\n+\nIIII\n";
let mut cursor = Cursor::new(data);
let boundary = find_record_boundary(&mut cursor, 0).unwrap();
assert_eq!(boundary, Some(0));
}
#[test]
fn test_find_record_boundary_mid_record() {
let data = b"@READ1\nACGT\n+\nIIII\n@READ2\nGGGG\n+\nIIII\n";
let mut cursor = Cursor::new(data);
let boundary = find_record_boundary(&mut cursor, 10).unwrap();
assert!(boundary.unwrap() > 10);
assert!(boundary.unwrap() <= data.len() as u64);
}
#[test]
fn test_find_record_boundary_at_eof() {
let data = b"@READ1\nACGT\n+\nIIII\n";
let mut cursor = Cursor::new(data);
let boundary = find_record_boundary(&mut cursor, data.len() as u64).unwrap();
assert_eq!(boundary, None); }
#[test]
fn test_file_chunk_size() {
let chunk = FileChunk::new(0, 1000, 0);
assert_eq!(chunk.size(), 1000);
let chunk = FileChunk::new(500, 1500, 1);
assert_eq!(chunk.size(), 1000);
}
}