ragc_core/
segment_buffer.rs

1// Buffered segment storage matching C++ AGC's CBufferedSegPart
2// Reference: agc_compressor.h lines 27-536
3
4use std::collections::BTreeSet;
5use std::sync::atomic::{AtomicI32, Ordering};
6use std::sync::Mutex;
7
8/// Block size for atomic work distribution
9/// Matches C++ AGC's PART_ID_STEP (agc_compressor.h)
10pub const PART_ID_STEP: i32 = 128;
11
12/// Segment part data
13///
14/// Matches C++ AGC's seg_part_t (agc_compressor.h:29-120) and kk_seg_part_t (lines 124-165)
15///
16/// Fields:
17/// - `kmer1`, `kmer2`: First and last k-mers of segment
18/// - `sample_name`, `contig_name`: Origin of segment
19/// - `seg_data`: Compressed segment bytes
20/// - `is_rev_comp`: Whether segment is reverse complemented
21/// - `seg_part_no`: Part number within contig
22///
23/// Ordering: By (sample_name, contig_name, seg_part_no) for deterministic storage
24#[derive(Debug, Clone, PartialEq, Eq)]
25pub struct SegmentPart {
26    pub kmer1: u64,
27    pub kmer2: u64,
28    pub sample_name: String,
29    pub contig_name: String,
30    pub seg_data: Vec<u8>,
31    pub is_rev_comp: bool,
32    pub seg_part_no: u32,
33}
34
35impl PartialOrd for SegmentPart {
36    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
37        Some(self.cmp(other))
38    }
39}
40
41impl Ord for SegmentPart {
42    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
43        // Match C++ AGC ordering: (sample_name, contig_name, seg_part_no)
44        (&self.sample_name, &self.contig_name, self.seg_part_no).cmp(&(
45            &other.sample_name,
46            &other.contig_name,
47            other.seg_part_no,
48        ))
49    }
50}
51
52/// Thread-safe list of segments for one group
53///
54/// Matches C++ AGC's list_seg_part_t (agc_compressor.h:169-296)
55struct SegmentPartList {
56    /// Vector of segments (grows during compression, read during storage)
57    parts: Mutex<Vec<SegmentPart>>,
58
59    /// Virtual begin index for pop operations (avoid actually removing elements)
60    virt_begin: Mutex<usize>,
61}
62
63impl SegmentPartList {
64    fn new() -> Self {
65        SegmentPartList {
66            parts: Mutex::new(Vec::new()),
67            virt_begin: Mutex::new(0),
68        }
69    }
70
71    /// Add segment (thread-safe)
72    ///
73    /// Matches C++ AGC's list_seg_part_t::emplace (agc_compressor.h:224-228)
74    fn emplace(&self, part: SegmentPart) {
75        let mut parts = self.parts.lock().unwrap();
76        parts.push(part);
77    }
78
79    /// Sort segments by (sample, contig, part_no)
80    ///
81    /// Matches C++ AGC's list_seg_part_t::sort (agc_compressor.h:235-238)
82    fn sort(&self) {
83        let mut parts = self.parts.lock().unwrap();
84        parts.sort();
85    }
86
87    /// Pop next segment (uses virt_begin to avoid actual removal)
88    ///
89    /// Matches C++ AGC's list_seg_part_t::pop (agc_compressor.h:251-265)
90    fn pop(&self) -> Option<SegmentPart> {
91        let mut virt_begin = self.virt_begin.lock().unwrap();
92        let parts = self.parts.lock().unwrap();
93
94        if *virt_begin >= parts.len() {
95            drop(parts); // Release lock before clearing
96            let mut parts = self.parts.lock().unwrap();
97            *virt_begin = 0;
98            parts.clear();
99            return None;
100        }
101
102        let part = parts[*virt_begin].clone();
103        *virt_begin += 1;
104
105        Some(part)
106    }
107
108    /// Check if list is empty (from virt_begin perspective)
109    ///
110    /// Matches C++ AGC's list_seg_part_t::empty (agc_compressor.h:246-249)
111    fn is_empty(&self) -> bool {
112        let virt_begin = self.virt_begin.lock().unwrap();
113        let parts = self.parts.lock().unwrap();
114        *virt_begin >= parts.len()
115    }
116
117    /// Clear list and reset virt_begin
118    ///
119    /// Matches C++ AGC's list_seg_part_t::clear (agc_compressor.h:240-244)
120    fn clear(&self) {
121        let mut parts = self.parts.lock().unwrap();
122        let mut virt_begin = self.virt_begin.lock().unwrap();
123        parts.clear();
124        *virt_begin = 0;
125    }
126
127    fn size(&self) -> usize {
128        let parts = self.parts.lock().unwrap();
129        parts.len()
130    }
131}
132
133/// Buffered segments storage
134///
135/// Matches C++ AGC's CBufferedSegPart (agc_compressor.h:27-536)
136///
137/// **Architecture**:
138/// - `vl_seg_part`: Vector of thread-safe lists, indexed by group_id (KNOWN segments)
139/// - `s_seg_part`: BTreeSet of NEW segments (not yet assigned group_id)
140///
141/// **Workflow**:
142/// 1. During compression: Workers call `add_known()` or `add_new()`
143/// 2. At registration barrier: Main thread calls `sort_known()`, `process_new()`, `distribute_segments()`
144/// 3. During storage: Workers call `get_vec_id()` and `get_part()` to read segments
145/// 4. After storage: Main thread calls `clear()`
146pub struct BufferedSegments {
147    /// KNOWN segments indexed by group_id
148    /// Matches C++ AGC's vector<list_seg_part_t> vl_seg_part (line 298)
149    vl_seg_part: Vec<SegmentPartList>,
150
151    /// NEW segments (no group_id yet)
152    /// Matches C++ AGC's set<kk_seg_part_t> s_seg_part (line 300)
153    s_seg_part: Mutex<BTreeSet<SegmentPart>>,
154
155    /// Atomic counter for reading segments (starts at size-1, decrements to 0)
156    /// Matches C++ AGC's atomic<int32_t> a_v_part_id (line 303)
157    a_v_part_id: AtomicI32,
158
159    /// Mutex for resizing vl_seg_part
160    /// Matches C++ AGC's mutex mtx (line 301)
161    resize_mtx: Mutex<()>,
162}
163
164impl BufferedSegments {
165    /// Create new buffered segments storage
166    ///
167    /// Matches C++ AGC's CBufferedSegPart::CBufferedSegPart (agc_compressor.h:308-311)
168    ///
169    /// # Arguments
170    /// * `no_raw_groups` - Initial number of group IDs
171    pub fn new(no_raw_groups: usize) -> Self {
172        let mut vl_seg_part = Vec::with_capacity(no_raw_groups);
173        for _ in 0..no_raw_groups {
174            vl_seg_part.push(SegmentPartList::new());
175        }
176
177        BufferedSegments {
178            vl_seg_part,
179            s_seg_part: Mutex::new(BTreeSet::new()),
180            a_v_part_id: AtomicI32::new(0),
181            resize_mtx: Mutex::new(()),
182        }
183    }
184
185    /// Add segment to KNOWN group
186    ///
187    /// Matches C++ AGC's CBufferedSegPart::add_known (agc_compressor.h:320-324)
188    ///
189    /// Thread-safe: `SegmentPartList::emplace()` has internal mutex
190    pub fn add_known(
191        &self,
192        group_id: u32,
193        kmer1: u64,
194        kmer2: u64,
195        sample_name: String,
196        contig_name: String,
197        seg_data: Vec<u8>,
198        is_rev_comp: bool,
199        seg_part_no: u32,
200    ) {
201        self.vl_seg_part[group_id as usize].emplace(SegmentPart {
202            kmer1,
203            kmer2,
204            sample_name,
205            contig_name,
206            seg_data,
207            is_rev_comp,
208            seg_part_no,
209        });
210    }
211
212    /// Add NEW segment (not yet assigned group_id)
213    ///
214    /// Matches C++ AGC's CBufferedSegPart::add_new (agc_compressor.h:326-331)
215    ///
216    /// Thread-safe: Locks `s_seg_part` mutex
217    pub fn add_new(
218        &self,
219        kmer1: u64,
220        kmer2: u64,
221        sample_name: String,
222        contig_name: String,
223        seg_data: Vec<u8>,
224        is_rev_comp: bool,
225        seg_part_no: u32,
226    ) {
227        let mut s_seg_part = self.s_seg_part.lock().unwrap();
228        s_seg_part.insert(SegmentPart {
229            kmer1,
230            kmer2,
231            sample_name,
232            contig_name,
233            seg_data,
234            is_rev_comp,
235            seg_part_no,
236        });
237    }
238
239    /// Sort all KNOWN segments in parallel
240    ///
241    /// Matches C++ AGC's CBufferedSegPart::sort_known (agc_compressor.h:333-377)
242    ///
243    /// **Note**: In C++ AGC, this uses std::async for parallelism. For now, we'll use
244    /// a simple sequential implementation. Parallelism can be added later with rayon.
245    ///
246    /// # Arguments
247    /// * `_num_threads` - Number of threads (unused in sequential version)
248    pub fn sort_known(&self, _num_threads: usize) {
249        // TODO: Implement parallel sorting with rayon
250        // For now, sequential sorting
251        for list in &self.vl_seg_part {
252            list.sort();
253        }
254    }
255
256    /// Process NEW segments: assign group IDs and move to KNOWN
257    ///
258    /// Matches C++ AGC's CBufferedSegPart::process_new (agc_compressor.h:384-415)
259    ///
260    /// Returns: Number of NEW groups created
261    pub fn process_new(&mut self) -> u32 {
262        let _lock = self.resize_mtx.lock().unwrap();
263        let mut s_seg_part = self.s_seg_part.lock().unwrap();
264
265        if s_seg_part.is_empty() {
266            return 0;
267        }
268
269        // Assign group IDs to unique (kmer1, kmer2) pairs
270        let mut m_kmers = std::collections::HashMap::new();
271        let mut group_id = self.vl_seg_part.len() as u32;
272
273        for part in s_seg_part.iter() {
274            let key = (part.kmer1, part.kmer2);
275            if !m_kmers.contains_key(&key) {
276                m_kmers.insert(key, group_id);
277                group_id += 1;
278            }
279        }
280
281        let no_new = group_id - self.vl_seg_part.len() as u32;
282
283        // Resize vl_seg_part to accommodate new groups
284        let new_size = group_id as usize;
285        if self.vl_seg_part.capacity() < new_size {
286            self.vl_seg_part
287                .reserve((new_size as f64 * 1.2) as usize - self.vl_seg_part.len());
288        }
289        while self.vl_seg_part.len() < new_size {
290            self.vl_seg_part.push(SegmentPartList::new());
291        }
292
293        // Move NEW segments to KNOWN groups
294        for part in s_seg_part.iter() {
295            let key = (part.kmer1, part.kmer2);
296            let group_id = m_kmers[&key] as usize;
297
298            self.vl_seg_part[group_id].emplace(part.clone());
299        }
300
301        s_seg_part.clear();
302
303        no_new
304    }
305
306    /// Get the number of NEW segments (not yet assigned group_id)
307    ///
308    /// **For testing** - count segments in s_seg_part
309    pub fn get_num_new(&self) -> usize {
310        let s_seg_part = self.s_seg_part.lock().unwrap();
311        s_seg_part.len()
312    }
313
314    /// Distribute segments from src_id to range [dest_from, dest_to)
315    ///
316    /// Matches C++ AGC's CBufferedSegPart::distribute_segments (agc_compressor.h:417-435)
317    ///
318    /// **Pattern**: Round-robin distribution
319    ///
320    /// Example: `distribute_segments(0, 0, num_workers)`
321    /// - Distributes group 0 segments among workers 0..num_workers
322    pub fn distribute_segments(&self, src_id: u32, dest_id_from: u32, dest_id_to: u32) {
323        let src_id = src_id as usize;
324        let no_in_src = self.vl_seg_part[src_id].size();
325        let mut dest_id_curr = dest_id_from;
326
327        for _ in 0..no_in_src {
328            if dest_id_curr != src_id as u32 {
329                if let Some(part) = self.vl_seg_part[src_id].pop() {
330                    self.vl_seg_part[dest_id_curr as usize].emplace(part);
331                }
332            }
333
334            dest_id_curr += 1;
335            if dest_id_curr == dest_id_to {
336                dest_id_curr = dest_id_from;
337            }
338        }
339    }
340
341    /// Clear all buffered segments
342    ///
343    /// Matches C++ AGC's CBufferedSegPart::clear (agc_compressor.h:461-507)
344    ///
345    /// # Arguments
346    /// * `_num_threads` - Number of threads (unused in sequential version)
347    pub fn clear(&mut self, _num_threads: usize) {
348        // TODO: Implement parallel clearing with rayon
349        // For now, sequential clearing
350        let _lock = self.resize_mtx.lock().unwrap();
351
352        let mut s_seg_part = self.s_seg_part.lock().unwrap();
353        s_seg_part.clear();
354        drop(s_seg_part);
355
356        for list in &self.vl_seg_part {
357            list.clear();
358        }
359    }
360
361    /// Restart reading from highest group_id
362    ///
363    /// Matches C++ AGC's CBufferedSegPart::restart_read_vec (agc_compressor.h:509-514)
364    pub fn restart_read_vec(&self) {
365        let _lock = self.resize_mtx.lock().unwrap();
366        self.a_v_part_id
367            .store((self.vl_seg_part.len() - 1) as i32, Ordering::SeqCst);
368    }
369
370    /// Atomically get next group_id to process (decrements from size-1 to 0)
371    ///
372    /// Matches C++ AGC's CBufferedSegPart::get_vec_id (agc_compressor.h:516-520)
373    ///
374    /// Returns: group_id to process, or negative if done
375    pub fn get_vec_id(&self) -> i32 {
376        self.a_v_part_id.fetch_sub(1, Ordering::SeqCst)
377    }
378
379    /// Check if group is empty
380    ///
381    /// Matches C++ AGC's CBufferedSegPart::is_empty_part (agc_compressor.h:527-530)
382    pub fn is_empty_part(&self, group_id: i32) -> bool {
383        if group_id < 0 || group_id as usize >= self.vl_seg_part.len() {
384            return true;
385        }
386        self.vl_seg_part[group_id as usize].is_empty()
387    }
388
389    /// Pop next segment from group
390    ///
391    /// Matches C++ AGC's CBufferedSegPart::get_part (agc_compressor.h:532-535)
392    ///
393    /// Returns: (kmer1, kmer2, sample_name, contig_name, seg_data, is_rev_comp, seg_part_no)
394    pub fn get_part(
395        &self,
396        group_id: i32,
397    ) -> Option<(u64, u64, String, String, Vec<u8>, bool, u32)> {
398        if group_id < 0 || group_id as usize >= self.vl_seg_part.len() {
399            return None;
400        }
401
402        self.vl_seg_part[group_id as usize].pop().map(|part| {
403            (
404                part.kmer1,
405                part.kmer2,
406                part.sample_name,
407                part.contig_name,
408                part.seg_data,
409                part.is_rev_comp,
410                part.seg_part_no,
411            )
412        })
413    }
414
415    /// Get current number of groups
416    pub fn get_no_parts(&self) -> usize {
417        self.vl_seg_part.len()
418    }
419}
420
421#[cfg(test)]
422mod tests {
423    use super::*;
424
425    #[test]
426    fn test_segment_part_ordering() {
427        let part1 = SegmentPart {
428            kmer1: 100,
429            kmer2: 200,
430            sample_name: "sample1".to_string(),
431            contig_name: "chr1".to_string(),
432            seg_data: vec![0, 1, 2],
433            is_rev_comp: false,
434            seg_part_no: 0,
435        };
436
437        let part2 = SegmentPart {
438            kmer1: 100,
439            kmer2: 200,
440            sample_name: "sample1".to_string(),
441            contig_name: "chr1".to_string(),
442            seg_data: vec![3, 4, 5],
443            is_rev_comp: false,
444            seg_part_no: 1,
445        };
446
447        assert!(part1 < part2); // Ordered by seg_part_no
448    }
449
450    #[test]
451    fn test_buffered_segments_add_known() {
452        let buf = BufferedSegments::new(10);
453
454        buf.add_known(
455            5,
456            100,
457            200,
458            "sample1".to_string(),
459            "chr1".to_string(),
460            vec![0, 1, 2, 3],
461            false,
462            0,
463        );
464
465        assert!(!buf.is_empty_part(5));
466    }
467
468    #[test]
469    fn test_buffered_segments_add_new_and_process() {
470        let mut buf = BufferedSegments::new(10);
471
472        // Add NEW segment
473        buf.add_new(
474            300,
475            400,
476            "sample1".to_string(),
477            "chr1".to_string(),
478            vec![4, 5, 6, 7],
479            false,
480            0,
481        );
482
483        // Process NEW segments
484        let no_new = buf.process_new();
485        assert_eq!(no_new, 1); // One new group created
486
487        // New group should be at index 10 (after initial 10)
488        assert_eq!(buf.get_no_parts(), 11);
489        assert!(!buf.is_empty_part(10));
490    }
491
492    #[test]
493    fn test_buffered_segments_get_vec_id() {
494        let buf = BufferedSegments::new(5);
495
496        buf.restart_read_vec();
497
498        // Should return 4, 3, 2, 1, 0, then negative
499        assert_eq!(buf.get_vec_id(), 4);
500        assert_eq!(buf.get_vec_id(), 3);
501        assert_eq!(buf.get_vec_id(), 2);
502        assert_eq!(buf.get_vec_id(), 1);
503        assert_eq!(buf.get_vec_id(), 0);
504        assert!(buf.get_vec_id() < 0);
505    }
506
507    #[test]
508    fn test_buffered_segments_get_part() {
509        let buf = BufferedSegments::new(10);
510
511        buf.add_known(
512            5,
513            100,
514            200,
515            "sample1".to_string(),
516            "chr1".to_string(),
517            vec![0, 1, 2, 3],
518            false,
519            0,
520        );
521
522        let part = buf.get_part(5);
523        assert!(part.is_some());
524
525        let (kmer1, kmer2, sample, contig, data, is_rev, part_no) = part.unwrap();
526        assert_eq!(kmer1, 100);
527        assert_eq!(kmer2, 200);
528        assert_eq!(sample, "sample1");
529        assert_eq!(contig, "chr1");
530        assert_eq!(data, vec![0, 1, 2, 3]);
531        assert_eq!(is_rev, false);
532        assert_eq!(part_no, 0);
533
534        // Second call should return None (list empty)
535        assert!(buf.get_part(5).is_none());
536    }
537
538    #[test]
539    fn test_buffered_segments_sort() {
540        let buf = BufferedSegments::new(1);
541
542        // Add segments in wrong order
543        buf.add_known(
544            0,
545            100,
546            200,
547            "sample1".to_string(),
548            "chr1".to_string(),
549            vec![2],
550            false,
551            2,
552        );
553        buf.add_known(
554            0,
555            100,
556            200,
557            "sample1".to_string(),
558            "chr1".to_string(),
559            vec![0],
560            false,
561            0,
562        );
563        buf.add_known(
564            0,
565            100,
566            200,
567            "sample1".to_string(),
568            "chr1".to_string(),
569            vec![1],
570            false,
571            1,
572        );
573
574        buf.sort_known(1);
575
576        // Should pop in sorted order: 0, 1, 2
577        let part0 = buf.get_part(0).unwrap();
578        assert_eq!(part0.6, 0); // seg_part_no
579
580        let part1 = buf.get_part(0).unwrap();
581        assert_eq!(part1.6, 1);
582
583        let part2 = buf.get_part(0).unwrap();
584        assert_eq!(part2.6, 2);
585    }
586}