use crate::lattice::{Lattice, Node, NodeId, NodeType, INVALID_NODE_ID};
use std::cmp::Ordering;
use std::collections::BinaryHeap;
use std::rc::Rc;
#[cfg(feature = "simd")]
pub mod simd;
#[cfg(feature = "simd")]
pub use simd::{simd_forward_pass_position, simd_update_node_cost};
#[inline(always)]
const fn saturating_add_chain(a: i32, b: i32, c: i32, d: i32) -> i32 {
a.saturating_add(b).saturating_add(c).saturating_add(d)
}
pub trait ConnectionCost {
fn cost(&self, right_id: u16, left_id: u16) -> i32;
}
#[derive(Debug, Clone, Default)]
pub struct ZeroConnectionCost;
impl ConnectionCost for ZeroConnectionCost {
#[inline(always)]
fn cost(&self, _right_id: u16, _left_id: u16) -> i32 {
0
}
}
#[derive(Debug, Clone)]
pub struct FixedConnectionCost {
pub default_cost: i32,
}
impl FixedConnectionCost {
#[must_use]
pub const fn new(cost: i32) -> Self {
Self { default_cost: cost }
}
}
impl ConnectionCost for FixedConnectionCost {
#[inline(always)]
fn cost(&self, _right_id: u16, _left_id: u16) -> i32 {
self.default_cost
}
}
impl<T: mecab_ko_dict::Matrix> ConnectionCost for T {
#[inline(always)]
fn cost(&self, right_id: u16, left_id: u16) -> i32 {
self.get(right_id, left_id)
}
}
#[derive(Debug, Clone, Default)]
pub struct SpacePenalty {
penalties: Vec<(u16, i32)>,
}
impl SpacePenalty {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn korean_default() -> Self {
let mut penalties: Vec<(u16, i32)> = (1700u16..1760)
.chain(1780..1810)
.map(|id| (id, 6000))
.collect();
penalties.sort_unstable_by_key(|&(id, _)| id);
Self { penalties }
}
#[must_use]
pub fn from_dicrc(config: &str) -> Self {
let mut penalties = Vec::new();
for part in config.split(';') {
let parts: Vec<&str> = part.trim().split(',').collect();
if parts.len() == 2 {
if let (Ok(id), Ok(penalty)) = (
parts[0].trim().parse::<u16>(),
parts[1].trim().parse::<i32>(),
) {
penalties.push((id, penalty));
}
}
}
penalties.sort_unstable_by_key(|&(id, _)| id);
Self { penalties }
}
pub fn add(&mut self, left_id: u16, penalty: i32) {
let pos = self.penalties.partition_point(|&(id, _)| id < left_id);
self.penalties.insert(pos, (left_id, penalty));
}
#[must_use]
#[inline]
pub fn get(&self, left_id: u16) -> i32 {
self.penalties
.binary_search_by_key(&left_id, |&(id, _)| id)
.map_or(0, |idx| self.penalties[idx].1)
}
#[must_use]
#[inline]
pub fn is_empty(&self) -> bool {
self.penalties.is_empty()
}
#[must_use]
#[inline]
pub fn len(&self) -> usize {
self.penalties.len()
}
}
#[derive(Debug, Clone)]
pub struct ViterbiSearcher {
pub space_penalty: SpacePenalty,
}
impl Default for ViterbiSearcher {
fn default() -> Self {
Self::new()
}
}
impl ViterbiSearcher {
#[must_use]
pub fn new() -> Self {
Self {
space_penalty: SpacePenalty::default(),
}
}
#[must_use]
pub fn with_space_penalty(mut self, penalty: SpacePenalty) -> Self {
self.space_penalty = penalty;
self
}
pub fn search<C: ConnectionCost>(&self, lattice: &mut Lattice, conn_cost: &C) -> Vec<NodeId> {
self.forward_pass(lattice, conn_cost);
Self::backward_pass(lattice)
}
fn forward_pass<C: ConnectionCost>(&self, lattice: &mut Lattice, conn_cost: &C) {
let char_len = lattice.char_len();
let mut starting_ids: Vec<NodeId> = Vec::new();
let mut ending_nodes: Vec<(NodeId, i32, u16)> = Vec::new();
for pos in 0..=char_len {
starting_ids.clear();
starting_ids.extend(lattice.nodes_starting_at(pos).map(|n| n.id));
ending_nodes.clear();
ending_nodes.extend(
lattice
.nodes_ending_at(pos)
.map(|n| (n.id, n.total_cost, n.right_id)),
);
for &node_id in &starting_ids {
self.update_node_cost_with_endings(lattice, conn_cost, node_id, &ending_nodes);
}
}
}
#[inline]
fn update_node_cost_with_endings<C: ConnectionCost>(
&self,
lattice: &mut Lattice,
conn_cost: &C,
node_id: NodeId,
ending_nodes: &[(NodeId, i32, u16)],
) {
#[cfg(feature = "simd")]
if ending_nodes.len() >= 8 {
let (best_cost, best_prev) = simd::simd_update_node_cost(
lattice,
conn_cost,
node_id,
ending_nodes,
&self.space_penalty,
);
if let Some(node) = lattice.node_mut(node_id) {
node.total_cost = best_cost;
node.prev_node_id = best_prev;
}
return;
}
let (left_id, word_cost, has_space) = {
let Some(node) = lattice.node(node_id) else {
return;
};
(node.left_id, node.word_cost, node.has_space_before)
};
let space_penalty = if has_space {
self.space_penalty.get(left_id)
} else {
0
};
let mut best_cost = i32::MAX;
let mut best_prev = INVALID_NODE_ID;
for &(prev_id, prev_cost, prev_right_id) in ending_nodes {
if prev_cost == i32::MAX {
continue;
}
let connection = conn_cost.cost(prev_right_id, left_id);
let total = saturating_add_chain(prev_cost, connection, word_cost, space_penalty);
if total < best_cost {
best_cost = total;
best_prev = prev_id;
}
}
if let Some(node) = lattice.node_mut(node_id) {
node.total_cost = best_cost;
node.prev_node_id = best_prev;
}
}
#[cfg(test)]
#[allow(dead_code)]
fn update_node_cost<C: ConnectionCost>(
&self,
lattice: &mut Lattice,
conn_cost: &C,
node_id: NodeId,
pos: usize,
) {
let (left_id, word_cost, has_space) = {
let Some(node) = lattice.node(node_id) else {
return;
};
(node.left_id, node.word_cost, node.has_space_before)
};
let ending_nodes: Vec<(NodeId, i32, u16)> = lattice
.nodes_ending_at(pos)
.map(|n| (n.id, n.total_cost, n.right_id))
.collect();
let mut best_cost = i32::MAX;
let mut best_prev = INVALID_NODE_ID;
for (prev_id, prev_cost, prev_right_id) in ending_nodes {
if prev_cost == i32::MAX {
continue;
}
let connection = conn_cost.cost(prev_right_id, left_id);
let space_penalty = if has_space {
self.space_penalty.get(left_id)
} else {
0
};
let total = prev_cost
.saturating_add(connection)
.saturating_add(word_cost)
.saturating_add(space_penalty);
if total < best_cost {
best_cost = total;
best_prev = prev_id;
}
}
if let Some(node) = lattice.node_mut(node_id) {
node.total_cost = best_cost;
node.prev_node_id = best_prev;
}
}
fn backward_pass(lattice: &Lattice) -> Vec<NodeId> {
let mut path = Vec::new();
let mut current_id = lattice.eos().id;
while current_id != INVALID_NODE_ID {
if let Some(node) = lattice.node(current_id) {
if node.node_type != NodeType::Bos && node.node_type != NodeType::Eos {
path.push(current_id);
}
current_id = node.prev_node_id;
} else {
break;
}
}
path.reverse();
path
}
#[must_use]
pub fn get_best_cost(&self, lattice: &Lattice) -> i32 {
lattice.eos().total_cost
}
#[must_use]
pub fn has_valid_path(&self, lattice: &Lattice) -> bool {
lattice.eos().total_cost != i32::MAX && lattice.eos().prev_node_id != INVALID_NODE_ID
}
}
#[derive(Debug, Clone)]
struct PathNode {
node_id: NodeId,
prev: Option<Rc<Self>>,
}
impl PathNode {
#[allow(clippy::missing_const_for_fn)]
fn new(node_id: NodeId, prev: Option<Rc<Self>>) -> Self {
Self { node_id, prev }
}
fn to_vec(&self) -> Vec<NodeId> {
let mut path = Vec::new();
let mut current = Some(self);
while let Some(node) = current {
path.push(node.node_id);
current = node.prev.as_ref().map(std::convert::AsRef::as_ref);
}
path.reverse();
path
}
}
#[derive(Debug, Clone)]
struct NbestCandidate {
node_id: NodeId,
cost: i32,
path: Option<Rc<PathNode>>,
}
impl Eq for NbestCandidate {}
impl PartialEq for NbestCandidate {
fn eq(&self, other: &Self) -> bool {
self.cost == other.cost
}
}
impl Ord for NbestCandidate {
fn cmp(&self, other: &Self) -> Ordering {
other.cost.cmp(&self.cost)
}
}
impl PartialOrd for NbestCandidate {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
#[derive(Debug, Clone)]
pub struct NbestSearcher {
viterbi: ViterbiSearcher,
max_results: usize,
}
impl NbestSearcher {
#[must_use]
pub fn new(n: usize) -> Self {
Self {
viterbi: ViterbiSearcher::new(),
max_results: n,
}
}
#[must_use]
pub fn with_space_penalty(mut self, penalty: SpacePenalty) -> Self {
self.viterbi.space_penalty = penalty;
self
}
pub fn search<C: ConnectionCost>(
&self,
lattice: &mut Lattice,
conn_cost: &C,
) -> Vec<(Vec<NodeId>, i32)> {
self.viterbi.forward_pass(lattice, conn_cost);
if !self.viterbi.has_valid_path(lattice) {
return Vec::new();
}
if self.max_results == 1 {
let path = ViterbiSearcher::backward_pass(lattice);
let cost = self.viterbi.get_best_cost(lattice);
return vec![(path, cost)];
}
self.search_nbest(lattice, conn_cost)
}
fn search_nbest<C: ConnectionCost>(
&self,
lattice: &Lattice,
_conn_cost: &C,
) -> Vec<(Vec<NodeId>, i32)> {
let mut results: Vec<(Vec<NodeId>, i32)> = Vec::new();
let mut heap: BinaryHeap<NbestCandidate> = BinaryHeap::new();
let eos = lattice.eos();
if eos.total_cost == i32::MAX {
return results;
}
heap.push(NbestCandidate {
node_id: eos.id,
cost: eos.total_cost,
path: None,
});
while let Some(candidate) = heap.pop() {
if results.len() >= self.max_results {
break;
}
let Some(node) = lattice.node(candidate.node_id) else {
continue;
};
let current_path = if node.node_type != NodeType::Bos && node.node_type != NodeType::Eos
{
Some(Rc::new(PathNode::new(candidate.node_id, candidate.path)))
} else {
candidate.path
};
if node.node_type == NodeType::Bos {
let path_vec = current_path.map_or_else(Vec::new, |path_node| path_node.to_vec());
results.push((path_vec, candidate.cost));
continue;
}
if node.prev_node_id != INVALID_NODE_ID {
heap.push(NbestCandidate {
node_id: node.prev_node_id,
cost: candidate.cost,
path: current_path,
});
}
}
results
}
}
pub struct ViterbiResult<'a> {
lattice: &'a Lattice,
path: Vec<NodeId>,
total_cost: i32,
}
impl<'a> ViterbiResult<'a> {
#[must_use]
pub const fn new(lattice: &'a Lattice, path: Vec<NodeId>, total_cost: i32) -> Self {
Self {
lattice,
path,
total_cost,
}
}
pub fn nodes(&self) -> impl Iterator<Item = &'a Node> + '_ {
self.path.iter().filter_map(|&id| self.lattice.node(id))
}
#[must_use]
pub const fn cost(&self) -> i32 {
self.total_cost
}
#[must_use]
pub fn len(&self) -> usize {
self.path.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.path.is_empty()
}
#[must_use]
pub fn surfaces(&self) -> Vec<&str> {
self.nodes().map(|n| n.surface.as_ref()).collect()
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use crate::lattice::NodeBuilder;
struct TestConnectionCost {
costs: std::collections::HashMap<(u16, u16), i32>,
default: i32,
}
impl TestConnectionCost {
fn new(default: i32) -> Self {
Self {
costs: std::collections::HashMap::new(),
default,
}
}
fn set(&mut self, right_id: u16, left_id: u16, cost: i32) {
self.costs.insert((right_id, left_id), cost);
}
}
impl ConnectionCost for TestConnectionCost {
fn cost(&self, right_id: u16, left_id: u16) -> i32 {
self.costs
.get(&(right_id, left_id))
.copied()
.unwrap_or(self.default)
}
}
#[test]
fn test_space_penalty_default() {
let penalty = SpacePenalty::default();
assert!(penalty.is_empty());
assert_eq!(penalty.get(100), 0);
}
#[test]
fn test_space_penalty_from_dicrc() {
let penalty = SpacePenalty::from_dicrc("100,5000;200,3000;300,1000");
assert_eq!(penalty.len(), 3);
assert_eq!(penalty.get(100), 5000);
assert_eq!(penalty.get(200), 3000);
assert_eq!(penalty.get(300), 1000);
assert_eq!(penalty.get(999), 0); }
#[test]
fn test_space_penalty_korean_default() {
let penalty = SpacePenalty::korean_default();
assert!(!penalty.is_empty());
assert!(penalty.get(1785) > 0);
}
#[test]
fn test_viterbi_simple_path() {
let mut lattice = Lattice::new("AB");
lattice.add_node(
NodeBuilder::new("A", 0, 1)
.left_id(1)
.right_id(1)
.word_cost(100),
);
lattice.add_node(
NodeBuilder::new("B", 1, 2)
.left_id(2)
.right_id(2)
.word_cost(200),
);
let conn_cost = ZeroConnectionCost;
let searcher = ViterbiSearcher::new();
let path = searcher.search(&mut lattice, &conn_cost);
assert_eq!(path.len(), 2);
let first = lattice.node(path[0]).unwrap();
assert_eq!(first.surface.as_ref(), "A");
let second = lattice.node(path[1]).unwrap();
assert_eq!(second.surface.as_ref(), "B");
let total_cost = searcher.get_best_cost(&lattice);
assert_eq!(total_cost, 300); }
#[test]
fn test_viterbi_choose_best_path() {
let mut lattice = Lattice::new("AB");
lattice.add_node(
NodeBuilder::new("AB", 0, 2)
.left_id(1)
.right_id(1)
.word_cost(500),
);
lattice.add_node(
NodeBuilder::new("A", 0, 1)
.left_id(2)
.right_id(2)
.word_cost(100),
);
lattice.add_node(
NodeBuilder::new("B", 1, 2)
.left_id(3)
.right_id(3)
.word_cost(200),
);
let conn_cost = ZeroConnectionCost;
let searcher = ViterbiSearcher::new();
let path = searcher.search(&mut lattice, &conn_cost);
assert_eq!(path.len(), 2);
assert_eq!(lattice.node(path[0]).unwrap().surface.as_ref(), "A");
assert_eq!(lattice.node(path[1]).unwrap().surface.as_ref(), "B");
}
#[test]
fn test_viterbi_with_connection_cost() {
let mut lattice = Lattice::new("AB");
lattice.add_node(
NodeBuilder::new("AB", 0, 2)
.left_id(1)
.right_id(1)
.word_cost(300),
);
lattice.add_node(
NodeBuilder::new("A", 0, 1)
.left_id(2)
.right_id(2)
.word_cost(100),
);
lattice.add_node(
NodeBuilder::new("B", 1, 2)
.left_id(3)
.right_id(3)
.word_cost(100),
);
let mut conn_cost = TestConnectionCost::new(0);
conn_cost.set(2, 3, 500);
let searcher = ViterbiSearcher::new();
let path = searcher.search(&mut lattice, &conn_cost);
assert_eq!(path.len(), 1);
assert_eq!(lattice.node(path[0]).unwrap().surface.as_ref(), "AB");
}
#[test]
fn test_viterbi_with_space_penalty() {
let mut lattice = Lattice::new("A B");
lattice.add_node(
NodeBuilder::new("AB", 0, 2)
.left_id(1)
.right_id(1)
.word_cost(500),
);
lattice.add_node(
NodeBuilder::new("A", 0, 1)
.left_id(2)
.right_id(2)
.word_cost(100),
);
lattice.add_node(
NodeBuilder::new("B", 1, 2)
.left_id(100) .right_id(3)
.word_cost(100)
.has_space_before(true),
);
let mut penalty = SpacePenalty::new();
penalty.add(100, 1000);
let conn_cost = ZeroConnectionCost;
let searcher = ViterbiSearcher::new().with_space_penalty(penalty);
let path = searcher.search(&mut lattice, &conn_cost);
assert_eq!(path.len(), 1);
assert_eq!(lattice.node(path[0]).unwrap().surface.as_ref(), "AB");
}
#[test]
fn test_viterbi_korean_example() {
let mut lattice = Lattice::new("아버지가");
lattice.add_node(
NodeBuilder::new("아버지", 0, 3)
.left_id(1)
.right_id(1)
.word_cost(1000),
);
lattice.add_node(
NodeBuilder::new("가", 3, 4)
.left_id(100) .right_id(100)
.word_cost(500),
);
lattice.add_node(
NodeBuilder::new("아버", 0, 2)
.left_id(2)
.right_id(2)
.word_cost(3000),
);
lattice.add_node(
NodeBuilder::new("지가", 2, 4)
.left_id(3)
.right_id(3)
.word_cost(3000),
);
let conn_cost = ZeroConnectionCost;
let searcher = ViterbiSearcher::new();
let path = searcher.search(&mut lattice, &conn_cost);
assert_eq!(path.len(), 2);
assert_eq!(lattice.node(path[0]).unwrap().surface.as_ref(), "아버지");
assert_eq!(lattice.node(path[1]).unwrap().surface.as_ref(), "가");
}
#[test]
fn test_viterbi_empty_lattice() {
let mut lattice = Lattice::new("");
let conn_cost = ZeroConnectionCost;
let searcher = ViterbiSearcher::new();
let path = searcher.search(&mut lattice, &conn_cost);
assert!(path.is_empty());
}
#[test]
fn test_viterbi_no_path() {
let mut lattice = Lattice::new("ABC");
lattice.add_node(
NodeBuilder::new("A", 0, 1)
.left_id(1)
.right_id(1)
.word_cost(100),
);
let conn_cost = ZeroConnectionCost;
let searcher = ViterbiSearcher::new();
let path = searcher.search(&mut lattice, &conn_cost);
assert!(!searcher.has_valid_path(&lattice));
assert!(path.is_empty());
}
#[test]
fn test_nbest_single() {
let mut lattice = Lattice::new("AB");
lattice.add_node(
NodeBuilder::new("A", 0, 1)
.left_id(1)
.right_id(1)
.word_cost(100),
);
lattice.add_node(
NodeBuilder::new("B", 1, 2)
.left_id(2)
.right_id(2)
.word_cost(200),
);
let conn_cost = ZeroConnectionCost;
let searcher = NbestSearcher::new(1);
let results = searcher.search(&mut lattice, &conn_cost);
assert_eq!(results.len(), 1);
assert_eq!(results[0].1, 300); }
#[test]
fn test_viterbi_result_helper() {
let mut lattice = Lattice::new("AB");
let _id1 = lattice.add_node(
NodeBuilder::new("A", 0, 1)
.left_id(1)
.right_id(1)
.word_cost(100),
);
let _id2 = lattice.add_node(
NodeBuilder::new("B", 1, 2)
.left_id(2)
.right_id(2)
.word_cost(200),
);
let conn_cost = ZeroConnectionCost;
let searcher = ViterbiSearcher::new();
let path = searcher.search(&mut lattice, &conn_cost);
let cost = searcher.get_best_cost(&lattice);
let result = ViterbiResult::new(&lattice, path, cost);
assert_eq!(result.len(), 2);
assert_eq!(result.cost(), 300);
assert_eq!(result.surfaces(), vec!["A", "B"]);
}
#[test]
fn test_viterbi_with_dense_matrix() {
use mecab_ko_dict::DenseMatrix;
let mut matrix = DenseMatrix::new(3, 3, 0);
matrix.set(0, 1, 100);
matrix.set(1, 2, 50);
matrix.set(2, 0, 30);
matrix.set(0, 2, 5000);
matrix.set(1, 0, 200);
let mut lattice = Lattice::new("책을");
lattice.add_node(
NodeBuilder::new("책", 0, 1)
.left_id(1) .right_id(1) .word_cost(500),
);
lattice.add_node(
NodeBuilder::new("을", 1, 2)
.left_id(2) .right_id(2) .word_cost(100),
);
let searcher = ViterbiSearcher::new();
let path = searcher.search(&mut lattice, &matrix);
assert!(!path.is_empty());
let result = ViterbiResult::new(&lattice, path, searcher.get_best_cost(&lattice));
assert_eq!(result.surfaces(), vec!["책", "을"]);
assert_eq!(result.cost(), 780);
}
}