use crate::Scoring;
pub fn smith_waterman(
needle: &str,
haystack: &str,
scoring: &Scoring,
) -> (u16, Vec<Vec<u16>>, bool) {
let needle = needle.as_bytes();
let haystack = haystack.as_bytes();
let delimiter_bytes: Vec<u8> = scoring.delimiters.bytes().collect();
let mut score_matrix = vec![vec![0; haystack.len()]; needle.len()];
let mut all_time_max_score = 0;
for i in 0..needle.len() {
let (prev_col_scores, curr_col_scores) = if i > 0 {
let (prev_col_scores_slice, curr_col_scores_slice) = score_matrix.split_at_mut(i);
(&prev_col_scores_slice[i - 1], &mut curr_col_scores_slice[0])
} else {
(&vec![0; haystack.len()], &mut score_matrix[i])
};
let mut up_score_simd: u16 = 0;
let mut up_gap_penalty_mask = true;
let needle_char = needle[i];
let needle_is_uppercase = needle_char.is_ascii_uppercase();
let needle_char = needle_char.to_ascii_lowercase();
let mut left_gap_penalty_mask = true;
let mut delimiter_bonus_enabled = false;
let mut prev_haystack_is_delimiter = false;
let mut prev_haystack_is_lowercase = false;
for j in 0..haystack.len() {
let is_prefix = j == 0;
let is_offset_prefix =
j == 1 && prev_col_scores[0] == 0 && !haystack[0].is_ascii_alphabetic();
let haystack_char = haystack[j];
let haystack_is_uppercase = haystack_char.is_ascii_uppercase();
let haystack_is_lowercase = haystack_char.is_ascii_lowercase();
let haystack_char = haystack_char.to_ascii_lowercase();
let haystack_is_delimiter = delimiter_bytes.contains(&haystack_char);
let matched_casing_mask = needle_is_uppercase == haystack_is_uppercase;
let match_score = if is_prefix {
scoring.match_score + scoring.prefix_bonus
} else if is_offset_prefix {
scoring.match_score + scoring.offset_prefix_bonus
} else {
scoring.match_score
};
let diag = if is_prefix { 0 } else { prev_col_scores[j - 1] };
let is_match = needle_char == haystack_char;
let diag_score = if is_match {
diag + match_score
+ if prev_haystack_is_delimiter && delimiter_bonus_enabled && !haystack_is_delimiter { scoring.delimiter_bonus } else { 0 }
+ if !is_prefix && haystack_is_uppercase && prev_haystack_is_lowercase { scoring.capitalization_bonus } else { 0 }
+ if matched_casing_mask { scoring.matching_case_bonus } else { 0 }
} else {
diag.saturating_sub(scoring.mismatch_penalty)
};
let up_gap_penalty = if up_gap_penalty_mask {
scoring.gap_open_penalty
} else {
scoring.gap_extend_penalty
};
let up_score = up_score_simd.saturating_sub(up_gap_penalty);
let left = prev_col_scores[j];
let left_gap_penalty = if left_gap_penalty_mask {
scoring.gap_open_penalty
} else {
scoring.gap_extend_penalty
};
let left_score = left.saturating_sub(left_gap_penalty);
let max_score = diag_score.max(up_score).max(left_score);
let diag_mask = max_score == diag_score;
up_gap_penalty_mask = max_score != up_score || diag_mask;
left_gap_penalty_mask = max_score != left_score || diag_mask;
prev_haystack_is_lowercase = haystack_is_lowercase;
prev_haystack_is_delimiter = haystack_is_delimiter;
delimiter_bonus_enabled |= !prev_haystack_is_delimiter;
up_score_simd = max_score;
curr_col_scores[j] = max_score;
all_time_max_score = all_time_max_score.max(max_score);
}
}
let mut max_score = all_time_max_score;
let exact = haystack == needle;
if exact {
max_score += scoring.exact_match_bonus;
}
(max_score, score_matrix, exact)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Scoring, r#const::*, smith_waterman::simd::smith_waterman as smith_waterman_simd};
const CHAR_SCORE: u16 = MATCH_SCORE + MATCHING_CASE_BONUS;
fn get_score(needle: &str, haystack: &str) -> u16 {
let scoring = Scoring::default();
let ref_score = smith_waterman(needle, haystack, &scoring).0;
let simd_score = smith_waterman_simd::<16, 1>(needle, &[haystack], None, &scoring).0[0];
assert_eq!(
ref_score, simd_score,
"Reference and SIMD scores don't match"
);
ref_score
}
#[test]
fn test_score_basic() {
assert_eq!(get_score("b", "abc"), CHAR_SCORE);
assert_eq!(get_score("c", "abc"), CHAR_SCORE);
}
#[test]
fn test_score_prefix() {
assert_eq!(get_score("a", "abc"), CHAR_SCORE + PREFIX_BONUS);
assert_eq!(get_score("a", "aabc"), CHAR_SCORE + PREFIX_BONUS);
assert_eq!(get_score("a", "babc"), CHAR_SCORE);
}
#[test]
fn test_score_offset_prefix() {
assert_eq!(get_score("a", "-a"), CHAR_SCORE + OFFSET_PREFIX_BONUS);
assert_eq!(get_score("-a", "-ab"), 2 * CHAR_SCORE + PREFIX_BONUS);
assert_eq!(get_score("a", "'a"), CHAR_SCORE + OFFSET_PREFIX_BONUS);
assert_eq!(get_score("a", "Ba"), CHAR_SCORE);
}
#[test]
fn test_score_exact_match() {
assert_eq!(
get_score("a", "a"),
CHAR_SCORE + EXACT_MATCH_BONUS + PREFIX_BONUS
);
assert_eq!(
get_score("abc", "abc"),
3 * CHAR_SCORE + EXACT_MATCH_BONUS + PREFIX_BONUS
);
assert_eq!(get_score("ab", "abc"), 2 * CHAR_SCORE + PREFIX_BONUS);
assert_eq!(get_score("abc", "ab"), 2 * CHAR_SCORE + PREFIX_BONUS);
}
#[test]
fn test_score_delimiter() {
assert_eq!(get_score("-", "a--bc"), CHAR_SCORE);
assert_eq!(get_score("b", "a-b"), CHAR_SCORE + DELIMITER_BONUS);
assert_eq!(get_score("a", "a-b-c"), CHAR_SCORE + PREFIX_BONUS);
assert_eq!(get_score("b", "a--b"), CHAR_SCORE + DELIMITER_BONUS);
assert_eq!(get_score("c", "a--bc"), CHAR_SCORE);
assert_eq!(get_score("a", "-a--bc"), CHAR_SCORE + OFFSET_PREFIX_BONUS);
}
#[test]
fn test_score_no_delimiter_for_delimiter_chars() {
assert_eq!(get_score("-", "a-bc"), CHAR_SCORE);
assert_eq!(get_score("-", "a--bc"), CHAR_SCORE);
assert!(get_score("a_b", "a_bb") > get_score("a_b", "a__b"));
}
#[test]
fn test_score_affine_gap() {
assert_eq!(
get_score("test", "Uterst"),
CHAR_SCORE * 4 - GAP_OPEN_PENALTY
);
assert_eq!(
get_score("test", "Uterrst"),
CHAR_SCORE * 4 - GAP_OPEN_PENALTY - GAP_EXTEND_PENALTY
);
}
#[test]
fn test_score_capital_bonus() {
assert_eq!(get_score("a", "A"), MATCH_SCORE + PREFIX_BONUS);
assert_eq!(get_score("A", "Aa"), CHAR_SCORE + PREFIX_BONUS);
assert_eq!(get_score("D", "forDist"), CHAR_SCORE + CAPITALIZATION_BONUS);
assert_eq!(get_score("D", "foRDist"), CHAR_SCORE);
assert_eq!(get_score("D", "FOR_DIST"), CHAR_SCORE + DELIMITER_BONUS);
}
#[test]
fn test_score_prefix_beats_delimiter() {
assert!(get_score("swap", "swap(test)") > get_score("swap", "iter_swap(test)"));
assert!(get_score("_", "_private_member") > get_score("_", "public_member"));
}
#[test]
fn test_scattered_match_should_score_low() {
let scoring = Scoring::default();
let good_score = smith_waterman("SortedMap", "SortedArrayMap", &scoring).0;
let bad_score = smith_waterman("SortedMap", "LightSourceTeamApiKeys", &scoring).0;
assert!(
good_score > bad_score,
"SortedArrayMap (score={}) should score higher than \
LightSourceTeamApiKeys (score={}) for needle 'SortedMap'",
good_score,
bad_score
);
}
#[test]
fn test_high_gap_penalties_reject_scattered_matches() {
let strict_scoring = Scoring {
gap_open_penalty: 100,
gap_extend_penalty: 80,
..Scoring::default()
};
let good_score = smith_waterman("SortedMap", "SortedArrayMap", &strict_scoring).0;
let bad_score = smith_waterman("SortedMap", "LightSourceTeamApiKeys", &strict_scoring).0;
assert!(
good_score > 0,
"Good match should still have a positive score, got {}",
good_score
);
assert!(
good_score > bad_score,
"Good match (score={}) should score higher than \
scattered match (score={}) with high gap penalties",
good_score,
bad_score
);
let default_scoring = Scoring::default();
let default_bad_score =
smith_waterman("SortedMap", "LightSourceTeamApiKeys", &default_scoring).0;
assert!(
default_bad_score > bad_score,
"Default scoring bad_score ({}) should be higher than \
strict scoring bad_score ({}) — confirms gap penalties take effect",
default_bad_score,
bad_score
);
}
#[test]
fn test_gap_penalties_affect_reference_scoring() {
let default_scoring = Scoring::default();
let strict_scoring = Scoring {
gap_open_penalty: 100,
gap_extend_penalty: 80,
..Scoring::default()
};
let default_score = smith_waterman("test", "Uterrrrrst", &default_scoring).0;
let strict_score = smith_waterman("test", "Uterrrrrst", &strict_scoring).0;
assert!(
default_score > strict_score,
"Default score ({}) should be higher than strict score ({}) for gapped match",
default_score,
strict_score,
);
}
}