use super::DictParams;
use crate::dictionary::frequency::estimate_frequency;
use core::convert::TryInto;
use std::collections::HashMap;
use std::vec::Vec;
pub(super) const K: usize = 16;
pub(super) type KMer = [u8; K];
pub struct Segment {
pub raw: Vec<u8>,
pub score: usize,
}
impl Eq for Segment {}
impl PartialEq for Segment {
fn eq(&self, other: &Self) -> bool {
self.score == other.score
}
}
impl PartialOrd for Segment {
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Segment {
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
self.score.cmp(&other.score)
}
}
pub struct Context {
pub frequencies: HashMap<KMer, usize>,
}
pub fn pick_best_segment(
params: &DictParams,
ctx: &mut Context,
epoch: &'_ [u8],
collection_sample: &'_ [u8],
) -> Segment {
let mut segments = epoch.chunks(params.segment_size as usize).peekable();
let mut best_segment: &[u8] = segments.peek().expect("at least one segment");
let mut top_segment_score: usize = 0;
for segment in segments {
let segment_score = score_segment(ctx, collection_sample, segment);
if segment_score > top_segment_score {
best_segment = segment;
top_segment_score = segment_score;
}
}
Segment {
raw: best_segment.into(),
score: top_segment_score,
}
}
fn score_segment(ctx: &mut Context, collection_sample: &[u8], segment: &[u8]) -> usize {
if segment.len() < K {
return 0;
}
let mut segment_score = 0;
for i in 0..=(segment.len() - K) {
let kmer: &KMer = (&segment[i..i + K])
.try_into()
.expect("Failed to make kmer");
if ctx.frequencies.contains_key(kmer) {
continue;
}
let kmer_score = estimate_frequency(kmer, collection_sample);
ctx.frequencies.insert(*kmer, kmer_score);
segment_score += kmer_score;
}
segment_score
}
pub fn compute_epoch_info(
params: &DictParams,
max_dict_size: usize,
num_kmers: usize,
) -> (usize, usize) {
let min_epoch_size = 10_000; let mut num_epochs: usize = usize::max(1, max_dict_size / params.segment_size as usize);
let mut epoch_size: usize = num_kmers / num_epochs;
if epoch_size >= min_epoch_size {
assert!(epoch_size * num_epochs <= num_kmers);
return (num_epochs, epoch_size);
}
epoch_size = usize::min(min_epoch_size, num_kmers);
num_epochs = num_kmers / epoch_size;
(num_epochs, epoch_size)
}