Skip to main content

holodeck_lib/
haplotype.rs

1use coitrees::{COITree, Interval, IntervalTree};
2
3use crate::vcf::genotype::VariantRecord;
4
5/// A single variant assigned to a specific haplotype.
6#[derive(Debug, Clone)]
7pub struct HaplotypeVariant {
8    /// 0-based reference position where the variant starts.
9    pub ref_pos: u32,
10    /// Length of the reference allele in bases. Uses `u32` to support
11    /// structural variants with large reference spans.
12    pub ref_len: u32,
13    /// Alternate allele bases.
14    pub alt_bases: Vec<u8>,
15}
16
17/// Sparse representation of one haplotype — a reference overlay of variants.
18///
19/// Instead of materializing a full haplotype sequence (which would require
20/// ~250MB per haplotype for human chr1), this stores only the differences from
21/// the reference as a sorted set of variants in a [`COITree`] for efficient
22/// range queries.
23///
24/// Fragment extraction works by walking the reference sequence and
25/// substituting alt alleles at variant positions on the fly.
26///
27/// Because coitrees requires `Copy + Default` for metadata, we store variant
28/// indices (as `u32`) in the tree and keep the actual variant data in a
29/// separate `Vec`.
30pub struct Haplotype {
31    /// 0-based haplotype allele index (e.g., 0 or 1 for diploid).
32    allele_index: usize,
33    /// Variant data, indexed by position in this vec.
34    variant_data: Vec<HaplotypeVariant>,
35    /// Interval tree mapping genomic ranges to indices in `variant_data`.
36    variant_tree: COITree<u32, u32>,
37}
38
39impl Haplotype {
40    /// Return the allele index of this haplotype.
41    #[must_use]
42    pub fn allele_index(&self) -> usize {
43        self.allele_index
44    }
45
46    /// Extract a fragment from this haplotype at the given reference
47    /// coordinates.
48    ///
49    /// Walks the reference from `ref_start` and produces `fragment_len` bases,
50    /// substituting alternate alleles where this haplotype has variants.
51    /// Returns the fragment bases and a list of reference positions
52    /// corresponding to each fragment base (for golden BAM coordinate mapping).
53    ///
54    /// # Arguments
55    /// * `reference` — Full reference sequence for this contig.
56    /// * `ref_start` — 0-based start position on the reference.
57    /// * `fragment_len` — Desired number of output bases.
58    ///
59    /// # Returns
60    /// A tuple of `(fragment_bases, ref_positions)` where `ref_positions[i]`
61    /// is the reference position corresponding to `fragment_bases[i]`. For
62    /// inserted bases, the reference position is that of the base preceding
63    /// the insertion.
64    #[must_use]
65    #[expect(clippy::cast_possible_wrap, reason = "genomic coords < i32::MAX")]
66    pub fn extract_fragment(
67        &self,
68        reference: &[u8],
69        ref_start: u32,
70        fragment_len: usize,
71    ) -> (Vec<u8>, Vec<u32>) {
72        let mut bases = Vec::with_capacity(fragment_len);
73        let mut ref_positions = Vec::with_capacity(fragment_len);
74
75        // Collect overlapping variant indices. We query exactly the range we
76        // need: [ref_start, ref_start + fragment_len). Deletions whose ref
77        // allele extends past our window are still caught because the tree
78        // stores them with their full ref allele span.
79        let query_end = (ref_start as usize + fragment_len).min(reference.len());
80        #[expect(
81            clippy::cast_possible_truncation,
82            reason = "query_end bounded by reference.len() which fits i32"
83        )]
84        let query_end_i32 = query_end.saturating_sub(1) as i32;
85        let mut overlapping_indices: Vec<u32> = Vec::new();
86        self.variant_tree.query(ref_start as i32, query_end_i32, |node| {
87            // clone() rather than *deref because coitrees' query callback
88            // yields &T on NEON (ARM) but T directly on nosimd (x86).
89            #[allow(clippy::clone_on_copy)]
90            overlapping_indices.push(node.metadata.clone());
91        });
92
93        // Sort variants by position for sequential processing.
94        overlapping_indices.sort_unstable_by_key(|&idx| self.variant_data[idx as usize].ref_pos);
95
96        let mut ref_pos = ref_start as usize;
97        let mut var_idx = 0;
98
99        // Advance past variants entirely before our window, and handle
100        // variants that start before ref_start but span into it (e.g., a
101        // deletion that started upstream).
102        while var_idx < overlapping_indices.len() {
103            let var = &self.variant_data[overlapping_indices[var_idx] as usize];
104            let var_end = var.ref_pos as usize + var.ref_len as usize;
105            if var_end <= ref_pos {
106                // Variant is entirely before our window — skip it.
107                var_idx += 1;
108            } else if (var.ref_pos as usize) < ref_pos {
109                // Variant starts before our window but spans into it.
110                // For a deletion, the ref bases are consumed — skip past them.
111                // For an insertion at this position, the alt bases were already
112                // partially consumed upstream, so we skip the variant entirely.
113                ref_pos = var_end;
114                var_idx += 1;
115            } else {
116                break;
117            }
118        }
119
120        while bases.len() < fragment_len && ref_pos < reference.len() {
121            // Check if the current reference position is a variant start.
122            if var_idx < overlapping_indices.len() {
123                let var = &self.variant_data[overlapping_indices[var_idx] as usize];
124                if var.ref_pos as usize == ref_pos {
125                    // Emit alt allele bases.
126                    for &b in &var.alt_bases {
127                        if bases.len() >= fragment_len {
128                            break;
129                        }
130                        bases.push(b);
131                        #[expect(
132                            clippy::cast_possible_truncation,
133                            reason = "ref positions fit in u32"
134                        )]
135                        ref_positions.push(ref_pos as u32);
136                    }
137                    // Skip over reference allele bases.
138                    ref_pos += var.ref_len as usize;
139                    var_idx += 1;
140                    continue;
141                }
142            }
143
144            // Emit reference base.
145            bases.push(reference[ref_pos]);
146            #[expect(clippy::cast_possible_truncation, reason = "ref positions fit in u32")]
147            ref_positions.push(ref_pos as u32);
148            ref_pos += 1;
149        }
150
151        // Truncate to exact fragment length (alt alleles may have added extra).
152        bases.truncate(fragment_len);
153        ref_positions.truncate(fragment_len);
154
155        (bases, ref_positions)
156    }
157}
158
159/// Build haplotypes from a set of variant records for one contig.
160///
161/// For each allele index up to `max_ploidy`, constructs a [`Haplotype`]
162/// containing only the variants assigned to that allele.
163///
164/// Phased genotypes assign alleles deterministically. Unphased genotypes
165/// assign non-reference alleles to haplotypes using the provided RNG.
166///
167/// # Arguments
168/// * `variants` — Sorted variant records for this contig.
169/// * `max_ploidy` — Maximum ploidy across all variants (e.g. 2 for diploid).
170/// * `rng` — Random number generator for unphased genotype assignment.
171pub fn build_haplotypes(
172    variants: &[VariantRecord],
173    max_ploidy: usize,
174    rng: &mut impl rand::Rng,
175) -> Vec<Haplotype> {
176    // Collect variant data and tree intervals per haplotype.
177    let mut variant_data_per_hap: Vec<Vec<HaplotypeVariant>> =
178        (0..max_ploidy).map(|_| Vec::new()).collect();
179    let mut intervals_per_hap: Vec<Vec<Interval<u32>>> =
180        (0..max_ploidy).map(|_| Vec::new()).collect();
181
182    for vr in variants {
183        let gt = &vr.genotype;
184
185        // For unphased genotypes, generate a random permutation mapping
186        // allele indices to haplotype indices. This avoids artificial
187        // phasing of nearby variants while ensuring each allele goes to
188        // exactly one haplotype (unlike independent random draws, which
189        // would incorrectly place both alleles of a hom-alt on the same
190        // haplotype 25% of the time for diploid).
191        let hap_permutation: Vec<usize> = if gt.is_phased() {
192            (0..max_ploidy).collect()
193        } else {
194            let mut perm: Vec<usize> = (0..max_ploidy).collect();
195            // Fisher-Yates shuffle.
196            for i in (1..perm.len()).rev() {
197                let j = rng.random_range(0..=i);
198                perm.swap(i, j);
199            }
200            perm
201        };
202
203        for (allele_idx, allele) in gt.alleles().iter().enumerate() {
204            if allele_idx >= max_ploidy {
205                break;
206            }
207
208            let Some(allele_num) = allele else { continue };
209            if *allele_num == 0 {
210                continue;
211            }
212
213            let Some(alt_bases) = vr.allele_bases(*allele_num) else {
214                continue;
215            };
216
217            let hap_var = HaplotypeVariant {
218                ref_pos: vr.position,
219                #[expect(clippy::cast_possible_truncation, reason = "ref allele < 4 GB")]
220                ref_len: vr.ref_allele.len() as u32,
221                alt_bases: alt_bases.to_vec(),
222            };
223
224            let target_hap = hap_permutation[allele_idx];
225
226            // Store the variant and create a tree interval pointing to it.
227            let data_idx = variant_data_per_hap[target_hap].len();
228            variant_data_per_hap[target_hap].push(hap_var);
229
230            let end_pos = (vr.position as usize + vr.ref_allele.len()).saturating_sub(1);
231            #[expect(
232                clippy::cast_possible_wrap,
233                reason = "genomic coords and variant index < i32::MAX / u32::MAX"
234            )]
235            #[expect(
236                clippy::cast_possible_truncation,
237                reason = "genomic coords and variant index < i32::MAX / u32::MAX"
238            )]
239            let iv = Interval::new(vr.position as i32, end_pos as i32, data_idx as u32);
240            intervals_per_hap[target_hap].push(iv);
241        }
242    }
243
244    variant_data_per_hap
245        .into_iter()
246        .zip(intervals_per_hap)
247        .enumerate()
248        .map(|(i, (data, ivs))| Haplotype {
249            allele_index: i,
250            variant_data: data,
251            variant_tree: COITree::new(&ivs),
252        })
253        .collect()
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259    use crate::vcf::genotype::Genotype;
260
261    /// Build a simple SNP variant record.
262    fn snp(pos: u32, ref_base: u8, alt_base: u8, gt: &str) -> VariantRecord {
263        VariantRecord {
264            position: pos,
265            ref_allele: vec![ref_base],
266            alt_alleles: vec![vec![alt_base]],
267            genotype: Genotype::parse(gt).unwrap(),
268        }
269    }
270
271    /// Build an indel variant record.
272    fn indel(pos: u32, ref_allele: &[u8], alt_allele: &[u8], gt: &str) -> VariantRecord {
273        VariantRecord {
274            position: pos,
275            ref_allele: ref_allele.to_vec(),
276            alt_alleles: vec![alt_allele.to_vec()],
277            genotype: Genotype::parse(gt).unwrap(),
278        }
279    }
280
281    #[test]
282    fn test_extract_fragment_no_variants() {
283        let reference = b"ACGTACGTACGT";
284        let haps = build_haplotypes(&[], 2, &mut rand::rng());
285        assert_eq!(haps.len(), 2);
286
287        let (bases, positions) = haps[0].extract_fragment(reference, 2, 5);
288        assert_eq!(&bases, b"GTACG");
289        assert_eq!(&positions, &[2, 3, 4, 5, 6]);
290    }
291
292    #[test]
293    fn test_extract_fragment_with_snp() {
294        let reference = b"AAAAAAAA";
295        let variants = vec![snp(3, b'A', b'T', "0|1")];
296        let haps = build_haplotypes(&variants, 2, &mut rand::rng());
297
298        // Haplotype 0 should have reference (allele 0).
299        let (bases, _) = haps[0].extract_fragment(reference, 0, 8);
300        assert_eq!(&bases, b"AAAAAAAA");
301
302        // Haplotype 1 should have the SNP (allele 1).
303        let (bases, _) = haps[1].extract_fragment(reference, 0, 8);
304        assert_eq!(&bases, b"AAATAAAA");
305    }
306
307    #[test]
308    fn test_extract_fragment_with_insertion() {
309        let reference = b"AAAAAAAA";
310        // Insertion: A -> ATT at position 3 (ref allele len 1, alt len 3).
311        let variants = vec![indel(3, b"A", b"ATT", "0|1")];
312        let haps = build_haplotypes(&variants, 2, &mut rand::rng());
313
314        // Haplotype 1 has the insertion.
315        let (bases, positions) = haps[1].extract_fragment(reference, 0, 10);
316        assert_eq!(&bases, b"AAAATTAAAA");
317        // Inserted bases all map back to the ref position of the variant (3).
318        assert_eq!(&positions, &[0, 1, 2, 3, 3, 3, 4, 5, 6, 7]);
319    }
320
321    #[test]
322    fn test_extract_fragment_with_deletion() {
323        let reference = b"ACGTACGTAC";
324        // Deletion: ACG -> A at position 4. Replaces 3 ref bases (ACG at
325        // positions 4-6) with 1 alt base (A).
326        let variants = vec![indel(4, b"ACG", b"A", "0|1")];
327        let haps = build_haplotypes(&variants, 2, &mut rand::rng());
328
329        let (bases, _) = haps[1].extract_fragment(reference, 0, 8);
330        assert_eq!(&bases, b"ACGTATAC");
331    }
332
333    #[test]
334    fn test_hom_alt_both_haplotypes_affected() {
335        let reference = b"AAAA";
336        let variants = vec![snp(1, b'A', b'T', "1|1")];
337        let haps = build_haplotypes(&variants, 2, &mut rand::rng());
338
339        let (bases0, _) = haps[0].extract_fragment(reference, 0, 4);
340        let (bases1, _) = haps[1].extract_fragment(reference, 0, 4);
341        assert_eq!(&bases0, b"ATAA");
342        assert_eq!(&bases1, b"ATAA");
343    }
344
345    #[test]
346    fn test_phased_allele_assignment() {
347        let reference = b"AAAA";
348        // Phased 1|0: alt on haplotype 0, ref on haplotype 1.
349        let variants = vec![snp(1, b'A', b'T', "1|0")];
350        let haps = build_haplotypes(&variants, 2, &mut rand::rng());
351
352        let (bases0, _) = haps[0].extract_fragment(reference, 0, 4);
353        let (bases1, _) = haps[1].extract_fragment(reference, 0, 4);
354        assert_eq!(&bases0, b"ATAA");
355        assert_eq!(&bases1, b"AAAA");
356    }
357
358    #[test]
359    fn test_fragment_starts_mid_reference() {
360        let reference = b"ACGTACGTAC";
361        let variants = vec![snp(5, b'C', b'T', "0|1")];
362        let haps = build_haplotypes(&variants, 2, &mut rand::rng());
363
364        // Fragment starting at position 3, length 5: covers pos 3-7.
365        let (bases, positions) = haps[1].extract_fragment(reference, 3, 5);
366        assert_eq!(&bases, b"TATGT");
367        assert_eq!(&positions, &[3, 4, 5, 6, 7]);
368    }
369
370    #[test]
371    fn test_fragment_starts_within_deletion() {
372        // Reference: ACGTACGTAC (positions 0-9)
373        // Deletion at pos 2: GTA (3 bases) -> G (1 base)
374        // After variant: AC + G + CGTAC = ACGCGTAC
375        let reference = b"ACGTACGTAC";
376        let variants = vec![indel(2, b"GTA", b"G", "0|1")];
377        let haps = build_haplotypes(&variants, 2, &mut rand::rng());
378
379        // Fragment starting at position 3 (mid-deletion). The deletion
380        // consumes ref positions 2-4, so starting at 3 means we're inside
381        // the deletion. The pre-loop handler should skip past the deletion
382        // end (position 5) and continue from there.
383        let (bases, _) = haps[1].extract_fragment(reference, 3, 5);
384        assert_eq!(&bases, b"CGTAC");
385    }
386
387    #[test]
388    fn test_adjacent_variants() {
389        // Two adjacent SNPs with no reference gap between them.
390        let reference = b"AAAA";
391        let variants = vec![snp(1, b'A', b'T', "0|1"), snp(2, b'A', b'C', "0|1")];
392        let haps = build_haplotypes(&variants, 2, &mut rand::rng());
393
394        let (bases, _) = haps[1].extract_fragment(reference, 0, 4);
395        assert_eq!(&bases, b"ATCA");
396    }
397
398    #[test]
399    fn test_variant_at_position_zero() {
400        let reference = b"ACGT";
401        let variants = vec![snp(0, b'A', b'T', "0|1")];
402        let haps = build_haplotypes(&variants, 2, &mut rand::rng());
403
404        let (bases, _) = haps[1].extract_fragment(reference, 0, 4);
405        assert_eq!(&bases, b"TCGT");
406    }
407
408    #[test]
409    fn test_unphased_hom_alt_both_haplotypes() {
410        // Unphased hom-alt must place the alt on both haplotypes.
411        let reference = b"AAAA";
412        let variants = vec![snp(1, b'A', b'T', "1/1")];
413        let haps = build_haplotypes(&variants, 2, &mut rand::rng());
414
415        let (bases0, _) = haps[0].extract_fragment(reference, 0, 4);
416        let (bases1, _) = haps[1].extract_fragment(reference, 0, 4);
417        assert_eq!(&bases0, b"ATAA");
418        assert_eq!(&bases1, b"ATAA");
419    }
420}