Skip to main content

rsomics_plink_tdt/
lib.rs

1//! PLINK1 `--tdt`: transmission disequilibrium test over complete trios.
2//!
3//! For each variant, over affected offspring whose two parents are both
4//! genotyped and Mendel-consistent, count the minor allele transmitted (T) and
5//! untransmitted (U) by heterozygous parents. The statistic is McNemar's:
6//!   CHISQ = (T-U)² / (T+U),  1 df,  OR = T/U.
7//!
8//! A1 is the minor allele by founder allele frequency (PLINK's convention),
9//! which fixes the T/U orientation and the A1/A2 labels.
10
11use rsomics_pgen::Pgen;
12use std::io::{self, Write};
13
14pub struct TdtRecord {
15    pub chrom: String,
16    pub snp: String,
17    pub bp: u64,
18    pub a1: String,
19    pub a2: String,
20    pub t: u32,
21    pub u: u32,
22}
23
24/// Flat 64-entry table keyed by the 6-bit `(dad<<4)|(mom<<2)|child` genotype
25/// triple. Each entry packs the transmitted (high nibble) and untransmitted
26/// (low nibble) bim-A1 count; both are at most 2. Codes are PLINK 2-bit:
27/// 0=HomA1, 1=Missing, 2=Het, 3=HomA2 — missing or Mendel-inconsistent trios
28/// score zero.
29fn build_transmit_lut() -> [u8; 64] {
30    let mut lut = [0u8; 64];
31    for dad in 0..4u8 {
32        for mom in 0..4u8 {
33            for child in 0..4u8 {
34                let (t, u) = transmit(dad, mom, child);
35                lut[((dad << 4) | (mom << 2) | child) as usize] = (t << 4) | u;
36            }
37        }
38    }
39    lut
40}
41
42/// A genotype code is a present (non-missing) diploid call.
43fn called(code: u8) -> bool {
44    matches!(code, 0 | 2 | 3)
45}
46
47fn transmit(dad: u8, mom: u8, child: u8) -> (u8, u8) {
48    if !called(dad) || !called(mom) || !called(child) {
49        return (0, 0);
50    }
51    let child_a1 = match child {
52        0 => 2u8,
53        2 => 1,
54        _ => 0,
55    };
56    // One ordered (dad allele, mom allele) pair must reconstruct the child;
57    // het parents then transmit a determinable allele. A1 carried as a bit.
58    let dad_alleles = parent_alleles(dad);
59    let mom_alleles = parent_alleles(mom);
60    let mut sol = None;
61    for &da in dad_alleles {
62        for &ma in mom_alleles {
63            if da + ma == child_a1 {
64                sol = Some((da, ma));
65            }
66        }
67    }
68    let Some((da, ma)) = sol else { return (0, 0) };
69    let (mut t, mut u) = (0u8, 0u8);
70    if dad == 2 {
71        if da == 1 { t += 1 } else { u += 1 }
72    }
73    if mom == 2 {
74        if ma == 1 { t += 1 } else { u += 1 }
75    }
76    (t, u)
77}
78
79/// Distinct A1-counts a parent genotype can transmit (1 = A1, 0 = A2).
80fn parent_alleles(code: u8) -> &'static [u8] {
81    match code {
82        0 => &[1],
83        2 => &[1, 0],
84        3 => &[0],
85        _ => &[],
86    }
87}
88
89struct Trio {
90    dad: usize,
91    mom: usize,
92    child: usize,
93}
94
95/// Affected offspring whose parents are both present in the same family.
96fn trios(pgen: &Pgen) -> Vec<Trio> {
97    use std::collections::HashMap;
98    let mut by_key: HashMap<(&str, &str), usize> = HashMap::new();
99    for (i, s) in pgen.samples.iter().enumerate() {
100        by_key.insert((s.fid.as_str(), s.iid.as_str()), i);
101    }
102    pgen.samples
103        .iter()
104        .enumerate()
105        .filter(|(_, s)| s.phen == "2" && s.pid != "0" && s.mid != "0")
106        .filter_map(|(child, s)| {
107            let dad = *by_key.get(&(s.fid.as_str(), s.pid.as_str()))?;
108            let mom = *by_key.get(&(s.fid.as_str(), s.mid.as_str()))?;
109            Some(Trio { dad, mom, child })
110        })
111        .collect()
112}
113
114fn founder_mask(pgen: &Pgen) -> Vec<bool> {
115    pgen.samples
116        .iter()
117        .map(|s| s.pid == "0" && s.mid == "0")
118        .collect()
119}
120
121#[inline]
122fn code_at(row: &[u8], s: usize) -> u8 {
123    (row[s / 4] >> ((s % 4) * 2)) & 0b11
124}
125
126/// Signed A1-minus-A2 founder dosage contribution of one genotype code.
127#[inline]
128fn dosage_diff(code: u8) -> i32 {
129    match code {
130        0 => 2,
131        3 => -2,
132        _ => 0,
133    }
134}
135
136/// PLINK's `--tdt` skips the fully-haploid chromosomes Y and MT; autosomes,
137/// the unplaced chromosome 0, X (23) and XY (25) are all tested.
138fn tdt_tested(chrom: &str) -> bool {
139    !matches!(report_chrom(chrom).as_str(), "24" | "26")
140}
141
142/// A trio with flags marking whether its parents' founder dosage should be
143/// counted here. Each founder is attributed to the first trio that names it, so
144/// summing over trios counts every founder exactly once.
145struct TrioWork {
146    dad: usize,
147    mom: usize,
148    child: usize,
149    count_dad: bool,
150    count_mom: bool,
151}
152
153#[must_use]
154pub fn tdt(pgen: &Pgen) -> Vec<TdtRecord> {
155    use rayon::prelude::*;
156    let lut = build_transmit_lut();
157    let triples = trios(pgen);
158
159    // The minor allele is decided over founders. A trio's parents are founders,
160    // so their dosage is folded into the trio loop — but a founder may parent
161    // several trios, so attribute it to its first trio to count it once.
162    let founders = founder_mask(pgen);
163    let mut seen = vec![false; pgen.n_samples()];
164    let work: Vec<TrioWork> = triples
165        .iter()
166        .map(|t| {
167            let count_dad = !std::mem::replace(&mut seen[t.dad], true);
168            let count_mom = !std::mem::replace(&mut seen[t.mom], true);
169            TrioWork {
170                dad: t.dad,
171                mom: t.mom,
172                child: t.child,
173                count_dad,
174                count_mom,
175            }
176        })
177        .collect();
178    let other_founders: Vec<u32> = (0..pgen.n_samples())
179        .filter(|&s| founders[s] && !seen[s])
180        .map(|s| s as u32)
181        .collect();
182
183    let bpv = pgen.bytes_per_variant();
184    let gt = &pgen.gt_raw;
185
186    (0..pgen.n_variants())
187        .into_par_iter()
188        .filter(|&v| tdt_tested(&pgen.variants[v].chrom))
189        .map(|v| {
190            let row = &gt[v * bpv..v * bpv + bpv];
191            let (mut t_a1, mut u_a1) = (0u32, 0u32);
192            let mut diff = 0i32;
193            for w in &work {
194                let dad = code_at(row, w.dad);
195                let mom = code_at(row, w.mom);
196                let key = (dad << 4) | (mom << 2) | code_at(row, w.child);
197                let packed = lut[key as usize];
198                t_a1 += u32::from(packed >> 4);
199                u_a1 += u32::from(packed & 0x0f);
200                if w.count_dad {
201                    diff += dosage_diff(dad);
202                }
203                if w.count_mom {
204                    diff += dosage_diff(mom);
205                }
206            }
207            for &s in &other_founders {
208                diff += dosage_diff(code_at(row, s as usize));
209            }
210            let var = &pgen.variants[v];
211            let (a1, a2, t, u) = if diff <= 0 {
212                (&var.a1, &var.a2, t_a1, u_a1)
213            } else {
214                (&var.a2, &var.a1, u_a1, t_a1)
215            };
216            TdtRecord {
217                chrom: var.chrom.clone(),
218                snp: var.id.clone(),
219                bp: var.pos,
220                a1: a1.clone(),
221                a2: a2.clone(),
222                t,
223                u,
224            }
225        })
226        .collect()
227}
228
229/// PLINK maps the sex chromosomes and MT onto numeric codes in its reports.
230fn report_chrom(chrom: &str) -> String {
231    match chrom {
232        "X" | "x" => "23".to_string(),
233        "Y" | "y" => "24".to_string(),
234        "XY" | "xy" => "25".to_string(),
235        "MT" | "mt" | "M" | "m" => "26".to_string(),
236        other => other.to_string(),
237    }
238}
239
240struct Widths {
241    chr: usize,
242    snp: usize,
243    a1: usize,
244    a2: usize,
245}
246
247impl Widths {
248    fn measure(records: &[TdtRecord]) -> Self {
249        let mut chr = 0;
250        let mut snp = 0;
251        let mut a1 = 0;
252        let mut a2 = 0;
253        for r in records {
254            chr = chr.max(report_chrom(&r.chrom).len());
255            snp = snp.max(r.snp.len());
256            a1 = a1.max(r.a1.len());
257            a2 = a2.max(r.a2.len());
258        }
259        Self {
260            chr: chr.max(2) + 2,
261            snp: if snp < 5 { 5 } else { snp + 3 },
262            a1: a1.max(2) + 2,
263            a2: a2.max(2) + 2,
264        }
265    }
266}
267
268/// Write the records in PLINK's `.tdt` layout (default, non-poo).
269pub fn write_tdt<W: Write>(records: &[TdtRecord], out: &mut W) -> io::Result<()> {
270    let w = Widths::measure(records);
271    writeln!(
272        out,
273        "{:>cw$}{:>sw$}{:>13}{:>a1$}{:>a2$}{:>7}{:>7}{:>13}{:>13}{:>13} ",
274        "CHR",
275        "SNP",
276        "BP",
277        "A1",
278        "A2",
279        "T",
280        "U",
281        "OR",
282        "CHISQ",
283        "P",
284        cw = w.chr,
285        sw = w.snp,
286        a1 = w.a1,
287        a2 = w.a2,
288    )?;
289    for r in records {
290        let (or, chisq, p) = stats(r.t, r.u);
291        writeln!(
292            out,
293            "{:>cw$}{:>sw$}{:>13}{:>a1$}{:>a2$}{:>7}{:>7}{:>13}{:>13}{:>13}  ",
294            report_chrom(&r.chrom),
295            r.snp,
296            r.bp,
297            r.a1,
298            r.a2,
299            r.t,
300            r.u,
301            or,
302            chisq,
303            p,
304            cw = w.chr,
305            sw = w.snp,
306            a1 = w.a1,
307            a2 = w.a2,
308        )?;
309    }
310    Ok(())
311}
312
313/// OR / CHISQ / P tokens. T+U=0 → all NA; U=0 → OR NA; T=0 → OR 0.
314fn stats(t: u32, u: u32) -> (String, String, String) {
315    let n = t + u;
316    if n == 0 {
317        return ("NA".into(), "NA".into(), "NA".into());
318    }
319    let or = if u == 0 {
320        "NA".to_string()
321    } else {
322        fmt_g(f64::from(t) / f64::from(u))
323    };
324    let diff = f64::from(t) - f64::from(u);
325    let chisq = diff * diff / f64::from(n);
326    let p = chisq_1df_sf(chisq);
327    (or, fmt_g(chisq), fmt_g(p))
328}
329
330/// Upper-tail probability of a 1-df chi-square = the regularised upper
331/// incomplete gamma `Q(1/2, x/2)`, evaluated to full precision.
332fn chisq_1df_sf(x: f64) -> f64 {
333    if x <= 0.0 {
334        return 1.0;
335    }
336    gamma_q(0.5, x / 2.0)
337}
338
339/// ln Γ(z) — Lanczos approximation.
340fn ln_gamma(z: f64) -> f64 {
341    const C: [f64; 6] = [
342        76.180_091_729_471_46,
343        -86.505_320_329_416_77,
344        24.014_098_240_830_91,
345        -1.231_739_572_450_155,
346        0.001_208_650_973_866_179,
347        -0.000_005_395_239_384_953,
348    ];
349    let mut x = z;
350    let mut tmp = z + 5.5;
351    tmp -= (z + 0.5) * tmp.ln();
352    let mut ser = 1.000_000_000_190_015;
353    for c in C {
354        x += 1.0;
355        ser += c / x;
356    }
357    -tmp + (2.506_628_274_631_000_5 * ser / z).ln()
358}
359
360/// Regularised upper incomplete gamma `Q(a, x)` (Numerical Recipes `gser`/`gcf`).
361fn gamma_q(a: f64, x: f64) -> f64 {
362    if x < a + 1.0 {
363        1.0 - gamma_p_series(a, x)
364    } else {
365        gamma_q_cf(a, x)
366    }
367}
368
369fn gamma_p_series(a: f64, x: f64) -> f64 {
370    let gln = ln_gamma(a);
371    let mut ap = a;
372    let mut sum = 1.0 / a;
373    let mut del = sum;
374    for _ in 0..400 {
375        ap += 1.0;
376        del *= x / ap;
377        sum += del;
378        if del.abs() < sum.abs() * 1e-16 {
379            break;
380        }
381    }
382    sum * (-x + a * x.ln() - gln).exp()
383}
384
385fn gamma_q_cf(a: f64, x: f64) -> f64 {
386    const TINY: f64 = 1e-300;
387    let gln = ln_gamma(a);
388    let mut b = x + 1.0 - a;
389    let mut c = 1.0 / TINY;
390    let mut d = 1.0 / b;
391    let mut h = d;
392    for i in 1..400 {
393        let an = -(i as f64) * (i as f64 - a);
394        b += 2.0;
395        d = an * d + b;
396        if d.abs() < TINY {
397            d = TINY;
398        }
399        c = b + an / c;
400        if c.abs() < TINY {
401            c = TINY;
402        }
403        d = 1.0 / d;
404        let del = d * c;
405        h *= del;
406        if (del - 1.0).abs() < 1e-16 {
407            break;
408        }
409    }
410    (-x + a * x.ln() - gln).exp() * h
411}
412
413/// PLINK's numeric output format: the shortest round-tripping decimal rounded
414/// to 4 significant figures with round-half-to-even, then `%g`-displayed
415/// (trailing zeros stripped, scientific notation outside the exponent range
416/// -4..4). PLINK's dtoa differs from libc `%g` only at exact half-way ties,
417/// which it breaks toward the even digit; reproducing it needs the shortest
418/// decimal, not the raw binary value.
419fn fmt_g(x: f64) -> String {
420    const SIG: usize = 4;
421    if x.is_nan() {
422        return "nan".to_string();
423    }
424    if x == 0.0 {
425        return "0".to_string();
426    }
427    let neg = x < 0.0;
428    // Shortest round-tripping decimal digits (no sign, no point) and the
429    // power-of-ten exponent of the leading digit.
430    let (digits, lead_exp) = shortest_decimal(x.abs());
431    let (digits, exp) = round_sig_half_even(&digits, lead_exp, SIG);
432
433    let mut s = if !(-4..SIG as i32).contains(&exp) {
434        let mant = mantissa(&digits, 1);
435        format!("{mant}e{}{:02}", if exp < 0 { '-' } else { '+' }, exp.abs())
436    } else if exp >= 0 {
437        mantissa(&digits, (exp + 1) as usize)
438    } else {
439        let zeros = "0".repeat((-exp - 1) as usize);
440        strip_trailing(&format!("0.{zeros}{digits}"))
441    };
442    if neg {
443        s.insert(0, '-');
444    }
445    s
446}
447
448/// Digits of `x`'s shortest round-tripping decimal (Ryū via `{}`), returned as
449/// a digit string with no leading zeros plus the base-ten exponent of the
450/// first digit. `x` must be finite and positive.
451fn shortest_decimal(x: f64) -> (String, i32) {
452    let sci = format!("{:e}", x); // e.g. "9.3125e-1"
453    let (mant, e) = sci.split_once('e').unwrap();
454    let exp: i32 = e.parse().unwrap();
455    let digits: String = mant.chars().filter(|c| c.is_ascii_digit()).collect();
456    (digits, exp)
457}
458
459/// Round a digit string (first digit has power `lead_exp`) to `sig` significant
460/// figures, half-to-even. Returns the rounded digit string and the power of its
461/// leading digit (which may shift on carry).
462fn round_sig_half_even(digits: &str, lead_exp: i32, sig: usize) -> (String, i32) {
463    let bytes: Vec<u8> = digits.bytes().map(|b| b - b'0').collect();
464    if bytes.len() <= sig {
465        let mut d: Vec<u8> = bytes;
466        while d.len() > 1 && *d.last().unwrap() == 0 {
467            d.pop();
468        }
469        return (d.iter().map(|&b| (b + b'0') as char).collect(), lead_exp);
470    }
471    let mut kept: Vec<u8> = bytes[..sig].to_vec();
472    let next = bytes[sig];
473    let rest_nonzero = bytes[sig + 1..].iter().any(|&b| b != 0);
474    let round_up = next > 5 || (next == 5 && (rest_nonzero || kept[sig - 1] % 2 == 1));
475    let mut lead = lead_exp;
476    if round_up {
477        let mut i = sig;
478        loop {
479            if i == 0 {
480                kept.insert(0, 1);
481                lead += 1;
482                kept.pop();
483                break;
484            }
485            i -= 1;
486            if kept[i] == 9 {
487                kept[i] = 0;
488            } else {
489                kept[i] += 1;
490                break;
491            }
492        }
493    }
494    while kept.len() > 1 && *kept.last().unwrap() == 0 {
495        kept.pop();
496    }
497    (kept.iter().map(|&b| (b + b'0') as char).collect(), lead)
498}
499
500/// Format `digits` with `int_len` digits before the decimal point (zero-padding
501/// on the right if short), trailing zeros stripped.
502fn mantissa(digits: &str, int_len: usize) -> String {
503    let padded = if digits.len() < int_len {
504        format!("{digits}{}", "0".repeat(int_len - digits.len()))
505    } else {
506        digits.to_string()
507    };
508    if padded.len() <= int_len {
509        padded
510    } else {
511        strip_trailing(&format!("{}.{}", &padded[..int_len], &padded[int_len..]))
512    }
513}
514
515fn strip_trailing(s: &str) -> String {
516    if s.contains('.') {
517        s.trim_end_matches('0').trim_end_matches('.').to_string()
518    } else {
519        s.to_string()
520    }
521}
522
523#[cfg(test)]
524mod tests {
525    use super::*;
526
527    #[test]
528    fn transmit_table_matches_verified_combos() {
529        // Spot-check entries validated byte-exact against plink 1.9 over all 64
530        // (dad, mom, child) combos. (t_a1, u_a1) where A1 is the bim-A1 allele.
531        assert_eq!(transmit(0, 2, 0), (1, 0)); // mom het, child HomA1
532        assert_eq!(transmit(0, 2, 2), (0, 1)); // mom het, child Het
533        assert_eq!(transmit(2, 2, 2), (1, 1)); // both het, child Het
534        assert_eq!(transmit(2, 2, 0), (2, 0)); // both het, child HomA1
535        assert_eq!(transmit(2, 2, 3), (0, 2)); // both het, child HomA2
536        assert_eq!(transmit(0, 0, 2), (0, 0)); // Mendel-inconsistent
537        assert_eq!(transmit(1, 2, 2), (0, 0)); // missing parent
538        assert_eq!(transmit(2, 3, 2), (1, 0));
539    }
540
541    #[test]
542    fn stats_edge_cases() {
543        assert_eq!(stats(0, 0), ("NA".into(), "NA".into(), "NA".into()));
544        let (or, chisq, _) = stats(3, 0);
545        assert_eq!(or, "NA");
546        assert_eq!(chisq, "3");
547        let (or, _, _) = stats(0, 2);
548        assert_eq!(or, "0");
549    }
550
551    #[test]
552    fn g_formatting_matches_plink() {
553        assert_eq!(fmt_g(1.273), "1.273");
554        assert_eq!(fmt_g(0.8868), "0.8868");
555        assert_eq!(fmt_g(0.1573), "0.1573");
556        assert_eq!(fmt_g(0.0), "0");
557        assert_eq!(fmt_g(1.44), "1.44");
558        assert_eq!(fmt_g(3.0), "3");
559        assert_eq!(fmt_g(0.0006871), "0.0006871");
560    }
561
562    #[test]
563    fn g_half_ties_round_to_even() {
564        // PLINK breaks exact decimal half-ties toward the even digit.
565        assert_eq!(fmt_g(894.0 / 960.0), "0.9312");
566        assert_eq!(fmt_g(473.0 / 400.0), "1.182");
567        assert_eq!(fmt_g(696.0 / 640.0), "1.088");
568        assert_eq!(fmt_g(763.0 / 800.0), "0.9538");
569        assert_eq!(fmt_g(431.0 / 400.0), "1.078");
570        assert_eq!(fmt_g(0.91875), "0.9188");
571    }
572
573    #[test]
574    fn p_value_matches_plink() {
575        // chi2.sf(x, 1) values plink prints.
576        assert_eq!(fmt_g(chisq_1df_sf(2.0)), "0.1573");
577        assert_eq!(fmt_g(chisq_1df_sf(5.76)), "0.0164");
578        assert_eq!(fmt_g(chisq_1df_sf(0.0)), "1");
579        assert_eq!(fmt_g(chisq_1df_sf(3.0)), "0.08326");
580    }
581}