1use bitvec::prelude::*;
2use seq_io::fastq::RefRecord;
3use std::ops::Range;
4
5use crate::DNA_BASES;
6use crate::color::{COLOR_BACKGROUND, COLOR_BASES, COLOR_QUALS};
7use crate::color::{color_background, color_head};
8use crate::reverse_complement;
9use anyhow::{Context, Result, bail};
10use bstr::ByteSlice;
11use regex::bytes::{Regex, RegexBuilder, RegexSet, RegexSetBuilder};
12use seq_io::fastq::{OwnedRecord, Record};
13
14#[derive(Copy, Clone, Debug)]
16pub struct MatcherOpts {
17 pub invert_match: bool,
19 pub reverse_complement: bool,
22}
23
24fn to_bitvec(ranges: impl Iterator<Item = Range<usize>>, len: usize) -> BitVec {
26 let mut vec = bitvec![0; len];
27 ranges.for_each(|range| {
28 for index in range {
29 vec.set(index, true);
30 }
31 });
32 vec
33}
34
35fn bases_colored(
38 bases: &[u8],
39 quals: &[u8],
40 ranges: impl Iterator<Item = Range<usize>>,
41) -> (Vec<u8>, Vec<u8>) {
42 let mut colored_bases = Vec::with_capacity(bases.len());
44 let mut colored_quals = Vec::with_capacity(bases.len());
45
46 let bits = to_bitvec(ranges, bases.len());
48
49 let mut last_color_on = false;
52 let mut last_bases_index = 0;
53 let mut cur_bases_index = 0;
54 for base_color_on in bits.iter() {
55 if *base_color_on {
56 if !last_color_on {
58 if last_bases_index + 1 < cur_bases_index {
60 COLOR_BACKGROUND
61 .paint(&bases[last_bases_index..cur_bases_index])
62 .write_to(&mut colored_bases)
63 .unwrap();
64 COLOR_BACKGROUND
65 .paint(&quals[last_bases_index..cur_bases_index])
66 .write_to(&mut colored_quals)
67 .unwrap();
68 }
69 last_bases_index = cur_bases_index;
71 }
72
73 last_color_on = true;
74 } else {
75 if last_color_on {
77 if last_bases_index + 1 < cur_bases_index {
79 COLOR_BASES
80 .paint(&bases[last_bases_index..cur_bases_index])
81 .write_to(&mut colored_bases)
82 .unwrap();
83 COLOR_QUALS
84 .paint(&quals[last_bases_index..cur_bases_index])
85 .write_to(&mut colored_quals)
86 .unwrap();
87 }
88 last_bases_index = cur_bases_index;
90 }
91 last_color_on = false;
92 }
93 cur_bases_index += 1;
94 }
95 if last_bases_index + 1 < cur_bases_index {
97 if last_color_on {
98 COLOR_BASES
99 .paint(&bases[last_bases_index..cur_bases_index])
100 .write_to(&mut colored_bases)
101 .unwrap();
102 COLOR_QUALS
103 .paint(&quals[last_bases_index..cur_bases_index])
104 .write_to(&mut colored_quals)
105 .unwrap();
106 } else {
107 COLOR_BACKGROUND
108 .paint(&bases[last_bases_index..cur_bases_index])
109 .write_to(&mut colored_bases)
110 .unwrap();
111 COLOR_BACKGROUND
112 .paint(&quals[last_bases_index..cur_bases_index])
113 .write_to(&mut colored_quals)
114 .unwrap();
115 }
116 }
117
118 (colored_bases, colored_quals)
119}
120
121pub fn validate_fixed_pattern(pattern: &str) -> Result<()> {
123 for (index, base) in pattern.chars().enumerate() {
124 if !DNA_BASES.contains(&(base as u8)) {
125 bail!(
126 "Fixed pattern must contain only DNA bases: {} .. [{}] .. {}",
127 &pattern[0..index],
128 &pattern[index..=index],
129 &pattern[index + 1..],
130 )
131 }
132 }
133 Ok(())
134}
135
136pub trait Matcher {
138 fn opts(&self) -> MatcherOpts;
140
141 fn bases_match(&self, bases: &[u8]) -> bool;
143
144 fn color_matched_bases(&self, bases: &[u8], quals: &[u8]) -> (Vec<u8>, Vec<u8>);
148
149 #[inline]
151 fn read_match(&self, read: &RefRecord) -> bool {
152 let bases_match = self.bases_match(read.seq());
153 if self.opts().invert_match {
154 bases_match
155 && (!self.opts().reverse_complement
156 || self.bases_match(&reverse_complement(read.seq())))
157 } else {
158 bases_match
159 || (self.opts().reverse_complement
160 && self.bases_match(&reverse_complement(read.seq())))
161 }
162 }
163
164 #[inline]
167 fn color(&self, read: &mut OwnedRecord, match_found: bool) {
168 if match_found {
169 let (seq, qual) = self.color_matched_bases(&read.seq, &read.qual);
170 read.head = color_head(&read.head);
171 read.seq = seq;
172 read.qual = qual;
173 } else {
174 read.head = color_background(&read.head);
176 read.seq = color_background(&read.seq);
177 read.qual = color_background(&read.qual);
178 }
179 }
180}
181
182pub struct FixedStringMatcher {
184 pattern: Vec<u8>,
185 opts: MatcherOpts,
186}
187
188impl Matcher for FixedStringMatcher {
189 #[inline]
190 fn bases_match(&self, bases: &[u8]) -> bool {
191 bases.find(&self.pattern).is_some() != self.opts.invert_match
192 }
193
194 fn color_matched_bases(&self, bases: &[u8], quals: &[u8]) -> (Vec<u8>, Vec<u8>) {
195 let ranges = bases.find_iter(&self.pattern).map(|start| Range {
196 start,
197 end: start + self.pattern.len(),
198 });
199 if self.opts().reverse_complement {
200 let bases_revcomp = &reverse_complement(bases);
201 let ranges_revcomp = bases_revcomp
202 .find_iter(&self.pattern)
203 .map(|start| bases.len() - start - self.pattern.len())
204 .map(|start| Range {
205 start,
206 end: start + self.pattern.len(),
207 });
208 bases_colored(bases, quals, ranges.chain(ranges_revcomp))
209 } else {
210 bases_colored(bases, quals, ranges)
211 }
212 }
213
214 #[inline]
215 fn opts(&self) -> MatcherOpts {
216 self.opts
217 }
218}
219
220impl FixedStringMatcher {
221 pub fn new(pattern: &str, opts: MatcherOpts) -> Self {
222 let pattern = pattern.as_bytes().to_vec();
223 Self { pattern, opts }
224 }
225}
226
227pub struct FixedStringSetMatcher {
229 patterns: Vec<Vec<u8>>,
230 opts: MatcherOpts,
231}
232
233impl Matcher for FixedStringSetMatcher {
234 #[inline]
235 fn bases_match(&self, bases: &[u8]) -> bool {
236 self.patterns
237 .iter()
238 .any(|pattern| bases.find(pattern).is_some())
239 != self.opts.invert_match
240 }
241
242 fn color_matched_bases(&self, bases: &[u8], quals: &[u8]) -> (Vec<u8>, Vec<u8>) {
243 let ranges = self.patterns.iter().flat_map(|pattern| {
244 bases
245 .find_iter(&pattern)
246 .map(|start| Range {
247 start,
248 end: start + pattern.len(),
249 })
250 .collect::<Vec<_>>()
251 });
252 if self.opts().reverse_complement {
253 let bases_revcomp = &reverse_complement(bases);
254 let ranges_revcomp = self.patterns.iter().flat_map(|pattern| {
255 bases_revcomp
256 .find_iter(&pattern)
257 .map(|start| bases.len() - start - pattern.len())
258 .map(|start| Range {
259 start,
260 end: start + pattern.len(),
261 })
262 .collect::<Vec<_>>()
263 });
264 bases_colored(bases, quals, ranges.chain(ranges_revcomp))
265 } else {
266 bases_colored(bases, quals, ranges)
267 }
268 }
269
270 #[inline]
271 fn opts(&self) -> MatcherOpts {
272 self.opts
273 }
274}
275
276impl FixedStringSetMatcher {
277 pub fn new<I, S>(patterns: I, opts: MatcherOpts) -> Self
278 where
279 S: AsRef<str>,
280 I: IntoIterator<Item = S>,
281 {
282 let patterns: Vec<Vec<u8>> = patterns
283 .into_iter()
284 .map(|pattern| pattern.as_ref().to_owned().as_bytes().to_vec())
285 .collect();
286 Self { patterns, opts }
287 }
288}
289
290pub struct RegexMatcher {
292 regex: Regex,
293 opts: MatcherOpts,
294}
295
296impl RegexMatcher {
297 pub fn new(pattern: &str, opts: MatcherOpts) -> Self {
298 let regex = RegexBuilder::new(pattern)
299 .build()
300 .context(format!("Invalid regular expression: {}", pattern))
301 .unwrap();
302 Self { regex, opts }
303 }
304}
305
306impl Matcher for RegexMatcher {
307 #[inline]
308 fn bases_match(&self, bases: &[u8]) -> bool {
309 self.regex.is_match(bases) != self.opts.invert_match
310 }
311
312 fn color_matched_bases(&self, bases: &[u8], quals: &[u8]) -> (Vec<u8>, Vec<u8>) {
313 let ranges = self.regex.find_iter(bases).map(|m| m.range());
314 if self.opts().reverse_complement {
315 let bases_revcomp = &reverse_complement(bases);
316 let ranges_revcomp =
317 self.regex
318 .find_iter(bases_revcomp)
319 .map(|m| m.range())
320 .map(|range| Range {
321 start: bases.len() - range.start - range.len(),
322 end: bases.len() - range.start,
323 });
324 bases_colored(bases, quals, ranges.chain(ranges_revcomp))
325 } else {
326 bases_colored(bases, quals, ranges)
327 }
328 }
329
330 #[inline]
331 fn opts(&self) -> MatcherOpts {
332 self.opts
333 }
334}
335
336pub struct RegexSetMatcher {
337 regex_set: RegexSet,
338 regex_matchers: Vec<RegexMatcher>,
339 opts: MatcherOpts,
340}
341
342impl RegexSetMatcher {
344 pub fn new<I, S>(patterns: I, opts: MatcherOpts) -> Self
345 where
346 S: AsRef<str>,
347 I: IntoIterator<Item = S>,
348 {
349 let string_patterns: Vec<String> = patterns
350 .into_iter()
351 .map(|p| p.as_ref().to_string())
352 .collect();
353 let regex_set = RegexSetBuilder::new(string_patterns.clone())
354 .dfa_size_limit(usize::MAX)
355 .build()
356 .unwrap();
357 let regex_matchers: Vec<RegexMatcher> = string_patterns
358 .into_iter()
359 .map(|pattern| RegexMatcher::new(pattern.as_ref(), opts))
360 .collect();
361 Self {
362 regex_set,
363 regex_matchers,
364 opts,
365 }
366 }
367}
368
369impl Matcher for RegexSetMatcher {
370 #[inline]
371 fn bases_match(&self, bases: &[u8]) -> bool {
372 self.regex_set.is_match(bases) != self.opts.invert_match
373 }
374
375 fn color_matched_bases(&self, bases: &[u8], quals: &[u8]) -> (Vec<u8>, Vec<u8>) {
376 let ranges = self
377 .regex_matchers
378 .iter()
379 .flat_map(|r| r.regex.find_iter(bases).map(|m| m.range()));
380 if self.opts().reverse_complement {
381 let bases_revcomp = &reverse_complement(bases);
382 let ranges_revcomp = self.regex_matchers.iter().flat_map(|r| {
383 r.regex
384 .find_iter(bases_revcomp)
385 .map(|m| m.range())
386 .map(|range| Range {
387 start: bases.len() - range.start - range.len(),
388 end: bases.len() - range.start,
389 })
390 });
391 bases_colored(bases, quals, ranges.chain(ranges_revcomp))
392 } else {
393 bases_colored(bases, quals, ranges)
394 }
395 }
396
397 #[inline]
398 fn opts(&self) -> MatcherOpts {
399 self.opts
400 }
401}
402
403pub struct MatcherFactory;
405
406impl MatcherFactory {
407 pub fn new_matcher(
408 pattern: &Option<String>,
409 fixed_strings: bool,
410 regexp: &Vec<String>,
411 match_opts: MatcherOpts,
412 ) -> Box<dyn Matcher + Sync + Send> {
413 match (fixed_strings, &pattern) {
414 (true, Some(pattern)) => Box::new(FixedStringMatcher::new(pattern, match_opts)),
415 (false, Some(pattern)) => Box::new(RegexMatcher::new(pattern, match_opts)),
416 (true, None) => Box::new(FixedStringSetMatcher::new(regexp, match_opts)),
417 (false, None) => Box::new(RegexSetMatcher::new(regexp, match_opts)),
418 }
419 }
420}
421
422#[cfg(test)]
424pub mod tests {
425 use crate::matcher::*;
426 use rstest::rstest;
427
428 #[rstest]
433 #[case(vec![(0, 1)], "AGG", bitvec![1, 0, 0])] #[case(vec![(2, 3)], "AGG", bitvec![0, 0, 1])] #[case(vec![(1, 2)], "AGG", bitvec![0, 1, 0])] #[case(vec![(0 ,0)], "AGG", bitvec![0, 0, 0])] #[case(vec![(0, 3)], "AGG", bitvec![1, 1, 1])] #[case(vec![(1, 4)], "AGGTC", bitvec![0, 1, 1, 1, 0])] #[case(vec![(0, 2), (3, 5)], "AGGTC", bitvec![1, 1, 0, 1, 1])] #[case(vec![(0, 3), (3, 5)], "AGGTC", bitvec![1, 1, 1, 1, 1])] #[case(vec![(0, 4), (3, 5)], "AGGTC", bitvec![1, 1, 1, 1, 1])] #[case(vec![(0, 3), (0, 5)], "AGGTC", bitvec![1, 1, 1, 1, 1])] #[case(vec![(4, 5), (0, 2)], "AGGTC", bitvec![1, 1, 0, 0, 1])] fn test_to_bitvec(
445 #[case] ranges: Vec<(usize, usize)>,
446 #[case] bases: &str,
447 #[case] expected: BitVec,
448 ) {
449 let ranges = ranges
450 .into_iter()
451 .map(|(start, end)| std::ops::Range { start, end });
452 let result_bitvec = to_bitvec(ranges, bases.len());
453 assert_eq!(result_bitvec, expected);
454 }
455
456 #[rstest]
461 #[case(false, "AG", "AGG", true)] #[case(false, "CC", "AGG", false)] #[case(true, "CC", "AGG", true)] #[case(true, "TT", "AGG", false)] #[case(false, "AT", "ATGAT", true)] #[case(true, "CG", "GCCG", true)] #[case(false, "AGAG", "AGAGAGAG", true)] #[case(true, "TCTC", "AGAGAGAG", true)] fn test_fixed_string_matcher_read_match(
470 #[case] reverse_complement: bool,
471 #[case] pattern: &str,
472 #[case] seq: &str,
473 #[case] expected: bool,
474 ) {
475 let invert_matches = [true, false];
476 for invert_match in IntoIterator::into_iter(invert_matches) {
477 let opts = MatcherOpts {
478 invert_match,
479 reverse_complement,
480 };
481 let matcher = FixedStringMatcher::new(pattern, opts);
482 let qual = (0..seq.len()).map(|_| "X").collect::<String>();
483 let record = format!("@id\n{seq}\n+\n{qual}\n");
484 let mut reader = seq_io::fastq::Reader::new(record.as_bytes());
485 let read_record = reader.next().unwrap().unwrap();
486 let result = matcher.read_match(&read_record);
487 if invert_match {
488 assert_ne!(result, expected);
489 } else {
490 assert_eq!(result, expected);
491 }
492 }
493 }
494
495 #[rstest]
500 #[case(false, vec!["A", "AGG", "G"], "AGGG", true)] #[case(true, vec!["A", "AGG", "G"], "TCCC", true)] #[case(false, vec!["A", "AGG", "G"], "TTTT", false)] #[case(true, vec!["T", "AAA"], "CCCCC", false)] #[case(false, vec!["AGG", "C", "TT"], "AGGTT", true)] #[case(true, vec!["AGG", "C", "TT"], "GGGGG", true)] #[case(false, vec!["AC", "TT"], "TTACGTT", true)] #[case(true, vec!["GT", "AA"], "TTACGTT", true)] #[case(false, vec!["GAGA","AGTT"], "GAGAGTT", true)] #[case(true, vec!["CTCT","AACT"], "GAGAGTT", true)] fn test_fixed_string_set_metcher_read_match(
511 #[case] reverse_complement: bool,
512 #[case] patterns: Vec<&str>,
513 #[case] seq: &str,
514 #[case] expected: bool,
515 ) {
516 let invert_matches = [true, false];
517 for invert_match in IntoIterator::into_iter(invert_matches) {
518 let opts = MatcherOpts {
519 invert_match,
520 reverse_complement,
521 };
522 let matcher = FixedStringSetMatcher::new(patterns.iter(), opts);
523 let qual = (0..seq.len()).map(|_| "X").collect::<String>();
524 let record = format!("@id\n{seq}\n+\n{qual}\n");
525 let mut reader = seq_io::fastq::Reader::new(record.as_bytes());
526 let read_record = reader.next().unwrap().unwrap();
527 let result = matcher.read_match(&read_record);
528 if invert_match {
529 assert_ne!(result, expected);
530 } else {
531 assert_eq!(result, expected);
532 }
533 }
534 }
535
536 #[rstest]
541 #[case(false, "^A", "AGG", true)] #[case(false, "^T", "AGG", false)] #[case(true, "^C", "AGG", true)] #[case(true, "^T", "AGG", false)] #[case(false, "A.A", "ATATA", true)] #[case(true, "T.G", "CACACA", false)] fn test_regex_matcher_read_match(
548 #[case] reverse_complement: bool,
549 #[case] pattern: &str,
550 #[case] seq: &str,
551 #[case] expected: bool,
552 ) {
553 let invert_matches = [true, false];
554 for invert_match in IntoIterator::into_iter(invert_matches) {
555 let opts = MatcherOpts {
556 invert_match,
557 reverse_complement,
558 };
559
560 let matcher = RegexMatcher::new(pattern, opts);
561 let qual = (0..seq.len()).map(|_| "X").collect::<String>();
562 let record = format!("@id\n{seq}\n+\n{qual}\n");
563 let mut reader = seq_io::fastq::Reader::new(record.as_bytes());
564 let read_record = reader.next().unwrap().unwrap();
565 let result = matcher.read_match(&read_record);
566 if invert_match {
567 assert_ne!(result, expected);
568 } else {
569 assert_eq!(result, expected);
570 }
571 }
572 }
573
574 #[rstest]
579 #[case(false, vec!["^A.G", "C..", "$T"], "AGGCTT", true)] #[case(true, vec!["^T.C", "..G", "$A"], "AGGCTT", true)] #[case(false, vec!["^A.G", "G..", "$T"], "CCTCA", false)] #[case(true, vec!["$A", "C.CC"], "CCTCA", false)] #[case(false, vec!["^T", ".GG", "A.+G"], "ATCTACTACG", true)] #[case(true, vec!["^C", ".CC", "C+.T"], "ATCTACTACG", true)] #[case(false, vec!["^T", "T.A"], "TTAATAA", true)] #[case(true, vec!["^T", "T.A"], "AATA", true)] #[case(false, vec!["^T","T.+G"], "TAGAGTG", true)] #[case(true, vec!["^A","A.+C"], "TAGAGTG", true)] fn test_regex_set_metcher_read_match(
590 #[case] reverse_complement: bool,
591 #[case] patterns: Vec<&str>,
592 #[case] seq: &str,
593 #[case] expected: bool,
594 ) {
595 let invert_matches = [true, false];
596 for invert_match in IntoIterator::into_iter(invert_matches) {
597 let opts = MatcherOpts {
598 invert_match,
599 reverse_complement,
600 };
601
602 let matcher = RegexSetMatcher::new(patterns.iter(), opts);
603 let qual = (0..seq.len()).map(|_| "X").collect::<String>();
604 let record = format!("@id\n{seq}\n+\n{qual}\n");
605 let mut reader = seq_io::fastq::Reader::new(record.as_bytes());
606 let read_record = reader.next().unwrap().unwrap();
607 let result = matcher.read_match(&read_record);
608 if invert_match {
609 assert_ne!(result, expected);
610 } else {
611 assert_eq!(result, expected);
612 }
613 }
614 }
615
616 #[test]
621 fn test_validate_fixed_pattern_is_ok() {
622 let pattern = "AGTGTGATG";
623 let result = validate_fixed_pattern(pattern);
624 assert!(result.is_ok());
625 }
626 #[test]
627 fn test_validate_fixed_pattern_error() {
628 let pattern = "AXGTGTGATG";
629 let msg = String::from("Fixed pattern must contain only DNA bases: A .. [X] .. GTGTGATG");
630 let result = validate_fixed_pattern(pattern);
631 let inner = result.unwrap_err().to_string();
632 assert_eq!(inner, msg);
633 }
634}