use std::cmp::Ordering;
use std::collections::BinaryHeap;
use smallvec::SmallVec;
use crate::backend::LatticeBackend;
use crate::lattice::{EdgeId, Lattice, LatticePath, NodeId};
use crate::semiring::Semiring;
#[derive(Clone, Debug)]
struct PartialPath<W: Semiring> {
node: NodeId,
edges: SmallVec<[EdgeId; 16]>,
weight: W,
}
impl<W: Semiring> PartialPath<W> {
fn new(start: NodeId) -> Self {
Self {
node: start,
edges: SmallVec::new(),
weight: W::one(),
}
}
fn extend(&self, edge_id: EdgeId, target: NodeId, edge_weight: W) -> Self {
let mut new_edges = self.edges.clone();
new_edges.push(edge_id);
Self {
node: target,
edges: new_edges,
weight: self.weight.times(&edge_weight),
}
}
fn extend_move(mut self, edge_id: EdgeId, target: NodeId, edge_weight: W) -> Self {
self.edges.push(edge_id);
self.node = target;
self.weight = self.weight.times(&edge_weight);
self
}
fn into_lattice_path(self) -> LatticePath<W> {
let mut path = LatticePath::with_weight(self.weight);
path.edges = self.edges;
path.mark_complete();
path
}
}
struct OrderedPath<W: Semiring>(PartialPath<W>);
impl<W: Semiring> PartialEq for OrderedPath<W> {
fn eq(&self, other: &Self) -> bool {
self.0.weight == other.0.weight
}
}
impl<W: Semiring> Eq for OrderedPath<W> {}
impl<W: Semiring> PartialOrd for OrderedPath<W> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<W: Semiring> Ord for OrderedPath<W> {
fn cmp(&self, other: &Self) -> Ordering {
match self.0.weight.natural_less(&other.0.weight) {
Some(true) => Ordering::Greater, Some(false) => match other.0.weight.natural_less(&self.0.weight) {
Some(true) => Ordering::Less,
_ => Ordering::Equal,
},
None => Ordering::Equal,
}
}
}
pub struct NBestIterator<'a, W: Semiring, B: LatticeBackend> {
lattice: &'a Lattice<W, B>,
heap: BinaryHeap<OrderedPath<W>>,
end: NodeId,
limit: usize,
count: usize,
}
impl<'a, W: Semiring, B: LatticeBackend> NBestIterator<'a, W, B> {
pub fn new(lattice: &'a Lattice<W, B>, n: usize) -> Self {
let mut heap = BinaryHeap::new();
let start = lattice.start();
let end = lattice.end();
heap.push(OrderedPath(PartialPath::new(start)));
Self {
lattice,
heap,
end,
limit: n,
count: 0,
}
}
}
impl<'a, W: Semiring, B: LatticeBackend> Iterator for NBestIterator<'a, W, B> {
type Item = LatticePath<W>;
fn next(&mut self) -> Option<Self::Item> {
if self.count >= self.limit {
return None;
}
while let Some(OrderedPath(partial)) = self.heap.pop() {
if partial.node == self.end {
self.count += 1;
return Some(partial.into_lattice_path());
}
let mut edges_iter = self.lattice.outgoing_edges(partial.node);
if let Some(first_edge) = edges_iter.next() {
let mut last_edge = (first_edge.id, first_edge.target, first_edge.weight);
for edge in edges_iter {
let extended = partial.extend(last_edge.0, last_edge.1, last_edge.2);
self.heap.push(OrderedPath(extended));
last_edge = (edge.id, edge.target, edge.weight);
}
let extended = partial.extend_move(last_edge.0, last_edge.1, last_edge.2);
self.heap.push(OrderedPath(extended));
}
}
None
}
}
pub fn nbest<W: Semiring, B: LatticeBackend>(
lattice: &mut Lattice<W, B>,
n: usize,
) -> Vec<LatticePath<W>> {
NBestIterator::new(lattice, n).collect()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backend::HashMapBackend;
use crate::lattice::{EdgeMetadata, LatticeBuilder};
use crate::semiring::TropicalWeight;
#[test]
fn test_nbest_simple() {
let backend = HashMapBackend::new();
let mut builder = LatticeBuilder::new(backend);
builder.add_correction(0, 1, "a", TropicalWeight::new(1.0), EdgeMetadata::default());
builder.add_correction(0, 1, "b", TropicalWeight::new(2.0), EdgeMetadata::default());
let mut lattice = builder.build(1);
let paths = nbest(&mut lattice, 10);
assert_eq!(paths.len(), 2);
assert_eq!(paths[0].weight.value(), 1.0);
assert_eq!(paths[1].weight.value(), 2.0);
}
#[test]
fn test_nbest_multi_position() {
let backend = HashMapBackend::new();
let mut builder = LatticeBuilder::new(backend);
builder.add_correction(0, 1, "a", TropicalWeight::new(1.0), EdgeMetadata::default());
builder.add_correction(0, 1, "b", TropicalWeight::new(2.0), EdgeMetadata::default());
builder.add_correction(1, 2, "c", TropicalWeight::new(1.0), EdgeMetadata::default());
builder.add_correction(1, 2, "d", TropicalWeight::new(2.0), EdgeMetadata::default());
let mut lattice = builder.build(2);
let paths = nbest(&mut lattice, 10);
assert_eq!(paths.len(), 4);
let weights: Vec<_> = paths.iter().map(|p| p.weight.value()).collect();
assert_eq!(weights[0], 2.0); assert!(weights[1] == 3.0 || weights[1] == 3.0);
assert!(weights[2] == 3.0 || weights[2] == 3.0);
assert_eq!(weights[3], 4.0); }
#[test]
fn test_nbest_limit() {
let backend = HashMapBackend::new();
let mut builder = LatticeBuilder::new(backend);
for i in 0..10 {
builder.add_correction(
0,
1,
&format!("word{}", i),
TropicalWeight::new(i as f64),
EdgeMetadata::default(),
);
}
let mut lattice = builder.build(1);
let paths = nbest(&mut lattice, 3);
assert_eq!(paths.len(), 3);
assert_eq!(paths[0].weight.value(), 0.0);
assert_eq!(paths[1].weight.value(), 1.0);
assert_eq!(paths[2].weight.value(), 2.0);
}
#[test]
fn test_nbest_empty_lattice() {
let backend = HashMapBackend::new();
let builder: LatticeBuilder<TropicalWeight, _> = LatticeBuilder::new(backend);
let mut lattice = builder.build(0);
let paths = nbest(&mut lattice, 10);
assert_eq!(paths.len(), 1);
assert!(paths[0].is_empty());
}
#[test]
fn test_nbest_single_path() {
let backend = HashMapBackend::new();
let mut builder = LatticeBuilder::new(backend);
builder.add_correction(
0,
1,
"hello",
TropicalWeight::new(1.0),
EdgeMetadata::default(),
);
builder.add_correction(
1,
2,
"world",
TropicalWeight::new(2.0),
EdgeMetadata::default(),
);
let mut lattice = builder.build(2);
let paths = nbest(&mut lattice, 10);
assert_eq!(paths.len(), 1);
assert_eq!(paths[0].weight.value(), 3.0);
let words = paths[0].to_words(&lattice);
assert_eq!(words, vec!["hello", "world"]);
}
#[test]
fn test_nbest_diamond() {
let backend = HashMapBackend::new();
let mut builder = LatticeBuilder::new(backend);
builder.add_correction(0, 1, "a", TropicalWeight::new(1.0), EdgeMetadata::default());
builder.add_correction(0, 2, "b", TropicalWeight::new(2.0), EdgeMetadata::default());
builder.add_correction(1, 3, "c", TropicalWeight::new(1.0), EdgeMetadata::default());
builder.add_correction(2, 3, "d", TropicalWeight::new(0.5), EdgeMetadata::default());
let mut lattice = builder.build(3);
let paths = nbest(&mut lattice, 10);
assert_eq!(paths.len(), 2);
assert_eq!(paths[0].weight.value(), 2.0);
assert_eq!(paths[1].weight.value(), 2.5);
}
#[test]
fn test_nbest_iterator() {
let backend = HashMapBackend::new();
let mut builder = LatticeBuilder::new(backend);
builder.add_correction(0, 1, "a", TropicalWeight::new(1.0), EdgeMetadata::default());
builder.add_correction(0, 1, "b", TropicalWeight::new(2.0), EdgeMetadata::default());
builder.add_correction(0, 1, "c", TropicalWeight::new(3.0), EdgeMetadata::default());
let lattice = builder.build(1);
let mut iter = NBestIterator::new(&lattice, 2);
let first = iter.next().expect("first path");
assert_eq!(first.weight.value(), 1.0);
let second = iter.next().expect("second path");
assert_eq!(second.weight.value(), 2.0);
assert!(iter.next().is_none()); }
#[test]
fn test_nbest_preserves_order() {
let backend = HashMapBackend::new();
let mut builder = LatticeBuilder::new(backend);
builder.add_correction(0, 1, "c", TropicalWeight::new(3.0), EdgeMetadata::default());
builder.add_correction(0, 1, "a", TropicalWeight::new(1.0), EdgeMetadata::default());
builder.add_correction(0, 1, "b", TropicalWeight::new(2.0), EdgeMetadata::default());
let mut lattice = builder.build(1);
let paths = nbest(&mut lattice, 3);
assert_eq!(paths[0].weight.value(), 1.0);
assert_eq!(paths[1].weight.value(), 2.0);
assert_eq!(paths[2].weight.value(), 3.0);
}
}
#[cfg(test)]
mod property_tests {
use super::*;
use crate::test_utils::{arb_diamond_lattice, arb_linear_lattice, arb_tropical_lattice};
use proptest::prelude::*;
use proptest::strategy::ValueTree;
proptest! {
#![proptest_config(ProptestConfig::with_cases(50))]
#[test]
fn nbest_linear_returns_one(
mut lattice in arb_linear_lattice(4)
) {
let paths = nbest(&mut lattice, 100);
prop_assert_eq!(paths.len(), 1);
}
#[test]
fn nbest_returns_sorted(
mut lattice in arb_tropical_lattice(3, 3)
) {
let paths = nbest(&mut lattice, 50);
for i in 1..paths.len() {
prop_assert!(
paths[i - 1].weight.value() <= paths[i].weight.value() + 1e-9,
"Path {} (weight {}) > Path {} (weight {})",
i - 1, paths[i - 1].weight.value(),
i, paths[i].weight.value()
);
}
}
#[test]
fn nbest_respects_limit(
mut lattice in arb_diamond_lattice(4), n in 1usize..10
) {
let paths = nbest(&mut lattice, n);
prop_assert!(paths.len() <= n);
}
#[test]
fn nbest_returns_at_least_one(
mut lattice in arb_tropical_lattice(2, 2)
) {
let paths = nbest(&mut lattice, 10);
prop_assert!(!paths.is_empty());
}
#[test]
fn nbest_first_matches_viterbi(
mut lattice in arb_tropical_lattice(3, 2)
) {
use crate::path::viterbi;
let viterbi_result = viterbi(&mut lattice);
let nbest_paths = nbest(&mut lattice, 1);
prop_assert!(viterbi_result.success);
prop_assert_eq!(nbest_paths.len(), 1);
let diff = (viterbi_result.path.weight.value() - nbest_paths[0].weight.value()).abs();
prop_assert!(diff < 1e-9, "Weight mismatch: viterbi={}, nbest={}",
viterbi_result.path.weight.value(), nbest_paths[0].weight.value());
}
#[test]
fn nbest_diamond_path_count(n in 1usize..5) {
let mut lattice = arb_diamond_lattice(n)
.new_tree(&mut proptest::test_runner::TestRunner::deterministic())
.expect("generate lattice")
.current();
let expected_count = 1usize << n; let paths = nbest(&mut lattice, 1000);
prop_assert_eq!(paths.len(), expected_count);
}
#[test]
fn nbest_paths_complete(
mut lattice in arb_tropical_lattice(3, 2)
) {
let paths = nbest(&mut lattice, 10);
for path in &paths {
prop_assert!(path.is_complete);
}
}
#[test]
fn nbest_paths_same_length(
mut lattice in arb_tropical_lattice(4, 2)
) {
let paths = nbest(&mut lattice, 20);
for path in &paths {
prop_assert_eq!(path.len(), 4, "Path length {} != expected 4", path.len());
}
}
}
}