use std::cmp::min;
use crate::decode;
use crate::encode;
use super::byte_is_nocall;
use super::samples::Sample;
use crate::bitenc::BitEnc;
use ahash::HashMap as AHashMap;
use ahash::HashMapExt;
const STARTING_CACHE_SIZE: usize = 1_000_000;
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub struct BarcodeMatch {
pub best_match: usize,
pub best_mismatches: u8,
pub next_best_mismatches: u8,
}
#[derive(Clone, Debug)]
pub struct BarcodeMatcher {
samples: Vec<Sample>,
sample_barcodes: Vec<BitEnc>,
max_ns_in_barcodes: usize,
max_mismatches: u8,
min_mismatch_delta: u8,
use_cache: bool,
cache: AHashMap<Vec<u8>, BarcodeMatch>,
}
impl BarcodeMatcher {
#[must_use]
pub fn new(
samples: &[Sample],
max_mismatches: u8,
min_mismatch_delta: u8,
use_cache: bool,
) -> Self {
let patterns: Vec<Vec<u8>> =
samples.iter().map(|s| s.barcode.as_bytes().to_vec()).collect();
Self::with_patterns(samples, patterns, max_mismatches, min_mismatch_delta, use_cache)
}
#[must_use]
pub fn with_patterns(
samples: &[Sample],
patterns: Vec<Vec<u8>>,
max_mismatches: u8,
min_mismatch_delta: u8,
use_cache: bool,
) -> Self {
assert!(!samples.is_empty(), "Must provide at least one sample");
assert!(
patterns.len() == samples.len(),
"Number of patterns ({}) must match number of samples ({})",
patterns.len(),
samples.len(),
);
assert!(
patterns.iter().all(|p| !p.is_empty()),
"Sample matching pattern cannot be empty string",
);
let pattern_len = patterns[0].len();
assert!(
patterns.iter().all(|p| p.len() == pattern_len),
"All sample matching patterns must have the same length",
);
let mut max_ns_in_barcodes = 0;
let mut modified_samples = samples.to_vec();
let mut sample_barcodes = Vec::with_capacity(samples.len());
for (sample, pattern) in modified_samples.iter_mut().zip(patterns.into_iter()) {
let pattern_upper: Vec<u8> = pattern.iter().map(u8::to_ascii_uppercase).collect();
sample.barcode = String::from_utf8(pattern_upper.clone())
.expect("matching pattern must be valid UTF-8");
let num_ns: usize = pattern_upper.iter().filter(|&&b| byte_is_nocall(b)).count();
max_ns_in_barcodes = max_ns_in_barcodes.max(num_ns);
sample_barcodes.push(encode(&pattern_upper));
}
Self {
samples: modified_samples,
sample_barcodes,
max_ns_in_barcodes,
max_mismatches,
min_mismatch_delta,
use_cache,
cache: AHashMap::with_capacity(STARTING_CACHE_SIZE),
}
}
fn count_mismatches(
observed_bases: &BitEnc,
expected_bases: &BitEnc,
sample: &Sample,
max_mismatches: u8,
) -> u8 {
if observed_bases.nr_symbols() != expected_bases.nr_symbols() {
let observed_string = decode(observed_bases);
assert_eq!(
observed_bases.nr_symbols(),
expected_bases.nr_symbols(),
"Read barcode ({}) length ({}) differs from expected barcode ({}) length ({}) for sample {}",
observed_string,
observed_bases.nr_symbols(),
sample.barcode,
expected_bases.nr_symbols(),
sample.sample_id
);
}
let count = observed_bases.hamming(expected_bases, u32::from(max_mismatches));
u8::try_from(count).expect("Overflow on number of mismatch bases")
}
fn expected_barcode_length(&self) -> usize {
self.samples[0].barcode.len()
}
#[must_use]
fn assign_internal(&self, read_bases: &[u8]) -> Option<BarcodeMatch> {
let mut best_barcode_index = self.samples.len();
let mut best_mismatches = 255u8;
let mut next_best_mismatches = 255u8;
let mut max_mismatches = 255u8;
let read_bases = encode(read_bases); for (index, sample_barcode) in self.sample_barcodes.iter().enumerate() {
let mismatches = Self::count_mismatches(
&read_bases,
sample_barcode,
&self.samples[index],
max_mismatches,
);
if mismatches < best_mismatches {
next_best_mismatches = best_mismatches;
best_mismatches = mismatches;
best_barcode_index = index;
if next_best_mismatches < 255u8 - self.min_mismatch_delta {
max_mismatches =
min(max_mismatches, next_best_mismatches + self.min_mismatch_delta);
}
} else if mismatches < next_best_mismatches {
next_best_mismatches = mismatches;
if next_best_mismatches < 255u8 - self.min_mismatch_delta {
max_mismatches =
min(max_mismatches, next_best_mismatches + self.min_mismatch_delta);
}
}
}
if best_mismatches > self.max_mismatches
|| (next_best_mismatches - best_mismatches) < self.min_mismatch_delta
{
None
} else {
Some(BarcodeMatch {
best_match: best_barcode_index,
best_mismatches,
next_best_mismatches,
})
}
}
pub fn assign(&mut self, read_bases: &[u8]) -> Option<BarcodeMatch> {
if read_bases.len() < self.expected_barcode_length() {
return None;
}
let num_no_calls = read_bases.iter().filter(|&&b| byte_is_nocall(b)).count();
if num_no_calls > (self.max_mismatches as usize) + self.max_ns_in_barcodes {
None
} else if self.use_cache {
if let Some(cached_match) = self.cache.get(read_bases) {
Some(*cached_match)
} else {
let maybe_match = self.assign_internal(read_bases);
if let Some(internal_val) = maybe_match {
self.cache.insert(read_bases.to_vec(), internal_val);
};
maybe_match
}
} else {
self.assign_internal(read_bases)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
fn barcode_to_sample(barcode: &str, idx: usize) -> Sample {
Sample {
barcode: barcode.to_string(),
sample_id: format!("sample_{idx}").to_string(),
read_structures: None,
ordinal: idx,
}
}
fn barcodes_to_samples(barcodes: &[&str]) -> Vec<Sample> {
barcodes
.iter()
.enumerate()
.map(|(idx, barcode)| barcode_to_sample(barcode, idx))
.collect::<Vec<_>>()
}
fn count_mismatches(observed_bases: &str, expected_bases: &str) -> u8 {
let sample = barcode_to_sample(expected_bases, 0);
BarcodeMatcher::count_mismatches(
&encode(observed_bases.as_bytes()),
&encode(expected_bases.as_bytes()),
&sample,
255,
)
}
#[rstest]
#[case(true)]
#[case(false)]
fn test_barcode_matcher_instantiation_can_succeed(#[case] use_cache: bool) {
let samples = barcodes_to_samples(&["ACGT"]);
let _matcher = BarcodeMatcher::new(&samples, 2, 1, use_cache);
}
#[rstest]
#[case(true)]
#[case(false)]
#[should_panic(expected = "Must provide at least one sample")]
fn test_barcode_matcher_fails_if_no_samples_provided(#[case] use_cache: bool) {
let samples = barcodes_to_samples(&[]);
let _matcher = BarcodeMatcher::new(&samples, 2, 1, use_cache);
}
#[test]
#[should_panic(
expected = "Read barcode () length (0) differs from expected barcode (CTATGT) length (6) for sample sample_0"
)]
fn empty_read_barcode_fails_length_mismatch() {
count_mismatches("", "CTATGT");
}
#[test]
fn empty_string_can_run_in_count_mismatches() {
assert_eq!(count_mismatches("", ""), 0);
}
#[test]
fn find_no_mismatches() {
assert_eq!(count_mismatches("GATTACA", "GATTACA"), 0,);
}
#[test]
fn ns_in_expected_barcode_dont_contribute_to_mismatch_counter() {
assert_eq!(count_mismatches("GATTACA", "GANNACA"), 0,);
}
#[test]
fn all_ns_barcode_have_no_mismatches() {
assert_eq!(count_mismatches("GANNACA", "NNNNNNN"), 0,);
}
#[test]
fn find_two_mismatches() {
assert_eq!(count_mismatches("GATTACA", "GACCACA"), 2,);
}
#[test]
fn not_count_no_calls() {
assert_eq!(count_mismatches("GATTACA", "GANNACA"), 0,);
}
#[test]
fn find_compare_two_sequences_that_have_all_mismatches() {
assert_eq!(count_mismatches("GATTACA", "CTAATGT"), 7,);
}
#[test]
fn find_compare_iupac_barcode() {
assert_eq!(count_mismatches("ACGTTAAACCGAAACA", "ACGTUMRWSYKVHDBN"), 0,);
assert_eq!(count_mismatches("ACGTUMRWSYKVHDBN", "ACGTTAAACCGAAACA"), 11,);
}
#[test]
fn count_mismatches_iupac_bases_assymetry() {
assert_eq!(count_mismatches("N", "R"), 1,);
assert_eq!(count_mismatches("N", "N"), 0,);
assert_eq!(count_mismatches("R", "R"), 0,);
assert_eq!(count_mismatches("R", "V"), 0,);
assert_eq!(count_mismatches("R", "D"), 0,);
assert_eq!(count_mismatches("R", "N"), 0,);
assert_eq!(count_mismatches("R", "B"), 1,);
}
#[test]
#[should_panic(
expected = "Read barcode (GATTA) length (5) differs from expected barcode (CTATGT) length (6) for sample sample_0"
)]
fn find_compare_two_sequences_of_different_length() {
let _mismatches = count_mismatches("GATTA", "CTATGT");
}
#[rstest]
#[case(true)]
#[case(false)]
fn test_assign_exact_match(#[case] use_cache: bool) {
const EXPECTED_BARCODE_INDEX: usize = 0;
let samples = barcodes_to_samples(&["ACGT", "AAAG", "CACA"]);
let mut matcher = BarcodeMatcher::new(&samples, 2, 2, use_cache);
assert_eq!(
matcher.assign(samples[EXPECTED_BARCODE_INDEX].barcode.as_bytes()),
Some(BarcodeMatch {
best_match: EXPECTED_BARCODE_INDEX,
best_mismatches: 0,
next_best_mismatches: 3,
}),
);
}
#[rstest]
#[case(true)]
#[case(false)]
fn test_assign_imprecise_match(#[case] use_cache: bool) {
let samples = barcodes_to_samples(&["AAAT", "AGAG", "CACA"]);
let mut matcher = BarcodeMatcher::new(&samples, 2, 2, use_cache);
let test_barcode: &[u8] = b"GAAT";
let expected = BarcodeMatch { best_match: 0, best_mismatches: 1, next_best_mismatches: 3 };
assert_eq!(matcher.assign(test_barcode), Some(expected));
}
#[rstest]
#[case(true)]
#[case(false)]
fn test_assign_precise_match_with_no_call(#[case] use_cache: bool) {
let samples = barcodes_to_samples(&["AAAT", "AGAG", "CACA"]);
let mut matcher = BarcodeMatcher::new(&samples, 2, 2, use_cache);
let test_barcode: &[u8; 4] = b"NAAT";
let expected = BarcodeMatch { best_match: 0, best_mismatches: 1, next_best_mismatches: 3 };
assert_eq!(matcher.assign(test_barcode), Some(expected));
}
#[rstest]
#[case(true)]
#[case(false)]
fn test_assign_imprecise_match_with_no_call(#[case] use_cache: bool) {
let samples = barcodes_to_samples(&["AAATTT", "AGAGGG", "CACAGG"]);
let mut matcher = BarcodeMatcher::new(&samples, 2, 2, use_cache);
let test_barcode: &[u8; 6] = b"NAGTTT";
let expected = BarcodeMatch { best_match: 0, best_mismatches: 2, next_best_mismatches: 5 };
assert_eq!(matcher.assign(test_barcode), Some(expected));
}
#[rstest]
#[case(true)]
#[case(false)]
fn test_sample_no_call_doesnt_contribute_to_mismatch_number(#[case] use_cache: bool) {
let samples = barcodes_to_samples(&["NAGTTT", "AGAGGG", "CACAGG"]);
let mut matcher = BarcodeMatcher::new(&samples, 1, 2, use_cache);
let test_barcode: &[u8; 6] = b"AAATTT";
let expected = BarcodeMatch { best_match: 0, best_mismatches: 1, next_best_mismatches: 4 };
assert_eq!(matcher.assign(test_barcode), Some(expected));
}
#[rstest]
#[case(true)]
#[case(false)]
fn test_read_no_call_contributes_to_mismatch_number(#[case] use_cache: bool) {
let samples = barcodes_to_samples(&["AAATTT", "AGAGGG", "CACAGG"]);
let mut matcher = BarcodeMatcher::new(&samples, 1, 2, use_cache);
let test_barcode: &[u8; 6] = b"NAGTTT";
assert_eq!(matcher.assign(test_barcode), None);
}
#[rstest]
#[case(true)]
#[case(false)]
fn test_produce_no_match_if_too_many_mismatches(#[case] use_cache: bool) {
let samples = barcodes_to_samples(&["AAGCTAG", "CAGCTAG", "GAGCTAG", "TAGCTAG"]);
let assignment_barcode: &[u8] = b"ATCGATC";
let mut matcher = BarcodeMatcher::new(&samples, 0, 100, use_cache);
assert_eq!(matcher.assign(assignment_barcode), None);
}
#[rstest]
#[case(true)]
#[case(false)]
fn test_produce_no_match_if_within_mismatch_delta(#[case] use_cache: bool) {
let samples = barcodes_to_samples(&["AAAAAAAA", "CCCCCCCC", "GGGGGGGG", "GGGGGGTT"]);
let assignment_barcode: &[u8] = samples[3].barcode.as_bytes();
let mut matcher = BarcodeMatcher::new(&samples, 100, 3, use_cache);
assert_eq!(matcher.assign(assignment_barcode), None);
}
#[rstest]
#[case(true)]
#[case(false)]
fn test_produce_no_match_if_too_many_mismatches_via_nocalls(#[case] use_cache: bool) {
let samples = barcodes_to_samples(&["AAAAAAAA", "CCCCCCCC", "GGGGGGGG", "GGGGGGTT"]);
let assignment_barcode: &[u8] = b"GGGGGGTN";
let mut matcher = BarcodeMatcher::new(&samples, 0, 100, use_cache);
assert_eq!(matcher.assign(assignment_barcode), None);
}
}