genomicframe_core/
parallel.rs1use crate::error::Result;
7use lazy_static::lazy_static;
8use rayon::ThreadPool;
9use std::collections::HashMap;
10use std::io::{BufRead, Seek, SeekFrom};
11use std::sync::{Arc, Mutex};
12
13pub trait Mergeable: Sized + Send {
26 fn merge(&mut self, other: Self);
31
32 fn merge_all(stats: Vec<Self>) -> Option<Self> {
34 let mut iter = stats.into_iter();
35 let mut result = iter.next()?;
36 for stat in iter {
37 result.merge(stat);
38 }
39 Some(result)
40 }
41}
42
43#[derive(Debug, Clone)]
45pub struct ParallelConfig {
46 pub num_threads: Option<usize>,
48
49 pub min_chunk_size: usize,
53
54 pub max_chunks: Option<usize>,
57}
58
59impl Default for ParallelConfig {
60 fn default() -> Self {
61 Self {
62 num_threads: None, min_chunk_size: 5 * 1024 * 1024, max_chunks: None,
65 }
66 }
67}
68
69impl ParallelConfig {
70 pub fn new() -> Self {
72 Self::default()
73 }
74
75 pub fn with_threads(mut self, threads: usize) -> Self {
77 self.num_threads = Some(threads);
78 self
79 }
80
81 pub fn with_chunk_size(mut self, size: usize) -> Self {
83 self.min_chunk_size = size;
84 self
85 }
86
87 pub fn with_max_chunks(mut self, max: usize) -> Self {
89 self.max_chunks = Some(max);
90 self
91 }
92
93 pub fn threads(&self) -> usize {
95 self.num_threads.unwrap_or_else(num_cpus::get)
96 }
97
98 pub fn max_chunks(&self) -> usize {
100 self.max_chunks.unwrap_or_else(|| self.threads() * 4)
101 }
102}
103
104
105lazy_static! {
106 static ref THREAD_POOL_CACHE: Mutex<HashMap<usize, Arc<ThreadPool>>> = Mutex::new(HashMap::new());
107}
108
109pub fn get_thread_pool(num_threads: usize) -> Result<Arc<ThreadPool>> {
121 let mut cache = THREAD_POOL_CACHE
122 .lock()
123 .map_err(|e| crate::error::Error::InvalidInput(format!("Thread pool cache poisoned: {}", e)))?;
124
125 if let Some(pool) = cache.get(&num_threads) {
127 return Ok(Arc::clone(pool));
129 }
130
131 let pool = rayon::ThreadPoolBuilder::new()
133 .num_threads(num_threads)
134 .build()
135 .map_err(|e| crate::error::Error::InvalidInput(format!("Failed to create thread pool: {}", e)))?;
136
137 let cached_pool = Arc::new(pool);
139 cache.insert(num_threads, Arc::clone(&cached_pool));
140
141 Ok(cached_pool)
142}
143
144#[derive(Debug, Clone, Copy)]
146pub struct FileChunk {
147 pub start: u64,
149 pub end: u64,
151 pub index: usize,
153}
154
155impl FileChunk {
156 pub fn new(start: u64, end: u64, index: usize) -> Self {
158 Self { start, end, index }
159 }
160
161 pub fn size(&self) -> u64 {
163 self.end - self.start
164 }
165}
166
167pub fn calculate_chunks(file_size: u64, config: &ParallelConfig) -> Vec<FileChunk> {
172 if file_size == 0 {
173 return vec![];
174 }
175
176 let num_threads = config.threads();
177 let min_chunk_size = config.min_chunk_size as u64;
178 let max_chunks = config.max_chunks();
179
180 let ideal_chunks = num_threads;
182 let max_chunks_by_size = (file_size / min_chunk_size).max(1) as usize;
183 let num_chunks = ideal_chunks.min(max_chunks_by_size).min(max_chunks);
184
185 let chunk_size = file_size / num_chunks as u64;
187
188 (0..num_chunks)
190 .map(|i| {
191 let start = i as u64 * chunk_size;
192 let end = if i == num_chunks - 1 {
193 file_size } else {
195 (i + 1) as u64 * chunk_size
196 };
197 FileChunk::new(start, end, i)
198 })
199 .collect()
200}
201
202
203pub fn find_record_boundary<R: BufRead + Seek>(reader: &mut R, offset: u64) -> Result<Option<u64>> {
222 if offset == 0 {
224 return Ok(Some(0));
225 }
226
227 reader.seek(SeekFrom::Start(offset))?;
229
230 let mut discard = Vec::new();
232 let first_skip = reader.read_until(b'\n', &mut discard)?;
233 if first_skip == 0 {
234 return Ok(None); }
236
237 let mut current_pos = offset + first_skip as u64;
238
239 loop {
242 let mut line = Vec::new();
243 let bytes_read = reader.read_until(b'\n', &mut line)?;
244
245 if bytes_read == 0 {
246 return Ok(None);
248 }
249
250 if !line.is_empty() && line[0] == b'@' {
252 return Ok(Some(current_pos));
254 }
255
256 current_pos += bytes_read as u64;
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263 use std::io::Cursor;
264
265 #[test]
266 fn test_parallel_config_defaults() {
267 let config = ParallelConfig::default();
268 assert!(config.threads() > 0);
269 assert_eq!(config.min_chunk_size, 5 * 1024 * 1024);
270 }
271
272 #[test]
273 fn test_parallel_config_builder() {
274 let config = ParallelConfig::new()
275 .with_threads(4)
276 .with_chunk_size(5_000_000)
277 .with_max_chunks(16);
278
279 assert_eq!(config.threads(), 4);
280 assert_eq!(config.min_chunk_size, 5_000_000);
281 assert_eq!(config.max_chunks(), 16);
282 }
283
284 #[test]
285 fn test_calculate_chunks_small_file() {
286 let config = ParallelConfig::new().with_threads(4);
287
288 let chunks = calculate_chunks(1_000_000, &config);
290 assert_eq!(chunks.len(), 1);
291 assert_eq!(chunks[0].start, 0);
292 assert_eq!(chunks[0].end, 1_000_000);
293 }
294
295 #[test]
296 fn test_calculate_chunks_large_file() {
297 let config = ParallelConfig::new()
298 .with_threads(4)
299 .with_chunk_size(10_000_000);
300
301 let chunks = calculate_chunks(100_000_000, &config);
303 assert_eq!(chunks.len(), 4);
304
305 for (i, chunk) in chunks.iter().enumerate() {
307 assert_eq!(chunk.index, i);
308 if i < 3 {
309 assert_eq!(chunk.size(), 25_000_000);
310 } else {
311 assert_eq!(chunk.size(), 25_000_000);
313 }
314 }
315 }
316
317 #[test]
318 fn test_calculate_chunks_empty_file() {
319 let config = ParallelConfig::default();
320 let chunks = calculate_chunks(0, &config);
321 assert_eq!(chunks.len(), 0);
322 }
323
324 #[test]
325 fn test_find_record_boundary_at_start() {
326 let data = b"@READ1\nACGT\n+\nIIII\n@READ2\nGGGG\n+\nIIII\n";
327 let mut cursor = Cursor::new(data);
328
329 let boundary = find_record_boundary(&mut cursor, 0).unwrap();
330 assert_eq!(boundary, Some(0));
331 }
332
333 #[test]
334 fn test_find_record_boundary_mid_record() {
335 let data = b"@READ1\nACGT\n+\nIIII\n@READ2\nGGGG\n+\nIIII\n";
336 let mut cursor = Cursor::new(data);
337
338 let boundary = find_record_boundary(&mut cursor, 10).unwrap();
340
341 assert!(boundary.unwrap() > 10);
343 assert!(boundary.unwrap() <= data.len() as u64);
344 }
345
346 #[test]
347 fn test_find_record_boundary_at_eof() {
348 let data = b"@READ1\nACGT\n+\nIIII\n";
349 let mut cursor = Cursor::new(data);
350
351 let boundary = find_record_boundary(&mut cursor, data.len() as u64).unwrap();
352 assert_eq!(boundary, None); }
354
355 #[test]
356 fn test_file_chunk_size() {
357 let chunk = FileChunk::new(0, 1000, 0);
358 assert_eq!(chunk.size(), 1000);
359
360 let chunk = FileChunk::new(500, 1500, 1);
361 assert_eq!(chunk.size(), 1000);
362 }
363}