Skip to main content

jam_rs/
bias.rs

1use anyhow::{Context, Result};
2use indicatif::ProgressBar;
3use jamhash::jamhash_u64;
4use needletail::{Sequence, parse_fastx_file};
5use std::io::{Read, Write};
6use std::path::Path;
7use std::sync::Arc;
8use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
9
10const BIAS_MAGIC: &[u8; 4] = b"BIA3";
11const BIAS_VERSION: u32 = 3;
12
13const DEFAULT_CMS_WIDTH: usize = 1 << 20;
14const DEFAULT_CMS_DEPTH: usize = 5;
15const QUANTIZATION_SCALE: f32 = 10.0;
16const MAX_SAMPLE_HASHES: usize = 100_000;
17
18#[derive(Debug, Clone)]
19pub struct CMSConfig {
20    pub width: usize,
21    pub depth: usize,
22    pub k: u8,
23    pub fscale: u64,
24}
25
26impl Default for CMSConfig {
27    fn default() -> Self {
28        Self {
29            width: DEFAULT_CMS_WIDTH,
30            depth: DEFAULT_CMS_DEPTH,
31            k: 21,
32            fscale: 1000,
33        }
34    }
35}
36
37#[derive(Debug, Clone)]
38pub struct CountMinSketch {
39    width: usize,
40    depth: usize,
41    seeds: Vec<u64>,
42    counts: Vec<u64>,
43}
44
45impl CountMinSketch {
46    pub fn new(width: usize, depth: usize) -> Self {
47        let seeds: Vec<u64> = (0..depth)
48            .map(|i| 0x517cc1b727220a95u64.wrapping_add(i as u64))
49            .collect();
50        let counts = vec![0u64; width * depth];
51        Self {
52            width,
53            depth,
54            seeds,
55            counts,
56        }
57    }
58
59    pub fn with_seeds(width: usize, depth: usize, seeds: Vec<u64>) -> Self {
60        assert_eq!(seeds.len(), depth);
61        let counts = vec![0u64; width * depth];
62        Self {
63            width,
64            depth,
65            seeds,
66            counts,
67        }
68    }
69
70    #[inline]
71    fn index(&self, row: usize, hash: u64) -> usize {
72        let mixed = hash.wrapping_mul(self.seeds[row]);
73        row * self.width + (mixed as usize % self.width)
74    }
75
76    #[inline]
77    pub fn increment(&mut self, hash: u64) {
78        for row in 0..self.depth {
79            let idx = self.index(row, hash);
80            self.counts[idx] = self.counts[idx].saturating_add(1);
81        }
82    }
83
84    #[inline]
85    pub fn estimate(&self, hash: u64) -> u64 {
86        (0..self.depth)
87            .map(|row| self.counts[self.index(row, hash)])
88            .min()
89            .unwrap_or(0)
90    }
91
92    pub fn width(&self) -> usize {
93        self.width
94    }
95    pub fn depth(&self) -> usize {
96        self.depth
97    }
98    pub fn seeds(&self) -> &[u64] {
99        &self.seeds
100    }
101    pub fn counts(&self) -> &[u64] {
102        &self.counts
103    }
104
105    pub fn cell_stats(&self) -> (u64, u64, f64, f64, usize) {
106        let min = *self.counts.iter().min().unwrap_or(&0);
107        let max = *self.counts.iter().max().unwrap_or(&0);
108        let sum: u64 = self.counts.iter().sum();
109        let mean = sum as f64 / self.counts.len() as f64;
110        let variance: f64 = self
111            .counts
112            .iter()
113            .map(|&c| {
114                let d = c as f64 - mean;
115                d * d
116            })
117            .sum::<f64>()
118            / self.counts.len() as f64;
119        let non_zero = self.counts.iter().filter(|&&c| c > 0).count();
120        (min, max, mean, variance.sqrt(), non_zero)
121    }
122}
123
124#[derive(Debug, Clone)]
125pub(crate) struct RawHashCounts {
126    pub(crate) config: CMSConfig,
127    pub(crate) cms: CountMinSketch,
128    pub(crate) total: u64,
129    pub(crate) samples: Vec<u64>,
130}
131
132impl RawHashCounts {
133    pub(crate) fn new(config: CMSConfig) -> Self {
134        let cms = CountMinSketch::new(config.width, config.depth);
135        Self {
136            config,
137            cms,
138            total: 0,
139            samples: Vec::with_capacity(MAX_SAMPLE_HASHES),
140        }
141    }
142
143    pub(crate) fn build(
144        paths: &[&Path],
145        config: CMSConfig,
146        record_counter: &AtomicU64,
147        hash_counter: &AtomicU64,
148    ) -> Result<Self> {
149        let frac_max = u64::MAX / config.fscale;
150        let k = config.k;
151
152        let mut raw = RawHashCounts::new(config);
153        for path in paths {
154            process_path(&mut raw, path, k, frac_max, record_counter, hash_counter)?;
155        }
156
157        if raw.samples.len() > MAX_SAMPLE_HASHES {
158            downsample_samples(&mut raw.samples);
159        }
160
161        Ok(raw)
162    }
163}
164
165fn process_path(
166    raw: &mut RawHashCounts,
167    path: &Path,
168    k: u8,
169    frac_max: u64,
170    record_counter: &AtomicU64,
171    hash_counter: &AtomicU64,
172) -> Result<()> {
173    let mut reader = match parse_fastx_file(path) {
174        Ok(reader) => reader,
175        Err(e) if e.kind == needletail::errors::ParseErrorKind::EmptyFile => {
176            return Ok(());
177        }
178        Err(e) => {
179            return Err(e).with_context(|| format!("Failed to parse: {}", path.display()));
180        }
181    };
182
183    while let Some(record) = reader.next() {
184        let record = record.context("Failed to parse sequence record")?;
185        let seq = record.normalize(false);
186        record_counter.fetch_add(1, Ordering::Relaxed);
187
188        if seq.len() < k as usize {
189            continue;
190        }
191
192        for (_, kmer, _) in seq.bit_kmers(k, true) {
193            let hash = jamhash_u64(kmer.0);
194            if hash < frac_max {
195                raw.cms.increment(hash);
196                raw.total += 1;
197                if raw.samples.len() < MAX_SAMPLE_HASHES {
198                    raw.samples.push(hash);
199                } else {
200                    let seen = raw.total;
201                    let pick = (jamhash_u64(hash ^ seen) % seen) as usize;
202                    if pick < MAX_SAMPLE_HASHES {
203                        raw.samples[pick] = hash;
204                    }
205                }
206                hash_counter.fetch_add(1, Ordering::Relaxed);
207            }
208        }
209    }
210
211    Ok(())
212}
213
214fn downsample_samples(samples: &mut Vec<u64>) {
215    if samples.len() <= MAX_SAMPLE_HASHES {
216        return;
217    }
218    samples.sort_unstable_by_key(|&hash| jamhash_u64(hash));
219    samples.truncate(MAX_SAMPLE_HASHES);
220}
221
222#[derive(Debug, Clone)]
223pub struct BiasCreateConfig {
224    pub cms: CMSConfig,
225    pub alpha: f32,
226    pub target_fold_enrichment: Option<f32>,
227}
228
229#[derive(Debug, Clone, Copy)]
230pub struct CalibrationResult {
231    pub threshold: i8,
232    pub positive_retention: f32,
233    pub negative_retention: f32,
234    pub fold_enrichment: f32,
235    pub max_fold_enrichment: f32,
236}
237
238#[derive(Debug, Clone)]
239pub struct HashBiasTable {
240    pub config: CMSConfig,
241    seeds: Vec<u64>,
242    weights: Vec<i8>,
243    pub alpha: f32,
244    pub threshold: i8,
245    pub positive_retention: f32,
246    pub negative_retention: f32,
247    pub max_fold_enrichment: f32,
248}
249
250fn validate_cms_compatibility(positive: &RawHashCounts, negative: &RawHashCounts) -> Result<()> {
251    if positive.config.k != negative.config.k {
252        anyhow::bail!(
253            "k-mer size mismatch: positive={}, negative={}",
254            positive.config.k,
255            negative.config.k
256        );
257    }
258    if positive.config.fscale != negative.config.fscale {
259        anyhow::bail!(
260            "fscale mismatch: positive={}, negative={}",
261            positive.config.fscale,
262            negative.config.fscale
263        );
264    }
265    if positive.config.width != negative.config.width
266        || positive.config.depth != negative.config.depth
267    {
268        anyhow::bail!(
269            "CMS dimensions mismatch: positive={}x{}, negative={}x{}",
270            positive.config.width,
271            positive.config.depth,
272            negative.config.width,
273            negative.config.depth
274        );
275    }
276    Ok(())
277}
278
279impl HashBiasTable {
280    pub fn create(
281        positive_paths: &[&Path],
282        negative_paths: &[&Path],
283        config: &BiasCreateConfig,
284        progress: Option<ProgressBar>,
285    ) -> Result<Self> {
286        let record_counter = Arc::new(AtomicU64::new(0));
287        let hash_counter = Arc::new(AtomicU64::new(0));
288        let stop_flag = Arc::new(AtomicBool::new(false));
289
290        let update_handle = progress.as_ref().map(|pb| {
291            let pb = pb.clone();
292            let record_counter = Arc::clone(&record_counter);
293            let hash_counter = Arc::clone(&hash_counter);
294            let stop_flag = Arc::clone(&stop_flag);
295
296            std::thread::spawn(move || {
297                loop {
298                    if stop_flag.load(Ordering::Relaxed) || pb.is_finished() {
299                        break;
300                    }
301                    let records = record_counter.load(Ordering::Relaxed);
302                    let hashes = hash_counter.load(Ordering::Relaxed);
303                    pb.set_message(format!(
304                        "{} records, {} hashes",
305                        format_number(records),
306                        format_number(hashes)
307                    ));
308                    std::thread::sleep(std::time::Duration::from_millis(100));
309                }
310            })
311        });
312
313        let (pos_raw, neg_raw) = rayon::join(
314            || {
315                RawHashCounts::build(
316                    positive_paths,
317                    config.cms.clone(),
318                    &record_counter,
319                    &hash_counter,
320                )
321            },
322            || {
323                RawHashCounts::build(
324                    negative_paths,
325                    config.cms.clone(),
326                    &record_counter,
327                    &hash_counter,
328                )
329            },
330        );
331
332        stop_flag.store(true, Ordering::Relaxed);
333        if let Some(handle) = update_handle {
334            let _ = handle.join();
335        }
336
337        let pos_raw = pos_raw?;
338        let neg_raw = neg_raw?;
339
340        if let Some(ref pb) = progress {
341            pb.set_message("Computing bias weights...");
342        }
343
344        let table = Self::build(
345            &pos_raw,
346            &neg_raw,
347            config.alpha,
348            config.target_fold_enrichment,
349        )?;
350
351        if let Some(ref pb) = progress {
352            pb.finish();
353        }
354
355        Ok(table)
356    }
357
358    pub(crate) fn build(
359        positive: &RawHashCounts,
360        negative: &RawHashCounts,
361        alpha: f32,
362        target_fold_enrichment: Option<f32>,
363    ) -> Result<Self> {
364        validate_cms_compatibility(positive, negative)?;
365
366        let width = positive.config.width;
367        let depth = positive.config.depth;
368        let seeds = positive.cms.seeds().to_vec();
369
370        let pos_counts = positive.cms.counts();
371        let neg_counts = negative.cms.counts();
372
373        let pos_total = positive.total as f64;
374        let neg_total = negative.total as f64;
375
376        let mut weights = vec![0i8; width * depth];
377
378        if pos_total > 0.0 && neg_total > 0.0 {
379            let scale = pos_total.max(neg_total);
380
381            for i in 0..(width * depth) {
382                let norm_pos = (pos_counts[i] as f64 / pos_total) * scale;
383                let norm_neg = (neg_counts[i] as f64 / neg_total) * scale;
384                let adj_neg = (norm_neg - norm_pos).max(0.0) as f32;
385                let norm_pos_f32 = norm_pos as f32;
386
387                let log_ratio = ((norm_pos_f32 + alpha) / (adj_neg + alpha)).ln();
388                let quantized = (log_ratio * QUANTIZATION_SCALE).clamp(-127.0, 127.0) as i8;
389                weights[i] = quantized;
390            }
391        }
392
393        let calibration = calibrate_threshold(
394            positive,
395            negative,
396            &weights,
397            &seeds,
398            width,
399            target_fold_enrichment,
400        )?;
401
402        Ok(Self {
403            config: positive.config.clone(),
404            seeds,
405            weights,
406            alpha,
407            threshold: calibration.threshold,
408            positive_retention: calibration.positive_retention,
409            negative_retention: calibration.negative_retention,
410            max_fold_enrichment: calibration.max_fold_enrichment,
411        })
412    }
413
414    #[inline]
415    fn index(&self, row: usize, hash: u64) -> usize {
416        let mixed = hash.wrapping_mul(self.seeds[row]);
417        row * self.config.width + (mixed as usize % self.config.width)
418    }
419
420    #[inline]
421    pub fn weight(&self, hash: u64) -> i8 {
422        (0..self.config.depth)
423            .map(|row| self.weights[self.index(row, hash)])
424            .min()
425            .unwrap_or(0)
426    }
427
428    #[inline]
429    pub fn passes_filter(&self, hash: u64) -> bool {
430        self.weight(hash) >= self.threshold
431    }
432
433    pub fn k(&self) -> u8 {
434        self.config.k
435    }
436    pub fn fscale(&self) -> u64 {
437        self.config.fscale
438    }
439
440    pub fn fold_enrichment(&self) -> f32 {
441        if self.negative_retention > 0.0 {
442            self.positive_retention / self.negative_retention
443        } else {
444            f32::INFINITY
445        }
446    }
447
448    pub fn save(&self, path: &Path) -> Result<()> {
449        let mut file = std::fs::File::create(path)
450            .with_context(|| format!("Failed to create bias table file: {}", path.display()))?;
451
452        file.write_all(BIAS_MAGIC)?;
453        file.write_all(&BIAS_VERSION.to_le_bytes())?;
454        file.write_all(&[self.config.k])?;
455        file.write_all(&self.config.fscale.to_le_bytes())?;
456        file.write_all(&(self.config.width as u32).to_le_bytes())?;
457        file.write_all(&[self.config.depth as u8])?;
458        file.write_all(&self.alpha.to_le_bytes())?;
459        file.write_all(&[self.threshold as u8])?;
460        file.write_all(&self.positive_retention.to_le_bytes())?;
461        file.write_all(&self.negative_retention.to_le_bytes())?;
462
463        for &seed in &self.seeds {
464            file.write_all(&seed.to_le_bytes())?;
465        }
466        for &w in &self.weights {
467            file.write_all(&[w as u8])?;
468        }
469
470        Ok(())
471    }
472
473    pub fn load(path: &Path) -> Result<Self> {
474        let mut file = std::fs::File::open(path)
475            .with_context(|| format!("Failed to open bias table file: {}", path.display()))?;
476
477        let mut magic = [0u8; 4];
478        file.read_exact(&mut magic)?;
479
480        if &magic != BIAS_MAGIC {
481            anyhow::bail!("Invalid bias table file (bad magic): {}", path.display());
482        }
483
484        let mut buf4 = [0u8; 4];
485        file.read_exact(&mut buf4)?;
486        let version = u32::from_le_bytes(buf4);
487        if version != BIAS_VERSION {
488            anyhow::bail!(
489                "Unsupported bias table version {} (expected {})",
490                version,
491                BIAS_VERSION
492            );
493        }
494
495        let mut k_buf = [0u8; 1];
496        file.read_exact(&mut k_buf)?;
497        let k = k_buf[0];
498
499        let mut buf8 = [0u8; 8];
500        file.read_exact(&mut buf8)?;
501        let fscale = u64::from_le_bytes(buf8);
502
503        file.read_exact(&mut buf4)?;
504        let width = u32::from_le_bytes(buf4) as usize;
505
506        let mut depth_buf = [0u8; 1];
507        file.read_exact(&mut depth_buf)?;
508        let depth = depth_buf[0] as usize;
509
510        file.read_exact(&mut buf4)?;
511        let alpha = f32::from_le_bytes(buf4);
512
513        let mut threshold_buf = [0u8; 1];
514        file.read_exact(&mut threshold_buf)?;
515        let threshold = threshold_buf[0] as i8;
516
517        file.read_exact(&mut buf4)?;
518        let positive_retention = f32::from_le_bytes(buf4);
519
520        file.read_exact(&mut buf4)?;
521        let negative_retention = f32::from_le_bytes(buf4);
522
523        let mut seeds = Vec::with_capacity(depth);
524        for _ in 0..depth {
525            file.read_exact(&mut buf8)?;
526            seeds.push(u64::from_le_bytes(buf8));
527        }
528
529        let mut weights = vec![0i8; width * depth];
530        let mut weight_buf = vec![0u8; width * depth];
531        file.read_exact(&mut weight_buf)?;
532        for (i, &b) in weight_buf.iter().enumerate() {
533            weights[i] = b as i8;
534        }
535
536        let config = CMSConfig {
537            width,
538            depth,
539            k,
540            fscale,
541        };
542
543        let max_fold_enrichment = if negative_retention > 0.0 {
544            positive_retention / negative_retention
545        } else {
546            f32::INFINITY
547        };
548
549        Ok(Self {
550            config,
551            seeds,
552            weights,
553            alpha,
554            threshold,
555            positive_retention,
556            negative_retention,
557            max_fold_enrichment,
558        })
559    }
560
561    pub fn to_bytes(&self) -> Vec<u8> {
562        let header_size = 4 + 4 + 1 + 8 + 4 + 1 + 4 + 1 + 4 + 4;
563        let seeds_size = self.config.depth * 8;
564        let weights_size = self.config.width * self.config.depth;
565        let total_size = header_size + seeds_size + weights_size;
566
567        let mut out = Vec::with_capacity(total_size);
568        out.extend_from_slice(BIAS_MAGIC);
569        out.extend_from_slice(&BIAS_VERSION.to_le_bytes());
570        out.push(self.config.k);
571        out.extend_from_slice(&self.config.fscale.to_le_bytes());
572        out.extend_from_slice(&(self.config.width as u32).to_le_bytes());
573        out.push(self.config.depth as u8);
574        out.extend_from_slice(&self.alpha.to_le_bytes());
575        out.push(self.threshold as u8);
576        out.extend_from_slice(&self.positive_retention.to_le_bytes());
577        out.extend_from_slice(&self.negative_retention.to_le_bytes());
578
579        for &seed in &self.seeds {
580            out.extend_from_slice(&seed.to_le_bytes());
581        }
582        for &w in &self.weights {
583            out.push(w as u8);
584        }
585
586        out
587    }
588
589    pub fn from_bytes(data: &[u8]) -> Result<Self> {
590        if data.len() < 35 {
591            anyhow::bail!("Bias table data too small: {} bytes", data.len());
592        }
593
594        let magic: [u8; 4] = data[0..4].try_into().unwrap();
595        if &magic != BIAS_MAGIC {
596            anyhow::bail!("Invalid bias table magic bytes");
597        }
598
599        let version = u32::from_le_bytes(data[4..8].try_into().unwrap());
600        if version != BIAS_VERSION {
601            anyhow::bail!("Unsupported bias table version {}", version);
602        }
603
604        let k = data[8];
605        let fscale = u64::from_le_bytes(data[9..17].try_into().unwrap());
606        let width = u32::from_le_bytes(data[17..21].try_into().unwrap()) as usize;
607        let depth = data[21] as usize;
608        let alpha = f32::from_le_bytes(data[22..26].try_into().unwrap());
609        let threshold = data[26] as i8;
610        let positive_retention = f32::from_le_bytes(data[27..31].try_into().unwrap());
611        let negative_retention = f32::from_le_bytes(data[31..35].try_into().unwrap());
612
613        let seeds_start = 35;
614        let seeds_end = seeds_start + depth * 8;
615        let weights_start = seeds_end;
616        let weights_end = weights_start + width * depth;
617
618        if data.len() < weights_end {
619            anyhow::bail!(
620                "Bias table data truncated: expected {} bytes, got {}",
621                weights_end,
622                data.len()
623            );
624        }
625
626        let mut seeds = Vec::with_capacity(depth);
627        for i in 0..depth {
628            let offset = seeds_start + i * 8;
629            seeds.push(u64::from_le_bytes(
630                data[offset..offset + 8].try_into().unwrap(),
631            ));
632        }
633
634        let mut weights = vec![0i8; width * depth];
635        for (i, &b) in data[weights_start..weights_end].iter().enumerate() {
636            weights[i] = b as i8;
637        }
638
639        let config = CMSConfig {
640            width,
641            depth,
642            k,
643            fscale,
644        };
645
646        let max_fold_enrichment = if negative_retention > 0.0 {
647            positive_retention / negative_retention
648        } else {
649            f32::INFINITY
650        };
651
652        Ok(Self {
653            config,
654            seeds,
655            weights,
656            alpha,
657            threshold,
658            positive_retention,
659            negative_retention,
660            max_fold_enrichment,
661        })
662    }
663
664    pub fn weight_stats(&self) -> (f32, f32, f32, f32, usize) {
665        let min = *self.weights.iter().min().unwrap_or(&0) as f32 / QUANTIZATION_SCALE;
666        let max = *self.weights.iter().max().unwrap_or(&0) as f32 / QUANTIZATION_SCALE;
667        let sum: i64 = self.weights.iter().map(|&w| w as i64).sum();
668        let mean = sum as f32 / self.weights.len() as f32 / QUANTIZATION_SCALE;
669        let variance: f32 = self
670            .weights
671            .iter()
672            .map(|&w| {
673                let d = w as f32 / QUANTIZATION_SCALE - mean;
674                d * d
675            })
676            .sum::<f32>()
677            / self.weights.len() as f32;
678        let positive = self.weights.iter().filter(|&&w| w > 0).count();
679        (min, max, mean, variance.sqrt(), positive)
680    }
681
682    pub fn memory_usage(&self) -> usize {
683        self.weights.len() + self.seeds.len() * 8
684    }
685
686    pub fn threshold_f32(&self) -> f32 {
687        self.threshold as f32 / QUANTIZATION_SCALE
688    }
689
690    pub fn print_stats(&self) {
691        let (min, max, mean, std, positive) = self.weight_stats();
692        let total_cells = self.config.width * self.config.depth;
693        eprintln!("Hash Bias Table (v3)");
694        eprintln!("  k-mer size:     {}", self.config.k);
695        eprintln!("  fscale:         {}", self.config.fscale);
696        eprintln!(
697            "  CMS dimensions: {} x {}",
698            self.config.width, self.config.depth
699        );
700        eprintln!("  Smoothing (alpha): {:.1}", self.alpha);
701        eprintln!(
702            "  Threshold: {:.2} (quantized: {})",
703            self.threshold_f32(),
704            self.threshold
705        );
706        eprintln!(
707            "  Positive retention: {:.2}%",
708            self.positive_retention * 100.0
709        );
710        eprintln!(
711            "  Negative retention: {:.2}%",
712            self.negative_retention * 100.0
713        );
714        eprintln!("  Fold enrichment: {:.2}x", self.fold_enrichment());
715        eprintln!(
716            "  Weight stats: min={:.2}, max={:.2}, mean={:.2}, std={:.2}",
717            min, max, mean, std
718        );
719        eprintln!(
720            "  Positive weights: {} ({:.1}%)",
721            positive,
722            positive as f64 / total_cells as f64 * 100.0
723        );
724    }
725}
726
727fn calibrate_threshold(
728    positive: &RawHashCounts,
729    negative: &RawHashCounts,
730    weights: &[i8],
731    seeds: &[u64],
732    width: usize,
733    target_fold_enrichment: Option<f32>,
734) -> Result<CalibrationResult> {
735    let sample_hashes = |raw: &RawHashCounts, max_samples: usize| -> Vec<u64> {
736        if raw.samples.len() <= max_samples {
737            return raw.samples.clone();
738        }
739        let step = raw.samples.len() / max_samples;
740        raw.samples
741            .iter()
742            .step_by(step)
743            .take(max_samples)
744            .copied()
745            .collect()
746    };
747
748    let estimate_weight = |hash: u64| -> i8 {
749        let depth = seeds.len();
750        (0..depth)
751            .map(|row| {
752                let mixed = hash.wrapping_mul(seeds[row]);
753                let idx = row * width + (mixed as usize % width);
754                weights[idx]
755            })
756            .min()
757            .unwrap_or(0)
758    };
759
760    let pos_sample_weights: Vec<i8> = sample_hashes(positive, 100_000)
761        .iter()
762        .map(|&h| estimate_weight(h))
763        .collect();
764    let neg_sample_weights: Vec<i8> = sample_hashes(negative, 100_000)
765        .iter()
766        .map(|&h| estimate_weight(h))
767        .collect();
768
769    if pos_sample_weights.is_empty() || neg_sample_weights.is_empty() {
770        return Ok(CalibrationResult {
771            threshold: 0,
772            positive_retention: 1.0,
773            negative_retention: 1.0,
774            fold_enrichment: 1.0,
775            max_fold_enrichment: 1.0,
776        });
777    }
778
779    let mut max_enrichment = 0.0f32;
780    let mut max_threshold = 0i8;
781    let mut max_pos_ret = 1.0f32;
782    let mut max_neg_ret = 1.0f32;
783
784    for t in -127i8..=127i8 {
785        let pos_passing = pos_sample_weights.iter().filter(|&&w| w >= t).count();
786        let neg_passing = neg_sample_weights.iter().filter(|&&w| w >= t).count();
787
788        let pos_ret = pos_passing as f32 / pos_sample_weights.len() as f32;
789        let neg_ret = neg_passing as f32 / neg_sample_weights.len().max(1) as f32;
790
791        if neg_ret < 1e-6 {
792            continue;
793        }
794
795        let enrichment = pos_ret / neg_ret;
796        if enrichment > max_enrichment {
797            max_enrichment = enrichment;
798            max_threshold = t;
799            max_pos_ret = pos_ret;
800            max_neg_ret = neg_ret;
801        }
802    }
803
804    match target_fold_enrichment {
805        None => Ok(CalibrationResult {
806            threshold: max_threshold,
807            positive_retention: max_pos_ret,
808            negative_retention: max_neg_ret,
809            fold_enrichment: max_enrichment,
810            max_fold_enrichment: max_enrichment,
811        }),
812        Some(target) => {
813            if target > max_enrichment {
814                return Ok(CalibrationResult {
815                    threshold: max_threshold,
816                    positive_retention: max_pos_ret,
817                    negative_retention: max_neg_ret,
818                    fold_enrichment: max_enrichment,
819                    max_fold_enrichment: max_enrichment,
820                });
821            }
822
823            let mut best_threshold = 0i8;
824            let mut best_diff = f32::MAX;
825            let mut best_pos_ret = 1.0f32;
826            let mut best_neg_ret = 1.0f32;
827
828            for t in -127i8..=127i8 {
829                let pos_passing = pos_sample_weights.iter().filter(|&&w| w >= t).count();
830                let neg_passing = neg_sample_weights.iter().filter(|&&w| w >= t).count();
831
832                let pos_ret = pos_passing as f32 / pos_sample_weights.len() as f32;
833                let neg_ret = neg_passing as f32 / neg_sample_weights.len().max(1) as f32;
834
835                if neg_ret < 1e-6 {
836                    continue;
837                }
838
839                let enrichment = pos_ret / neg_ret;
840                let diff = (enrichment - target).abs();
841
842                if diff < best_diff {
843                    best_diff = diff;
844                    best_threshold = t;
845                    best_pos_ret = pos_ret;
846                    best_neg_ret = neg_ret;
847                }
848            }
849
850            Ok(CalibrationResult {
851                threshold: best_threshold,
852                positive_retention: best_pos_ret,
853                negative_retention: best_neg_ret,
854                fold_enrichment: if best_neg_ret > 0.0 {
855                    best_pos_ret / best_neg_ret
856                } else {
857                    f32::INFINITY
858                },
859                max_fold_enrichment: max_enrichment,
860            })
861        }
862    }
863}
864
865fn format_number(n: u64) -> String {
866    if n >= 1_000_000_000 {
867        format!("{:.2}G", n as f64 / 1_000_000_000.0)
868    } else if n >= 1_000_000 {
869        format!("{:.2}M", n as f64 / 1_000_000.0)
870    } else if n >= 1_000 {
871        format!("{:.2}K", n as f64 / 1_000.0)
872    } else {
873        format!("{}", n)
874    }
875}
876
877pub fn format_bp(bp: u64) -> String {
878    if bp >= 1_000_000_000 {
879        format!("{:.2} Gbp", bp as f64 / 1_000_000_000.0)
880    } else if bp >= 1_000_000 {
881        format!("{:.2} Mbp", bp as f64 / 1_000_000.0)
882    } else if bp >= 1_000 {
883        format!("{:.2} Kbp", bp as f64 / 1_000.0)
884    } else {
885        format!("{} bp", bp)
886    }
887}
888
889pub const BIAS_TABLE_SERIALIZED_SIZE: usize =
890    35 + DEFAULT_CMS_DEPTH * 8 + DEFAULT_CMS_WIDTH * DEFAULT_CMS_DEPTH;
891
892impl PartialEq for HashBiasTable {
893    fn eq(&self, other: &Self) -> bool {
894        self.config.k == other.config.k
895            && self.config.fscale == other.config.fscale
896            && self.config.width == other.config.width
897            && self.config.depth == other.config.depth
898            && self.alpha == other.alpha
899            && self.threshold == other.threshold
900            && self.positive_retention == other.positive_retention
901            && self.negative_retention == other.negative_retention
902            && self.seeds == other.seeds
903            && self.weights == other.weights
904    }
905}
906
907#[cfg(test)]
908mod tests {
909    use super::*;
910    use std::io::Write;
911    use tempfile::NamedTempFile;
912
913    fn create_fasta(sequences: &[&str]) -> NamedTempFile {
914        let mut file = NamedTempFile::new().unwrap();
915        for (i, seq) in sequences.iter().enumerate() {
916            writeln!(file, ">seq_{}", i).unwrap();
917            writeln!(file, "{}", seq).unwrap();
918        }
919        file
920    }
921
922    #[test]
923    fn test_cms_basic() {
924        let mut cms = CountMinSketch::new(1024, 5);
925        let hash = 0x12345678u64;
926
927        assert_eq!(cms.estimate(hash), 0);
928
929        cms.increment(hash);
930        assert_eq!(cms.estimate(hash), 1);
931
932        for _ in 0..9 {
933            cms.increment(hash);
934        }
935        assert_eq!(cms.estimate(hash), 10);
936    }
937
938    #[test]
939    fn test_cms_collision_handling() {
940        let mut cms = CountMinSketch::new(16, 5);
941
942        for i in 0..100u64 {
943            cms.increment(i);
944        }
945
946        for i in 0..100u64 {
947            assert!(cms.estimate(i) >= 1);
948        }
949    }
950
951    #[test]
952    fn test_raw_hash_counts_build() {
953        let fasta = create_fasta(&["ATCGATCGATCGATCGATCGATCGATCGATCGATCGATCG"]);
954        let config = CMSConfig {
955            width: 1024,
956            depth: 3,
957            k: 11,
958            fscale: 1,
959        };
960
961        let raw = RawHashCounts::build(
962            &[fasta.path()],
963            config,
964            &AtomicU64::new(0),
965            &AtomicU64::new(0),
966        )
967        .unwrap();
968        assert!(raw.total > 0);
969    }
970
971    #[test]
972    fn test_hash_bias_table_build() {
973        let pos = create_fasta(&[
974            "ATATATATATATATATATATATATATATATATATATATAT",
975            "TATATATATATATATATATATATATATATATATATATAT",
976        ]);
977        let neg = create_fasta(&[
978            "GCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGC",
979            "CGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCG",
980        ]);
981
982        let config = CMSConfig {
983            width: 1024,
984            depth: 3,
985            k: 11,
986            fscale: 1,
987        };
988
989        let pos_raw = RawHashCounts::build(
990            &[pos.path()],
991            config.clone(),
992            &AtomicU64::new(0),
993            &AtomicU64::new(0),
994        )
995        .unwrap();
996        let neg_raw = RawHashCounts::build(
997            &[neg.path()],
998            config,
999            &AtomicU64::new(0),
1000            &AtomicU64::new(0),
1001        )
1002        .unwrap();
1003
1004        let table = HashBiasTable::build(&pos_raw, &neg_raw, 1.0, Some(5.0)).unwrap();
1005        assert!(table.threshold >= -127);
1006    }
1007
1008    #[test]
1009    fn test_hash_bias_table_save_load() {
1010        let pos = create_fasta(&["ATATATATATATATATATATATATATATATATATATATAT"]);
1011        let neg = create_fasta(&["GCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGC"]);
1012
1013        let config = CMSConfig {
1014            width: 1024,
1015            depth: 3,
1016            k: 11,
1017            fscale: 10,
1018        };
1019
1020        let pos_raw = RawHashCounts::build(
1021            &[pos.path()],
1022            config.clone(),
1023            &AtomicU64::new(0),
1024            &AtomicU64::new(0),
1025        )
1026        .unwrap();
1027        let neg_raw = RawHashCounts::build(
1028            &[neg.path()],
1029            config,
1030            &AtomicU64::new(0),
1031            &AtomicU64::new(0),
1032        )
1033        .unwrap();
1034
1035        let table = HashBiasTable::build(&pos_raw, &neg_raw, 1.0, Some(2.0)).unwrap();
1036
1037        let output = NamedTempFile::new().unwrap();
1038        table.save(output.path()).unwrap();
1039
1040        let loaded = HashBiasTable::load(output.path()).unwrap();
1041        assert_eq!(table.config.k, loaded.config.k);
1042        assert_eq!(table.threshold, loaded.threshold);
1043        assert_eq!(table.weights, loaded.weights);
1044    }
1045
1046    #[test]
1047    fn test_hash_bias_table_bytes_roundtrip() {
1048        let pos = create_fasta(&["ATATATATATATATATATATATATATATATATATATATAT"]);
1049        let neg = create_fasta(&["GCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGC"]);
1050
1051        let config = CMSConfig {
1052            width: 512,
1053            depth: 3,
1054            k: 11,
1055            fscale: 10,
1056        };
1057
1058        let pos_raw = RawHashCounts::build(
1059            &[pos.path()],
1060            config.clone(),
1061            &AtomicU64::new(0),
1062            &AtomicU64::new(0),
1063        )
1064        .unwrap();
1065        let neg_raw = RawHashCounts::build(
1066            &[neg.path()],
1067            config,
1068            &AtomicU64::new(0),
1069            &AtomicU64::new(0),
1070        )
1071        .unwrap();
1072
1073        let table = HashBiasTable::build(&pos_raw, &neg_raw, 1.0, Some(2.0)).unwrap();
1074
1075        let bytes = table.to_bytes();
1076        let loaded = HashBiasTable::from_bytes(&bytes).unwrap();
1077
1078        assert_eq!(table, loaded);
1079    }
1080
1081    #[test]
1082    fn test_passes_filter() {
1083        let pos = create_fasta(&["ATATATATATATATATATATATATATATATATATATATAT"]);
1084        let neg = create_fasta(&["GCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGC"]);
1085
1086        let config = CMSConfig {
1087            width: 1024,
1088            depth: 3,
1089            k: 11,
1090            fscale: 1,
1091        };
1092
1093        let pos_raw = RawHashCounts::build(
1094            &[pos.path()],
1095            config.clone(),
1096            &AtomicU64::new(0),
1097            &AtomicU64::new(0),
1098        )
1099        .unwrap();
1100        let neg_raw = RawHashCounts::build(
1101            &[neg.path()],
1102            config,
1103            &AtomicU64::new(0),
1104            &AtomicU64::new(0),
1105        )
1106        .unwrap();
1107
1108        let table = HashBiasTable::build(&pos_raw, &neg_raw, 1.0, Some(2.0)).unwrap();
1109
1110        let mut passed = 0;
1111        let mut failed = 0;
1112        for h in 0..1000u64 {
1113            if table.passes_filter(h) {
1114                passed += 1;
1115            } else {
1116                failed += 1;
1117            }
1118        }
1119
1120        assert!(passed > 0 || failed > 0);
1121    }
1122
1123    #[test]
1124    fn test_maximize_fold_enrichment() {
1125        let pos = create_fasta(&[
1126            "ATATATATATATATATATATATATATATATATATATATAT",
1127            "TATATATATATATATATATATATATATATATATATATAT",
1128        ]);
1129        let neg = create_fasta(&[
1130            "GCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGC",
1131            "CGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCG",
1132        ]);
1133
1134        let config = CMSConfig {
1135            width: 1024,
1136            depth: 3,
1137            k: 11,
1138            fscale: 1,
1139        };
1140
1141        let pos_raw = RawHashCounts::build(
1142            &[pos.path()],
1143            config.clone(),
1144            &AtomicU64::new(0),
1145            &AtomicU64::new(0),
1146        )
1147        .unwrap();
1148        let neg_raw = RawHashCounts::build(
1149            &[neg.path()],
1150            config,
1151            &AtomicU64::new(0),
1152            &AtomicU64::new(0),
1153        )
1154        .unwrap();
1155
1156        let table = HashBiasTable::build(&pos_raw, &neg_raw, 1.0, None).unwrap();
1157        assert!(table.threshold >= -127);
1158        assert!(table.fold_enrichment() >= 1.0);
1159    }
1160
1161    #[test]
1162    fn test_create_unified() {
1163        let pos = create_fasta(&["ATATATATATATATATATATATATATATATATATATATAT"]);
1164        let neg = create_fasta(&["GCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGC"]);
1165
1166        let config = BiasCreateConfig {
1167            cms: CMSConfig {
1168                width: 1024,
1169                depth: 3,
1170                k: 11,
1171                fscale: 1,
1172            },
1173            alpha: 1.0,
1174            target_fold_enrichment: None,
1175        };
1176
1177        let table = HashBiasTable::create(&[pos.path()], &[neg.path()], &config, None).unwrap();
1178
1179        assert!(table.threshold >= -127);
1180        assert!(table.fold_enrichment() >= 1.0);
1181    }
1182}