use ahash::AHashMap;
use std::collections::BTreeSet;
use std::sync::atomic::{AtomicI32, Ordering};
use std::sync::Mutex;
pub const PART_ID_STEP: i32 = 128;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SegmentPart {
pub kmer1: u64,
pub kmer2: u64,
pub sample_name: String,
pub contig_name: String,
pub seg_data: Vec<u8>,
pub is_rev_comp: bool,
pub seg_part_no: u32,
}
impl PartialOrd for SegmentPart {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for SegmentPart {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
(&self.sample_name, &self.contig_name, self.seg_part_no).cmp(&(
&other.sample_name,
&other.contig_name,
other.seg_part_no,
))
}
}
struct SegmentPartList {
parts: Mutex<Vec<SegmentPart>>,
virt_begin: Mutex<usize>,
}
impl SegmentPartList {
fn new() -> Self {
SegmentPartList {
parts: Mutex::new(Vec::new()),
virt_begin: Mutex::new(0),
}
}
fn emplace(&self, part: SegmentPart) {
let mut parts = self.parts.lock().unwrap();
parts.push(part);
}
fn sort(&self) {
let mut parts = self.parts.lock().unwrap();
parts.sort();
}
fn pop(&self) -> Option<SegmentPart> {
let mut virt_begin = self.virt_begin.lock().unwrap();
let parts = self.parts.lock().unwrap();
if *virt_begin >= parts.len() {
drop(parts); let mut parts = self.parts.lock().unwrap();
*virt_begin = 0;
parts.clear();
return None;
}
let part = parts[*virt_begin].clone();
*virt_begin += 1;
Some(part)
}
fn is_empty(&self) -> bool {
let virt_begin = self.virt_begin.lock().unwrap();
let parts = self.parts.lock().unwrap();
*virt_begin >= parts.len()
}
fn clear(&self) {
let mut parts = self.parts.lock().unwrap();
let mut virt_begin = self.virt_begin.lock().unwrap();
parts.clear();
*virt_begin = 0;
}
fn size(&self) -> usize {
let parts = self.parts.lock().unwrap();
parts.len()
}
}
pub struct BufferedSegments {
vl_seg_part: Vec<SegmentPartList>,
s_seg_part: Mutex<BTreeSet<SegmentPart>>,
a_v_part_id: AtomicI32,
resize_mtx: Mutex<()>,
}
impl BufferedSegments {
pub fn new(no_raw_groups: usize) -> Self {
let mut vl_seg_part = Vec::with_capacity(no_raw_groups);
for _ in 0..no_raw_groups {
vl_seg_part.push(SegmentPartList::new());
}
BufferedSegments {
vl_seg_part,
s_seg_part: Mutex::new(BTreeSet::new()),
a_v_part_id: AtomicI32::new(0),
resize_mtx: Mutex::new(()),
}
}
pub fn add_known(
&self,
group_id: u32,
kmer1: u64,
kmer2: u64,
sample_name: String,
contig_name: String,
seg_data: Vec<u8>,
is_rev_comp: bool,
seg_part_no: u32,
) {
self.vl_seg_part[group_id as usize].emplace(SegmentPart {
kmer1,
kmer2,
sample_name,
contig_name,
seg_data,
is_rev_comp,
seg_part_no,
});
}
pub fn add_new(
&self,
kmer1: u64,
kmer2: u64,
sample_name: String,
contig_name: String,
seg_data: Vec<u8>,
is_rev_comp: bool,
seg_part_no: u32,
) {
let mut s_seg_part = self.s_seg_part.lock().unwrap();
s_seg_part.insert(SegmentPart {
kmer1,
kmer2,
sample_name,
contig_name,
seg_data,
is_rev_comp,
seg_part_no,
});
}
pub fn sort_known(&self, _num_threads: usize) {
for list in &self.vl_seg_part {
list.sort();
}
}
pub fn process_new(
&mut self,
map_segments: &std::sync::Mutex<AHashMap<(u64, u64), u32>>,
) -> u32 {
let _lock = self.resize_mtx.lock().unwrap();
let mut s_seg_part = self.s_seg_part.lock().unwrap();
if s_seg_part.is_empty() {
return 0;
}
let global_map = map_segments.lock().unwrap();
let mut m_kmers = AHashMap::new();
let mut group_id = self.vl_seg_part.len() as u32;
for part in s_seg_part.iter() {
let key = (part.kmer1, part.kmer2);
if let Some(&existing_group_id) = global_map.get(&key) {
m_kmers.insert(key, existing_group_id);
} else if !m_kmers.contains_key(&key) {
m_kmers.insert(key, group_id);
group_id += 1;
}
}
drop(global_map);
let no_new = group_id - self.vl_seg_part.len() as u32;
let new_size = group_id as usize;
if self.vl_seg_part.capacity() < new_size {
self.vl_seg_part
.reserve((new_size as f64 * 1.2) as usize - self.vl_seg_part.len());
}
while self.vl_seg_part.len() < new_size {
self.vl_seg_part.push(SegmentPartList::new());
}
for part in s_seg_part.iter() {
let key = (part.kmer1, part.kmer2);
let group_id = m_kmers[&key] as usize;
self.vl_seg_part[group_id].emplace(part.clone());
}
s_seg_part.clear();
no_new
}
pub fn get_num_new(&self) -> usize {
let s_seg_part = self.s_seg_part.lock().unwrap();
s_seg_part.len()
}
pub fn distribute_segments(&self, src_id: u32, dest_id_from: u32, dest_id_to: u32) {
let src_id = src_id as usize;
let no_in_src = self.vl_seg_part[src_id].size();
let mut dest_id_curr = dest_id_from;
for _ in 0..no_in_src {
if dest_id_curr != src_id as u32 {
if let Some(part) = self.vl_seg_part[src_id].pop() {
self.vl_seg_part[dest_id_curr as usize].emplace(part);
}
}
dest_id_curr += 1;
if dest_id_curr == dest_id_to {
dest_id_curr = dest_id_from;
}
}
}
pub fn clear(&mut self, _num_threads: usize) {
let _lock = self.resize_mtx.lock().unwrap();
let mut s_seg_part = self.s_seg_part.lock().unwrap();
s_seg_part.clear();
drop(s_seg_part);
for list in &self.vl_seg_part {
list.clear();
}
}
pub fn restart_read_vec(&self) {
let _lock = self.resize_mtx.lock().unwrap();
self.a_v_part_id
.store((self.vl_seg_part.len() - 1) as i32, Ordering::SeqCst);
}
pub fn get_vec_id(&self) -> i32 {
self.a_v_part_id.fetch_sub(1, Ordering::SeqCst)
}
pub fn is_empty_part(&self, group_id: i32) -> bool {
if group_id < 0 || group_id as usize >= self.vl_seg_part.len() {
return true;
}
self.vl_seg_part[group_id as usize].is_empty()
}
pub fn get_part(
&self,
group_id: i32,
) -> Option<(u64, u64, String, String, Vec<u8>, bool, u32)> {
if group_id < 0 || group_id as usize >= self.vl_seg_part.len() {
return None;
}
self.vl_seg_part[group_id as usize].pop().map(|part| {
(
part.kmer1,
part.kmer2,
part.sample_name,
part.contig_name,
part.seg_data,
part.is_rev_comp,
part.seg_part_no,
)
})
}
pub fn get_no_parts(&self) -> usize {
self.vl_seg_part.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_segment_part_ordering() {
let part1 = SegmentPart {
kmer1: 100,
kmer2: 200,
sample_name: "sample1".to_string(),
contig_name: "chr1".to_string(),
seg_data: vec![0, 1, 2],
is_rev_comp: false,
seg_part_no: 0,
};
let part2 = SegmentPart {
kmer1: 100,
kmer2: 200,
sample_name: "sample1".to_string(),
contig_name: "chr1".to_string(),
seg_data: vec![3, 4, 5],
is_rev_comp: false,
seg_part_no: 1,
};
assert!(part1 < part2); }
#[test]
fn test_buffered_segments_add_known() {
let buf = BufferedSegments::new(10);
buf.add_known(
5,
100,
200,
"sample1".to_string(),
"chr1".to_string(),
vec![0, 1, 2, 3],
false,
0,
);
assert!(!buf.is_empty_part(5));
}
#[test]
fn test_buffered_segments_add_new_and_process() {
let mut buf = BufferedSegments::new(10);
buf.add_new(
300,
400,
"sample1".to_string(),
"chr1".to_string(),
vec![4, 5, 6, 7],
false,
0,
);
let map_segments = std::sync::Mutex::new(AHashMap::new());
let no_new = buf.process_new(&map_segments);
assert_eq!(no_new, 1);
assert_eq!(buf.get_no_parts(), 11);
assert!(!buf.is_empty_part(10));
}
#[test]
fn test_buffered_segments_get_vec_id() {
let buf = BufferedSegments::new(5);
buf.restart_read_vec();
assert_eq!(buf.get_vec_id(), 4);
assert_eq!(buf.get_vec_id(), 3);
assert_eq!(buf.get_vec_id(), 2);
assert_eq!(buf.get_vec_id(), 1);
assert_eq!(buf.get_vec_id(), 0);
assert!(buf.get_vec_id() < 0);
}
#[test]
fn test_buffered_segments_get_part() {
let buf = BufferedSegments::new(10);
buf.add_known(
5,
100,
200,
"sample1".to_string(),
"chr1".to_string(),
vec![0, 1, 2, 3],
false,
0,
);
let part = buf.get_part(5);
assert!(part.is_some());
let (kmer1, kmer2, sample, contig, data, is_rev, part_no) = part.unwrap();
assert_eq!(kmer1, 100);
assert_eq!(kmer2, 200);
assert_eq!(sample, "sample1");
assert_eq!(contig, "chr1");
assert_eq!(data, vec![0, 1, 2, 3]);
assert_eq!(is_rev, false);
assert_eq!(part_no, 0);
assert!(buf.get_part(5).is_none());
}
#[test]
fn test_buffered_segments_sort() {
let buf = BufferedSegments::new(1);
buf.add_known(
0,
100,
200,
"sample1".to_string(),
"chr1".to_string(),
vec![2],
false,
2,
);
buf.add_known(
0,
100,
200,
"sample1".to_string(),
"chr1".to_string(),
vec![0],
false,
0,
);
buf.add_known(
0,
100,
200,
"sample1".to_string(),
"chr1".to_string(),
vec![1],
false,
1,
);
buf.sort_known(1);
let part0 = buf.get_part(0).unwrap();
assert_eq!(part0.6, 0);
let part1 = buf.get_part(0).unwrap();
assert_eq!(part1.6, 1);
let part2 = buf.get_part(0).unwrap();
assert_eq!(part2.6, 2);
}
}