ragc_core/
segment_buffer.rs

1// Buffered segment storage matching C++ AGC's CBufferedSegPart
2// Reference: agc_compressor.h lines 27-536
3
4use ahash::AHashMap;
5use std::collections::BTreeSet;
6use std::sync::atomic::{AtomicI32, Ordering};
7use std::sync::Mutex;
8
9/// Block size for atomic work distribution
10/// Matches C++ AGC's PART_ID_STEP (agc_compressor.h)
11pub const PART_ID_STEP: i32 = 128;
12
13/// Segment part data
14///
15/// Matches C++ AGC's seg_part_t (agc_compressor.h:29-120) and kk_seg_part_t (lines 124-165)
16///
17/// Fields:
18/// - `kmer1`, `kmer2`: First and last k-mers of segment
19/// - `sample_name`, `contig_name`: Origin of segment
20/// - `seg_data`: Compressed segment bytes
21/// - `is_rev_comp`: Whether segment is reverse complemented
22/// - `seg_part_no`: Part number within contig
23///
24/// Ordering: By (sample_name, contig_name, seg_part_no) for deterministic storage
25#[derive(Debug, Clone, PartialEq, Eq)]
26pub struct SegmentPart {
27    pub kmer1: u64,
28    pub kmer2: u64,
29    pub sample_name: String,
30    pub contig_name: String,
31    pub seg_data: Vec<u8>,
32    pub is_rev_comp: bool,
33    pub seg_part_no: u32,
34}
35
36impl PartialOrd for SegmentPart {
37    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
38        Some(self.cmp(other))
39    }
40}
41
42impl Ord for SegmentPart {
43    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
44        // Match C++ AGC ordering: (sample_name, contig_name, seg_part_no)
45        (&self.sample_name, &self.contig_name, self.seg_part_no).cmp(&(
46            &other.sample_name,
47            &other.contig_name,
48            other.seg_part_no,
49        ))
50    }
51}
52
53/// Thread-safe list of segments for one group
54///
55/// Matches C++ AGC's list_seg_part_t (agc_compressor.h:169-296)
56struct SegmentPartList {
57    /// Vector of segments (grows during compression, read during storage)
58    parts: Mutex<Vec<SegmentPart>>,
59
60    /// Virtual begin index for pop operations (avoid actually removing elements)
61    virt_begin: Mutex<usize>,
62}
63
64impl SegmentPartList {
65    fn new() -> Self {
66        SegmentPartList {
67            parts: Mutex::new(Vec::new()),
68            virt_begin: Mutex::new(0),
69        }
70    }
71
72    /// Add segment (thread-safe)
73    ///
74    /// Matches C++ AGC's list_seg_part_t::emplace (agc_compressor.h:224-228)
75    fn emplace(&self, part: SegmentPart) {
76        let mut parts = self.parts.lock().unwrap();
77        parts.push(part);
78    }
79
80    /// Sort segments by (sample, contig, part_no)
81    ///
82    /// Matches C++ AGC's list_seg_part_t::sort (agc_compressor.h:235-238)
83    fn sort(&self) {
84        let mut parts = self.parts.lock().unwrap();
85        parts.sort();
86    }
87
88    /// Pop next segment (uses virt_begin to avoid actual removal)
89    ///
90    /// Matches C++ AGC's list_seg_part_t::pop (agc_compressor.h:251-265)
91    fn pop(&self) -> Option<SegmentPart> {
92        let mut virt_begin = self.virt_begin.lock().unwrap();
93        let parts = self.parts.lock().unwrap();
94
95        if *virt_begin >= parts.len() {
96            drop(parts); // Release lock before clearing
97            let mut parts = self.parts.lock().unwrap();
98            *virt_begin = 0;
99            parts.clear();
100            return None;
101        }
102
103        let part = parts[*virt_begin].clone();
104        *virt_begin += 1;
105
106        Some(part)
107    }
108
109    /// Check if list is empty (from virt_begin perspective)
110    ///
111    /// Matches C++ AGC's list_seg_part_t::empty (agc_compressor.h:246-249)
112    fn is_empty(&self) -> bool {
113        let virt_begin = self.virt_begin.lock().unwrap();
114        let parts = self.parts.lock().unwrap();
115        *virt_begin >= parts.len()
116    }
117
118    /// Clear list and reset virt_begin
119    ///
120    /// Matches C++ AGC's list_seg_part_t::clear (agc_compressor.h:240-244)
121    fn clear(&self) {
122        let mut parts = self.parts.lock().unwrap();
123        let mut virt_begin = self.virt_begin.lock().unwrap();
124        parts.clear();
125        *virt_begin = 0;
126    }
127
128    fn size(&self) -> usize {
129        let parts = self.parts.lock().unwrap();
130        parts.len()
131    }
132}
133
134/// Buffered segments storage
135///
136/// Matches C++ AGC's CBufferedSegPart (agc_compressor.h:27-536)
137///
138/// **Architecture**:
139/// - `vl_seg_part`: Vector of thread-safe lists, indexed by group_id (KNOWN segments)
140/// - `s_seg_part`: BTreeSet of NEW segments (not yet assigned group_id)
141///
142/// **Workflow**:
143/// 1. During compression: Workers call `add_known()` or `add_new()`
144/// 2. At registration barrier: Main thread calls `sort_known()`, `process_new()`, `distribute_segments()`
145/// 3. During storage: Workers call `get_vec_id()` and `get_part()` to read segments
146/// 4. After storage: Main thread calls `clear()`
147pub struct BufferedSegments {
148    /// KNOWN segments indexed by group_id
149    /// Matches C++ AGC's vector<list_seg_part_t> vl_seg_part (line 298)
150    vl_seg_part: Vec<SegmentPartList>,
151
152    /// NEW segments (no group_id yet)
153    /// Matches C++ AGC's set<kk_seg_part_t> s_seg_part (line 300)
154    s_seg_part: Mutex<BTreeSet<SegmentPart>>,
155
156    /// Atomic counter for reading segments (starts at size-1, decrements to 0)
157    /// Matches C++ AGC's atomic<int32_t> a_v_part_id (line 303)
158    a_v_part_id: AtomicI32,
159
160    /// Mutex for resizing vl_seg_part
161    /// Matches C++ AGC's mutex mtx (line 301)
162    resize_mtx: Mutex<()>,
163}
164
165impl BufferedSegments {
166    /// Create new buffered segments storage
167    ///
168    /// Matches C++ AGC's CBufferedSegPart::CBufferedSegPart (agc_compressor.h:308-311)
169    ///
170    /// # Arguments
171    /// * `no_raw_groups` - Initial number of group IDs
172    pub fn new(no_raw_groups: usize) -> Self {
173        let mut vl_seg_part = Vec::with_capacity(no_raw_groups);
174        for _ in 0..no_raw_groups {
175            vl_seg_part.push(SegmentPartList::new());
176        }
177
178        BufferedSegments {
179            vl_seg_part,
180            s_seg_part: Mutex::new(BTreeSet::new()),
181            a_v_part_id: AtomicI32::new(0),
182            resize_mtx: Mutex::new(()),
183        }
184    }
185
186    /// Add segment to KNOWN group
187    ///
188    /// Matches C++ AGC's CBufferedSegPart::add_known (agc_compressor.h:320-324)
189    ///
190    /// Thread-safe: `SegmentPartList::emplace()` has internal mutex
191    pub fn add_known(
192        &self,
193        group_id: u32,
194        kmer1: u64,
195        kmer2: u64,
196        sample_name: String,
197        contig_name: String,
198        seg_data: Vec<u8>,
199        is_rev_comp: bool,
200        seg_part_no: u32,
201    ) {
202        self.vl_seg_part[group_id as usize].emplace(SegmentPart {
203            kmer1,
204            kmer2,
205            sample_name,
206            contig_name,
207            seg_data,
208            is_rev_comp,
209            seg_part_no,
210        });
211    }
212
213    /// Add NEW segment (not yet assigned group_id)
214    ///
215    /// Matches C++ AGC's CBufferedSegPart::add_new (agc_compressor.h:326-331)
216    ///
217    /// Thread-safe: Locks `s_seg_part` mutex
218    pub fn add_new(
219        &self,
220        kmer1: u64,
221        kmer2: u64,
222        sample_name: String,
223        contig_name: String,
224        seg_data: Vec<u8>,
225        is_rev_comp: bool,
226        seg_part_no: u32,
227    ) {
228        let mut s_seg_part = self.s_seg_part.lock().unwrap();
229        s_seg_part.insert(SegmentPart {
230            kmer1,
231            kmer2,
232            sample_name,
233            contig_name,
234            seg_data,
235            is_rev_comp,
236            seg_part_no,
237        });
238    }
239
240    /// Sort all KNOWN segments in parallel
241    ///
242    /// Matches C++ AGC's CBufferedSegPart::sort_known (agc_compressor.h:333-377)
243    ///
244    /// **Note**: In C++ AGC, this uses std::async for parallelism. For now, we'll use
245    /// a simple sequential implementation. Parallelism can be added later with rayon.
246    ///
247    /// # Arguments
248    /// * `_num_threads` - Number of threads (unused in sequential version)
249    pub fn sort_known(&self, _num_threads: usize) {
250        // TODO: Implement parallel sorting with rayon
251        // For now, sequential sorting
252        for list in &self.vl_seg_part {
253            list.sort();
254        }
255    }
256
257    /// Process NEW segments: assign group IDs and move to KNOWN
258    ///
259    /// Matches C++ AGC's CBufferedSegPart::process_new (agc_compressor.h:384-415)
260    ///
261    /// Returns: Number of NEW groups created
262    pub fn process_new(
263        &mut self,
264        map_segments: &std::sync::Mutex<AHashMap<(u64, u64), u32>>,
265    ) -> u32 {
266        let _lock = self.resize_mtx.lock().unwrap();
267        let mut s_seg_part = self.s_seg_part.lock().unwrap();
268
269        if s_seg_part.is_empty() {
270            return 0;
271        }
272
273        // Assign group IDs to unique (kmer1, kmer2) pairs
274        // CRITICAL FIX: Check global map_segments FIRST before assigning new group IDs
275        let global_map = map_segments.lock().unwrap();
276        let mut m_kmers = AHashMap::new();
277        let mut group_id = self.vl_seg_part.len() as u32;
278
279        for part in s_seg_part.iter() {
280            let key = (part.kmer1, part.kmer2);
281
282            // Check global registry first for existing group
283            if let Some(&existing_group_id) = global_map.get(&key) {
284                m_kmers.insert(key, existing_group_id);
285            } else if !m_kmers.contains_key(&key) {
286                // New group needed - assign new ID
287                m_kmers.insert(key, group_id);
288                group_id += 1;
289            }
290        }
291
292        drop(global_map);
293
294        let no_new = group_id - self.vl_seg_part.len() as u32;
295
296        // Resize vl_seg_part to accommodate new groups
297        let new_size = group_id as usize;
298        if self.vl_seg_part.capacity() < new_size {
299            self.vl_seg_part
300                .reserve((new_size as f64 * 1.2) as usize - self.vl_seg_part.len());
301        }
302        while self.vl_seg_part.len() < new_size {
303            self.vl_seg_part.push(SegmentPartList::new());
304        }
305
306        // Move NEW segments to KNOWN groups
307        for part in s_seg_part.iter() {
308            let key = (part.kmer1, part.kmer2);
309            let group_id = m_kmers[&key] as usize;
310
311            self.vl_seg_part[group_id].emplace(part.clone());
312        }
313
314        s_seg_part.clear();
315
316        no_new
317    }
318
319    /// Get the number of NEW segments (not yet assigned group_id)
320    ///
321    /// **For testing** - count segments in s_seg_part
322    pub fn get_num_new(&self) -> usize {
323        let s_seg_part = self.s_seg_part.lock().unwrap();
324        s_seg_part.len()
325    }
326
327    /// Distribute segments from src_id to range [dest_from, dest_to)
328    ///
329    /// Matches C++ AGC's CBufferedSegPart::distribute_segments (agc_compressor.h:417-435)
330    ///
331    /// **Pattern**: Round-robin distribution
332    ///
333    /// Example: `distribute_segments(0, 0, num_workers)`
334    /// - Distributes group 0 segments among workers 0..num_workers
335    pub fn distribute_segments(&self, src_id: u32, dest_id_from: u32, dest_id_to: u32) {
336        let src_id = src_id as usize;
337        let no_in_src = self.vl_seg_part[src_id].size();
338        let mut dest_id_curr = dest_id_from;
339
340        for _ in 0..no_in_src {
341            if dest_id_curr != src_id as u32 {
342                if let Some(part) = self.vl_seg_part[src_id].pop() {
343                    self.vl_seg_part[dest_id_curr as usize].emplace(part);
344                }
345            }
346
347            dest_id_curr += 1;
348            if dest_id_curr == dest_id_to {
349                dest_id_curr = dest_id_from;
350            }
351        }
352    }
353
354    /// Clear all buffered segments
355    ///
356    /// Matches C++ AGC's CBufferedSegPart::clear (agc_compressor.h:461-507)
357    ///
358    /// # Arguments
359    /// * `_num_threads` - Number of threads (unused in sequential version)
360    pub fn clear(&mut self, _num_threads: usize) {
361        // TODO: Implement parallel clearing with rayon
362        // For now, sequential clearing
363        let _lock = self.resize_mtx.lock().unwrap();
364
365        let mut s_seg_part = self.s_seg_part.lock().unwrap();
366        s_seg_part.clear();
367        drop(s_seg_part);
368
369        for list in &self.vl_seg_part {
370            list.clear();
371        }
372    }
373
374    /// Restart reading from highest group_id
375    ///
376    /// Matches C++ AGC's CBufferedSegPart::restart_read_vec (agc_compressor.h:509-514)
377    pub fn restart_read_vec(&self) {
378        let _lock = self.resize_mtx.lock().unwrap();
379        self.a_v_part_id
380            .store((self.vl_seg_part.len() - 1) as i32, Ordering::SeqCst);
381    }
382
383    /// Atomically get next group_id to process (decrements from size-1 to 0)
384    ///
385    /// Matches C++ AGC's CBufferedSegPart::get_vec_id (agc_compressor.h:516-520)
386    ///
387    /// Returns: group_id to process, or negative if done
388    pub fn get_vec_id(&self) -> i32 {
389        self.a_v_part_id.fetch_sub(1, Ordering::SeqCst)
390    }
391
392    /// Check if group is empty
393    ///
394    /// Matches C++ AGC's CBufferedSegPart::is_empty_part (agc_compressor.h:527-530)
395    pub fn is_empty_part(&self, group_id: i32) -> bool {
396        if group_id < 0 || group_id as usize >= self.vl_seg_part.len() {
397            return true;
398        }
399        self.vl_seg_part[group_id as usize].is_empty()
400    }
401
402    /// Pop next segment from group
403    ///
404    /// Matches C++ AGC's CBufferedSegPart::get_part (agc_compressor.h:532-535)
405    ///
406    /// Returns: (kmer1, kmer2, sample_name, contig_name, seg_data, is_rev_comp, seg_part_no)
407    pub fn get_part(
408        &self,
409        group_id: i32,
410    ) -> Option<(u64, u64, String, String, Vec<u8>, bool, u32)> {
411        if group_id < 0 || group_id as usize >= self.vl_seg_part.len() {
412            return None;
413        }
414
415        self.vl_seg_part[group_id as usize].pop().map(|part| {
416            (
417                part.kmer1,
418                part.kmer2,
419                part.sample_name,
420                part.contig_name,
421                part.seg_data,
422                part.is_rev_comp,
423                part.seg_part_no,
424            )
425        })
426    }
427
428    /// Get current number of groups
429    pub fn get_no_parts(&self) -> usize {
430        self.vl_seg_part.len()
431    }
432}
433
434#[cfg(test)]
435mod tests {
436    use super::*;
437
438    #[test]
439    fn test_segment_part_ordering() {
440        let part1 = SegmentPart {
441            kmer1: 100,
442            kmer2: 200,
443            sample_name: "sample1".to_string(),
444            contig_name: "chr1".to_string(),
445            seg_data: vec![0, 1, 2],
446            is_rev_comp: false,
447            seg_part_no: 0,
448        };
449
450        let part2 = SegmentPart {
451            kmer1: 100,
452            kmer2: 200,
453            sample_name: "sample1".to_string(),
454            contig_name: "chr1".to_string(),
455            seg_data: vec![3, 4, 5],
456            is_rev_comp: false,
457            seg_part_no: 1,
458        };
459
460        assert!(part1 < part2); // Ordered by seg_part_no
461    }
462
463    #[test]
464    fn test_buffered_segments_add_known() {
465        let buf = BufferedSegments::new(10);
466
467        buf.add_known(
468            5,
469            100,
470            200,
471            "sample1".to_string(),
472            "chr1".to_string(),
473            vec![0, 1, 2, 3],
474            false,
475            0,
476        );
477
478        assert!(!buf.is_empty_part(5));
479    }
480
481    #[test]
482    fn test_buffered_segments_add_new_and_process() {
483        let mut buf = BufferedSegments::new(10);
484
485        // Add NEW segment
486        buf.add_new(
487            300,
488            400,
489            "sample1".to_string(),
490            "chr1".to_string(),
491            vec![4, 5, 6, 7],
492            false,
493            0,
494        );
495
496        // Process NEW segments
497        let map_segments = std::sync::Mutex::new(AHashMap::new());
498        let no_new = buf.process_new(&map_segments);
499        assert_eq!(no_new, 1); // One new group created
500
501        // New group should be at index 10 (after initial 10)
502        assert_eq!(buf.get_no_parts(), 11);
503        assert!(!buf.is_empty_part(10));
504    }
505
506    #[test]
507    fn test_buffered_segments_get_vec_id() {
508        let buf = BufferedSegments::new(5);
509
510        buf.restart_read_vec();
511
512        // Should return 4, 3, 2, 1, 0, then negative
513        assert_eq!(buf.get_vec_id(), 4);
514        assert_eq!(buf.get_vec_id(), 3);
515        assert_eq!(buf.get_vec_id(), 2);
516        assert_eq!(buf.get_vec_id(), 1);
517        assert_eq!(buf.get_vec_id(), 0);
518        assert!(buf.get_vec_id() < 0);
519    }
520
521    #[test]
522    fn test_buffered_segments_get_part() {
523        let buf = BufferedSegments::new(10);
524
525        buf.add_known(
526            5,
527            100,
528            200,
529            "sample1".to_string(),
530            "chr1".to_string(),
531            vec![0, 1, 2, 3],
532            false,
533            0,
534        );
535
536        let part = buf.get_part(5);
537        assert!(part.is_some());
538
539        let (kmer1, kmer2, sample, contig, data, is_rev, part_no) = part.unwrap();
540        assert_eq!(kmer1, 100);
541        assert_eq!(kmer2, 200);
542        assert_eq!(sample, "sample1");
543        assert_eq!(contig, "chr1");
544        assert_eq!(data, vec![0, 1, 2, 3]);
545        assert_eq!(is_rev, false);
546        assert_eq!(part_no, 0);
547
548        // Second call should return None (list empty)
549        assert!(buf.get_part(5).is_none());
550    }
551
552    #[test]
553    fn test_buffered_segments_sort() {
554        let buf = BufferedSegments::new(1);
555
556        // Add segments in wrong order
557        buf.add_known(
558            0,
559            100,
560            200,
561            "sample1".to_string(),
562            "chr1".to_string(),
563            vec![2],
564            false,
565            2,
566        );
567        buf.add_known(
568            0,
569            100,
570            200,
571            "sample1".to_string(),
572            "chr1".to_string(),
573            vec![0],
574            false,
575            0,
576        );
577        buf.add_known(
578            0,
579            100,
580            200,
581            "sample1".to_string(),
582            "chr1".to_string(),
583            vec![1],
584            false,
585            1,
586        );
587
588        buf.sort_known(1);
589
590        // Should pop in sorted order: 0, 1, 2
591        let part0 = buf.get_part(0).unwrap();
592        assert_eq!(part0.6, 0); // seg_part_no
593
594        let part1 = buf.get_part(0).unwrap();
595        assert_eq!(part1.6, 1);
596
597        let part2 = buf.get_part(0).unwrap();
598        assert_eq!(part2.6, 2);
599    }
600}