#![allow(unsafe_code)]
use crate::lattice::{Lattice, NodeId};
use crate::viterbi::{ConnectionCost, SpacePenalty};
use std::simd::{cmp::SimdPartialOrd, i32x8, num::SimdInt, Select};
const SIMD_LANES: usize = 8;
const SIMD_THRESHOLD: usize = 8;
#[inline]
pub fn simd_update_node_cost<C: ConnectionCost>(
lattice: &Lattice,
conn_cost: &C,
node_id: NodeId,
prev_nodes: &[(NodeId, i32, u16)],
space_penalty: &SpacePenalty,
) -> (i32, NodeId) {
let current_node = match lattice.node(node_id) {
Some(n) => n,
None => return (i32::MAX, crate::lattice::INVALID_NODE_ID),
};
let left_id = current_node.left_id;
let word_cost = current_node.word_cost;
let has_space = current_node.has_space_before;
let space_penalty_cost = if has_space {
space_penalty.get(left_id)
} else {
0
};
let num_prev = prev_nodes.len();
if num_prev >= SIMD_THRESHOLD {
simd_batch_cost_calculation(
prev_nodes,
conn_cost,
left_id,
word_cost,
space_penalty_cost,
)
} else {
scalar_cost_calculation(
prev_nodes,
conn_cost,
left_id,
word_cost,
space_penalty_cost,
)
}
}
#[inline]
fn simd_batch_cost_calculation<C: ConnectionCost>(
prev_nodes: &[(NodeId, i32, u16)],
conn_cost: &C,
left_id: u16,
word_cost: i32,
space_penalty: i32,
) -> (i32, NodeId) {
let mut best_cost = i32::MAX;
let mut best_prev_id = crate::lattice::INVALID_NODE_ID;
let num_chunks = prev_nodes.len() / SIMD_LANES;
for chunk_idx in 0..num_chunks {
let start = chunk_idx * SIMD_LANES;
let end = start + SIMD_LANES;
let chunk = &prev_nodes[start..end];
let (min_cost, min_idx) =
process_chunk_simd(chunk, conn_cost, left_id, word_cost, space_penalty);
if min_cost < best_cost {
best_cost = min_cost;
best_prev_id = chunk[min_idx].0;
}
}
let remainder_start = num_chunks * SIMD_LANES;
for (prev_id, prev_cost, prev_right_id) in &prev_nodes[remainder_start..] {
if *prev_cost == i32::MAX {
continue;
}
let connection = conn_cost.cost(*prev_right_id, left_id);
let total = saturating_add_chain(*prev_cost, connection, word_cost, space_penalty);
if total < best_cost {
best_cost = total;
best_prev_id = *prev_id;
}
}
(best_cost, best_prev_id)
}
#[inline]
fn process_chunk_simd<C: ConnectionCost>(
chunk: &[(NodeId, i32, u16)],
conn_cost: &C,
left_id: u16,
word_cost: i32,
space_penalty: i32,
) -> (i32, usize) {
let mut prev_costs = [i32::MAX; SIMD_LANES];
let mut right_ids = [0u16; SIMD_LANES];
for (i, (_, cost, right_id)) in chunk.iter().enumerate().take(SIMD_LANES) {
prev_costs[i] = *cost;
right_ids[i] = *right_id;
}
let conn_costs = batch_connection_cost_lookup(conn_cost, &right_ids, left_id);
let totals = simd_calculate_totals(&prev_costs, &conn_costs, word_cost, space_penalty);
find_min_with_index(&totals)
}
#[inline(always)]
fn batch_connection_cost_lookup<C: ConnectionCost>(
conn_cost: &C,
right_ids: &[u16; SIMD_LANES],
left_id: u16,
) -> [i32; SIMD_LANES] {
[
conn_cost.cost(right_ids[0], left_id),
conn_cost.cost(right_ids[1], left_id),
conn_cost.cost(right_ids[2], left_id),
conn_cost.cost(right_ids[3], left_id),
conn_cost.cost(right_ids[4], left_id),
conn_cost.cost(right_ids[5], left_id),
conn_cost.cost(right_ids[6], left_id),
conn_cost.cost(right_ids[7], left_id),
]
}
#[inline]
fn simd_calculate_totals(
prev_costs: &[i32; SIMD_LANES],
conn_costs: &[i32; SIMD_LANES],
word_cost: i32,
space_penalty: i32,
) -> [i32; SIMD_LANES] {
let prev_vec = i32x8::from_array(*prev_costs);
let conn_vec = i32x8::from_array(*conn_costs);
let word_vec = i32x8::splat(word_cost);
let penalty_vec = i32x8::splat(space_penalty);
let sum1 = saturating_add_simd(prev_vec, conn_vec);
let sum2 = saturating_add_simd(sum1, word_vec);
let total = saturating_add_simd(sum2, penalty_vec);
total.to_array()
}
#[inline]
fn saturating_add_simd(a: i32x8, b: i32x8) -> i32x8 {
let sum = a + b;
let zero = i32x8::splat(0);
let a_pos = a.simd_gt(zero);
let b_pos = b.simd_gt(zero);
let sum_neg = sum.simd_lt(zero);
let overflow = a_pos & b_pos & sum_neg;
let a_neg = a.simd_lt(zero);
let b_neg = b.simd_lt(zero);
let sum_pos = sum.simd_gt(zero);
let underflow = a_neg & b_neg & sum_pos;
let max_vec = i32x8::splat(i32::MAX);
let min_vec = i32x8::splat(i32::MIN);
let saturated = overflow.select(max_vec, sum);
underflow.select(min_vec, saturated)
}
#[inline]
fn find_min_with_index(values: &[i32; SIMD_LANES]) -> (i32, usize) {
let vec = i32x8::from_array(*values);
let min_val = vec.reduce_min();
let mut min_idx = 0;
for (i, &val) in values.iter().enumerate() {
if val == min_val {
min_idx = i;
break;
}
}
(min_val, min_idx)
}
#[inline]
fn scalar_cost_calculation<C: ConnectionCost>(
prev_nodes: &[(NodeId, i32, u16)],
conn_cost: &C,
left_id: u16,
word_cost: i32,
space_penalty: i32,
) -> (i32, NodeId) {
let mut best_cost = i32::MAX;
let mut best_prev_id = crate::lattice::INVALID_NODE_ID;
for (prev_id, prev_cost, prev_right_id) in prev_nodes {
if *prev_cost == i32::MAX {
continue;
}
let connection = conn_cost.cost(*prev_right_id, left_id);
let total = saturating_add_chain(*prev_cost, connection, word_cost, space_penalty);
if total < best_cost {
best_cost = total;
best_prev_id = *prev_id;
}
}
(best_cost, best_prev_id)
}
#[inline(always)]
fn saturating_add_chain(a: i32, b: i32, c: i32, d: i32) -> i32 {
a.saturating_add(b).saturating_add(c).saturating_add(d)
}
#[cfg(feature = "simd-dict")]
#[inline]
pub fn batch_connection_cost<M>(
matrix: &M,
right_ids: &[u16; SIMD_LANES],
left_id: u16,
) -> [i32; SIMD_LANES]
where
M: mecab_ko_dict::matrix::simd::SimdMatrix,
{
let left_ids = [left_id; SIMD_LANES];
matrix.batch_get_8(right_ids, &left_ids)
}
pub fn simd_forward_pass_position<C: ConnectionCost>(
lattice: &mut Lattice,
conn_cost: &C,
space_penalty: &SpacePenalty,
pos: usize,
) {
let starting_ids: Vec<NodeId> = lattice.nodes_starting_at(pos).map(|n| n.id).collect();
for node_id in starting_ids {
let ending_nodes: Vec<(NodeId, i32, u16)> = lattice
.nodes_ending_at(pos)
.map(|n| (n.id, n.total_cost, n.right_id))
.collect();
let (best_cost, best_prev) =
simd_update_node_cost(lattice, conn_cost, node_id, &ending_nodes, space_penalty);
if let Some(node) = lattice.node_mut(node_id) {
node.total_cost = best_cost;
node.prev_node_id = best_prev;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::viterbi::ZeroConnectionCost;
#[test]
fn test_simd_calculate_totals() {
let prev_costs = [100, 200, 300, 400, 500, 600, 700, 800];
let conn_costs = [10, 20, 30, 40, 50, 60, 70, 80];
let word_cost = 1000;
let space_penalty = 0;
let totals = simd_calculate_totals(&prev_costs, &conn_costs, word_cost, space_penalty);
assert_eq!(totals[0], 1110); assert_eq!(totals[7], 1880); }
#[test]
fn test_find_min_with_index() {
let values = [500, 300, 800, 100, 600, 200, 700, 400];
let (min_val, min_idx) = find_min_with_index(&values);
assert_eq!(min_val, 100);
assert_eq!(min_idx, 3);
}
#[test]
fn test_saturating_add_simd() {
let a = i32x8::from_array([i32::MAX - 10, 100, 200, 300, 400, 500, 600, 700]);
let b = i32x8::from_array([20, 50, 60, 70, 80, 90, 100, 110]);
let result = saturating_add_simd(a, b);
let result_array = result.to_array();
assert_eq!(result_array[0], i32::MAX); assert_eq!(result_array[1], 150);
assert_eq!(result_array[7], 810);
}
#[test]
fn test_scalar_cost_calculation() {
let prev_nodes = vec![(1, 100, 10), (2, 200, 20), (3, 300, 30)];
let conn_cost = ZeroConnectionCost;
let left_id = 5;
let word_cost = 1000;
let space_penalty = 0;
let (best_cost, best_prev) =
scalar_cost_calculation(&prev_nodes, &conn_cost, left_id, word_cost, space_penalty);
assert_eq!(best_cost, 1100); assert_eq!(best_prev, 1);
}
#[test]
fn test_simd_batch_cost_calculation() {
let prev_nodes: Vec<(NodeId, i32, u16)> = (0..16)
.map(|i| (i as NodeId, (i as i32) * 100, i as u16))
.collect();
let conn_cost = ZeroConnectionCost;
let left_id = 5;
let word_cost = 1000;
let space_penalty = 0;
let (best_cost, best_prev) =
simd_batch_cost_calculation(&prev_nodes, &conn_cost, left_id, word_cost, space_penalty);
assert_eq!(best_cost, 1000); assert_eq!(best_prev, 0);
}
#[test]
fn test_saturating_add_chain() {
assert_eq!(saturating_add_chain(100, 200, 300, 400), 1000);
assert_eq!(saturating_add_chain(i32::MAX, 1, 0, 0), i32::MAX);
assert_eq!(saturating_add_chain(i32::MAX - 100, 50, 50, 50), i32::MAX);
}
#[test]
fn test_simd_overflow_handling() {
let a = i32x8::splat(i32::MAX);
let b = i32x8::splat(1);
let result = saturating_add_simd(a, b);
let result_array = result.to_array();
for &val in result_array.iter() {
assert_eq!(val, i32::MAX);
}
}
#[test]
fn test_simd_underflow_handling() {
let a = i32x8::splat(i32::MIN);
let b = i32x8::splat(-1);
let result = saturating_add_simd(a, b);
let result_array = result.to_array();
for &val in result_array.iter() {
assert_eq!(val, i32::MIN);
}
}
}