use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::sync::atomic::Ordering;
use std::thread;
use std::io;
use crossbeam_queue::ArrayQueue;
use bitvec::prelude::*;
use rayon::prelude::*;
use uf_rush::UFRush;
use crate::pos::{PosT, offset, is_rev, incr_pos, incr_pos_by, decr_pos, decr_pos_by, make_pos_t};
use crate::seqindex::SeqIndex;
use iitree_rs::IITree;
#[inline]
pub fn wang_hash_64(mut key: u64) -> u64 {
key = (!key).wrapping_add(key << 21); key = key ^ (key >> 24);
key = key.wrapping_add(key << 3).wrapping_add(key << 8); key = key ^ (key >> 14);
key = key.wrapping_add(key << 2).wrapping_add(key << 4); key = key ^ (key >> 28);
key = key.wrapping_add(key << 31);
key
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Range {
pub begin: u64,
pub end: u64,
}
impl Range {
pub fn new(begin: u64, end: u64) -> Self {
Range { begin, end }
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Match {
pub start: u64,
pub end: u64,
pub data: PosT,
}
impl Match {
pub fn new(start: u64, end: u64, data: PosT) -> Self {
Match { start, end, data }
}
pub fn length(&self) -> u64 {
self.end - self.start
}
}
type RangeAtomicQueue = ArrayQueue<(PosT, u64)>;
type OverlapAtomicQueue = ArrayQueue<(Match, bool)>;
#[derive(Debug)]
struct AtomicBitVec {
bits: BitVec<u64, Lsb0>,
}
impl AtomicBitVec {
fn new(size: usize) -> Self {
AtomicBitVec {
bits: BitVec::repeat(false, size),
}
}
fn set(&self, index: usize) -> bool {
unsafe {
let ptr = self.bits.as_raw_slice().as_ptr() as *mut u64;
let word_index = index / 64;
let bit_index = index % 64;
let mask = 1u64 << bit_index;
let word_ptr = ptr.add(word_index);
let old_word = std::ptr::read_volatile(word_ptr);
let old_bit = (old_word & mask) != 0;
std::ptr::write_volatile(word_ptr, old_word | mask);
old_bit
}
}
}
pub fn extend_range(
s_pos: u64,
q_pos: PosT,
range_buffer: &mut HashMap<PosT, Range>,
seqidx: &SeqIndex,
node_iitree: &mut IITree<u64, PosT>,
path_iitree: &mut IITree<u64, PosT>,
) -> io::Result<()> {
let mut q_last_pos = q_pos;
decr_pos(&mut q_last_pos);
if let Some(found) = range_buffer.get(&q_last_pos).copied() {
let at_boundary = if !is_rev(q_pos) {
seqidx.seq_start(offset(q_pos))
} else {
seqidx.seq_start(offset(q_last_pos))
};
if at_boundary {
flush_single_range(&found, q_last_pos, node_iitree, path_iitree)?;
range_buffer.remove(&q_last_pos);
range_buffer.insert(q_pos, Range::new(s_pos, s_pos + 1));
} else if found.end == s_pos {
range_buffer.remove(&q_last_pos);
range_buffer.insert(q_pos, Range::new(found.begin, s_pos + 1));
} else {
range_buffer.insert(q_pos, Range::new(s_pos, s_pos + 1));
}
} else {
range_buffer.insert(q_pos, Range::new(s_pos, s_pos + 1));
}
Ok(())
}
fn flush_single_range(
range_in_s: &Range,
match_end_pos_in_q: PosT,
node_iitree: &mut IITree<u64, PosT>,
path_iitree: &mut IITree<u64, PosT>,
) -> io::Result<()> {
let is_rev_match = is_rev(match_end_pos_in_q);
let match_length = range_in_s.end - range_in_s.begin;
let match_start_in_s = range_in_s.begin;
let match_end_in_s = range_in_s.end;
let (match_pos_in_s, match_pos_in_q, match_start_in_q, match_end_in_q) = if !is_rev_match {
let match_end_in_q = offset(match_end_pos_in_q) + 1;
let match_start_in_q = match_end_in_q - match_length;
let match_pos_in_s = make_pos_t(match_start_in_s, false);
let match_pos_in_q = make_pos_t(match_start_in_q, false);
(match_pos_in_s, match_pos_in_q, match_start_in_q, match_end_in_q)
} else {
let match_end_in_q = offset(match_end_pos_in_q);
let mut match_end_pos_in_q_tmp = match_end_pos_in_q;
decr_pos_by(&mut match_end_pos_in_q_tmp, match_length as usize);
let match_pos_in_s = make_pos_t(match_end_in_s - 1, true);
let match_pos_in_q = make_pos_t(offset(match_end_pos_in_q_tmp) - 1, true);
let match_start_in_q = match_end_in_q;
let match_end_in_q = offset(match_end_pos_in_q_tmp);
(match_pos_in_s, match_pos_in_q, match_start_in_q, match_end_in_q)
};
node_iitree.add(match_start_in_s, match_end_in_s, match_pos_in_q);
path_iitree.add(match_start_in_q, match_end_in_q, match_pos_in_s);
Ok(())
}
pub fn flush_ranges(
s_pos: u64,
range_buffer: &mut HashMap<PosT, Range>,
node_iitree: &mut IITree<u64, PosT>,
path_iitree: &mut IITree<u64, PosT>,
) -> io::Result<()> {
let to_flush: Vec<_> = range_buffer
.iter()
.filter(|(_, range)| range.end != s_pos)
.map(|(k, v)| (*k, *v))
.collect();
for (key, range) in to_flush {
flush_single_range(&range, key, node_iitree, path_iitree)?;
range_buffer.remove(&key);
}
Ok(())
}
fn for_each_fresh_range<F>(
range: &Match,
seen_bv: &[bool],
mut lambda: F,
) where
F: FnMut(Match),
{
let mut p = range.start;
let mut t = range.data;
while p < range.end {
if seen_bv[p as usize] {
p += 1;
incr_pos(&mut t);
} else {
let q = p;
let v = t;
while p < range.end && !seen_bv[p as usize] {
p += 1;
incr_pos(&mut t);
}
lambda(Match::new(q, p, v));
}
}
}
fn handle_range(
s: Match,
curr_bv: &AtomicBitVec,
ovlp_q: &OverlapAtomicQueue,
todo_in: &RangeAtomicQueue,
) {
let mut all_set_there = true;
let mut n = s.data;
for _i in s.start..s.end {
let was_set = curr_bv.set(offset(n) as usize);
all_set_there = all_set_there && was_set;
incr_pos(&mut n);
}
let _ = ovlp_q.push((s, is_rev(s.data)));
if !all_set_there {
let item = (make_pos_t(offset(s.data), is_rev(s.data)), s.end - s.start);
let _ = todo_in.push(item);
}
}
fn explore_overlaps(
b: &Match,
seen_bv: &[bool],
curr_bv: &AtomicBitVec,
aln_iitree: &IITree<u64, PosT>,
ovlp_q: &OverlapAtomicQueue,
todo_in: &RangeAtomicQueue,
) {
aln_iitree.overlap(b.start, b.end, |_idx, start, end, pos| {
let mut r = Match::new(start, end, pos);
if b.start > r.start {
let trim_from_start = b.start - r.start;
r.start += trim_from_start;
incr_pos_by(&mut r.data, trim_from_start as usize);
}
if r.end > b.end {
let trim_from_end = r.end - b.end;
r.end -= trim_from_end;
}
assert!(r.start < r.end);
for_each_fresh_range(&r, seen_bv, |s| {
handle_range(s, curr_bv, ovlp_q, todo_in);
});
}).ok(); }
fn write_graph_chunk(
seqidx: &SeqIndex,
node_iitree: &mut IITree<u64, PosT>,
path_iitree: &mut IITree<u64, PosT>,
seq_v_out: &mut Vec<u8>,
range_buffer: &mut HashMap<PosT, Range>,
dsets: Vec<(u64, u64)>,
repeat_max: u64,
min_repeat_dist: u64,
) -> io::Result<()> {
let mut seq_v_length = seq_v_out.len() as u64;
let mut last_dset_id = u64::MAX;
let mut current_base = 0u8;
let mut seq_counts: HashMap<u64, u64> = HashMap::new();
let mut last_seq_pos: HashMap<u64, PosT> = HashMap::new();
let close_to_prev = |seq_id: u64, pos: PosT, last_seq_pos: &HashMap<u64, PosT>| -> bool {
if let Some(&last_pos) = last_seq_pos.get(&seq_id) {
let dist = (offset(pos) as i64 - offset(last_pos) as i64).abs() as u64;
dist < min_repeat_dist
} else {
false
}
};
let mut todos: HashMap<u64, Vec<PosT>> = HashMap::new();
for d in dsets {
let curr_dset_id = d.0;
let curr_offset = d.1;
let base = seqidx.at(curr_offset).unwrap_or('N') as u8;
if curr_dset_id != last_dset_id {
if repeat_max != 0 || min_repeat_dist != 0 {
for (_count, positions) in todos.iter() {
seq_v_out.push(current_base);
seq_v_length += 1;
for pos in positions {
extend_range(seq_v_length - 1, *pos, range_buffer, seqidx, node_iitree, path_iitree)?;
}
}
todos.clear();
seq_counts.clear();
last_seq_pos.clear();
}
current_base = base;
seq_v_out.push(current_base);
seq_v_length += 1;
flush_ranges(seq_v_length - 1, range_buffer, node_iitree, path_iitree)?;
last_dset_id = curr_dset_id;
}
let mut curr_q_pos = make_pos_t(curr_offset, false);
if current_base != seqidx.at_pos(curr_q_pos).unwrap_or('N') as u8 {
curr_q_pos = make_pos_t(curr_offset, true);
}
assert_eq!(current_base, seqidx.at_pos(curr_q_pos).unwrap_or('N') as u8);
if let Some(curr_seq_id) = seqidx.seq_id_at(curr_offset) {
let curr_seq_id = curr_seq_id as u64;
let mut curr_seq_count = 0u64;
if (min_repeat_dist != 0 && close_to_prev(curr_seq_id, curr_q_pos, &last_seq_pos))
|| (repeat_max != 0 && seq_counts.get(&curr_seq_id).unwrap_or(&0) + 1 > repeat_max)
{
curr_seq_count = *seq_counts.entry(curr_seq_id).or_insert(0) + 1;
seq_counts.insert(curr_seq_id, curr_seq_count);
} else if repeat_max != 0 || min_repeat_dist != 0 {
*seq_counts.entry(curr_seq_id).or_insert(0) += 1;
}
if curr_seq_count == 0 {
extend_range(seq_v_length - 1, curr_q_pos, range_buffer, seqidx, node_iitree, path_iitree)?;
} else {
todos.entry(curr_seq_count).or_insert_with(Vec::new).push(curr_q_pos);
}
last_seq_pos.insert(curr_seq_id, curr_q_pos);
}
}
for (_count, positions) in todos.iter() {
seq_v_out.push(current_base);
seq_v_length += 1;
for pos in positions {
extend_range(seq_v_length - 1, *pos, range_buffer, seqidx, node_iitree, path_iitree)?;
}
}
Ok(())
}
pub fn compute_transitive_closures(
seqidx: Arc<SeqIndex>,
aln_iitree: Arc<Mutex<IITree<u64, PosT>>>,
seq_v_file: &str,
node_iitree: Arc<Mutex<IITree<u64, PosT>>>,
path_iitree: Arc<Mutex<IITree<u64, PosT>>>,
repeat_max: u64,
min_repeat_dist: u64,
transclose_batch_size: u64,
show_progress: bool,
num_threads: usize,
) -> io::Result<usize> {
use std::fs::File;
use std::io::Write;
use std::collections::VecDeque;
use std::sync::atomic::{AtomicBool, Ordering};
eprintln!("[transclosure] Starting transitive closure computation");
eprintln!("[transclosure] Using {} threads", num_threads);
node_iitree.lock().unwrap().open_writer()?;
path_iitree.lock().unwrap().open_writer()?;
let mut seq_v_out = Vec::new();
let input_seq_length = seqidx.seq_length() as usize;
let mut q_seen_bv = vec![false; input_seq_length];
let mut range_buffer: HashMap<PosT, Range> = HashMap::new();
let mut bases_seen = 0u64;
let mut i = 0;
while i < input_seq_length {
while i < input_seq_length && q_seen_bv[i] {
i += 1;
}
if i >= input_seq_length {
break;
}
let chunk_start = i;
let mut bases_to_consider = 0;
let mut chunk_end = chunk_start;
while bases_to_consider < transclose_batch_size as usize && chunk_end < input_seq_length {
if !q_seen_bv[chunk_end] {
bases_to_consider += 1;
}
chunk_end += 1;
}
if show_progress {
eprintln!(
"[transclosure] {:.2}% {}-{} overlap_collect",
(bases_seen as f64 / input_seq_length as f64) * 100.0,
chunk_start,
chunk_end
);
}
let q_curr_bv = AtomicBitVec::new(input_seq_length);
let todo_in = Arc::new(ArrayQueue::new(100000));
let todo_out = Arc::new(ArrayQueue::new(100000));
let ovlp_q = Arc::new(ArrayQueue::new(100000));
let mut todo: VecDeque<(PosT, u64)> = VecDeque::new();
let mut ovlp: Vec<(Match, bool)> = Vec::new();
for_each_fresh_range(
&Match::new(chunk_start as u64, chunk_end as u64, 0),
&q_seen_bv,
|b| {
for j in b.start..b.end {
q_curr_bv.set(j as usize);
}
let range = (make_pos_t(b.start, false), b.end - b.start);
if todo_out.push(range).is_err() {
todo.push_back(range);
}
},
);
let work_todo = Arc::new(AtomicBool::new(true));
let aln_iitree_clone = Arc::clone(&aln_iitree);
let q_curr_bv_shared = Arc::new(q_curr_bv);
let workers: Vec<_> = (0..num_threads)
.map(|_| {
let work_todo = Arc::clone(&work_todo);
let todo_out = Arc::clone(&todo_out);
let todo_in = Arc::clone(&todo_in);
let ovlp_q = Arc::clone(&ovlp_q);
let aln_iitree = Arc::clone(&aln_iitree_clone);
let q_curr_bv = Arc::clone(&q_curr_bv_shared);
let q_seen_bv_clone = q_seen_bv.clone();
thread::spawn(move || {
while work_todo.load(Ordering::Relaxed) {
if let Some(item) = todo_out.pop() {
let (pos, match_len) = item;
let n = if !is_rev(pos) {
offset(pos)
} else {
offset(pos) - match_len + 1
};
let range_start = n;
let range_end = n + match_len;
if let Ok(aln_guard) = aln_iitree.lock() {
explore_overlaps(
&Match::new(range_start, range_end, pos),
&q_seen_bv_clone,
&q_curr_bv,
&aln_guard,
&ovlp_q,
&todo_in,
);
}
} else {
thread::sleep(std::time::Duration::from_nanos(1));
}
}
})
})
.collect();
let mut empty_iter_count = 0;
while !todo_in.is_empty() || !todo.is_empty() || !todo_out.is_empty() || !ovlp_q.is_empty() || empty_iter_count < 1000 {
thread::sleep(std::time::Duration::from_nanos(10));
while let Some(item) = todo_in.pop() {
todo.push_back(item);
}
while let Some(item) = todo.front().copied() {
if todo_out.push(item).is_ok() {
todo.pop_front();
empty_iter_count = 0;
} else {
break;
}
}
while let Some(o) = ovlp_q.pop() {
ovlp.push(o);
}
if todo_in.is_empty() && todo.is_empty() && todo_out.is_empty() && ovlp_q.is_empty() {
empty_iter_count += 1;
} else {
empty_iter_count = 0;
}
}
work_todo.store(false, Ordering::Relaxed);
for worker in workers {
worker.join().ok();
}
if show_progress {
eprintln!(
"[transclosure] {:.2}% {}-{} union_find",
(bases_seen as f64 / input_seq_length as f64) * 100.0,
chunk_start,
chunk_end
);
}
let q_curr_bv_final = Arc::try_unwrap(q_curr_bv_shared).unwrap();
let mut q_curr_positions = Vec::new();
for pos in 0..input_seq_length {
if q_curr_bv_final.bits[pos] {
q_curr_positions.push(pos as u64);
}
}
let q_curr_bv_count = q_curr_positions.len();
if q_curr_bv_count == 0 {
i = chunk_end;
continue;
}
let dsets = UFRush::new(q_curr_bv_count);
ovlp.par_iter().for_each(|s| {
let r = &s.0;
let mut p = r.data;
for j in r.start..r.end {
let j_rank = q_curr_positions.binary_search(&j).unwrap();
let p_rank = q_curr_positions.binary_search(&offset(p)).unwrap();
dsets.unite(j_rank, p_rank);
incr_pos(&mut p);
}
});
if show_progress {
eprintln!(
"[transclosure] {:.2}% {}-{} dset_write",
(bases_seen as f64 / input_seq_length as f64) * 100.0,
chunk_start,
chunk_end
);
}
let mut dsets_vec: Vec<(u64, u64)> = q_curr_positions
.par_iter()
.enumerate()
.filter_map(|(j, &p)| {
if !q_seen_bv[p as usize] {
Some((dsets.find(j) as u64, p))
} else {
None
}
})
.collect();
if dsets_vec.is_empty() {
i = chunk_end;
continue;
}
if show_progress {
eprintln!(
"[transclosure] {:.2}% {}-{} dset_sort",
(bases_seen as f64 / input_seq_length as f64) * 100.0,
chunk_start,
chunk_end
);
}
dsets_vec.par_sort_unstable();
let mut c = 0u64;
let mut last_id = dsets_vec[0].0;
for d in &mut dsets_vec {
if d.0 != last_id {
c += 1;
last_id = d.0;
}
d.0 = c;
}
let mut dsets_by_min_pos = vec![(u64::MAX, 0u64); (c + 1) as usize];
for i in 0..=c {
dsets_by_min_pos[i as usize].1 = i;
}
for d in &dsets_vec {
let minpos = &mut dsets_by_min_pos[d.0 as usize].0;
*minpos = (*minpos).min(d.1);
}
dsets_by_min_pos.par_sort_unstable();
let mut dset_names = vec![0u64; (c + 1) as usize];
for (x, d) in dsets_by_min_pos.iter().enumerate() {
dset_names[d.1 as usize] = x as u64;
}
for d in &mut dsets_vec {
d.0 = dset_names[d.0 as usize];
}
dsets_vec.par_sort_unstable();
for d in &dsets_vec {
q_seen_bv[d.1 as usize] = true;
bases_seen += 1;
}
if show_progress {
eprintln!(
"[transclosure] {:.2}% {}-{} graph_emission",
(bases_seen as f64 / input_seq_length as f64) * 100.0,
chunk_start,
chunk_end
);
}
{
let mut node_guard = node_iitree.lock().unwrap();
let mut path_guard = path_iitree.lock().unwrap();
write_graph_chunk(
&seqidx,
&mut node_guard,
&mut path_guard,
&mut seq_v_out,
&mut range_buffer,
dsets_vec,
repeat_max,
min_repeat_dist,
)?;
}
i = chunk_end;
}
let mut file = File::create(seq_v_file)?;
file.write_all(&seq_v_out)?;
let seq_bytes = seq_v_out.len();
{
let mut node_guard = node_iitree.lock().unwrap();
let mut path_guard = path_iitree.lock().unwrap();
flush_ranges(seq_bytes as u64 + 1, &mut range_buffer, &mut node_guard, &mut path_guard)?;
}
if show_progress {
eprintln!("[transclosure] Building node_iitree and path_iitree indexes");
}
node_iitree.lock().unwrap().close_writer()?;
path_iitree.lock().unwrap().close_writer()?;
node_iitree.lock().unwrap().index()?;
path_iitree.lock().unwrap().index()?;
eprintln!("[transclosure] Transitive closure computation complete");
Ok(seq_bytes)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_wang_hash() {
let val = 12345u64;
assert_eq!(wang_hash_64(val), wang_hash_64(val));
assert_ne!(wang_hash_64(12345), wang_hash_64(54321));
}
#[test]
fn test_range() {
let r = Range::new(10, 20);
assert_eq!(r.begin, 10);
assert_eq!(r.end, 20);
}
#[test]
fn test_match() {
let m = Match::new(0, 100, make_pos_t(50, false));
assert_eq!(m.length(), 100);
}
}