1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
use std::cmp::Ordering;
use std::collections::BinaryHeap;
use crate::viterbi::{Lattice, WordId};
/// An element in the A* priority queue for N-Best search.
/// Represents a partial path from EOS backward toward BOS.
#[derive(Clone, Debug)]
struct QueueElement {
/// Byte position of the current edge in ends_at
byte_pos: u32,
/// Index of the current edge in ends_at[byte_pos]
edge_index: u16,
/// f(x) = g(x) + h(x) -- total estimated cost
fx: i64,
/// g(x) = accumulated real cost from EOS backward to this point
gx: i64,
/// Link to the previous QueueElement in the elements chain (toward EOS)
prev: Option<usize>,
}
/// Min-heap ordering: lower fx = higher priority
impl Ord for QueueElement {
fn cmp(&self, other: &Self) -> Ordering {
other.fx.cmp(&self.fx) // Reversed for min-heap
}
}
impl PartialOrd for QueueElement {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Eq for QueueElement {}
impl PartialEq for QueueElement {
fn eq(&self, other: &Self) -> bool {
self.fx == other.fx
}
}
/// Generates N-best paths through a Lattice using Backward A* search.
///
/// After forward Viterbi (set_text_nbest), this generator uses the recorded
/// all_paths transitions and path_cost heuristics to enumerate paths
/// from EOS to BOS in order of increasing total cost.
pub struct NBestGenerator<'a> {
lattice: &'a Lattice,
queue: BinaryHeap<QueueElement>,
/// Storage for QueueElement chain (for path reconstruction)
elements: Vec<QueueElement>,
}
impl<'a> NBestGenerator<'a> {
/// Initialize the generator from a lattice that has been processed
/// with set_text_nbest().
pub fn new(lattice: &'a Lattice) -> Self {
let mut generator = NBestGenerator {
lattice,
queue: BinaryHeap::new(),
elements: Vec::new(),
};
generator.init();
generator
}
fn init(&mut self) {
let text_len = self.lattice.text_len();
let eos_edges = self.lattice.edges_at(text_len);
if eos_edges.is_empty() {
return;
}
// EOS is the last edge pushed to ends_at[text_len]
let eos_index = (eos_edges.len() - 1) as u16;
let eos_edge = &eos_edges[eos_index as usize];
// Initial element: start from EOS with g(x)=0
let elem = QueueElement {
byte_pos: text_len as u32,
edge_index: eos_index,
fx: eos_edge.path_cost as i64,
gx: 0,
prev: None,
};
self.queue.push(elem);
}
/// Returns the next best path as (path, cost).
/// The path is a vector of (byte_start, WordId) pairs.
/// The cost is the total path cost (fx at BOS), lower is better.
/// Returns None when no more paths are available.
#[allow(clippy::should_implement_trait)]
pub fn next(&mut self) -> Option<(Vec<(usize, WordId)>, i64)> {
while let Some(current) = self.queue.pop() {
let byte_pos = current.byte_pos as usize;
let edge_index = current.edge_index as usize;
let edges = self.lattice.edges_at(byte_pos);
if edge_index >= edges.len() {
continue;
}
let edge = &edges[edge_index];
// Check if we reached BOS (left_index == u16::MAX means no predecessor = BOS)
if edge.left_index == u16::MAX {
return Some((self.reconstruct_path(¤t), current.fx));
}
// Store current element for chain linking
let current_idx = self.elements.len();
self.elements.push(current.clone());
// Expand: for each predecessor path of this edge
let paths = self.lattice.paths_at(byte_pos);
for path_entry in paths {
if path_entry.edge_index != edge_index as u16 {
continue; // Not a path to this edge
}
let left_pos = path_entry.left_pos as usize;
let left_index = path_entry.left_index as usize;
let left_edges = self.lattice.edges_at(left_pos);
if left_index >= left_edges.len() {
continue;
}
let left_edge = &left_edges[left_index];
// g(x) for the predecessor:
// path_entry.cost = left_edge.path_cost + conn_cost + penalty
// conn_and_penalty = path_entry.cost - left_edge.path_cost
// new_gx = current.gx + conn_and_penalty + edge.word_cost
let conn_and_penalty = path_entry.cost as i64 - left_edge.path_cost as i64;
let new_gx = current.gx + conn_and_penalty + edge.word_entry.word_cost as i64;
// f(x) = h(x) + g(x), where h(x) = left_edge.path_cost
let new_fx = left_edge.path_cost as i64 + new_gx;
let new_elem = QueueElement {
byte_pos: left_pos as u32,
edge_index: left_index as u16,
fx: new_fx,
gx: new_gx,
prev: Some(current_idx),
};
self.queue.push(new_elem);
}
}
None
}
fn reconstruct_path(&self, bos_elem: &QueueElement) -> Vec<(usize, WordId)> {
let mut path = Vec::new();
let mut maybe_idx = bos_elem.prev;
// Walk the chain from BOS toward EOS via prev links.
// The chain visits edges in forward order (first word first)
// because each element's prev points toward EOS.
while let Some(idx) = maybe_idx {
let elem = &self.elements[idx];
let edges = self.lattice.edges_at(elem.byte_pos as usize);
let edge = &edges[elem.edge_index as usize];
// Skip EOS edge (start_index == stop_index)
if edge.start_index != edge.stop_index {
path.push((edge.start_index as usize, edge.word_entry.word_id));
}
maybe_idx = elem.prev;
}
path
}
}