1use std::cmp::min;
2
3use crate::decode;
4use crate::encode;
5
6use super::byte_is_nocall;
7use super::samples::Sample;
8use crate::bitenc::BitEnc;
9use ahash::HashMap as AHashMap;
10use ahash::HashMapExt;
11
12const STARTING_CACHE_SIZE: usize = 1_000_000;
13
14#[derive(Copy, Clone, Debug, PartialEq, Eq)]
16pub struct BarcodeMatch {
17 pub best_match: usize,
20 pub best_mismatches: u8,
22 pub next_best_mismatches: u8,
25}
26
27#[derive(Clone, Debug)]
29pub struct BarcodeMatcher {
30 samples: Vec<Sample>,
32 sample_barcodes: Vec<BitEnc>,
34 max_ns_in_barcodes: usize,
36 max_mismatches: u8,
38 min_mismatch_delta: u8,
41 use_cache: bool,
43 cache: AHashMap<Vec<u8>, BarcodeMatch>,
45}
46
47impl BarcodeMatcher {
48 #[must_use]
55 pub fn new(
56 samples: &[Sample],
57 max_mismatches: u8,
58 min_mismatch_delta: u8,
59 use_cache: bool,
60 ) -> Self {
61 let patterns: Vec<Vec<u8>> =
62 samples.iter().map(|s| s.barcode.as_bytes().to_vec()).collect();
63 Self::with_patterns(samples, patterns, max_mismatches, min_mismatch_delta, use_cache)
64 }
65
66 #[must_use]
85 pub fn with_patterns(
86 samples: &[Sample],
87 patterns: Vec<Vec<u8>>,
88 max_mismatches: u8,
89 min_mismatch_delta: u8,
90 use_cache: bool,
91 ) -> Self {
92 assert!(!samples.is_empty(), "Must provide at least one sample");
93 assert!(
94 patterns.len() == samples.len(),
95 "Number of patterns ({}) must match number of samples ({})",
96 patterns.len(),
97 samples.len(),
98 );
99 assert!(
100 patterns.iter().all(|p| !p.is_empty()),
101 "Sample matching pattern cannot be empty string",
102 );
103 let pattern_len = patterns[0].len();
104 assert!(
105 patterns.iter().all(|p| p.len() == pattern_len),
106 "All sample matching patterns must have the same length",
107 );
108
109 let mut max_ns_in_barcodes = 0;
110 let mut modified_samples = samples.to_vec();
111 let mut sample_barcodes = Vec::with_capacity(samples.len());
112 for (sample, pattern) in modified_samples.iter_mut().zip(patterns.into_iter()) {
113 let pattern_upper: Vec<u8> = pattern.iter().map(u8::to_ascii_uppercase).collect();
114 sample.barcode = String::from_utf8(pattern_upper.clone())
117 .expect("matching pattern must be valid UTF-8");
118 let num_ns: usize = pattern_upper.iter().filter(|&&b| byte_is_nocall(b)).count();
119 max_ns_in_barcodes = max_ns_in_barcodes.max(num_ns);
120 sample_barcodes.push(encode(&pattern_upper));
121 }
122 Self {
123 samples: modified_samples,
124 sample_barcodes,
125 max_ns_in_barcodes,
126 max_mismatches,
127 min_mismatch_delta,
128 use_cache,
129 cache: AHashMap::with_capacity(STARTING_CACHE_SIZE),
130 }
131 }
132
133 fn count_mismatches(
135 observed_bases: &BitEnc,
136 expected_bases: &BitEnc,
137 sample: &Sample,
138 max_mismatches: u8,
139 ) -> u8 {
140 if observed_bases.nr_symbols() != expected_bases.nr_symbols() {
141 let observed_string = decode(observed_bases);
142 assert_eq!(
143 observed_bases.nr_symbols(),
144 expected_bases.nr_symbols(),
145 "Read barcode ({}) length ({}) differs from expected barcode ({}) length ({}) for sample {}",
146 observed_string,
147 observed_bases.nr_symbols(),
148 sample.barcode,
149 expected_bases.nr_symbols(),
150 sample.sample_id
151 );
152 }
153 let count = observed_bases.hamming(expected_bases, u32::from(max_mismatches));
154 u8::try_from(count).expect("Overflow on number of mismatch bases")
155 }
156
157 fn expected_barcode_length(&self) -> usize {
159 self.samples[0].barcode.len()
160 }
161
162 #[must_use]
164 fn assign_internal(&self, read_bases: &[u8]) -> Option<BarcodeMatch> {
165 let mut best_barcode_index = self.samples.len();
166 let mut best_mismatches = 255u8;
167 let mut next_best_mismatches = 255u8;
168 let mut max_mismatches = 255u8;
169 let read_bases = encode(read_bases); for (index, sample_barcode) in self.sample_barcodes.iter().enumerate() {
171 let mismatches = Self::count_mismatches(
172 &read_bases,
173 sample_barcode,
174 &self.samples[index],
175 max_mismatches,
176 );
177 if mismatches < best_mismatches {
178 next_best_mismatches = best_mismatches;
179 best_mismatches = mismatches;
180 best_barcode_index = index;
181 if next_best_mismatches < 255u8 - self.min_mismatch_delta {
182 max_mismatches =
183 min(max_mismatches, next_best_mismatches + self.min_mismatch_delta);
184 }
185 } else if mismatches < next_best_mismatches {
186 next_best_mismatches = mismatches;
187 if next_best_mismatches < 255u8 - self.min_mismatch_delta {
188 max_mismatches =
189 min(max_mismatches, next_best_mismatches + self.min_mismatch_delta);
190 }
191 }
192 }
193
194 if best_mismatches > self.max_mismatches
195 || (next_best_mismatches - best_mismatches) < self.min_mismatch_delta
196 {
197 None
198 } else {
199 Some(BarcodeMatch {
200 best_match: best_barcode_index,
201 best_mismatches,
202 next_best_mismatches,
203 })
204 }
205 }
206
207 pub fn assign(&mut self, read_bases: &[u8]) -> Option<BarcodeMatch> {
211 if read_bases.len() < self.expected_barcode_length() {
213 return None;
214 }
215 let num_no_calls = read_bases.iter().filter(|&&b| byte_is_nocall(b)).count();
216 if num_no_calls > (self.max_mismatches as usize) + self.max_ns_in_barcodes {
217 None
218 } else if self.use_cache {
219 if let Some(cached_match) = self.cache.get(read_bases) {
220 Some(*cached_match)
221 } else {
222 let maybe_match = self.assign_internal(read_bases);
223 if let Some(internal_val) = maybe_match {
224 self.cache.insert(read_bases.to_vec(), internal_val);
225 };
226 maybe_match
227 }
228 } else {
229 self.assign_internal(read_bases)
230 }
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237 use rstest::rstest;
238
239 fn barcode_to_sample(barcode: &str, idx: usize) -> Sample {
242 Sample {
243 barcode: barcode.to_string(),
244 sample_id: format!("sample_{idx}").to_string(),
245 read_structures: None,
246 ordinal: idx,
247 }
248 }
249
250 fn barcodes_to_samples(barcodes: &[&str]) -> Vec<Sample> {
252 barcodes
253 .iter()
254 .enumerate()
255 .map(|(idx, barcode)| barcode_to_sample(barcode, idx))
256 .collect::<Vec<_>>()
257 }
258
259 fn count_mismatches(observed_bases: &str, expected_bases: &str) -> u8 {
262 let sample = barcode_to_sample(expected_bases, 0);
263 BarcodeMatcher::count_mismatches(
264 &encode(observed_bases.as_bytes()),
265 &encode(expected_bases.as_bytes()),
266 &sample,
267 255,
268 )
269 }
270
271 #[rstest]
275 #[case(true)]
276 #[case(false)]
277 fn test_barcode_matcher_instantiation_can_succeed(#[case] use_cache: bool) {
278 let samples = barcodes_to_samples(&["ACGT"]);
279 let _matcher = BarcodeMatcher::new(&samples, 2, 1, use_cache);
280 }
281
282 #[rstest]
283 #[case(true)]
284 #[case(false)]
285 #[should_panic(expected = "Must provide at least one sample")]
286 fn test_barcode_matcher_fails_if_no_samples_provided(#[case] use_cache: bool) {
287 let samples = barcodes_to_samples(&[]);
288 let _matcher = BarcodeMatcher::new(&samples, 2, 1, use_cache);
289 }
290
291 #[test]
298 #[should_panic(
299 expected = "Read barcode () length (0) differs from expected barcode (CTATGT) length (6) for sample sample_0"
300 )]
301 fn empty_read_barcode_fails_length_mismatch() {
302 count_mismatches("", "CTATGT");
303 }
304
305 #[test]
306 fn empty_string_can_run_in_count_mismatches() {
307 assert_eq!(count_mismatches("", ""), 0);
308 }
309
310 #[test]
311 fn find_no_mismatches() {
312 assert_eq!(count_mismatches("GATTACA", "GATTACA"), 0,);
313 }
314
315 #[test]
316 fn ns_in_expected_barcode_dont_contribute_to_mismatch_counter() {
317 assert_eq!(count_mismatches("GATTACA", "GANNACA"), 0,);
318 }
319
320 #[test]
321 fn all_ns_barcode_have_no_mismatches() {
322 assert_eq!(count_mismatches("GANNACA", "NNNNNNN"), 0,);
323 }
324
325 #[test]
326 fn find_two_mismatches() {
327 assert_eq!(count_mismatches("GATTACA", "GACCACA"), 2,);
328 }
329
330 #[test]
331 fn not_count_no_calls() {
332 assert_eq!(count_mismatches("GATTACA", "GANNACA"), 0,);
333 }
334
335 #[test]
336 fn find_compare_two_sequences_that_have_all_mismatches() {
337 assert_eq!(count_mismatches("GATTACA", "CTAATGT"), 7,);
338 }
339
340 #[test]
341 fn find_compare_iupac_barcode() {
342 assert_eq!(count_mismatches("ACGTTAAACCGAAACA", "ACGTUMRWSYKVHDBN"), 0,);
343 assert_eq!(count_mismatches("ACGTUMRWSYKVHDBN", "ACGTTAAACCGAAACA"), 11,);
345 }
346
347 #[test]
348 fn count_mismatches_iupac_bases_assymetry() {
349 assert_eq!(count_mismatches("N", "R"), 1,);
351 assert_eq!(count_mismatches("N", "N"), 0,);
352 assert_eq!(count_mismatches("R", "R"), 0,);
354 assert_eq!(count_mismatches("R", "V"), 0,);
355 assert_eq!(count_mismatches("R", "D"), 0,);
356 assert_eq!(count_mismatches("R", "N"), 0,);
357 assert_eq!(count_mismatches("R", "B"), 1,);
358 }
359
360 #[test]
361 #[should_panic(
362 expected = "Read barcode (GATTA) length (5) differs from expected barcode (CTATGT) length (6) for sample sample_0"
363 )]
364 fn find_compare_two_sequences_of_different_length() {
365 let _mismatches = count_mismatches("GATTA", "CTATGT");
366 }
367
368 #[rstest]
373 #[case(true)]
374 #[case(false)]
375 fn test_assign_exact_match(#[case] use_cache: bool) {
376 const EXPECTED_BARCODE_INDEX: usize = 0;
377 let samples = barcodes_to_samples(&["ACGT", "AAAG", "CACA"]);
378 let mut matcher = BarcodeMatcher::new(&samples, 2, 2, use_cache);
379 assert_eq!(
380 matcher.assign(samples[EXPECTED_BARCODE_INDEX].barcode.as_bytes()),
381 Some(BarcodeMatch {
382 best_match: EXPECTED_BARCODE_INDEX,
383 best_mismatches: 0,
384 next_best_mismatches: 3,
385 }),
386 );
387 }
388
389 #[rstest]
390 #[case(true)]
391 #[case(false)]
392 fn test_assign_imprecise_match(#[case] use_cache: bool) {
393 let samples = barcodes_to_samples(&["AAAT", "AGAG", "CACA"]);
394 let mut matcher = BarcodeMatcher::new(&samples, 2, 2, use_cache);
395 let test_barcode: &[u8] = b"GAAT";
399 let expected = BarcodeMatch { best_match: 0, best_mismatches: 1, next_best_mismatches: 3 };
400 assert_eq!(matcher.assign(test_barcode), Some(expected));
401 }
402
403 #[rstest]
404 #[case(true)]
405 #[case(false)]
406 fn test_assign_precise_match_with_no_call(#[case] use_cache: bool) {
407 let samples = barcodes_to_samples(&["AAAT", "AGAG", "CACA"]);
408 let mut matcher = BarcodeMatcher::new(&samples, 2, 2, use_cache);
409 let test_barcode: &[u8; 4] = b"NAAT";
413 let expected = BarcodeMatch { best_match: 0, best_mismatches: 1, next_best_mismatches: 3 };
414 assert_eq!(matcher.assign(test_barcode), Some(expected));
415 }
416
417 #[rstest]
418 #[case(true)]
419 #[case(false)]
420 fn test_assign_imprecise_match_with_no_call(#[case] use_cache: bool) {
421 let samples = barcodes_to_samples(&["AAATTT", "AGAGGG", "CACAGG"]);
422 let mut matcher = BarcodeMatcher::new(&samples, 2, 2, use_cache);
423 let test_barcode: &[u8; 6] = b"NAGTTT";
429 let expected = BarcodeMatch { best_match: 0, best_mismatches: 2, next_best_mismatches: 5 };
430 assert_eq!(matcher.assign(test_barcode), Some(expected));
431 }
432
433 #[rstest]
434 #[case(true)]
435 #[case(false)]
436 fn test_sample_no_call_doesnt_contribute_to_mismatch_number(#[case] use_cache: bool) {
437 let samples = barcodes_to_samples(&["NAGTTT", "AGAGGG", "CACAGG"]);
438 let mut matcher = BarcodeMatcher::new(&samples, 1, 2, use_cache);
439 let test_barcode: &[u8; 6] = b"AAATTT";
445 let expected = BarcodeMatch { best_match: 0, best_mismatches: 1, next_best_mismatches: 4 };
446 assert_eq!(matcher.assign(test_barcode), Some(expected));
447 }
448
449 #[rstest]
451 #[case(true)]
452 #[case(false)]
453 fn test_read_no_call_contributes_to_mismatch_number(#[case] use_cache: bool) {
454 let samples = barcodes_to_samples(&["AAATTT", "AGAGGG", "CACAGG"]);
455 let mut matcher = BarcodeMatcher::new(&samples, 1, 2, use_cache);
456 let test_barcode: &[u8; 6] = b"NAGTTT";
462 assert_eq!(matcher.assign(test_barcode), None);
463 }
464
465 #[rstest]
466 #[case(true)]
467 #[case(false)]
468 fn test_produce_no_match_if_too_many_mismatches(#[case] use_cache: bool) {
469 let samples = barcodes_to_samples(&["AAGCTAG", "CAGCTAG", "GAGCTAG", "TAGCTAG"]);
470 let assignment_barcode: &[u8] = b"ATCGATC";
471 let mut matcher = BarcodeMatcher::new(&samples, 0, 100, use_cache);
472 assert_eq!(matcher.assign(assignment_barcode), None);
473 }
474
475 #[rstest]
476 #[case(true)]
477 #[case(false)]
478 fn test_produce_no_match_if_within_mismatch_delta(#[case] use_cache: bool) {
479 let samples = barcodes_to_samples(&["AAAAAAAA", "CCCCCCCC", "GGGGGGGG", "GGGGGGTT"]);
480 let assignment_barcode: &[u8] = samples[3].barcode.as_bytes();
481 let mut matcher = BarcodeMatcher::new(&samples, 100, 3, use_cache);
482 assert_eq!(matcher.assign(assignment_barcode), None);
483 }
484
485 #[rstest]
486 #[case(true)]
487 #[case(false)]
488 fn test_produce_no_match_if_too_many_mismatches_via_nocalls(#[case] use_cache: bool) {
489 let samples = barcodes_to_samples(&["AAAAAAAA", "CCCCCCCC", "GGGGGGGG", "GGGGGGTT"]);
490 let assignment_barcode: &[u8] = b"GGGGGGTN";
491 let mut matcher = BarcodeMatcher::new(&samples, 0, 100, use_cache);
492 assert_eq!(matcher.assign(assignment_barcode), None);
493 }
494}