#![allow(unsafe_code)]
use super::{DenseMatrix, Matrix, INVALID_CONNECTION_COST};
use std::simd::{
cmp::{SimdOrd, SimdPartialOrd},
i32x8,
num::SimdInt,
u16x16, u16x8, Select, Simd,
};
const SIMD_LANES_8: usize = 8;
const SIMD_LANES_16: usize = 16;
pub trait SimdMatrix: Matrix {
fn batch_get_8(&self, right_ids: &[u16; 8], left_ids: &[u16; 8]) -> [i32; 8];
fn batch_get_16(&self, right_ids: &[u16; 16], left_ids: &[u16; 16]) -> [i32; 16];
fn batch_get_slice(&self, right_ids: &[u16], left_ids: &[u16], output: &mut [i32]);
}
impl SimdMatrix for DenseMatrix {
#[inline]
fn batch_get_8(&self, right_ids: &[u16; 8], left_ids: &[u16; 8]) -> [i32; 8] {
batch_lookup_simd_8(self, right_ids, left_ids)
}
#[inline]
fn batch_get_16(&self, right_ids: &[u16; 16], left_ids: &[u16; 16]) -> [i32; 16] {
batch_lookup_simd_16(self, right_ids, left_ids)
}
fn batch_get_slice(&self, right_ids: &[u16], left_ids: &[u16], output: &mut [i32]) {
let len = right_ids.len().min(left_ids.len()).min(output.len());
let simd_chunks = len / SIMD_LANES_8;
for i in 0..simd_chunks {
let start = i * SIMD_LANES_8;
let end = start + SIMD_LANES_8;
let right_chunk: [u16; 8] = right_ids[start..end].try_into().unwrap_or([0; 8]);
let left_chunk: [u16; 8] = left_ids[start..end].try_into().unwrap_or([0; 8]);
let costs = self.batch_get_8(&right_chunk, &left_chunk);
output[start..end].copy_from_slice(&costs);
}
for i in (simd_chunks * SIMD_LANES_8)..len {
output[i] = self.get(right_ids[i], left_ids[i]);
}
}
}
#[inline]
fn batch_lookup_simd_8(
matrix: &DenseMatrix,
right_ids: &[u16; 8],
left_ids: &[u16; 8],
) -> [i32; 8] {
let lsize = matrix.left_size() as u16;
let costs_ref = matrix.costs();
let right_vec = u16x8::from_array(*right_ids);
let left_vec = u16x8::from_array(*left_ids);
let lsize_vec = u16x8::splat(lsize);
let left_scaled = left_vec * lsize_vec;
let indices = right_vec + left_scaled;
let indices_array = indices.to_array();
let mut result = [INVALID_CONNECTION_COST; 8];
for (i, &idx) in indices_array.iter().enumerate() {
let idx = idx as usize;
if idx < costs_ref.len() {
result[i] = costs_ref[idx] as i32;
}
}
result
}
#[inline]
fn batch_lookup_simd_16(
matrix: &DenseMatrix,
right_ids: &[u16; 16],
left_ids: &[u16; 16],
) -> [i32; 16] {
let lsize = matrix.left_size() as u16;
let costs_ref = matrix.costs();
let right_vec = u16x16::from_array(*right_ids);
let left_vec = u16x16::from_array(*left_ids);
let lsize_vec = u16x16::splat(lsize);
let left_scaled = left_vec * lsize_vec;
let indices = right_vec + left_scaled;
let indices_array = indices.to_array();
let mut result = [INVALID_CONNECTION_COST; 16];
for (i, &idx) in indices_array.iter().enumerate() {
let idx = idx as usize;
if idx < costs_ref.len() {
result[i] = costs_ref[idx] as i32;
}
}
result
}
#[inline]
pub fn simd_find_min_cost_8(costs: &[i32; 8]) -> (i32, usize) {
let vec = i32x8::from_array(*costs);
let min_cost = vec.reduce_min();
let mut min_idx = 0;
for (i, &cost) in costs.iter().enumerate() {
if cost == min_cost {
min_idx = i;
break;
}
}
(min_cost, min_idx)
}
#[inline]
pub fn simd_calculate_total_costs_8(
prev_costs: &[i32; 8],
conn_costs: &[i32; 8],
word_cost: i32,
space_penalty: i32,
) -> [i32; 8] {
let prev_vec = i32x8::from_array(*prev_costs);
let conn_vec = i32x8::from_array(*conn_costs);
let word_vec = i32x8::splat(word_cost);
let penalty_vec = i32x8::splat(space_penalty);
let total = prev_vec + conn_vec + word_vec + penalty_vec;
total.to_array()
}
#[inline]
pub fn simd_saturating_add_8(a: &[i32; 8], b: &[i32; 8]) -> [i32; 8] {
let a_vec = i32x8::from_array(*a);
let b_vec = i32x8::from_array(*b);
let sum = a_vec + b_vec;
let a_pos = a_vec.simd_gt(i32x8::splat(0));
let b_pos = b_vec.simd_gt(i32x8::splat(0));
let sum_neg = sum.simd_lt(i32x8::splat(0));
let overflow = a_pos & b_pos & sum_neg;
let max_vec = i32x8::splat(i32::MAX);
overflow.select(max_vec, sum).to_array()
}
#[inline]
pub fn simd_min_across_8(a: &[i32; 8], b: &[i32; 8]) -> [i32; 8] {
let a_vec = i32x8::from_array(*a);
let b_vec = i32x8::from_array(*b);
a_vec.simd_min(b_vec).to_array()
}
#[inline]
#[allow(dead_code)]
fn batch_lookup_generic_8(matrix: &DenseMatrix, right_ids: &[u16], left_ids: &[u16]) -> Vec<i32> {
const LANES: usize = 8;
let len = right_ids.len().min(left_ids.len());
let mut result = Vec::with_capacity(len);
let lsize = matrix.left_size() as u16;
let costs_ref = matrix.costs();
let chunks = len / LANES;
for i in 0..chunks {
let start = i * LANES;
let end = start + LANES;
let right_slice = &right_ids[start..end];
let left_slice = &left_ids[start..end];
let right_arr: [u16; 8] = right_slice.try_into().unwrap_or([0; 8]);
let left_arr: [u16; 8] = left_slice.try_into().unwrap_or([0; 8]);
let right_vec = Simd::<u16, 8>::from_array(right_arr);
let left_vec = Simd::<u16, 8>::from_array(left_arr);
let lsize_vec = Simd::<u16, 8>::splat(lsize);
let left_scaled = left_vec * lsize_vec;
let indices = right_vec + left_scaled;
let indices_arr = indices.to_array();
for idx_u16 in indices_arr {
let idx = idx_u16 as usize;
let cost = if idx < costs_ref.len() {
costs_ref[idx] as i32
} else {
INVALID_CONNECTION_COST
};
result.push(cost);
}
}
for i in (chunks * LANES)..len {
result.push(matrix.get(right_ids[i], left_ids[i]));
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_batch_get_8() {
let mut matrix = DenseMatrix::new(10, 10, 0);
for i in 0..8 {
matrix.set(i, i, (i as i16 + 1) * 100);
}
let right_ids = [0, 1, 2, 3, 4, 5, 6, 7];
let left_ids = [0, 1, 2, 3, 4, 5, 6, 7];
let costs = matrix.batch_get_8(&right_ids, &left_ids);
assert_eq!(costs[0], 100);
assert_eq!(costs[1], 200);
assert_eq!(costs[7], 800);
}
#[test]
fn test_batch_get_slice() {
let mut matrix = DenseMatrix::new(20, 20, 0);
for i in 0..20 {
matrix.set(i, i, (i as i16) * 10);
}
let right_ids: Vec<u16> = (0..20).collect();
let left_ids: Vec<u16> = (0..20).collect();
let mut output = vec![0i32; 20];
matrix.batch_get_slice(&right_ids, &left_ids, &mut output);
for i in 0..20 {
assert_eq!(output[i], (i as i32) * 10);
}
}
#[test]
fn test_simd_find_min_cost_8() {
let costs = [100, 50, 200, 25, 300, 10, 150, 75];
let (min_cost, min_idx) = simd_find_min_cost_8(&costs);
assert_eq!(min_cost, 10);
assert_eq!(min_idx, 5);
}
#[test]
fn test_simd_calculate_total_costs_8() {
let prev_costs = [100, 200, 300, 400, 500, 600, 700, 800];
let conn_costs = [10, 20, 30, 40, 50, 60, 70, 80];
let word_cost = 1000;
let space_penalty = 500;
let totals =
simd_calculate_total_costs_8(&prev_costs, &conn_costs, word_cost, space_penalty);
assert_eq!(totals[0], 100 + 10 + 1000 + 500); assert_eq!(totals[1], 200 + 20 + 1000 + 500); assert_eq!(totals[7], 800 + 80 + 1000 + 500); }
#[test]
fn test_simd_saturating_add_8() {
let a = [i32::MAX - 10, 100, 200, 300, 400, 500, 600, 700];
let b = [20, 50, 60, 70, 80, 90, 100, 110];
let result = simd_saturating_add_8(&a, &b);
assert_eq!(result[0], i32::MAX);
assert_eq!(result[1], 150);
assert_eq!(result[7], 810);
}
#[test]
fn test_simd_min_across_8() {
let a = [100, 200, 300, 400, 500, 600, 700, 800];
let b = [150, 150, 250, 450, 450, 550, 750, 750];
let result = simd_min_across_8(&a, &b);
assert_eq!(result[0], 100); assert_eq!(result[1], 150); assert_eq!(result[7], 750); }
#[test]
fn test_batch_get_boundary() {
let matrix = DenseMatrix::new(10, 10, 0);
let right_ids = [50, 60, 70, 80, 90, 100, 110, 120];
let left_ids = [50, 60, 70, 80, 90, 100, 110, 120];
let costs = matrix.batch_get_8(&right_ids, &left_ids);
for cost in costs.iter() {
assert_eq!(*cost, INVALID_CONNECTION_COST);
}
}
#[test]
fn test_batch_get_mixed() {
let mut matrix = DenseMatrix::new(10, 10, 0);
matrix.set(0, 0, 100);
matrix.set(5, 5, 500);
let right_ids = [0, 1, 2, 5, 50, 60, 70, 80];
let left_ids = [0, 1, 2, 5, 50, 60, 70, 80];
let costs = matrix.batch_get_8(&right_ids, &left_ids);
assert_eq!(costs[0], 100);
assert_eq!(costs[1], 0);
assert_eq!(costs[3], 500);
assert_eq!(costs[4], INVALID_CONNECTION_COST);
}
}