1use rsomics_pgen::Pgen;
12use std::io::{self, Write};
13
14pub struct TdtRecord {
15 pub chrom: String,
16 pub snp: String,
17 pub bp: u64,
18 pub a1: String,
19 pub a2: String,
20 pub t: u32,
21 pub u: u32,
22}
23
24fn build_transmit_lut() -> [u8; 64] {
30 let mut lut = [0u8; 64];
31 for dad in 0..4u8 {
32 for mom in 0..4u8 {
33 for child in 0..4u8 {
34 let (t, u) = transmit(dad, mom, child);
35 lut[((dad << 4) | (mom << 2) | child) as usize] = (t << 4) | u;
36 }
37 }
38 }
39 lut
40}
41
42fn called(code: u8) -> bool {
44 matches!(code, 0 | 2 | 3)
45}
46
47fn transmit(dad: u8, mom: u8, child: u8) -> (u8, u8) {
48 if !called(dad) || !called(mom) || !called(child) {
49 return (0, 0);
50 }
51 let child_a1 = match child {
52 0 => 2u8,
53 2 => 1,
54 _ => 0,
55 };
56 let dad_alleles = parent_alleles(dad);
59 let mom_alleles = parent_alleles(mom);
60 let mut sol = None;
61 for &da in dad_alleles {
62 for &ma in mom_alleles {
63 if da + ma == child_a1 {
64 sol = Some((da, ma));
65 }
66 }
67 }
68 let Some((da, ma)) = sol else { return (0, 0) };
69 let (mut t, mut u) = (0u8, 0u8);
70 if dad == 2 {
71 if da == 1 { t += 1 } else { u += 1 }
72 }
73 if mom == 2 {
74 if ma == 1 { t += 1 } else { u += 1 }
75 }
76 (t, u)
77}
78
79fn parent_alleles(code: u8) -> &'static [u8] {
81 match code {
82 0 => &[1],
83 2 => &[1, 0],
84 3 => &[0],
85 _ => &[],
86 }
87}
88
89struct Trio {
90 dad: usize,
91 mom: usize,
92 child: usize,
93}
94
95fn trios(pgen: &Pgen) -> Vec<Trio> {
97 use std::collections::HashMap;
98 let mut by_key: HashMap<(&str, &str), usize> = HashMap::new();
99 for (i, s) in pgen.samples.iter().enumerate() {
100 by_key.insert((s.fid.as_str(), s.iid.as_str()), i);
101 }
102 pgen.samples
103 .iter()
104 .enumerate()
105 .filter(|(_, s)| s.phen == "2" && s.pid != "0" && s.mid != "0")
106 .filter_map(|(child, s)| {
107 let dad = *by_key.get(&(s.fid.as_str(), s.pid.as_str()))?;
108 let mom = *by_key.get(&(s.fid.as_str(), s.mid.as_str()))?;
109 Some(Trio { dad, mom, child })
110 })
111 .collect()
112}
113
114fn founder_mask(pgen: &Pgen) -> Vec<bool> {
115 pgen.samples
116 .iter()
117 .map(|s| s.pid == "0" && s.mid == "0")
118 .collect()
119}
120
121#[inline]
122fn code_at(row: &[u8], s: usize) -> u8 {
123 (row[s / 4] >> ((s % 4) * 2)) & 0b11
124}
125
126#[inline]
128fn dosage_diff(code: u8) -> i32 {
129 match code {
130 0 => 2,
131 3 => -2,
132 _ => 0,
133 }
134}
135
136fn tdt_tested(chrom: &str) -> bool {
139 !matches!(report_chrom(chrom).as_str(), "24" | "26")
140}
141
142struct TrioWork {
146 dad: usize,
147 mom: usize,
148 child: usize,
149 count_dad: bool,
150 count_mom: bool,
151}
152
153#[must_use]
154pub fn tdt(pgen: &Pgen) -> Vec<TdtRecord> {
155 use rayon::prelude::*;
156 let lut = build_transmit_lut();
157 let triples = trios(pgen);
158
159 let founders = founder_mask(pgen);
163 let mut seen = vec![false; pgen.n_samples()];
164 let work: Vec<TrioWork> = triples
165 .iter()
166 .map(|t| {
167 let count_dad = !std::mem::replace(&mut seen[t.dad], true);
168 let count_mom = !std::mem::replace(&mut seen[t.mom], true);
169 TrioWork {
170 dad: t.dad,
171 mom: t.mom,
172 child: t.child,
173 count_dad,
174 count_mom,
175 }
176 })
177 .collect();
178 let other_founders: Vec<u32> = (0..pgen.n_samples())
179 .filter(|&s| founders[s] && !seen[s])
180 .map(|s| s as u32)
181 .collect();
182
183 let bpv = pgen.bytes_per_variant();
184 let gt = &pgen.gt_raw;
185
186 (0..pgen.n_variants())
187 .into_par_iter()
188 .filter(|&v| tdt_tested(&pgen.variants[v].chrom))
189 .map(|v| {
190 let row = >[v * bpv..v * bpv + bpv];
191 let (mut t_a1, mut u_a1) = (0u32, 0u32);
192 let mut diff = 0i32;
193 for w in &work {
194 let dad = code_at(row, w.dad);
195 let mom = code_at(row, w.mom);
196 let key = (dad << 4) | (mom << 2) | code_at(row, w.child);
197 let packed = lut[key as usize];
198 t_a1 += u32::from(packed >> 4);
199 u_a1 += u32::from(packed & 0x0f);
200 if w.count_dad {
201 diff += dosage_diff(dad);
202 }
203 if w.count_mom {
204 diff += dosage_diff(mom);
205 }
206 }
207 for &s in &other_founders {
208 diff += dosage_diff(code_at(row, s as usize));
209 }
210 let var = &pgen.variants[v];
211 let (a1, a2, t, u) = if diff <= 0 {
212 (&var.a1, &var.a2, t_a1, u_a1)
213 } else {
214 (&var.a2, &var.a1, u_a1, t_a1)
215 };
216 TdtRecord {
217 chrom: var.chrom.clone(),
218 snp: var.id.clone(),
219 bp: var.pos,
220 a1: a1.clone(),
221 a2: a2.clone(),
222 t,
223 u,
224 }
225 })
226 .collect()
227}
228
229fn report_chrom(chrom: &str) -> String {
231 match chrom {
232 "X" | "x" => "23".to_string(),
233 "Y" | "y" => "24".to_string(),
234 "XY" | "xy" => "25".to_string(),
235 "MT" | "mt" | "M" | "m" => "26".to_string(),
236 other => other.to_string(),
237 }
238}
239
240struct Widths {
241 chr: usize,
242 snp: usize,
243 a1: usize,
244 a2: usize,
245}
246
247impl Widths {
248 fn measure(records: &[TdtRecord]) -> Self {
249 let mut chr = 0;
250 let mut snp = 0;
251 let mut a1 = 0;
252 let mut a2 = 0;
253 for r in records {
254 chr = chr.max(report_chrom(&r.chrom).len());
255 snp = snp.max(r.snp.len());
256 a1 = a1.max(r.a1.len());
257 a2 = a2.max(r.a2.len());
258 }
259 Self {
260 chr: chr.max(2) + 2,
261 snp: if snp < 5 { 5 } else { snp + 3 },
262 a1: a1.max(2) + 2,
263 a2: a2.max(2) + 2,
264 }
265 }
266}
267
268pub fn write_tdt<W: Write>(records: &[TdtRecord], out: &mut W) -> io::Result<()> {
270 let w = Widths::measure(records);
271 writeln!(
272 out,
273 "{:>cw$}{:>sw$}{:>13}{:>a1$}{:>a2$}{:>7}{:>7}{:>13}{:>13}{:>13} ",
274 "CHR",
275 "SNP",
276 "BP",
277 "A1",
278 "A2",
279 "T",
280 "U",
281 "OR",
282 "CHISQ",
283 "P",
284 cw = w.chr,
285 sw = w.snp,
286 a1 = w.a1,
287 a2 = w.a2,
288 )?;
289 for r in records {
290 let (or, chisq, p) = stats(r.t, r.u);
291 writeln!(
292 out,
293 "{:>cw$}{:>sw$}{:>13}{:>a1$}{:>a2$}{:>7}{:>7}{:>13}{:>13}{:>13} ",
294 report_chrom(&r.chrom),
295 r.snp,
296 r.bp,
297 r.a1,
298 r.a2,
299 r.t,
300 r.u,
301 or,
302 chisq,
303 p,
304 cw = w.chr,
305 sw = w.snp,
306 a1 = w.a1,
307 a2 = w.a2,
308 )?;
309 }
310 Ok(())
311}
312
313fn stats(t: u32, u: u32) -> (String, String, String) {
315 let n = t + u;
316 if n == 0 {
317 return ("NA".into(), "NA".into(), "NA".into());
318 }
319 let or = if u == 0 {
320 "NA".to_string()
321 } else {
322 fmt_g(f64::from(t) / f64::from(u))
323 };
324 let diff = f64::from(t) - f64::from(u);
325 let chisq = diff * diff / f64::from(n);
326 let p = chisq_1df_sf(chisq);
327 (or, fmt_g(chisq), fmt_g(p))
328}
329
330fn chisq_1df_sf(x: f64) -> f64 {
333 if x <= 0.0 {
334 return 1.0;
335 }
336 gamma_q(0.5, x / 2.0)
337}
338
339fn ln_gamma(z: f64) -> f64 {
341 const C: [f64; 6] = [
342 76.180_091_729_471_46,
343 -86.505_320_329_416_77,
344 24.014_098_240_830_91,
345 -1.231_739_572_450_155,
346 0.001_208_650_973_866_179,
347 -0.000_005_395_239_384_953,
348 ];
349 let mut x = z;
350 let mut tmp = z + 5.5;
351 tmp -= (z + 0.5) * tmp.ln();
352 let mut ser = 1.000_000_000_190_015;
353 for c in C {
354 x += 1.0;
355 ser += c / x;
356 }
357 -tmp + (2.506_628_274_631_000_5 * ser / z).ln()
358}
359
360fn gamma_q(a: f64, x: f64) -> f64 {
362 if x < a + 1.0 {
363 1.0 - gamma_p_series(a, x)
364 } else {
365 gamma_q_cf(a, x)
366 }
367}
368
369fn gamma_p_series(a: f64, x: f64) -> f64 {
370 let gln = ln_gamma(a);
371 let mut ap = a;
372 let mut sum = 1.0 / a;
373 let mut del = sum;
374 for _ in 0..400 {
375 ap += 1.0;
376 del *= x / ap;
377 sum += del;
378 if del.abs() < sum.abs() * 1e-16 {
379 break;
380 }
381 }
382 sum * (-x + a * x.ln() - gln).exp()
383}
384
385fn gamma_q_cf(a: f64, x: f64) -> f64 {
386 const TINY: f64 = 1e-300;
387 let gln = ln_gamma(a);
388 let mut b = x + 1.0 - a;
389 let mut c = 1.0 / TINY;
390 let mut d = 1.0 / b;
391 let mut h = d;
392 for i in 1..400 {
393 let an = -(i as f64) * (i as f64 - a);
394 b += 2.0;
395 d = an * d + b;
396 if d.abs() < TINY {
397 d = TINY;
398 }
399 c = b + an / c;
400 if c.abs() < TINY {
401 c = TINY;
402 }
403 d = 1.0 / d;
404 let del = d * c;
405 h *= del;
406 if (del - 1.0).abs() < 1e-16 {
407 break;
408 }
409 }
410 (-x + a * x.ln() - gln).exp() * h
411}
412
413fn fmt_g(x: f64) -> String {
420 const SIG: usize = 4;
421 if x.is_nan() {
422 return "nan".to_string();
423 }
424 if x == 0.0 {
425 return "0".to_string();
426 }
427 let neg = x < 0.0;
428 let (digits, lead_exp) = shortest_decimal(x.abs());
431 let (digits, exp) = round_sig_half_even(&digits, lead_exp, SIG);
432
433 let mut s = if !(-4..SIG as i32).contains(&exp) {
434 let mant = mantissa(&digits, 1);
435 format!("{mant}e{}{:02}", if exp < 0 { '-' } else { '+' }, exp.abs())
436 } else if exp >= 0 {
437 mantissa(&digits, (exp + 1) as usize)
438 } else {
439 let zeros = "0".repeat((-exp - 1) as usize);
440 strip_trailing(&format!("0.{zeros}{digits}"))
441 };
442 if neg {
443 s.insert(0, '-');
444 }
445 s
446}
447
448fn shortest_decimal(x: f64) -> (String, i32) {
452 let sci = format!("{:e}", x); let (mant, e) = sci.split_once('e').unwrap();
454 let exp: i32 = e.parse().unwrap();
455 let digits: String = mant.chars().filter(|c| c.is_ascii_digit()).collect();
456 (digits, exp)
457}
458
459fn round_sig_half_even(digits: &str, lead_exp: i32, sig: usize) -> (String, i32) {
463 let bytes: Vec<u8> = digits.bytes().map(|b| b - b'0').collect();
464 if bytes.len() <= sig {
465 let mut d: Vec<u8> = bytes;
466 while d.len() > 1 && *d.last().unwrap() == 0 {
467 d.pop();
468 }
469 return (d.iter().map(|&b| (b + b'0') as char).collect(), lead_exp);
470 }
471 let mut kept: Vec<u8> = bytes[..sig].to_vec();
472 let next = bytes[sig];
473 let rest_nonzero = bytes[sig + 1..].iter().any(|&b| b != 0);
474 let round_up = next > 5 || (next == 5 && (rest_nonzero || kept[sig - 1] % 2 == 1));
475 let mut lead = lead_exp;
476 if round_up {
477 let mut i = sig;
478 loop {
479 if i == 0 {
480 kept.insert(0, 1);
481 lead += 1;
482 kept.pop();
483 break;
484 }
485 i -= 1;
486 if kept[i] == 9 {
487 kept[i] = 0;
488 } else {
489 kept[i] += 1;
490 break;
491 }
492 }
493 }
494 while kept.len() > 1 && *kept.last().unwrap() == 0 {
495 kept.pop();
496 }
497 (kept.iter().map(|&b| (b + b'0') as char).collect(), lead)
498}
499
500fn mantissa(digits: &str, int_len: usize) -> String {
503 let padded = if digits.len() < int_len {
504 format!("{digits}{}", "0".repeat(int_len - digits.len()))
505 } else {
506 digits.to_string()
507 };
508 if padded.len() <= int_len {
509 padded
510 } else {
511 strip_trailing(&format!("{}.{}", &padded[..int_len], &padded[int_len..]))
512 }
513}
514
515fn strip_trailing(s: &str) -> String {
516 if s.contains('.') {
517 s.trim_end_matches('0').trim_end_matches('.').to_string()
518 } else {
519 s.to_string()
520 }
521}
522
523#[cfg(test)]
524mod tests {
525 use super::*;
526
527 #[test]
528 fn transmit_table_matches_verified_combos() {
529 assert_eq!(transmit(0, 2, 0), (1, 0)); assert_eq!(transmit(0, 2, 2), (0, 1)); assert_eq!(transmit(2, 2, 2), (1, 1)); assert_eq!(transmit(2, 2, 0), (2, 0)); assert_eq!(transmit(2, 2, 3), (0, 2)); assert_eq!(transmit(0, 0, 2), (0, 0)); assert_eq!(transmit(1, 2, 2), (0, 0)); assert_eq!(transmit(2, 3, 2), (1, 0));
539 }
540
541 #[test]
542 fn stats_edge_cases() {
543 assert_eq!(stats(0, 0), ("NA".into(), "NA".into(), "NA".into()));
544 let (or, chisq, _) = stats(3, 0);
545 assert_eq!(or, "NA");
546 assert_eq!(chisq, "3");
547 let (or, _, _) = stats(0, 2);
548 assert_eq!(or, "0");
549 }
550
551 #[test]
552 fn g_formatting_matches_plink() {
553 assert_eq!(fmt_g(1.273), "1.273");
554 assert_eq!(fmt_g(0.8868), "0.8868");
555 assert_eq!(fmt_g(0.1573), "0.1573");
556 assert_eq!(fmt_g(0.0), "0");
557 assert_eq!(fmt_g(1.44), "1.44");
558 assert_eq!(fmt_g(3.0), "3");
559 assert_eq!(fmt_g(0.0006871), "0.0006871");
560 }
561
562 #[test]
563 fn g_half_ties_round_to_even() {
564 assert_eq!(fmt_g(894.0 / 960.0), "0.9312");
566 assert_eq!(fmt_g(473.0 / 400.0), "1.182");
567 assert_eq!(fmt_g(696.0 / 640.0), "1.088");
568 assert_eq!(fmt_g(763.0 / 800.0), "0.9538");
569 assert_eq!(fmt_g(431.0 / 400.0), "1.078");
570 assert_eq!(fmt_g(0.91875), "0.9188");
571 }
572
573 #[test]
574 fn p_value_matches_plink() {
575 assert_eq!(fmt_g(chisq_1df_sf(2.0)), "0.1573");
577 assert_eq!(fmt_g(chisq_1df_sf(5.76)), "0.0164");
578 assert_eq!(fmt_g(chisq_1df_sf(0.0)), "1");
579 assert_eq!(fmt_g(chisq_1df_sf(3.0)), "0.08326");
580 }
581}