use bit_set::BitSet;
use hashbrown::{HashMap, HashSet};
use std::fs::File;
use std::io::{BufWriter, Write};
use crate::ska_dict::bit_encoding::UInt;
use crate::skalo::utils::{Config, DataInfo, VariantInfo};
type VariantGroups<IntT> = HashMap<(IntT, IntT), Vec<VariantInfo>>;
pub fn process_indels<IntT: for<'a> UInt<'a>>(
indel_groups: VariantGroups<IntT>,
kmer_2_samples: &HashMap<IntT, BitSet>,
data_info: &DataInfo,
config: &Config,
) -> HashSet<IntT> {
log::info!("Processing indels");
let (final_indels, entries_indels) = dereplicate_indels(indel_groups, data_info.k_graph);
let vcf_filename = format!("{}_indels.vcf", config.output_name);
let file = File::create(&vcf_filename).expect("Unable to create VCF file");
let mut writer = BufWriter::new(file);
writeln!(writer, "##fileformat=VCFv4.2").unwrap();
writeln!(
writer,
"# REF corresponds to the most frequent variant among samples"
)
.unwrap();
writeln!(
writer,
"#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\t{}",
data_info.sample_names.join("\t")
)
.unwrap();
let mut nb_indels = 0;
for vec_variants in final_indels.values() {
let bitset_vec: Vec<BitSet> = vec_variants
.iter()
.filter_map(|variant| {
let encoded_kmer =
IntT::encode_kmer(&variant.sequence.get_range(0, data_info.k_graph + 1));
kmer_2_samples.get(&encoded_kmer).cloned()
})
.collect();
let mut missing_samples = 0;
let mut ref_present = false;
let mut alt_present = false;
for i in 0..data_info.sample_names.len() {
let in_ref = bitset_vec[0].contains(i);
let in_alt = bitset_vec[1].contains(i);
if !in_ref && !in_alt {
missing_samples += 1;
} else if in_ref && in_alt {
missing_samples += 1; } else if in_ref {
ref_present = true;
} else {
alt_present = true;
}
}
let proportion_missing = missing_samples as f32 / data_info.sample_names.len() as f32;
if proportion_missing <= config.max_missing && ref_present && alt_present {
nb_indels += 1;
let (vec_inserts, last_kmer) = extract_middle_bases(vec_variants, data_info.k_graph);
let first_kmer = vec_variants[0].sequence.decode()[..data_info.k_graph].to_string();
let mut variants: Vec<(String, usize, &BitSet)> = vec_inserts
.iter()
.zip(&bitset_vec)
.map(|(seq, bitset)| (seq.clone(), bitset.len(), bitset))
.collect();
variants.sort_by(|a, b| b.1.cmp(&a.1));
let (ref_allele, _ref_count, ref_bitset) = &variants[0]; let (alt_allele, _alt_count, alt_bitset) = &variants[1];
let sample_calls: Vec<String> = data_info
.sample_names
.iter()
.enumerate()
.map(|(i, _sample)| {
let in_ref = ref_bitset.contains(i);
let in_alt = alt_bitset.contains(i);
match (in_ref, in_alt) {
(true, true) => "0/1".to_string(), (true, false) => "0".to_string(), (false, true) => "1".to_string(), (false, false) => ".".to_string(), }
})
.collect();
writeln!(
writer,
".\t.\t.\t{}\t{}\t.\tbefore={};after={}\t.\tGT\t{}",
ref_allele,
alt_allele,
first_kmer,
last_kmer,
sample_calls.join("\t")
)
.unwrap();
}
}
log::info!("{nb_indels} indels");
entries_indels
}
fn dereplicate_indels<IntT: for<'a> UInt<'a>>(
indel_groups: VariantGroups<IntT>,
k_graph: usize,
) -> (VariantGroups<IntT>, HashSet<IntT>) {
let mut entries_indels: HashSet<IntT> = HashSet::new();
let mut final_indels: VariantGroups<IntT> = HashMap::new();
let mut sorted_extremities: Vec<((IntT, IntT), usize)> = indel_groups
.iter()
.map(|(key, variants)| {
let total_length: usize = variants
.iter()
.map(|variant| variant.sequence.decode().len())
.sum();
(*key, total_length)
})
.collect();
sorted_extremities.sort_by(|a, b| {
a.1.cmp(&b.1) .then_with(|| a.0 .0.cmp(&b.0 .0)) });
for (combined_ext, _) in sorted_extremities {
let vec_variants = indel_groups.get(&combined_ext).unwrap();
if !entries_indels.contains(&combined_ext.0) {
let rc_1 = IntT::rev_comp(combined_ext.0, k_graph);
let rc_2 = IntT::rev_comp(combined_ext.1, k_graph);
entries_indels.insert(combined_ext.0);
entries_indels.insert(rc_1);
entries_indels.insert(combined_ext.1);
entries_indels.insert(rc_2);
final_indels.insert(combined_ext, vec_variants.clone());
}
}
(final_indels, entries_indels)
}
fn extract_middle_bases(vec_variants: &[VariantInfo], k_graph: usize) -> (Vec<String>, String) {
let reduced_seq: Vec<String> = vec_variants
.iter()
.map(|variant| {
let seq = variant.sequence.decode();
seq[k_graph..].to_string()
})
.collect();
let mut identical = true;
let mut n_nucl = 0;
while identical {
n_nucl += 1;
let mut all_ends: HashSet<String> = HashSet::new();
for seq in &reduced_seq {
if n_nucl > seq.len() {
identical = false;
} else {
let last_n_chars: Vec<String> = seq
.chars()
.rev()
.take(n_nucl)
.map(|c| c.to_string())
.collect();
let concatenated_last_chars: String = last_n_chars.into_iter().rev().collect();
all_ends.insert(concatenated_last_chars.clone());
}
}
if all_ends.len() > 1 {
identical = false;
}
}
n_nucl -= 1;
let pos_end = reduced_seq[0].len() - n_nucl;
let mut last_kmer = reduced_seq[0][pos_end..].to_string();
if last_kmer.len() > k_graph {
last_kmer = last_kmer[..k_graph].to_string();
}
let mut vec_middles: Vec<String> = Vec::new();
for seq in &reduced_seq {
let pos2_end = seq.len() - n_nucl;
let mut middle_bases = &seq[..pos2_end];
if middle_bases.is_empty() {
middle_bases = "-";
}
vec_middles.push(middle_bases.to_string());
}
(vec_middles, last_kmer)
}