use crate::kmer::{Kmer, KmerBits};
use crate::minimizer::{MinimizerInfo, MinimizerIterator};
use crate::encoding::encode_base;
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct LookupResult {
pub kmer_id: u64,
pub kmer_id_in_string: u64,
pub kmer_offset: u64,
pub kmer_orientation: i8,
pub string_id: u64,
pub string_begin: u64,
pub string_end: u64,
pub minimizer_found: bool,
}
impl LookupResult {
pub fn not_found() -> Self {
Self {
kmer_id: u64::MAX,
kmer_id_in_string: u64::MAX,
kmer_offset: u64::MAX,
kmer_orientation: 1, string_id: u64::MAX,
string_begin: u64::MAX,
string_end: u64::MAX,
minimizer_found: true,
}
}
#[inline]
pub fn is_found(&self) -> bool {
self.kmer_id != u64::MAX
}
#[inline]
pub fn string_length(&self) -> u64 {
if self.is_found() {
self.string_end - self.string_begin
} else {
0
}
}
}
impl Default for LookupResult {
fn default() -> Self {
Self::not_found()
}
}
pub struct StreamingQuery<const K: usize>
where
Kmer<K>: KmerBits,
{
k: usize,
_m: usize, _canonical: bool,
start: bool,
kmer: Option<Kmer<K>>,
kmer_rc: Option<Kmer<K>>,
minimizer_it: MinimizerIterator,
minimizer_it_rc: MinimizerIterator,
curr_mini_info: MinimizerInfo,
prev_mini_info: MinimizerInfo,
curr_mini_info_rc: MinimizerInfo,
prev_mini_info_rc: MinimizerInfo,
remaining_string_bases: u64,
result: LookupResult,
num_searches: u64,
num_extensions: u64,
num_invalid: u64,
num_negative: u64,
}
impl<const K: usize> StreamingQuery<K>
where
Kmer<K>: KmerBits,
{
pub fn new(k: usize, m: usize, canonical: bool) -> Self {
assert_eq!(k, K, "k parameter must match const generic K");
let dummy_mini = MinimizerInfo::new(u64::MAX, 0, 0);
Self {
k,
_m: m,
_canonical: canonical,
start: true,
kmer: None,
kmer_rc: None,
minimizer_it: MinimizerIterator::with_seed(k, m, 1),
minimizer_it_rc: MinimizerIterator::with_seed(k, m, 1),
curr_mini_info: dummy_mini,
prev_mini_info: dummy_mini,
curr_mini_info_rc: dummy_mini,
prev_mini_info_rc: dummy_mini,
remaining_string_bases: 0,
result: LookupResult::not_found(),
num_searches: 0,
num_extensions: 0,
num_invalid: 0,
num_negative: 0,
}
}
pub fn reset(&mut self) {
self.start = true;
self.remaining_string_bases = 0;
self.result = LookupResult::not_found();
self.minimizer_it.set_position(0);
self.minimizer_it_rc.set_position(0);
}
pub fn lookup(&mut self, kmer_bytes: &[u8]) -> LookupResult {
self.lookup_internal(kmer_bytes, None)
}
pub fn lookup_with_dict(&mut self, kmer_bytes: &[u8], dict: &crate::dictionary::Dictionary) -> LookupResult {
self.lookup_internal(kmer_bytes, Some(dict))
}
fn lookup_internal(&mut self, kmer_bytes: &[u8], dict_opt: Option<&crate::dictionary::Dictionary>) -> LookupResult {
let is_valid = if self.start {
self.is_valid_kmer_bytes(kmer_bytes)
} else {
self.is_valid_base(kmer_bytes[self.k - 1])
};
if !is_valid {
self.num_invalid += 1;
self.reset();
return self.result.clone();
}
if self.start {
let km = Kmer::<K>::from_ascii_unchecked(kmer_bytes);
self.kmer = Some(km);
let rc = km.reverse_complement();
self.kmer_rc = Some(rc);
self.curr_mini_info = self.minimizer_it.next(km);
self.curr_mini_info_rc = self.minimizer_it_rc.next(rc);
} else {
if let Some(mut km) = self.kmer {
for i in 0..(self.k - 1) {
let base = km.get_base(i + 1);
km.set_base(i, base);
}
let new_base = kmer_bytes[self.k - 1];
if let Ok(encoded) = encode_base(new_base) {
km.set_base(self.k - 1, encoded);
self.kmer = Some(km);
if let Some(mut km_rc) = self.kmer_rc {
for i in (1..self.k).rev() {
let base = km_rc.get_base(i - 1);
km_rc.set_base(i, base);
}
let complement = crate::encoding::complement_base(encoded);
km_rc.set_base(0, complement);
self.kmer_rc = Some(km_rc);
self.curr_mini_info = self.minimizer_it.next(km);
self.curr_mini_info_rc = self.minimizer_it_rc.next(km_rc);
}
}
}
}
if self.remaining_string_bases == 0 {
self.seed(dict_opt);
} else {
if let Some(dict) = dict_opt {
self.try_extend(dict);
} else {
self.seed(dict_opt);
}
}
self.prev_mini_info = self.curr_mini_info;
self.prev_mini_info_rc = self.curr_mini_info_rc;
self.start = false;
self.result.clone()
}
fn is_valid_kmer_bytes(&self, bytes: &[u8]) -> bool {
if bytes.len() != self.k {
return false;
}
for &b in bytes {
if !matches!(b, b'A' | b'C' | b'G' | b'T' | b'a' | b'c' | b'g' | b't') {
return false;
}
}
true
}
fn is_valid_base(&self, b: u8) -> bool {
matches!(b, b'A' | b'C' | b'G' | b'T' | b'a' | b'c' | b'g' | b't')
}
fn seed(&mut self, dict_opt: Option<&crate::dictionary::Dictionary>) {
self.remaining_string_bases = 0;
if !self.start
&& self.curr_mini_info.value == self.prev_mini_info.value
&& self.curr_mini_info_rc.value == self.prev_mini_info_rc.value
&& !self.result.minimizer_found
{
assert_eq!(self.result.kmer_id, u64::MAX);
self.num_negative += 1;
return;
}
if let (Some(dict), Some(kmer)) = (dict_opt, self.kmer) {
if self._canonical {
let kmer_rc = kmer.reverse_complement();
let mini_fwd = dict.extract_minimizer::<K>(&kmer);
let mini_rc = dict.extract_minimizer::<K>(&kmer_rc);
if mini_fwd.value < mini_rc.value {
self.result = dict.lookup_canonical_streaming::<K>(&kmer, &kmer_rc, mini_fwd);
} else if mini_rc.value < mini_fwd.value {
self.result = dict.lookup_canonical_streaming::<K>(&kmer, &kmer_rc, mini_rc);
} else {
self.result = dict.lookup_canonical_streaming::<K>(&kmer, &kmer_rc, mini_fwd);
if self.result.kmer_id == u64::MAX {
self.result = dict.lookup_canonical_streaming::<K>(&kmer, &kmer_rc, mini_rc);
}
}
} else {
let mini_fwd = dict.extract_minimizer::<K>(&kmer);
self.result = dict.lookup_regular_streaming::<K>(&kmer, mini_fwd);
let minimizer_found = self.result.minimizer_found;
if self.result.kmer_id == u64::MAX {
assert_eq!(self.result.kmer_orientation, 1); let kmer_rc = kmer.reverse_complement();
let mini_rc = dict.extract_minimizer::<K>(&kmer_rc);
self.result = dict.lookup_regular_streaming::<K>(&kmer_rc, mini_rc);
self.result.kmer_orientation = -1; let minimizer_rc_found = self.result.minimizer_found;
self.result.minimizer_found = minimizer_rc_found || minimizer_found;
}
}
if self.result.kmer_id == u64::MAX {
self.num_negative += 1;
return;
}
assert!(self.result.minimizer_found);
self.num_searches += 1;
let string_size = self.result.string_end - self.result.string_begin;
if self.result.kmer_orientation > 0 {
self.remaining_string_bases =
(string_size - self.k as u64) - self.result.kmer_id_in_string;
} else {
self.remaining_string_bases = self.result.kmer_id_in_string;
}
} else {
self.result = LookupResult::not_found();
self.num_negative += 1;
}
}
fn try_extend(&mut self, dict: &crate::dictionary::Dictionary) {
if let (Some(kmer), Some(kmer_rc)) = (self.kmer, self.kmer_rc) {
let abs_pos = self.result.kmer_id_in_string as usize
+ self.result.string_begin as usize;
let next_abs_pos = if self.result.kmer_orientation > 0 {
abs_pos + 1
} else {
abs_pos.wrapping_sub(1)
};
let expected_kmer: Kmer<K> = dict.spss().decode_kmer_at(next_abs_pos);
if expected_kmer.bits() == kmer.bits()
|| expected_kmer.bits() == kmer_rc.bits()
{
self.num_extensions += 1;
let delta = self.result.kmer_orientation as i64;
self.result.kmer_id = (self.result.kmer_id as i64 + delta) as u64;
self.result.kmer_id_in_string =
(self.result.kmer_id_in_string as i64 + delta) as u64;
self.result.kmer_offset =
(self.result.kmer_offset as i64 + delta) as u64;
self.remaining_string_bases -= 1;
return;
}
}
self.seed(Some(dict));
}
pub fn num_searches(&self) -> u64 {
self.num_searches
}
pub fn num_extensions(&self) -> u64 {
self.num_extensions
}
pub fn num_positive_lookups(&self) -> u64 {
self.num_searches + self.num_extensions
}
pub fn num_negative_lookups(&self) -> u64 {
self.num_negative
}
pub fn num_invalid_lookups(&self) -> u64 {
self.num_invalid
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_lookup_result_creation() {
let result = LookupResult::not_found();
assert!(!result.is_found());
assert_eq!(result.kmer_id, u64::MAX);
}
#[test]
fn test_lookup_result_string_length() {
let mut result = LookupResult::not_found();
result.string_begin = 100;
result.string_end = 200;
result.kmer_id = 42;
assert_eq!(result.string_length(), 100);
}
#[test]
fn test_streaming_query_creation() {
let query: StreamingQuery<31> = StreamingQuery::new(31, 13, true);
assert_eq!(query.k, 31);
assert_eq!(query._m, 13);
assert!(query._canonical);
assert_eq!(query.num_searches(), 0);
}
#[test]
fn test_streaming_query_reset() {
let mut query: StreamingQuery<31> = StreamingQuery::new(31, 13, false);
query.num_searches = 10;
query.num_extensions = 5;
query.reset();
assert!(query.start);
assert_eq!(query.remaining_string_bases, 0);
}
#[test]
fn test_streaming_query_validation() {
let query: StreamingQuery<31> = StreamingQuery::new(31, 13, true);
assert!(query.is_valid_kmer_bytes(b"ACGTACGTACGTACGTACGTACGTACGTACG")); assert!(!query.is_valid_kmer_bytes(b"ACGT")); assert!(!query.is_valid_kmer_bytes(b"ACGTACGTACGTACGTACGTACGTACGTACGN"));
assert!(query.is_valid_base(b'A'));
assert!(query.is_valid_base(b'a'));
assert!(!query.is_valid_base(b'N'));
}
#[test]
fn test_streaming_query_lookup_invalid() {
let mut query: StreamingQuery<15> = StreamingQuery::new(15, 7, true);
let result = query.lookup(b"ACGT");
assert!(!result.is_found());
assert_eq!(query.num_invalid_lookups(), 1);
query.reset();
let result = query.lookup(b"ACGTACGTACGTACN");
assert!(!result.is_found());
assert_eq!(query.num_invalid_lookups(), 2);
}
#[test]
fn test_streaming_query_incremental_update() {
let mut query: StreamingQuery<9> = StreamingQuery::new(9, 5, false);
let _result1 = query.lookup(b"ACGTACGTA");
assert!(!query.start);
let _result2 = query.lookup(b"CGTACGTAC");
assert!(!query.start);
}
}
pub struct StreamingQueryEngine<'a, const K: usize>
where
Kmer<K>: KmerBits,
{
dict: &'a crate::dictionary::Dictionary,
query: StreamingQuery<K>,
}
impl<'a, const K: usize> StreamingQueryEngine<'a, K>
where
Kmer<K>: KmerBits,
{
pub fn new(dict: &'a crate::dictionary::Dictionary) -> Self {
let canonical = dict.canonical();
Self {
dict,
query: StreamingQuery::new(dict.k(), dict.m(), canonical),
}
}
pub fn reset(&mut self) {
self.query.reset();
}
pub fn lookup(&mut self, kmer_bytes: &[u8]) -> LookupResult {
self.query.lookup_with_dict(kmer_bytes, self.dict)
}
pub fn num_searches(&self) -> u64 {
self.query.num_searches()
}
pub fn num_extensions(&self) -> u64 {
self.query.num_extensions()
}
pub fn stats(&self) -> StreamingQueryStats {
StreamingQueryStats {
num_searches: self.query.num_searches(),
num_extensions: self.query.num_extensions(),
num_invalid: self.query.num_invalid_lookups(),
num_negative: self.query.num_negative_lookups(),
}
}
}
#[derive(Debug, Clone)]
pub struct StreamingQueryStats {
pub num_searches: u64,
pub num_extensions: u64,
pub num_invalid: u64,
pub num_negative: u64,
}