use coitrees::{COITree, Interval, IntervalTree};
use crate::vcf::genotype::VariantRecord;
#[derive(Debug, Clone)]
pub struct HaplotypeVariant {
pub ref_pos: u32,
pub ref_len: u32,
pub alt_bases: Vec<u8>,
}
pub struct Haplotype {
allele_index: usize,
variant_data: Vec<HaplotypeVariant>,
variant_tree: COITree<u32, u32>,
}
impl Haplotype {
#[must_use]
pub fn allele_index(&self) -> usize {
self.allele_index
}
#[must_use]
#[expect(clippy::cast_possible_wrap, reason = "genomic coords < i32::MAX")]
pub fn extract_fragment(
&self,
reference: &[u8],
ref_start: u32,
fragment_len: usize,
) -> (Vec<u8>, Vec<u32>) {
let mut bases = Vec::with_capacity(fragment_len);
let mut ref_positions = Vec::with_capacity(fragment_len);
let query_end = (ref_start as usize + fragment_len).min(reference.len());
#[expect(
clippy::cast_possible_truncation,
reason = "query_end bounded by reference.len() which fits i32"
)]
let query_end_i32 = query_end.saturating_sub(1) as i32;
let mut overlapping_indices: Vec<u32> = Vec::new();
self.variant_tree.query(ref_start as i32, query_end_i32, |node| {
#[allow(clippy::clone_on_copy)]
overlapping_indices.push(node.metadata.clone());
});
overlapping_indices.sort_unstable_by_key(|&idx| self.variant_data[idx as usize].ref_pos);
let mut ref_pos = ref_start as usize;
let mut var_idx = 0;
while var_idx < overlapping_indices.len() {
let var = &self.variant_data[overlapping_indices[var_idx] as usize];
let var_end = var.ref_pos as usize + var.ref_len as usize;
if var_end <= ref_pos {
var_idx += 1;
} else if (var.ref_pos as usize) < ref_pos {
ref_pos = var_end;
var_idx += 1;
} else {
break;
}
}
while bases.len() < fragment_len && ref_pos < reference.len() {
if var_idx < overlapping_indices.len() {
let var = &self.variant_data[overlapping_indices[var_idx] as usize];
if var.ref_pos as usize == ref_pos {
for &b in &var.alt_bases {
if bases.len() >= fragment_len {
break;
}
bases.push(b);
#[expect(
clippy::cast_possible_truncation,
reason = "ref positions fit in u32"
)]
ref_positions.push(ref_pos as u32);
}
ref_pos += var.ref_len as usize;
var_idx += 1;
continue;
}
}
bases.push(reference[ref_pos]);
#[expect(clippy::cast_possible_truncation, reason = "ref positions fit in u32")]
ref_positions.push(ref_pos as u32);
ref_pos += 1;
}
bases.truncate(fragment_len);
ref_positions.truncate(fragment_len);
(bases, ref_positions)
}
}
pub fn build_haplotypes(
variants: &[VariantRecord],
max_ploidy: usize,
rng: &mut impl rand::Rng,
) -> Vec<Haplotype> {
let mut variant_data_per_hap: Vec<Vec<HaplotypeVariant>> =
(0..max_ploidy).map(|_| Vec::new()).collect();
let mut intervals_per_hap: Vec<Vec<Interval<u32>>> =
(0..max_ploidy).map(|_| Vec::new()).collect();
for vr in variants {
let gt = &vr.genotype;
let hap_permutation: Vec<usize> = if gt.is_phased() {
(0..max_ploidy).collect()
} else {
let mut perm: Vec<usize> = (0..max_ploidy).collect();
for i in (1..perm.len()).rev() {
let j = rng.random_range(0..=i);
perm.swap(i, j);
}
perm
};
for (allele_idx, allele) in gt.alleles().iter().enumerate() {
if allele_idx >= max_ploidy {
break;
}
let Some(allele_num) = allele else { continue };
if *allele_num == 0 {
continue;
}
let Some(alt_bases) = vr.allele_bases(*allele_num) else {
continue;
};
let hap_var = HaplotypeVariant {
ref_pos: vr.position,
#[expect(clippy::cast_possible_truncation, reason = "ref allele < 4 GB")]
ref_len: vr.ref_allele.len() as u32,
alt_bases: alt_bases.to_vec(),
};
let target_hap = hap_permutation[allele_idx];
let data_idx = variant_data_per_hap[target_hap].len();
variant_data_per_hap[target_hap].push(hap_var);
let end_pos = (vr.position as usize + vr.ref_allele.len()).saturating_sub(1);
#[expect(
clippy::cast_possible_wrap,
reason = "genomic coords and variant index < i32::MAX / u32::MAX"
)]
#[expect(
clippy::cast_possible_truncation,
reason = "genomic coords and variant index < i32::MAX / u32::MAX"
)]
let iv = Interval::new(vr.position as i32, end_pos as i32, data_idx as u32);
intervals_per_hap[target_hap].push(iv);
}
}
variant_data_per_hap
.into_iter()
.zip(intervals_per_hap)
.enumerate()
.map(|(i, (data, ivs))| Haplotype {
allele_index: i,
variant_data: data,
variant_tree: COITree::new(&ivs),
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::vcf::genotype::Genotype;
fn snp(pos: u32, ref_base: u8, alt_base: u8, gt: &str) -> VariantRecord {
VariantRecord {
position: pos,
ref_allele: vec![ref_base],
alt_alleles: vec![vec![alt_base]],
genotype: Genotype::parse(gt).unwrap(),
}
}
fn indel(pos: u32, ref_allele: &[u8], alt_allele: &[u8], gt: &str) -> VariantRecord {
VariantRecord {
position: pos,
ref_allele: ref_allele.to_vec(),
alt_alleles: vec![alt_allele.to_vec()],
genotype: Genotype::parse(gt).unwrap(),
}
}
#[test]
fn test_extract_fragment_no_variants() {
let reference = b"ACGTACGTACGT";
let haps = build_haplotypes(&[], 2, &mut rand::rng());
assert_eq!(haps.len(), 2);
let (bases, positions) = haps[0].extract_fragment(reference, 2, 5);
assert_eq!(&bases, b"GTACG");
assert_eq!(&positions, &[2, 3, 4, 5, 6]);
}
#[test]
fn test_extract_fragment_with_snp() {
let reference = b"AAAAAAAA";
let variants = vec![snp(3, b'A', b'T', "0|1")];
let haps = build_haplotypes(&variants, 2, &mut rand::rng());
let (bases, _) = haps[0].extract_fragment(reference, 0, 8);
assert_eq!(&bases, b"AAAAAAAA");
let (bases, _) = haps[1].extract_fragment(reference, 0, 8);
assert_eq!(&bases, b"AAATAAAA");
}
#[test]
fn test_extract_fragment_with_insertion() {
let reference = b"AAAAAAAA";
let variants = vec![indel(3, b"A", b"ATT", "0|1")];
let haps = build_haplotypes(&variants, 2, &mut rand::rng());
let (bases, positions) = haps[1].extract_fragment(reference, 0, 10);
assert_eq!(&bases, b"AAAATTAAAA");
assert_eq!(&positions, &[0, 1, 2, 3, 3, 3, 4, 5, 6, 7]);
}
#[test]
fn test_extract_fragment_with_deletion() {
let reference = b"ACGTACGTAC";
let variants = vec![indel(4, b"ACG", b"A", "0|1")];
let haps = build_haplotypes(&variants, 2, &mut rand::rng());
let (bases, _) = haps[1].extract_fragment(reference, 0, 8);
assert_eq!(&bases, b"ACGTATAC");
}
#[test]
fn test_hom_alt_both_haplotypes_affected() {
let reference = b"AAAA";
let variants = vec![snp(1, b'A', b'T', "1|1")];
let haps = build_haplotypes(&variants, 2, &mut rand::rng());
let (bases0, _) = haps[0].extract_fragment(reference, 0, 4);
let (bases1, _) = haps[1].extract_fragment(reference, 0, 4);
assert_eq!(&bases0, b"ATAA");
assert_eq!(&bases1, b"ATAA");
}
#[test]
fn test_phased_allele_assignment() {
let reference = b"AAAA";
let variants = vec![snp(1, b'A', b'T', "1|0")];
let haps = build_haplotypes(&variants, 2, &mut rand::rng());
let (bases0, _) = haps[0].extract_fragment(reference, 0, 4);
let (bases1, _) = haps[1].extract_fragment(reference, 0, 4);
assert_eq!(&bases0, b"ATAA");
assert_eq!(&bases1, b"AAAA");
}
#[test]
fn test_fragment_starts_mid_reference() {
let reference = b"ACGTACGTAC";
let variants = vec![snp(5, b'C', b'T', "0|1")];
let haps = build_haplotypes(&variants, 2, &mut rand::rng());
let (bases, positions) = haps[1].extract_fragment(reference, 3, 5);
assert_eq!(&bases, b"TATGT");
assert_eq!(&positions, &[3, 4, 5, 6, 7]);
}
#[test]
fn test_fragment_starts_within_deletion() {
let reference = b"ACGTACGTAC";
let variants = vec![indel(2, b"GTA", b"G", "0|1")];
let haps = build_haplotypes(&variants, 2, &mut rand::rng());
let (bases, _) = haps[1].extract_fragment(reference, 3, 5);
assert_eq!(&bases, b"CGTAC");
}
#[test]
fn test_adjacent_variants() {
let reference = b"AAAA";
let variants = vec![snp(1, b'A', b'T', "0|1"), snp(2, b'A', b'C', "0|1")];
let haps = build_haplotypes(&variants, 2, &mut rand::rng());
let (bases, _) = haps[1].extract_fragment(reference, 0, 4);
assert_eq!(&bases, b"ATCA");
}
#[test]
fn test_variant_at_position_zero() {
let reference = b"ACGT";
let variants = vec![snp(0, b'A', b'T', "0|1")];
let haps = build_haplotypes(&variants, 2, &mut rand::rng());
let (bases, _) = haps[1].extract_fragment(reference, 0, 4);
assert_eq!(&bases, b"TCGT");
}
#[test]
fn test_unphased_hom_alt_both_haplotypes() {
let reference = b"AAAA";
let variants = vec![snp(1, b'A', b'T', "1/1")];
let haps = build_haplotypes(&variants, 2, &mut rand::rng());
let (bases0, _) = haps[0].extract_fragment(reference, 0, 4);
let (bases1, _) = haps[1].extract_fragment(reference, 0, 4);
assert_eq!(&bases0, b"ATAA");
assert_eq!(&bases1, b"ATAA");
}
}