use crate::lattice::{Lattice, Node, NodeId, NodeType, INVALID_NODE_ID};
use crate::viterbi::{ConnectionCost, SpacePenalty};
use std::cmp::Ordering;
use std::collections::BinaryHeap;
#[derive(Debug, Clone)]
pub struct NbestPath {
pub node_ids: Vec<NodeId>,
pub total_cost: i32,
pub rank: usize,
}
impl NbestPath {
#[must_use]
pub const fn new(node_ids: Vec<NodeId>, total_cost: i32, rank: usize) -> Self {
Self {
node_ids,
total_cost,
rank,
}
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.node_ids.is_empty()
}
#[must_use]
pub fn len(&self) -> usize {
self.node_ids.len()
}
#[must_use]
pub const fn cost(&self) -> i32 {
self.total_cost
}
pub fn nodes<'a>(&'a self, lattice: &'a Lattice) -> impl Iterator<Item = &'a Node> + 'a {
self.node_ids.iter().filter_map(|&id| lattice.node(id))
}
#[must_use]
pub fn surfaces<'a>(&'a self, lattice: &'a Lattice) -> Vec<&'a str> {
self.nodes(lattice).map(|n| n.surface.as_ref()).collect()
}
#[must_use]
pub fn pos_tags<'a>(&'a self, lattice: &'a Lattice) -> Vec<&'a str> {
self.nodes(lattice)
.map(|n| n.feature.split(',').next().unwrap_or_default())
.collect()
}
}
#[derive(Debug, Clone, Default)]
pub struct NbestResult {
paths: Vec<NbestPath>,
}
impl NbestResult {
#[must_use]
pub const fn new(paths: Vec<NbestPath>) -> Self {
Self { paths }
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.paths.is_empty()
}
#[must_use]
pub fn len(&self) -> usize {
self.paths.len()
}
#[must_use]
pub fn best(&self) -> Option<&NbestPath> {
self.paths.first()
}
#[must_use]
pub fn get(&self, index: usize) -> Option<&NbestPath> {
self.paths.get(index)
}
pub fn iter(&self) -> impl Iterator<Item = &NbestPath> {
self.paths.iter()
}
#[must_use]
pub fn into_paths(self) -> Vec<NbestPath> {
self.paths
}
#[must_use]
pub fn to_pairs(&self) -> Vec<(Vec<NodeId>, i32)> {
self.paths
.iter()
.map(|p| (p.node_ids.clone(), p.total_cost))
.collect()
}
}
impl IntoIterator for NbestResult {
type Item = NbestPath;
type IntoIter = std::vec::IntoIter<NbestPath>;
fn into_iter(self) -> Self::IntoIter {
self.paths.into_iter()
}
}
#[derive(Debug, Clone)]
struct NodeCandidate {
cost: i32,
prev_node_id: NodeId,
prev_candidate_idx: usize,
}
impl Eq for NodeCandidate {}
impl PartialEq for NodeCandidate {
fn eq(&self, other: &Self) -> bool {
self.cost == other.cost
}
}
impl Ord for NodeCandidate {
fn cmp(&self, other: &Self) -> Ordering {
other.cost.cmp(&self.cost)
}
}
impl PartialOrd for NodeCandidate {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
#[derive(Debug, Clone)]
struct BackwardCandidate {
node_id: NodeId,
candidate_idx: usize,
cost: i32,
path: Vec<NodeId>,
}
impl Eq for BackwardCandidate {}
impl PartialEq for BackwardCandidate {
fn eq(&self, other: &Self) -> bool {
self.cost == other.cost
}
}
impl Ord for BackwardCandidate {
fn cmp(&self, other: &Self) -> Ordering {
other.cost.cmp(&self.cost)
}
}
impl PartialOrd for BackwardCandidate {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
#[derive(Debug, Clone)]
pub struct ImprovedNbestSearcher {
max_results: usize,
max_candidates_per_node: usize,
space_penalty: SpacePenalty,
}
impl ImprovedNbestSearcher {
#[must_use]
pub fn new(n: usize) -> Self {
Self {
max_results: n,
max_candidates_per_node: n.max(2) * 2,
space_penalty: SpacePenalty::default(),
}
}
#[must_use]
pub const fn with_max_candidates(mut self, k: usize) -> Self {
self.max_candidates_per_node = k;
self
}
#[must_use]
pub fn with_space_penalty(mut self, penalty: SpacePenalty) -> Self {
self.space_penalty = penalty;
self
}
pub fn search<C: ConnectionCost>(&self, lattice: &mut Lattice, conn_cost: &C) -> NbestResult {
if lattice.node_count() <= 2 {
return NbestResult::default();
}
let candidates = self.forward_pass_kbest(lattice, conn_cost);
self.backward_pass_nbest(lattice, &candidates)
}
fn forward_pass_kbest<C: ConnectionCost>(
&self,
lattice: &mut Lattice,
conn_cost: &C,
) -> Vec<Vec<NodeCandidate>> {
let node_count = lattice.node_count();
let char_len = lattice.char_len();
let mut candidates: Vec<Vec<NodeCandidate>> = vec![Vec::new(); node_count];
let bos_id = lattice.bos().id;
candidates[bos_id as usize].push(NodeCandidate {
cost: 0,
prev_node_id: INVALID_NODE_ID,
prev_candidate_idx: 0,
});
let mut starting_ids: Vec<NodeId> = Vec::new();
let mut ending_data: Vec<(NodeId, u16)> = Vec::new();
for pos in 0..=char_len {
starting_ids.clear();
starting_ids.extend(lattice.nodes_starting_at(pos).map(|n| n.id));
ending_data.clear();
ending_data.extend(lattice.nodes_ending_at(pos).map(|n| (n.id, n.right_id)));
for &node_id in &starting_ids {
let (left_id, word_cost, has_space) = {
let Some(node) = lattice.node(node_id) else {
continue;
};
(node.left_id, node.word_cost, node.has_space_before)
};
let space_penalty = if has_space {
self.space_penalty.get(left_id)
} else {
0
};
let mut new_candidates: BinaryHeap<NodeCandidate> = BinaryHeap::new();
for &(prev_id, prev_right_id) in &ending_data {
let prev_candidates = &candidates[prev_id as usize];
if prev_candidates.is_empty() {
continue;
}
let connection = conn_cost.cost(prev_right_id, left_id);
for (idx, prev_cand) in prev_candidates.iter().enumerate() {
if prev_cand.cost == i32::MAX {
continue;
}
let total = prev_cand
.cost
.saturating_add(connection)
.saturating_add(word_cost)
.saturating_add(space_penalty);
new_candidates.push(NodeCandidate {
cost: total,
prev_node_id: prev_id,
prev_candidate_idx: idx,
});
}
}
let k = self.max_candidates_per_node;
let mut selected: Vec<NodeCandidate> = Vec::with_capacity(k);
while selected.len() < k {
if let Some(cand) = new_candidates.pop() {
selected.push(cand);
} else {
break;
}
}
candidates[node_id as usize] = selected;
if let Some(best) = candidates[node_id as usize].first() {
if let Some(node) = lattice.node_mut(node_id) {
node.total_cost = best.cost;
node.prev_node_id = best.prev_node_id;
}
}
}
}
candidates
}
fn backward_pass_nbest(
&self,
lattice: &Lattice,
candidates: &[Vec<NodeCandidate>],
) -> NbestResult {
let eos = lattice.eos();
let eos_candidates = &candidates[eos.id as usize];
if eos_candidates.is_empty() {
return NbestResult::default();
}
let mut results: Vec<NbestPath> = Vec::with_capacity(self.max_results);
let mut heap: BinaryHeap<BackwardCandidate> = BinaryHeap::new();
for (idx, cand) in eos_candidates.iter().enumerate() {
heap.push(BackwardCandidate {
node_id: eos.id,
candidate_idx: idx,
cost: cand.cost,
path: Vec::new(),
});
}
while let Some(current) = heap.pop() {
if results.len() >= self.max_results {
break;
}
let node_cands = &candidates[current.node_id as usize];
if current.candidate_idx >= node_cands.len() {
continue;
}
let cand = &node_cands[current.candidate_idx];
if cand.prev_node_id == INVALID_NODE_ID {
let mut path = current.path;
path.reverse();
results.push(NbestPath::new(path, current.cost, results.len()));
continue;
}
let Some(node) = lattice.node(current.node_id) else {
continue;
};
let mut new_path = current.path.clone();
if node.node_type != NodeType::Bos && node.node_type != NodeType::Eos {
new_path.push(current.node_id);
}
heap.push(BackwardCandidate {
node_id: cand.prev_node_id,
candidate_idx: cand.prev_candidate_idx,
cost: current.cost,
path: new_path,
});
}
NbestResult::new(results)
}
}
impl ImprovedNbestSearcher {
pub fn search_pairs<C: ConnectionCost>(
&self,
lattice: &mut Lattice,
conn_cost: &C,
) -> Vec<(Vec<NodeId>, i32)> {
self.search(lattice, conn_cost).to_pairs()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::lattice::NodeBuilder;
use crate::viterbi::ZeroConnectionCost;
#[test]
fn test_nbest_single_path() {
let mut lattice = Lattice::new("AB");
lattice.add_node(
NodeBuilder::new("A", 0, 1)
.left_id(1)
.right_id(1)
.word_cost(100),
);
lattice.add_node(
NodeBuilder::new("B", 1, 2)
.left_id(2)
.right_id(2)
.word_cost(200),
);
let searcher = ImprovedNbestSearcher::new(5);
let conn_cost = ZeroConnectionCost;
let results = searcher.search(&mut lattice, &conn_cost);
assert_eq!(results.len(), 1);
assert_eq!(results.best().unwrap().cost(), 300);
}
#[test]
fn test_nbest_multiple_paths() {
let mut lattice = Lattice::new("AB");
lattice.add_node(
NodeBuilder::new("A", 0, 1)
.left_id(1)
.right_id(1)
.word_cost(100),
);
lattice.add_node(
NodeBuilder::new("B", 1, 2)
.left_id(2)
.right_id(2)
.word_cost(200),
);
lattice.add_node(
NodeBuilder::new("AB", 0, 2)
.left_id(3)
.right_id(3)
.word_cost(350),
);
let searcher = ImprovedNbestSearcher::new(5);
let conn_cost = ZeroConnectionCost;
let results = searcher.search(&mut lattice, &conn_cost);
assert_eq!(results.len(), 2);
assert_eq!(results.get(0).unwrap().cost(), 300);
assert_eq!(results.get(1).unwrap().cost(), 350);
}
#[test]
fn test_nbest_korean_example() {
let mut lattice = Lattice::new("아버지가");
lattice.add_node(
NodeBuilder::new("아버지", 0, 3)
.left_id(1)
.right_id(1)
.word_cost(1000),
);
lattice.add_node(
NodeBuilder::new("가", 3, 4)
.left_id(2)
.right_id(2)
.word_cost(500),
);
lattice.add_node(
NodeBuilder::new("아버", 0, 2)
.left_id(3)
.right_id(3)
.word_cost(3000),
);
lattice.add_node(
NodeBuilder::new("지가", 2, 4)
.left_id(4)
.right_id(4)
.word_cost(3000),
);
let searcher = ImprovedNbestSearcher::new(3);
let conn_cost = ZeroConnectionCost;
let results = searcher.search(&mut lattice, &conn_cost);
assert!(results.len() >= 2);
let best = results.best().unwrap();
assert_eq!(best.cost(), 1500);
assert_eq!(best.surfaces(&lattice), vec!["아버지", "가"]);
let second = results.get(1).unwrap();
assert_eq!(second.cost(), 6000);
assert_eq!(second.surfaces(&lattice), vec!["아버", "지가"]);
}
#[test]
fn test_nbest_result_api() {
let mut lattice = Lattice::new("AB");
lattice.add_node(
NodeBuilder::new("A", 0, 1)
.left_id(1)
.right_id(1)
.word_cost(100),
);
lattice.add_node(
NodeBuilder::new("B", 1, 2)
.left_id(2)
.right_id(2)
.word_cost(200),
);
let searcher = ImprovedNbestSearcher::new(5);
let conn_cost = ZeroConnectionCost;
let results = searcher.search(&mut lattice, &conn_cost);
for path in results.iter() {
assert!(!path.is_empty());
assert!(path.cost() > 0);
}
let results2 = searcher.search(&mut lattice, &conn_cost);
for path in results2 {
assert!(!path.is_empty());
}
}
#[test]
fn test_nbest_empty_lattice() {
let mut lattice = Lattice::new("");
let searcher = ImprovedNbestSearcher::new(5);
let conn_cost = ZeroConnectionCost;
let results = searcher.search(&mut lattice, &conn_cost);
assert!(results.is_empty());
}
#[test]
fn test_nbest_compatibility_pairs() {
let mut lattice = Lattice::new("AB");
lattice.add_node(
NodeBuilder::new("AB", 0, 2)
.left_id(1)
.right_id(1)
.word_cost(300),
);
let searcher = ImprovedNbestSearcher::new(5);
let conn_cost = ZeroConnectionCost;
let pairs = searcher.search_pairs(&mut lattice, &conn_cost);
assert_eq!(pairs.len(), 1);
assert_eq!(pairs[0].1, 300);
}
#[test]
fn test_nbest_with_max_candidates() {
let mut lattice = Lattice::new("ABC");
lattice.add_node(NodeBuilder::new("A", 0, 1).word_cost(100));
lattice.add_node(NodeBuilder::new("B", 1, 2).word_cost(100));
lattice.add_node(NodeBuilder::new("C", 2, 3).word_cost(100));
lattice.add_node(NodeBuilder::new("AB", 0, 2).word_cost(180));
lattice.add_node(NodeBuilder::new("BC", 1, 3).word_cost(180));
lattice.add_node(NodeBuilder::new("ABC", 0, 3).word_cost(250));
let searcher = ImprovedNbestSearcher::new(5).with_max_candidates(10);
let conn_cost = ZeroConnectionCost;
let results = searcher.search(&mut lattice, &conn_cost);
assert!(results.len() >= 2);
let costs: Vec<i32> = results.iter().map(super::NbestPath::cost).collect();
for i in 1..costs.len() {
assert!(costs[i] >= costs[i - 1]);
}
}
}