#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
use crate::protein::AminoAcid;
use crate::scoring::ScoringMatrix;
use crate::error::Result;
const SIMD_WIDTH: usize = 8;
#[cfg(target_arch = "x86_64")]
#[inline]
fn has_avx2() -> bool {
is_x86_feature_detected!("avx2")
}
#[cfg(not(target_arch = "x86_64"))]
#[inline]
fn has_avx2() -> bool {
false
}
#[cfg(target_arch = "x86_64")]
pub fn smith_waterman_avx2(
seq1: &[AminoAcid],
seq2: &[AminoAcid],
matrix: &ScoringMatrix,
open_penalty: i32,
extend_penalty: i32,
) -> Result<(Vec<Vec<i32>>, usize, usize)> {
if has_avx2() {
unsafe {
smith_waterman_avx2_optimized(seq1, seq2, matrix, open_penalty, extend_penalty)
}
} else {
smith_waterman_striped(seq1, seq2, matrix, open_penalty, extend_penalty)
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn smith_waterman_avx2_optimized(
seq1: &[AminoAcid],
seq2: &[AminoAcid],
matrix: &ScoringMatrix,
_open_penalty: i32,
extend_penalty: i32,
) -> Result<(Vec<Vec<i32>>, usize, usize)> {
let m = seq1.len();
let n = seq2.len();
let mut h = vec![vec![0i32; n + 1]; m + 1];
let mut max_score = 0;
let mut max_i = 0;
let mut max_j = 0;
let mut scores = vec![vec![0i32; n]; m];
for i in 0..m {
for j in 0..n {
scores[i][j] = matrix.score(seq1[i], seq2[j]);
}
}
let extend_vec = _mm256_set1_epi32(extend_penalty);
let zero_vec = _mm256_setzero_si256();
let mut batch_i = vec![0usize; SIMD_WIDTH];
let mut batch_j = vec![0usize; SIMD_WIDTH];
let mut diag_vals = vec![0i32; SIMD_WIDTH];
let mut up_vals = vec![0i32; SIMD_WIDTH];
let mut left_vals = vec![0i32; SIMD_WIDTH];
let mut scores_vals = vec![0i32; SIMD_WIDTH];
let mut results = vec![0i32; SIMD_WIDTH];
for k in 1..=(m + n) {
let i_start = std::cmp::max(1, k as i32 - n as i32) as usize;
let i_end = std::cmp::min(m, k - 1);
if i_start > i_end {
continue; }
for batch_start in (i_start..=i_end).step_by(SIMD_WIDTH) {
let batch_end = std::cmp::min(batch_start + SIMD_WIDTH, i_end + 1);
let batch_len = batch_end - batch_start;
for (batch_idx, i) in (batch_start..batch_end).enumerate() {
let j = k - i;
batch_i[batch_idx] = i;
batch_j[batch_idx] = j;
if i > 0 && j > 0 {
diag_vals[batch_idx] = h[i - 1][j - 1];
}
if i > 0 && j <= n {
up_vals[batch_idx] = h[i - 1][j];
}
if i <= m && j > 0 {
left_vals[batch_idx] = h[i][j - 1];
}
if i > 0 && j > 0 {
scores_vals[batch_idx] = scores[i - 1][j - 1];
} else {
scores_vals[batch_idx] = 0;
}
}
for batch_idx in batch_len..SIMD_WIDTH {
diag_vals[batch_idx] = 0;
up_vals[batch_idx] = 0;
left_vals[batch_idx] = 0;
scores_vals[batch_idx] = 0;
}
let diag_vec = _mm256_loadu_si256(diag_vals.as_ptr() as *const __m256i);
let up_vec = _mm256_loadu_si256(up_vals.as_ptr() as *const __m256i);
let left_vec = _mm256_loadu_si256(left_vals.as_ptr() as *const __m256i);
let scores_vec = _mm256_loadu_si256(scores_vals.as_ptr() as *const __m256i);
let diag_result = _mm256_add_epi32(diag_vec, scores_vec);
let up_result = _mm256_add_epi32(up_vec, extend_vec);
let left_result = _mm256_add_epi32(left_vec, extend_vec);
let max_du = _mm256_max_epi32(diag_result, up_result);
let max_dul = _mm256_max_epi32(max_du, left_result);
let result_vec = _mm256_max_epi32(max_dul, zero_vec);
_mm256_storeu_si256(results.as_mut_ptr() as *mut __m256i, result_vec);
for (batch_idx, i) in (batch_start..batch_end).enumerate() {
let j = k - i;
h[i][j] = results[batch_idx];
if h[i][j] > max_score {
max_score = h[i][j];
max_i = i;
max_j = j;
}
}
}
}
Ok((h, max_i, max_j))
}
fn smith_waterman_striped(
seq1: &[AminoAcid],
seq2: &[AminoAcid],
matrix: &ScoringMatrix,
_open_penalty: i32,
extend_penalty: i32,
) -> Result<(Vec<Vec<i32>>, usize, usize)> {
let m = seq1.len();
let n = seq2.len();
let mut h = vec![vec![0i32; n + 1]; m + 1];
let mut max_score = 0;
let mut max_i = 0;
let mut max_j = 0;
for k in 1..=(m + n) {
let i_start = std::cmp::max(1, k as i32 - n as i32) as usize;
let i_end = std::cmp::min(m, k - 1);
if i_start > i_end {
continue;
}
for i in i_start..=i_end {
for j_base in (1..=n).step_by(SIMD_WIDTH) {
let j_end = std::cmp::min(j_base + SIMD_WIDTH, n + 1);
for j in j_base..j_end {
let match_score = matrix.score(seq1[i - 1], seq2[j - 1]);
let diagonal = h[i - 1][j - 1] + match_score;
let up = h[i - 1][j] + extend_penalty;
let left = h[i][j - 1] + extend_penalty;
h[i][j] = std::cmp::max(0, std::cmp::max(diagonal, std::cmp::max(up, left)));
if h[i][j] > max_score {
max_score = h[i][j];
max_i = i;
max_j = j;
}
}
}
}
}
Ok((h, max_i, max_j))
}
#[cfg(not(target_arch = "x86_64"))]
pub fn smith_waterman_avx2(
_seq1: &[AminoAcid],
_seq2: &[AminoAcid],
_matrix: &ScoringMatrix,
_open_penalty: i32,
_extend_penalty: i32,
) -> Result<(Vec<Vec<i32>>, usize, usize)> {
Err(crate::error::Error::Custom(
"AVX2 kernel requires x86_64 architecture".to_string(),
))
}
#[cfg(target_arch = "x86_64")]
pub fn needleman_wunsch_avx2(
seq1: &[AminoAcid],
seq2: &[AminoAcid],
matrix: &ScoringMatrix,
open_penalty: i32,
extend_penalty: i32,
) -> Result<Vec<Vec<i32>>> {
let m = seq1.len();
let n = seq2.len();
let mut h = vec![vec![0i32; n + 1]; m + 1];
for i in 0..=m {
h[i][0] = (i as i32) * open_penalty;
}
for j in 0..=n {
h[0][j] = (j as i32) * open_penalty;
}
for j_base in (1..=n).step_by(SIMD_WIDTH) {
let j_end = std::cmp::min(j_base + SIMD_WIDTH, n + 1);
for i in 1..=m {
let mut scores = [0i32; SIMD_WIDTH];
for (j_offset, j) in (j_base..j_end).enumerate() {
scores[j_offset] = matrix.score(seq1[i - 1], seq2[j - 1]);
}
for (j_offset, j) in (j_base..j_end).enumerate() {
let match_score = scores[j_offset];
let diagonal = h[i - 1][j - 1] + match_score;
let up = h[i - 1][j] + extend_penalty;
let left = h[i][j - 1] + extend_penalty;
h[i][j] = std::cmp::max(diagonal, std::cmp::max(up, left));
}
}
}
Ok(h)
}
#[cfg(not(target_arch = "x86_64"))]
pub fn needleman_wunsch_avx2(
_seq1: &[AminoAcid],
_seq2: &[AminoAcid],
_matrix: &ScoringMatrix,
_open_penalty: i32,
_extend_penalty: i32,
) -> Result<Vec<Vec<i32>>> {
Err(crate::error::Error::Custom(
"AVX2 kernel requires x86_64 architecture".to_string(),
))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::scoring::ScoringMatrix;
#[test]
fn test_avx2_smith_waterman_fallback() {
let matrix = ScoringMatrix::default();
let seq1 = vec![AminoAcid::Alanine];
let seq2 = vec![AminoAcid::Alanine];
let _ = smith_waterman_avx2(&seq1, &seq2, &matrix, -11, -1);
}
}