1use anyhow::{Context, Result};
2use indicatif::ProgressBar;
3use jamhash::jamhash_u64;
4use needletail::{Sequence, parse_fastx_file};
5use std::io::{Read, Write};
6use std::path::Path;
7use std::sync::Arc;
8use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
9
10const BIAS_MAGIC: &[u8; 4] = b"BIA3";
11const BIAS_VERSION: u32 = 3;
12
13const DEFAULT_CMS_WIDTH: usize = 1 << 20;
14const DEFAULT_CMS_DEPTH: usize = 5;
15const QUANTIZATION_SCALE: f32 = 10.0;
16const MAX_SAMPLE_HASHES: usize = 100_000;
17
18#[derive(Debug, Clone)]
19pub struct CMSConfig {
20 pub width: usize,
21 pub depth: usize,
22 pub k: u8,
23 pub fscale: u64,
24}
25
26impl Default for CMSConfig {
27 fn default() -> Self {
28 Self {
29 width: DEFAULT_CMS_WIDTH,
30 depth: DEFAULT_CMS_DEPTH,
31 k: 21,
32 fscale: 1000,
33 }
34 }
35}
36
37#[derive(Debug, Clone)]
38pub struct CountMinSketch {
39 width: usize,
40 depth: usize,
41 seeds: Vec<u64>,
42 counts: Vec<u64>,
43}
44
45impl CountMinSketch {
46 pub fn new(width: usize, depth: usize) -> Self {
47 let seeds: Vec<u64> = (0..depth)
48 .map(|i| 0x517cc1b727220a95u64.wrapping_add(i as u64))
49 .collect();
50 let counts = vec![0u64; width * depth];
51 Self {
52 width,
53 depth,
54 seeds,
55 counts,
56 }
57 }
58
59 pub fn with_seeds(width: usize, depth: usize, seeds: Vec<u64>) -> Self {
60 assert_eq!(seeds.len(), depth);
61 let counts = vec![0u64; width * depth];
62 Self {
63 width,
64 depth,
65 seeds,
66 counts,
67 }
68 }
69
70 #[inline]
71 fn index(&self, row: usize, hash: u64) -> usize {
72 let mixed = hash.wrapping_mul(self.seeds[row]);
73 row * self.width + (mixed as usize % self.width)
74 }
75
76 #[inline]
77 pub fn increment(&mut self, hash: u64) {
78 for row in 0..self.depth {
79 let idx = self.index(row, hash);
80 self.counts[idx] = self.counts[idx].saturating_add(1);
81 }
82 }
83
84 #[inline]
85 pub fn estimate(&self, hash: u64) -> u64 {
86 (0..self.depth)
87 .map(|row| self.counts[self.index(row, hash)])
88 .min()
89 .unwrap_or(0)
90 }
91
92 pub fn width(&self) -> usize {
93 self.width
94 }
95 pub fn depth(&self) -> usize {
96 self.depth
97 }
98 pub fn seeds(&self) -> &[u64] {
99 &self.seeds
100 }
101 pub fn counts(&self) -> &[u64] {
102 &self.counts
103 }
104
105 pub fn cell_stats(&self) -> (u64, u64, f64, f64, usize) {
106 let min = *self.counts.iter().min().unwrap_or(&0);
107 let max = *self.counts.iter().max().unwrap_or(&0);
108 let sum: u64 = self.counts.iter().sum();
109 let mean = sum as f64 / self.counts.len() as f64;
110 let variance: f64 = self
111 .counts
112 .iter()
113 .map(|&c| {
114 let d = c as f64 - mean;
115 d * d
116 })
117 .sum::<f64>()
118 / self.counts.len() as f64;
119 let non_zero = self.counts.iter().filter(|&&c| c > 0).count();
120 (min, max, mean, variance.sqrt(), non_zero)
121 }
122}
123
124#[derive(Debug, Clone)]
125pub(crate) struct RawHashCounts {
126 pub(crate) config: CMSConfig,
127 pub(crate) cms: CountMinSketch,
128 pub(crate) total: u64,
129 pub(crate) samples: Vec<u64>,
130}
131
132impl RawHashCounts {
133 pub(crate) fn new(config: CMSConfig) -> Self {
134 let cms = CountMinSketch::new(config.width, config.depth);
135 Self {
136 config,
137 cms,
138 total: 0,
139 samples: Vec::with_capacity(MAX_SAMPLE_HASHES),
140 }
141 }
142
143 pub(crate) fn build(
144 paths: &[&Path],
145 config: CMSConfig,
146 record_counter: &AtomicU64,
147 hash_counter: &AtomicU64,
148 ) -> Result<Self> {
149 let frac_max = u64::MAX / config.fscale;
150 let k = config.k;
151
152 let mut raw = RawHashCounts::new(config);
153 for path in paths {
154 process_path(&mut raw, path, k, frac_max, record_counter, hash_counter)?;
155 }
156
157 if raw.samples.len() > MAX_SAMPLE_HASHES {
158 downsample_samples(&mut raw.samples);
159 }
160
161 Ok(raw)
162 }
163}
164
165fn process_path(
166 raw: &mut RawHashCounts,
167 path: &Path,
168 k: u8,
169 frac_max: u64,
170 record_counter: &AtomicU64,
171 hash_counter: &AtomicU64,
172) -> Result<()> {
173 let mut reader = match parse_fastx_file(path) {
174 Ok(reader) => reader,
175 Err(e) if e.kind == needletail::errors::ParseErrorKind::EmptyFile => {
176 return Ok(());
177 }
178 Err(e) => {
179 return Err(e).with_context(|| format!("Failed to parse: {}", path.display()));
180 }
181 };
182
183 while let Some(record) = reader.next() {
184 let record = record.context("Failed to parse sequence record")?;
185 let seq = record.normalize(false);
186 record_counter.fetch_add(1, Ordering::Relaxed);
187
188 if seq.len() < k as usize {
189 continue;
190 }
191
192 for (_, kmer, _) in seq.bit_kmers(k, true) {
193 let hash = jamhash_u64(kmer.0);
194 if hash < frac_max {
195 raw.cms.increment(hash);
196 raw.total += 1;
197 if raw.samples.len() < MAX_SAMPLE_HASHES {
198 raw.samples.push(hash);
199 } else {
200 let seen = raw.total;
201 let pick = (jamhash_u64(hash ^ seen) % seen) as usize;
202 if pick < MAX_SAMPLE_HASHES {
203 raw.samples[pick] = hash;
204 }
205 }
206 hash_counter.fetch_add(1, Ordering::Relaxed);
207 }
208 }
209 }
210
211 Ok(())
212}
213
214fn downsample_samples(samples: &mut Vec<u64>) {
215 if samples.len() <= MAX_SAMPLE_HASHES {
216 return;
217 }
218 samples.sort_unstable_by_key(|&hash| jamhash_u64(hash));
219 samples.truncate(MAX_SAMPLE_HASHES);
220}
221
222#[derive(Debug, Clone)]
223pub struct BiasCreateConfig {
224 pub cms: CMSConfig,
225 pub alpha: f32,
226 pub target_fold_enrichment: Option<f32>,
227}
228
229#[derive(Debug, Clone, Copy)]
230pub struct CalibrationResult {
231 pub threshold: i8,
232 pub positive_retention: f32,
233 pub negative_retention: f32,
234 pub fold_enrichment: f32,
235 pub max_fold_enrichment: f32,
236}
237
238#[derive(Debug, Clone)]
239pub struct HashBiasTable {
240 pub config: CMSConfig,
241 seeds: Vec<u64>,
242 weights: Vec<i8>,
243 pub alpha: f32,
244 pub threshold: i8,
245 pub positive_retention: f32,
246 pub negative_retention: f32,
247 pub max_fold_enrichment: f32,
248}
249
250fn validate_cms_compatibility(positive: &RawHashCounts, negative: &RawHashCounts) -> Result<()> {
251 if positive.config.k != negative.config.k {
252 anyhow::bail!(
253 "k-mer size mismatch: positive={}, negative={}",
254 positive.config.k,
255 negative.config.k
256 );
257 }
258 if positive.config.fscale != negative.config.fscale {
259 anyhow::bail!(
260 "fscale mismatch: positive={}, negative={}",
261 positive.config.fscale,
262 negative.config.fscale
263 );
264 }
265 if positive.config.width != negative.config.width
266 || positive.config.depth != negative.config.depth
267 {
268 anyhow::bail!(
269 "CMS dimensions mismatch: positive={}x{}, negative={}x{}",
270 positive.config.width,
271 positive.config.depth,
272 negative.config.width,
273 negative.config.depth
274 );
275 }
276 Ok(())
277}
278
279impl HashBiasTable {
280 pub fn create(
281 positive_paths: &[&Path],
282 negative_paths: &[&Path],
283 config: &BiasCreateConfig,
284 progress: Option<ProgressBar>,
285 ) -> Result<Self> {
286 let record_counter = Arc::new(AtomicU64::new(0));
287 let hash_counter = Arc::new(AtomicU64::new(0));
288 let stop_flag = Arc::new(AtomicBool::new(false));
289
290 let update_handle = progress.as_ref().map(|pb| {
291 let pb = pb.clone();
292 let record_counter = Arc::clone(&record_counter);
293 let hash_counter = Arc::clone(&hash_counter);
294 let stop_flag = Arc::clone(&stop_flag);
295
296 std::thread::spawn(move || {
297 loop {
298 if stop_flag.load(Ordering::Relaxed) || pb.is_finished() {
299 break;
300 }
301 let records = record_counter.load(Ordering::Relaxed);
302 let hashes = hash_counter.load(Ordering::Relaxed);
303 pb.set_message(format!(
304 "{} records, {} hashes",
305 format_number(records),
306 format_number(hashes)
307 ));
308 std::thread::sleep(std::time::Duration::from_millis(100));
309 }
310 })
311 });
312
313 let (pos_raw, neg_raw) = rayon::join(
314 || {
315 RawHashCounts::build(
316 positive_paths,
317 config.cms.clone(),
318 &record_counter,
319 &hash_counter,
320 )
321 },
322 || {
323 RawHashCounts::build(
324 negative_paths,
325 config.cms.clone(),
326 &record_counter,
327 &hash_counter,
328 )
329 },
330 );
331
332 stop_flag.store(true, Ordering::Relaxed);
333 if let Some(handle) = update_handle {
334 let _ = handle.join();
335 }
336
337 let pos_raw = pos_raw?;
338 let neg_raw = neg_raw?;
339
340 if let Some(ref pb) = progress {
341 pb.set_message("Computing bias weights...");
342 }
343
344 let table = Self::build(
345 &pos_raw,
346 &neg_raw,
347 config.alpha,
348 config.target_fold_enrichment,
349 )?;
350
351 if let Some(ref pb) = progress {
352 pb.finish();
353 }
354
355 Ok(table)
356 }
357
358 pub(crate) fn build(
359 positive: &RawHashCounts,
360 negative: &RawHashCounts,
361 alpha: f32,
362 target_fold_enrichment: Option<f32>,
363 ) -> Result<Self> {
364 validate_cms_compatibility(positive, negative)?;
365
366 let width = positive.config.width;
367 let depth = positive.config.depth;
368 let seeds = positive.cms.seeds().to_vec();
369
370 let pos_counts = positive.cms.counts();
371 let neg_counts = negative.cms.counts();
372
373 let pos_total = positive.total as f64;
374 let neg_total = negative.total as f64;
375
376 let mut weights = vec![0i8; width * depth];
377
378 if pos_total > 0.0 && neg_total > 0.0 {
379 let scale = pos_total.max(neg_total);
380
381 for i in 0..(width * depth) {
382 let norm_pos = (pos_counts[i] as f64 / pos_total) * scale;
383 let norm_neg = (neg_counts[i] as f64 / neg_total) * scale;
384 let adj_neg = (norm_neg - norm_pos).max(0.0) as f32;
385 let norm_pos_f32 = norm_pos as f32;
386
387 let log_ratio = ((norm_pos_f32 + alpha) / (adj_neg + alpha)).ln();
388 let quantized = (log_ratio * QUANTIZATION_SCALE).clamp(-127.0, 127.0) as i8;
389 weights[i] = quantized;
390 }
391 }
392
393 let calibration = calibrate_threshold(
394 positive,
395 negative,
396 &weights,
397 &seeds,
398 width,
399 target_fold_enrichment,
400 )?;
401
402 Ok(Self {
403 config: positive.config.clone(),
404 seeds,
405 weights,
406 alpha,
407 threshold: calibration.threshold,
408 positive_retention: calibration.positive_retention,
409 negative_retention: calibration.negative_retention,
410 max_fold_enrichment: calibration.max_fold_enrichment,
411 })
412 }
413
414 #[inline]
415 fn index(&self, row: usize, hash: u64) -> usize {
416 let mixed = hash.wrapping_mul(self.seeds[row]);
417 row * self.config.width + (mixed as usize % self.config.width)
418 }
419
420 #[inline]
421 pub fn weight(&self, hash: u64) -> i8 {
422 (0..self.config.depth)
423 .map(|row| self.weights[self.index(row, hash)])
424 .min()
425 .unwrap_or(0)
426 }
427
428 #[inline]
429 pub fn passes_filter(&self, hash: u64) -> bool {
430 self.weight(hash) >= self.threshold
431 }
432
433 pub fn k(&self) -> u8 {
434 self.config.k
435 }
436 pub fn fscale(&self) -> u64 {
437 self.config.fscale
438 }
439
440 pub fn fold_enrichment(&self) -> f32 {
441 if self.negative_retention > 0.0 {
442 self.positive_retention / self.negative_retention
443 } else {
444 f32::INFINITY
445 }
446 }
447
448 pub fn save(&self, path: &Path) -> Result<()> {
449 let mut file = std::fs::File::create(path)
450 .with_context(|| format!("Failed to create bias table file: {}", path.display()))?;
451
452 file.write_all(BIAS_MAGIC)?;
453 file.write_all(&BIAS_VERSION.to_le_bytes())?;
454 file.write_all(&[self.config.k])?;
455 file.write_all(&self.config.fscale.to_le_bytes())?;
456 file.write_all(&(self.config.width as u32).to_le_bytes())?;
457 file.write_all(&[self.config.depth as u8])?;
458 file.write_all(&self.alpha.to_le_bytes())?;
459 file.write_all(&[self.threshold as u8])?;
460 file.write_all(&self.positive_retention.to_le_bytes())?;
461 file.write_all(&self.negative_retention.to_le_bytes())?;
462
463 for &seed in &self.seeds {
464 file.write_all(&seed.to_le_bytes())?;
465 }
466 for &w in &self.weights {
467 file.write_all(&[w as u8])?;
468 }
469
470 Ok(())
471 }
472
473 pub fn load(path: &Path) -> Result<Self> {
474 let mut file = std::fs::File::open(path)
475 .with_context(|| format!("Failed to open bias table file: {}", path.display()))?;
476
477 let mut magic = [0u8; 4];
478 file.read_exact(&mut magic)?;
479
480 if &magic != BIAS_MAGIC {
481 anyhow::bail!("Invalid bias table file (bad magic): {}", path.display());
482 }
483
484 let mut buf4 = [0u8; 4];
485 file.read_exact(&mut buf4)?;
486 let version = u32::from_le_bytes(buf4);
487 if version != BIAS_VERSION {
488 anyhow::bail!(
489 "Unsupported bias table version {} (expected {})",
490 version,
491 BIAS_VERSION
492 );
493 }
494
495 let mut k_buf = [0u8; 1];
496 file.read_exact(&mut k_buf)?;
497 let k = k_buf[0];
498
499 let mut buf8 = [0u8; 8];
500 file.read_exact(&mut buf8)?;
501 let fscale = u64::from_le_bytes(buf8);
502
503 file.read_exact(&mut buf4)?;
504 let width = u32::from_le_bytes(buf4) as usize;
505
506 let mut depth_buf = [0u8; 1];
507 file.read_exact(&mut depth_buf)?;
508 let depth = depth_buf[0] as usize;
509
510 file.read_exact(&mut buf4)?;
511 let alpha = f32::from_le_bytes(buf4);
512
513 let mut threshold_buf = [0u8; 1];
514 file.read_exact(&mut threshold_buf)?;
515 let threshold = threshold_buf[0] as i8;
516
517 file.read_exact(&mut buf4)?;
518 let positive_retention = f32::from_le_bytes(buf4);
519
520 file.read_exact(&mut buf4)?;
521 let negative_retention = f32::from_le_bytes(buf4);
522
523 let mut seeds = Vec::with_capacity(depth);
524 for _ in 0..depth {
525 file.read_exact(&mut buf8)?;
526 seeds.push(u64::from_le_bytes(buf8));
527 }
528
529 let mut weights = vec![0i8; width * depth];
530 let mut weight_buf = vec![0u8; width * depth];
531 file.read_exact(&mut weight_buf)?;
532 for (i, &b) in weight_buf.iter().enumerate() {
533 weights[i] = b as i8;
534 }
535
536 let config = CMSConfig {
537 width,
538 depth,
539 k,
540 fscale,
541 };
542
543 let max_fold_enrichment = if negative_retention > 0.0 {
544 positive_retention / negative_retention
545 } else {
546 f32::INFINITY
547 };
548
549 Ok(Self {
550 config,
551 seeds,
552 weights,
553 alpha,
554 threshold,
555 positive_retention,
556 negative_retention,
557 max_fold_enrichment,
558 })
559 }
560
561 pub fn to_bytes(&self) -> Vec<u8> {
562 let header_size = 4 + 4 + 1 + 8 + 4 + 1 + 4 + 1 + 4 + 4;
563 let seeds_size = self.config.depth * 8;
564 let weights_size = self.config.width * self.config.depth;
565 let total_size = header_size + seeds_size + weights_size;
566
567 let mut out = Vec::with_capacity(total_size);
568 out.extend_from_slice(BIAS_MAGIC);
569 out.extend_from_slice(&BIAS_VERSION.to_le_bytes());
570 out.push(self.config.k);
571 out.extend_from_slice(&self.config.fscale.to_le_bytes());
572 out.extend_from_slice(&(self.config.width as u32).to_le_bytes());
573 out.push(self.config.depth as u8);
574 out.extend_from_slice(&self.alpha.to_le_bytes());
575 out.push(self.threshold as u8);
576 out.extend_from_slice(&self.positive_retention.to_le_bytes());
577 out.extend_from_slice(&self.negative_retention.to_le_bytes());
578
579 for &seed in &self.seeds {
580 out.extend_from_slice(&seed.to_le_bytes());
581 }
582 for &w in &self.weights {
583 out.push(w as u8);
584 }
585
586 out
587 }
588
589 pub fn from_bytes(data: &[u8]) -> Result<Self> {
590 if data.len() < 35 {
591 anyhow::bail!("Bias table data too small: {} bytes", data.len());
592 }
593
594 let magic: [u8; 4] = data[0..4].try_into().unwrap();
595 if &magic != BIAS_MAGIC {
596 anyhow::bail!("Invalid bias table magic bytes");
597 }
598
599 let version = u32::from_le_bytes(data[4..8].try_into().unwrap());
600 if version != BIAS_VERSION {
601 anyhow::bail!("Unsupported bias table version {}", version);
602 }
603
604 let k = data[8];
605 let fscale = u64::from_le_bytes(data[9..17].try_into().unwrap());
606 let width = u32::from_le_bytes(data[17..21].try_into().unwrap()) as usize;
607 let depth = data[21] as usize;
608 let alpha = f32::from_le_bytes(data[22..26].try_into().unwrap());
609 let threshold = data[26] as i8;
610 let positive_retention = f32::from_le_bytes(data[27..31].try_into().unwrap());
611 let negative_retention = f32::from_le_bytes(data[31..35].try_into().unwrap());
612
613 let seeds_start = 35;
614 let seeds_end = seeds_start + depth * 8;
615 let weights_start = seeds_end;
616 let weights_end = weights_start + width * depth;
617
618 if data.len() < weights_end {
619 anyhow::bail!(
620 "Bias table data truncated: expected {} bytes, got {}",
621 weights_end,
622 data.len()
623 );
624 }
625
626 let mut seeds = Vec::with_capacity(depth);
627 for i in 0..depth {
628 let offset = seeds_start + i * 8;
629 seeds.push(u64::from_le_bytes(
630 data[offset..offset + 8].try_into().unwrap(),
631 ));
632 }
633
634 let mut weights = vec![0i8; width * depth];
635 for (i, &b) in data[weights_start..weights_end].iter().enumerate() {
636 weights[i] = b as i8;
637 }
638
639 let config = CMSConfig {
640 width,
641 depth,
642 k,
643 fscale,
644 };
645
646 let max_fold_enrichment = if negative_retention > 0.0 {
647 positive_retention / negative_retention
648 } else {
649 f32::INFINITY
650 };
651
652 Ok(Self {
653 config,
654 seeds,
655 weights,
656 alpha,
657 threshold,
658 positive_retention,
659 negative_retention,
660 max_fold_enrichment,
661 })
662 }
663
664 pub fn weight_stats(&self) -> (f32, f32, f32, f32, usize) {
665 let min = *self.weights.iter().min().unwrap_or(&0) as f32 / QUANTIZATION_SCALE;
666 let max = *self.weights.iter().max().unwrap_or(&0) as f32 / QUANTIZATION_SCALE;
667 let sum: i64 = self.weights.iter().map(|&w| w as i64).sum();
668 let mean = sum as f32 / self.weights.len() as f32 / QUANTIZATION_SCALE;
669 let variance: f32 = self
670 .weights
671 .iter()
672 .map(|&w| {
673 let d = w as f32 / QUANTIZATION_SCALE - mean;
674 d * d
675 })
676 .sum::<f32>()
677 / self.weights.len() as f32;
678 let positive = self.weights.iter().filter(|&&w| w > 0).count();
679 (min, max, mean, variance.sqrt(), positive)
680 }
681
682 pub fn memory_usage(&self) -> usize {
683 self.weights.len() + self.seeds.len() * 8
684 }
685
686 pub fn threshold_f32(&self) -> f32 {
687 self.threshold as f32 / QUANTIZATION_SCALE
688 }
689
690 pub fn print_stats(&self) {
691 let (min, max, mean, std, positive) = self.weight_stats();
692 let total_cells = self.config.width * self.config.depth;
693 eprintln!("Hash Bias Table (v3)");
694 eprintln!(" k-mer size: {}", self.config.k);
695 eprintln!(" fscale: {}", self.config.fscale);
696 eprintln!(
697 " CMS dimensions: {} x {}",
698 self.config.width, self.config.depth
699 );
700 eprintln!(" Smoothing (alpha): {:.1}", self.alpha);
701 eprintln!(
702 " Threshold: {:.2} (quantized: {})",
703 self.threshold_f32(),
704 self.threshold
705 );
706 eprintln!(
707 " Positive retention: {:.2}%",
708 self.positive_retention * 100.0
709 );
710 eprintln!(
711 " Negative retention: {:.2}%",
712 self.negative_retention * 100.0
713 );
714 eprintln!(" Fold enrichment: {:.2}x", self.fold_enrichment());
715 eprintln!(
716 " Weight stats: min={:.2}, max={:.2}, mean={:.2}, std={:.2}",
717 min, max, mean, std
718 );
719 eprintln!(
720 " Positive weights: {} ({:.1}%)",
721 positive,
722 positive as f64 / total_cells as f64 * 100.0
723 );
724 }
725}
726
727fn calibrate_threshold(
728 positive: &RawHashCounts,
729 negative: &RawHashCounts,
730 weights: &[i8],
731 seeds: &[u64],
732 width: usize,
733 target_fold_enrichment: Option<f32>,
734) -> Result<CalibrationResult> {
735 let sample_hashes = |raw: &RawHashCounts, max_samples: usize| -> Vec<u64> {
736 if raw.samples.len() <= max_samples {
737 return raw.samples.clone();
738 }
739 let step = raw.samples.len() / max_samples;
740 raw.samples
741 .iter()
742 .step_by(step)
743 .take(max_samples)
744 .copied()
745 .collect()
746 };
747
748 let estimate_weight = |hash: u64| -> i8 {
749 let depth = seeds.len();
750 (0..depth)
751 .map(|row| {
752 let mixed = hash.wrapping_mul(seeds[row]);
753 let idx = row * width + (mixed as usize % width);
754 weights[idx]
755 })
756 .min()
757 .unwrap_or(0)
758 };
759
760 let pos_sample_weights: Vec<i8> = sample_hashes(positive, 100_000)
761 .iter()
762 .map(|&h| estimate_weight(h))
763 .collect();
764 let neg_sample_weights: Vec<i8> = sample_hashes(negative, 100_000)
765 .iter()
766 .map(|&h| estimate_weight(h))
767 .collect();
768
769 if pos_sample_weights.is_empty() || neg_sample_weights.is_empty() {
770 return Ok(CalibrationResult {
771 threshold: 0,
772 positive_retention: 1.0,
773 negative_retention: 1.0,
774 fold_enrichment: 1.0,
775 max_fold_enrichment: 1.0,
776 });
777 }
778
779 let mut max_enrichment = 0.0f32;
780 let mut max_threshold = 0i8;
781 let mut max_pos_ret = 1.0f32;
782 let mut max_neg_ret = 1.0f32;
783
784 for t in -127i8..=127i8 {
785 let pos_passing = pos_sample_weights.iter().filter(|&&w| w >= t).count();
786 let neg_passing = neg_sample_weights.iter().filter(|&&w| w >= t).count();
787
788 let pos_ret = pos_passing as f32 / pos_sample_weights.len() as f32;
789 let neg_ret = neg_passing as f32 / neg_sample_weights.len().max(1) as f32;
790
791 if neg_ret < 1e-6 {
792 continue;
793 }
794
795 let enrichment = pos_ret / neg_ret;
796 if enrichment > max_enrichment {
797 max_enrichment = enrichment;
798 max_threshold = t;
799 max_pos_ret = pos_ret;
800 max_neg_ret = neg_ret;
801 }
802 }
803
804 match target_fold_enrichment {
805 None => Ok(CalibrationResult {
806 threshold: max_threshold,
807 positive_retention: max_pos_ret,
808 negative_retention: max_neg_ret,
809 fold_enrichment: max_enrichment,
810 max_fold_enrichment: max_enrichment,
811 }),
812 Some(target) => {
813 if target > max_enrichment {
814 return Ok(CalibrationResult {
815 threshold: max_threshold,
816 positive_retention: max_pos_ret,
817 negative_retention: max_neg_ret,
818 fold_enrichment: max_enrichment,
819 max_fold_enrichment: max_enrichment,
820 });
821 }
822
823 let mut best_threshold = 0i8;
824 let mut best_diff = f32::MAX;
825 let mut best_pos_ret = 1.0f32;
826 let mut best_neg_ret = 1.0f32;
827
828 for t in -127i8..=127i8 {
829 let pos_passing = pos_sample_weights.iter().filter(|&&w| w >= t).count();
830 let neg_passing = neg_sample_weights.iter().filter(|&&w| w >= t).count();
831
832 let pos_ret = pos_passing as f32 / pos_sample_weights.len() as f32;
833 let neg_ret = neg_passing as f32 / neg_sample_weights.len().max(1) as f32;
834
835 if neg_ret < 1e-6 {
836 continue;
837 }
838
839 let enrichment = pos_ret / neg_ret;
840 let diff = (enrichment - target).abs();
841
842 if diff < best_diff {
843 best_diff = diff;
844 best_threshold = t;
845 best_pos_ret = pos_ret;
846 best_neg_ret = neg_ret;
847 }
848 }
849
850 Ok(CalibrationResult {
851 threshold: best_threshold,
852 positive_retention: best_pos_ret,
853 negative_retention: best_neg_ret,
854 fold_enrichment: if best_neg_ret > 0.0 {
855 best_pos_ret / best_neg_ret
856 } else {
857 f32::INFINITY
858 },
859 max_fold_enrichment: max_enrichment,
860 })
861 }
862 }
863}
864
865fn format_number(n: u64) -> String {
866 if n >= 1_000_000_000 {
867 format!("{:.2}G", n as f64 / 1_000_000_000.0)
868 } else if n >= 1_000_000 {
869 format!("{:.2}M", n as f64 / 1_000_000.0)
870 } else if n >= 1_000 {
871 format!("{:.2}K", n as f64 / 1_000.0)
872 } else {
873 format!("{}", n)
874 }
875}
876
877pub fn format_bp(bp: u64) -> String {
878 if bp >= 1_000_000_000 {
879 format!("{:.2} Gbp", bp as f64 / 1_000_000_000.0)
880 } else if bp >= 1_000_000 {
881 format!("{:.2} Mbp", bp as f64 / 1_000_000.0)
882 } else if bp >= 1_000 {
883 format!("{:.2} Kbp", bp as f64 / 1_000.0)
884 } else {
885 format!("{} bp", bp)
886 }
887}
888
889pub const BIAS_TABLE_SERIALIZED_SIZE: usize =
890 35 + DEFAULT_CMS_DEPTH * 8 + DEFAULT_CMS_WIDTH * DEFAULT_CMS_DEPTH;
891
892impl PartialEq for HashBiasTable {
893 fn eq(&self, other: &Self) -> bool {
894 self.config.k == other.config.k
895 && self.config.fscale == other.config.fscale
896 && self.config.width == other.config.width
897 && self.config.depth == other.config.depth
898 && self.alpha == other.alpha
899 && self.threshold == other.threshold
900 && self.positive_retention == other.positive_retention
901 && self.negative_retention == other.negative_retention
902 && self.seeds == other.seeds
903 && self.weights == other.weights
904 }
905}
906
907#[cfg(test)]
908mod tests {
909 use super::*;
910 use std::io::Write;
911 use tempfile::NamedTempFile;
912
913 fn create_fasta(sequences: &[&str]) -> NamedTempFile {
914 let mut file = NamedTempFile::new().unwrap();
915 for (i, seq) in sequences.iter().enumerate() {
916 writeln!(file, ">seq_{}", i).unwrap();
917 writeln!(file, "{}", seq).unwrap();
918 }
919 file
920 }
921
922 #[test]
923 fn test_cms_basic() {
924 let mut cms = CountMinSketch::new(1024, 5);
925 let hash = 0x12345678u64;
926
927 assert_eq!(cms.estimate(hash), 0);
928
929 cms.increment(hash);
930 assert_eq!(cms.estimate(hash), 1);
931
932 for _ in 0..9 {
933 cms.increment(hash);
934 }
935 assert_eq!(cms.estimate(hash), 10);
936 }
937
938 #[test]
939 fn test_cms_collision_handling() {
940 let mut cms = CountMinSketch::new(16, 5);
941
942 for i in 0..100u64 {
943 cms.increment(i);
944 }
945
946 for i in 0..100u64 {
947 assert!(cms.estimate(i) >= 1);
948 }
949 }
950
951 #[test]
952 fn test_raw_hash_counts_build() {
953 let fasta = create_fasta(&["ATCGATCGATCGATCGATCGATCGATCGATCGATCGATCG"]);
954 let config = CMSConfig {
955 width: 1024,
956 depth: 3,
957 k: 11,
958 fscale: 1,
959 };
960
961 let raw = RawHashCounts::build(
962 &[fasta.path()],
963 config,
964 &AtomicU64::new(0),
965 &AtomicU64::new(0),
966 )
967 .unwrap();
968 assert!(raw.total > 0);
969 }
970
971 #[test]
972 fn test_hash_bias_table_build() {
973 let pos = create_fasta(&[
974 "ATATATATATATATATATATATATATATATATATATATAT",
975 "TATATATATATATATATATATATATATATATATATATAT",
976 ]);
977 let neg = create_fasta(&[
978 "GCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGC",
979 "CGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCG",
980 ]);
981
982 let config = CMSConfig {
983 width: 1024,
984 depth: 3,
985 k: 11,
986 fscale: 1,
987 };
988
989 let pos_raw = RawHashCounts::build(
990 &[pos.path()],
991 config.clone(),
992 &AtomicU64::new(0),
993 &AtomicU64::new(0),
994 )
995 .unwrap();
996 let neg_raw = RawHashCounts::build(
997 &[neg.path()],
998 config,
999 &AtomicU64::new(0),
1000 &AtomicU64::new(0),
1001 )
1002 .unwrap();
1003
1004 let table = HashBiasTable::build(&pos_raw, &neg_raw, 1.0, Some(5.0)).unwrap();
1005 assert!(table.threshold >= -127);
1006 }
1007
1008 #[test]
1009 fn test_hash_bias_table_save_load() {
1010 let pos = create_fasta(&["ATATATATATATATATATATATATATATATATATATATAT"]);
1011 let neg = create_fasta(&["GCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGC"]);
1012
1013 let config = CMSConfig {
1014 width: 1024,
1015 depth: 3,
1016 k: 11,
1017 fscale: 10,
1018 };
1019
1020 let pos_raw = RawHashCounts::build(
1021 &[pos.path()],
1022 config.clone(),
1023 &AtomicU64::new(0),
1024 &AtomicU64::new(0),
1025 )
1026 .unwrap();
1027 let neg_raw = RawHashCounts::build(
1028 &[neg.path()],
1029 config,
1030 &AtomicU64::new(0),
1031 &AtomicU64::new(0),
1032 )
1033 .unwrap();
1034
1035 let table = HashBiasTable::build(&pos_raw, &neg_raw, 1.0, Some(2.0)).unwrap();
1036
1037 let output = NamedTempFile::new().unwrap();
1038 table.save(output.path()).unwrap();
1039
1040 let loaded = HashBiasTable::load(output.path()).unwrap();
1041 assert_eq!(table.config.k, loaded.config.k);
1042 assert_eq!(table.threshold, loaded.threshold);
1043 assert_eq!(table.weights, loaded.weights);
1044 }
1045
1046 #[test]
1047 fn test_hash_bias_table_bytes_roundtrip() {
1048 let pos = create_fasta(&["ATATATATATATATATATATATATATATATATATATATAT"]);
1049 let neg = create_fasta(&["GCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGC"]);
1050
1051 let config = CMSConfig {
1052 width: 512,
1053 depth: 3,
1054 k: 11,
1055 fscale: 10,
1056 };
1057
1058 let pos_raw = RawHashCounts::build(
1059 &[pos.path()],
1060 config.clone(),
1061 &AtomicU64::new(0),
1062 &AtomicU64::new(0),
1063 )
1064 .unwrap();
1065 let neg_raw = RawHashCounts::build(
1066 &[neg.path()],
1067 config,
1068 &AtomicU64::new(0),
1069 &AtomicU64::new(0),
1070 )
1071 .unwrap();
1072
1073 let table = HashBiasTable::build(&pos_raw, &neg_raw, 1.0, Some(2.0)).unwrap();
1074
1075 let bytes = table.to_bytes();
1076 let loaded = HashBiasTable::from_bytes(&bytes).unwrap();
1077
1078 assert_eq!(table, loaded);
1079 }
1080
1081 #[test]
1082 fn test_passes_filter() {
1083 let pos = create_fasta(&["ATATATATATATATATATATATATATATATATATATATAT"]);
1084 let neg = create_fasta(&["GCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGC"]);
1085
1086 let config = CMSConfig {
1087 width: 1024,
1088 depth: 3,
1089 k: 11,
1090 fscale: 1,
1091 };
1092
1093 let pos_raw = RawHashCounts::build(
1094 &[pos.path()],
1095 config.clone(),
1096 &AtomicU64::new(0),
1097 &AtomicU64::new(0),
1098 )
1099 .unwrap();
1100 let neg_raw = RawHashCounts::build(
1101 &[neg.path()],
1102 config,
1103 &AtomicU64::new(0),
1104 &AtomicU64::new(0),
1105 )
1106 .unwrap();
1107
1108 let table = HashBiasTable::build(&pos_raw, &neg_raw, 1.0, Some(2.0)).unwrap();
1109
1110 let mut passed = 0;
1111 let mut failed = 0;
1112 for h in 0..1000u64 {
1113 if table.passes_filter(h) {
1114 passed += 1;
1115 } else {
1116 failed += 1;
1117 }
1118 }
1119
1120 assert!(passed > 0 || failed > 0);
1121 }
1122
1123 #[test]
1124 fn test_maximize_fold_enrichment() {
1125 let pos = create_fasta(&[
1126 "ATATATATATATATATATATATATATATATATATATATAT",
1127 "TATATATATATATATATATATATATATATATATATATAT",
1128 ]);
1129 let neg = create_fasta(&[
1130 "GCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGC",
1131 "CGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCG",
1132 ]);
1133
1134 let config = CMSConfig {
1135 width: 1024,
1136 depth: 3,
1137 k: 11,
1138 fscale: 1,
1139 };
1140
1141 let pos_raw = RawHashCounts::build(
1142 &[pos.path()],
1143 config.clone(),
1144 &AtomicU64::new(0),
1145 &AtomicU64::new(0),
1146 )
1147 .unwrap();
1148 let neg_raw = RawHashCounts::build(
1149 &[neg.path()],
1150 config,
1151 &AtomicU64::new(0),
1152 &AtomicU64::new(0),
1153 )
1154 .unwrap();
1155
1156 let table = HashBiasTable::build(&pos_raw, &neg_raw, 1.0, None).unwrap();
1157 assert!(table.threshold >= -127);
1158 assert!(table.fold_enrichment() >= 1.0);
1159 }
1160
1161 #[test]
1162 fn test_create_unified() {
1163 let pos = create_fasta(&["ATATATATATATATATATATATATATATATATATATATAT"]);
1164 let neg = create_fasta(&["GCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGC"]);
1165
1166 let config = BiasCreateConfig {
1167 cms: CMSConfig {
1168 width: 1024,
1169 depth: 3,
1170 k: 11,
1171 fscale: 1,
1172 },
1173 alpha: 1.0,
1174 target_fold_enrichment: None,
1175 };
1176
1177 let table = HashBiasTable::create(&[pos.path()], &[neg.path()], &config, None).unwrap();
1178
1179 assert!(table.threshold >= -127);
1180 assert!(table.fold_enrichment() >= 1.0);
1181 }
1182}