use bit_set::BitSet;
use hashbrown::{HashMap, HashSet};
use crate::ska_dict::bit_encoding::UInt;
use crate::skalo::output_snps::create_fasta_and_vcf;
use crate::skalo::positioning::{extract_genomic_kmers, scan_variants};
use crate::skalo::process_indels::process_indels;
use crate::skalo::utils::{Config, DataInfo, VariantInfo};
type VariantGroups<IntT> = HashMap<(IntT, IntT), Vec<VariantInfo>>;
pub fn analyse_variant_groups<IntT: for<'a> UInt<'a>>(
mut variant_groups: VariantGroups<IntT>,
indel_groups: VariantGroups<IntT>,
kmer_2_samples: HashMap<IntT, BitSet>,
config: &Config,
data_info: &DataInfo,
) {
let (do_postioning, kmer_map, genome_name, genome_seq) =
if let Some(path) = &config.reference_genome {
log::info!("Reading reference genome");
let (extracted_kmer_map, seq, name) =
extract_genomic_kmers(path.clone(), data_info.k_graph);
(true, extracted_kmer_map, name, seq)
} else {
(
false,
HashMap::<u128, Vec<u32>>::new(),
"".to_string(),
Vec::<u8>::new(),
)
};
let entries_indels = process_indels(indel_groups, &kmer_2_samples, data_info, config);
log::info!("Filtering paths");
for (_, vec_variant) in variant_groups.iter_mut() {
let mut i = 0;
while i < vec_variant.len() {
let nb_indel_kmers = find_internal_indels(&vec_variant[i], &entries_indels, data_info);
if nb_indel_kmers > config.max_indel_kmers {
vec_variant.remove(i);
} else {
i += 1;
}
}
}
log::info!("Sorting variant groups");
let mut sorted_keys: Vec<_> = variant_groups
.iter()
.filter_map(|(key, value)| {
value.first().map(|variant_info| {
let sequence_length = variant_info.sequence.len();
let ratio = value.len() as f64 / sequence_length as f64;
(key, ratio)
})
})
.collect();
sorted_keys.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
log::info!("Processing SNPs");
let mut entries_done: HashSet<IntT> = HashSet::new();
let mut final_snps: HashMap<u32, Vec<char>> = HashMap::new();
let mut not_postioned = 0;
let mut counter = 0;
for (key, _) in sorted_keys {
if !entries_indels.contains(&key.0)
&& !entries_indels.contains(&IntT::rev_comp(key.1, data_info.k_graph))
{
let vec_variants = variant_groups.get(key).unwrap();
if vec_variants.len() < 2 {
continue;
}
let real_snp_pos = get_potential_snp(vec_variants);
let mut kmers_to_save: HashSet<IntT> = HashSet::new();
let mut found_snp_pos: HashMap<usize, Vec<char>> =
HashMap::with_capacity(real_snp_pos.len());
for &pos in &real_snp_pos {
let mut snp_column = vec!['-'; data_info.sample_names.len()];
let mut tmp_kmers: HashSet<IntT> = HashSet::new();
let mut new_snp = true;
for variant in vec_variants {
let seq = &variant.sequence;
let full_before =
IntT::encode_kmer(&seq.get_range(pos - data_info.k_graph, pos + 1));
let full_after =
IntT::encode_kmer(&seq.get_range(pos, pos + data_info.k_graph + 1));
let rc_after = IntT::rev_comp(full_after, data_info.k_graph + 1);
if !entries_done.contains(&full_before) && !entries_done.contains(&rc_after) {
let last_nucl = IntT::get_last_nucl(full_before);
let samples = kmer_2_samples.get(&full_before).unwrap();
for sample_index in samples {
if snp_column[sample_index] == '-'
|| snp_column[sample_index] == last_nucl
{
snp_column[sample_index] = last_nucl;
} else {
snp_column[sample_index] = 'N';
}
}
tmp_kmers.insert(full_before);
tmp_kmers.insert(IntT::rev_comp(full_before, data_info.k_graph + 1));
tmp_kmers.insert(full_after);
tmp_kmers.insert(rc_after);
} else {
new_snp = false;
}
}
if new_snp {
let (true_variant, ratio_missing) =
check_missing_data(data_info.sample_names.len(), &snp_column);
if true_variant && ratio_missing <= config.max_missing {
kmers_to_save.extend(tmp_kmers);
found_snp_pos.insert(pos, snp_column);
}
}
}
entries_done.extend(kmers_to_save.iter());
if !found_snp_pos.is_empty() {
if do_postioning {
let (position_found, position, orientation) =
scan_variants(vec_variants, data_info.k_graph, &kmer_map);
if position_found {
let seq_length = vec_variants[0].sequence.len();
let is_forward = orientation == "for";
for (pos, column) in found_snp_pos {
let final_position = if is_forward {
position + (pos - data_info.k_graph) as u32
} else {
position + (seq_length - pos - data_info.k_graph - 1) as u32
};
let final_column = if is_forward {
column
} else {
complement_snp(&column)
};
if final_snps.contains_key(&final_position) {
not_postioned += 1;
} else {
final_snps.insert(final_position, final_column);
}
}
} else {
not_postioned += found_snp_pos.len();
}
} else {
for (_, column) in found_snp_pos {
counter += 1;
final_snps.insert(counter, column);
}
}
}
}
}
if do_postioning {
log::info!(
"{} SNPs (+ {} w/o position)",
final_snps.len(),
not_postioned
);
} else {
log::info!("{} SNPs", final_snps.len());
}
create_fasta_and_vcf(
genome_name,
genome_seq,
data_info.sample_names.clone(),
final_snps,
config,
);
}
fn find_internal_indels<IntT: for<'a> UInt<'a>>(
variant: &VariantInfo,
entries_indels: &HashSet<IntT>,
data_info: &DataInfo,
) -> usize {
let mut nb = 0;
let sequence = &variant.sequence.decode();
let k_graph = data_info.k_graph;
for i in 0..(sequence.len() - k_graph) {
let kmer = IntT::encode_kmer_str(&sequence[i..k_graph + i]);
if entries_indels.contains(&kmer) {
nb += 1;
}
}
nb
}
fn get_potential_snp(vec_variant: &Vec<VariantInfo>) -> HashSet<usize> {
let mut snps_set = HashSet::new();
for variant in vec_variant {
snps_set.extend(&variant.vec_snps);
}
let mut actual_snps_set = HashSet::new();
for &pos in &snps_set {
let mut nucleotide_presence = [false; 4];
for variant in vec_variant {
let seq = &variant.sequence;
if pos < seq.len() {
let nucl = seq.get_range(pos, pos + 1)[0];
match nucl {
b'A' => nucleotide_presence[0] = true,
b'C' => nucleotide_presence[1] = true,
b'G' => nucleotide_presence[2] = true,
b'T' => nucleotide_presence[3] = true,
_ => {}
}
}
}
let distinct_count = nucleotide_presence.iter().filter(|&&x| x).count();
if distinct_count > 1 {
actual_snps_set.insert(pos);
}
}
actual_snps_set
}
pub fn check_missing_data(nb_total: usize, snp_column: &[char]) -> (bool, f32) {
let mut nucleotide_counts = [false; 4];
let mut missing_samples = 0;
for &snp in snp_column {
match snp {
'A' => nucleotide_counts[0] = true,
'T' => nucleotide_counts[1] = true,
'G' => nucleotide_counts[2] = true,
'C' => nucleotide_counts[3] = true,
_ => missing_samples += 1,
}
}
let ratio_missing = missing_samples as f32 / nb_total as f32;
let valid_nucleotide_count = nucleotide_counts.iter().filter(|&&present| present).count();
(valid_nucleotide_count >= 2, ratio_missing)
}
fn complement_snp(dna: &[char]) -> Vec<char> {
dna.iter()
.map(|&nucleotide| match nucleotide {
'A' => 'T',
'T' => 'A',
'C' => 'G',
'G' => 'C',
'-' => '-',
'N' => 'N',
_ => panic!("Invalid nucleotide: {nucleotide}"),
})
.collect()
}