use ahash::AHashMap;
const MIN_MATCH_LEN: usize = 15;
const KEY_LEN: usize = 15;
const HASHING_STEP: usize = 3;
const MAX_NO_TRIES: usize = 4;
const MIN_NRUN_LEN: usize = 5;
pub fn get_coding_cost_vector(reference: &[u8], text: &[u8], prefix_costs: bool) -> Vec<u32> {
if reference.is_empty() || text.is_empty() {
return vec![1; text.len()];
}
let ht = build_hash_table(reference);
let mut v_costs = Vec::with_capacity(text.len());
let text_size = text.len();
let mut i = 0;
let mut pred_pos = 0;
let mut no_prev_literals = 0;
while i + KEY_LEN < text_size {
if text[i] == b'N' {
let nrun_len = get_nrun_len(&text[i..]);
if nrun_len >= MIN_NRUN_LEN {
let tc = coding_cost_nrun(nrun_len);
if prefix_costs {
v_costs.push(tc);
v_costs.extend(std::iter::repeat(0).take(nrun_len - 1));
} else {
v_costs.extend(std::iter::repeat(0).take(nrun_len - 1));
v_costs.push(tc);
}
i += nrun_len;
no_prev_literals = 0;
continue;
}
}
let key = compute_kmer_hash(&text[i..], KEY_LEN);
if let Some(match_result) = find_best_match(&ht, reference, &text[i..], no_prev_literals) {
let len_bck = match_result.len_bck;
let len_fwd = match_result.len_fwd;
let match_pos = match_result.pos;
if len_bck > 0 {
for _ in 0..len_bck {
v_costs.pop();
}
pred_pos -= len_bck;
i -= len_bck;
}
let total_len = len_bck + len_fwd;
let tc = coding_cost_match(match_pos, total_len, pred_pos);
if prefix_costs {
v_costs.push(tc);
v_costs.extend(std::iter::repeat(0).take(total_len - 1));
} else {
v_costs.extend(std::iter::repeat(0).take(total_len - 1));
v_costs.push(tc);
}
pred_pos = match_pos + total_len;
i += total_len;
no_prev_literals = 0;
} else {
v_costs.push(1);
i += 1;
pred_pos += 1;
no_prev_literals += 1;
}
}
while i < text_size {
v_costs.push(1);
i += 1;
}
v_costs
}
struct MatchResult {
pos: usize,
len_bck: usize,
len_fwd: usize,
}
fn build_hash_table(reference: &[u8]) -> AHashMap<u64, Vec<usize>> {
let mut ht = AHashMap::new();
let ref_len = reference.len();
let mut i = 0;
while i + KEY_LEN <= ref_len {
if reference[i..i + KEY_LEN].contains(&b'N') {
i += 1;
continue;
}
let key = compute_kmer_hash(&reference[i..], KEY_LEN);
ht.entry(key).or_insert_with(Vec::new).push(i);
i += HASHING_STEP;
}
ht
}
fn compute_kmer_hash(seq: &[u8], len: usize) -> u64 {
let mut hash: u64 = 0;
for i in 0..len.min(seq.len()) {
hash = hash.wrapping_mul(31).wrapping_add(seq[i] as u64);
}
hash
}
fn find_best_match(
ht: &AHashMap<u64, Vec<usize>>,
reference: &[u8],
text: &[u8],
no_prev_literals: usize,
) -> Option<MatchResult> {
if text.len() < KEY_LEN {
return None;
}
let key = compute_kmer_hash(text, KEY_LEN);
let positions = ht.get(&key)?;
let mut best_len = 0;
let mut best_pos = 0;
let mut best_len_bck = 0;
let mut best_len_fwd = 0;
for &ref_pos in positions.iter().take(MAX_NO_TRIES) {
let mut len_fwd = 0;
let max_fwd = (reference.len() - ref_pos).min(text.len());
while len_fwd < max_fwd && reference[ref_pos + len_fwd] == text[len_fwd] {
len_fwd += 1;
}
if len_fwd < KEY_LEN {
continue;
}
let len_bck = 0;
let max_bck = no_prev_literals.min(ref_pos);
let _ = max_bck;
let total_len = len_bck + len_fwd;
if total_len > best_len {
best_len = total_len;
best_pos = ref_pos;
best_len_bck = len_bck;
best_len_fwd = len_fwd;
}
}
if best_len >= MIN_MATCH_LEN {
Some(MatchResult {
pos: best_pos,
len_bck: best_len_bck,
len_fwd: best_len_fwd,
})
} else {
None
}
}
fn get_nrun_len(text: &[u8]) -> usize {
text.iter().take_while(|&&b| b == b'N').count()
}
fn coding_cost_nrun(len: usize) -> u32 {
if len <= 4 {
len as u32
} else {
4 + ((len as f32).log2().ceil() as u32)
}
}
fn coding_cost_match(match_pos: usize, match_len: usize, pred_pos: usize) -> u32 {
let delta = if match_pos >= pred_pos {
match_pos - pred_pos
} else {
pred_pos - match_pos
};
let len_cost = if match_len < 16 {
1
} else if match_len < 256 {
2
} else {
3
};
let pos_cost = if delta < 256 {
1
} else if delta < 65536 {
2
} else {
3
};
len_cost + pos_cost
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[ignore] fn test_get_coding_cost_vector() {
let reference = b"ACGTACGTACGTACGT";
let text = b"ACGTACGTNNNNACGT";
let costs = get_coding_cost_vector(reference, text, true);
assert_eq!(costs.len(), text.len());
assert!(costs.iter().sum::<u32>() < text.len() as u32);
}
#[test]
fn test_nrun_detection() {
assert_eq!(get_nrun_len(b"NNNNNACGT"), 5);
assert_eq!(get_nrun_len(b"ACGT"), 0);
}
}