use std::cell::RefCell;
use std::cmp::{min, Ordering};
use std::collections::BinaryHeap;
use std::rc::Rc;
use crate::TokenID;
#[derive(Debug, Clone)]
pub struct Node {
pub pos: usize,
pub token_id: TokenID,
pub token_len: usize,
pub score: f64,
pub prev: Option<usize>,
pub backtrack_score: f64,
}
impl PartialEq for Node {
fn eq(&self, other: &Node) -> bool {
self.token_id == other.token_id
}
}
impl Node {
pub fn new(pos: usize, token_id: TokenID, token_len: usize, score: f64) -> Self {
Self {
pos,
token_id,
token_len,
score,
prev: None,
backtrack_score: 0.0,
}
}
}
#[derive(Debug)]
pub struct Lattice<'a> {
pub sentence: &'a [u8],
pub begin_nodes: Vec<Vec<usize>>,
pub end_nodes: Vec<Vec<usize>>,
pub nodes: Vec<Node>,
bos_idx: usize,
eos_idx: usize,
}
impl<'a> Lattice<'a> {
pub fn new() -> Self {
Self {
sentence: &[],
nodes: Vec::with_capacity(1024 * 1024),
begin_nodes: vec![],
end_nodes: vec![],
bos_idx: 0,
eos_idx: 0,
}
}
pub fn from(&mut self, sentence: &'a [u8], vec_pool: &mut VecPool) {
self.sentence = sentence;
self.nodes.clear();
for vec in self.begin_nodes.drain(..) {
vec_pool.put(vec);
}
for vec in self.end_nodes.drain(..) {
vec_pool.put(vec);
}
self.begin_nodes
.resize_with(sentence.len() + 1, || vec_pool.get());
self.end_nodes
.resize_with(sentence.len() + 1, || vec_pool.get());
self.nodes.push(Node::new(0, TokenID::MAX - 1, 0, 0.0));
self.bos_idx = 0;
self.nodes
.push(Node::new(sentence.len(), TokenID::MAX, 0, 0.0));
self.eos_idx = 1;
self.end_nodes[0].push(self.bos_idx);
self.begin_nodes[sentence.len()].push(self.eos_idx);
}
pub fn insert(&mut self, pos: usize, token_id: TokenID, token_len: usize, score: f64) {
let node_idx = self.nodes.len();
self.begin_nodes[pos].push(node_idx);
self.end_nodes[pos + token_len].push(node_idx);
self.nodes.push(Node::new(pos, token_id, token_len, score));
}
pub fn viterbi(&mut self) -> Vec<Node> {
let sentence_len = self.sentence.len();
for pos in 0..=sentence_len {
for &rnode_idx in &self.begin_nodes[pos] {
self.nodes[rnode_idx].prev = None;
let mut best_score = 0.0;
let mut best_node = None;
for &lnode_idx in &self.end_nodes[pos] {
let score = self.nodes[lnode_idx].backtrack_score + self.nodes[rnode_idx].score;
if best_node.is_none() || score > best_score {
best_node = Some(lnode_idx);
best_score = score;
}
}
if best_node.is_none() {
return vec![];
}
self.nodes[rnode_idx].prev = best_node;
self.nodes[rnode_idx].backtrack_score = best_score;
}
}
let mut results = Vec::with_capacity(sentence_len / 4);
let mut node_idx = self.begin_nodes[sentence_len][0];
while let Some(prev_node_idx) = self.nodes[node_idx].prev {
let node = self.nodes[node_idx].clone();
results.push(node.clone());
node_idx = prev_node_idx;
}
results.reverse();
results
}
pub fn nbest(&mut self, n: usize) -> Vec<Vec<Node>> {
fn collect_nodes_from_indices(indices: &[usize], nodes: &[Node]) -> Vec<Node> {
let mut result = Vec::with_capacity(indices.len());
for &i in indices {
let node = nodes[i].clone();
result.push(node.clone());
}
result
}
match n {
0 => vec![],
1 => vec![self.viterbi()],
_ => {
let mut agenda: Agenda = BinaryHeap::new();
let mut hypotheses: Vec<Vec<usize>> = vec![];
let eos_id = 1;
let score = self.nodes[eos_id].score;
let hypo = Hypothesis::new(eos_id, None, score, score);
agenda.push(hypo);
self.viterbi();
while !agenda.is_empty() {
let top = Rc::new(RefCell::new(agenda.pop().unwrap()));
let node_idx = top.borrow().node_idx;
let node_id = self.nodes[node_idx].token_id;
let bos_node_id = self.nodes[self.bos_idx].token_id;
let node_pos = self.nodes[node_idx].pos;
if node_id == bos_node_id {
let mut hypothesis = vec![];
let mut next: HypothesisRef =
Rc::clone(top.borrow().next.as_ref().unwrap());
while next.borrow().next.is_some() {
hypothesis.push(next.borrow().node_idx);
let c: HypothesisRef = next.clone();
next = Rc::clone(c.borrow().next.as_ref().unwrap());
}
hypotheses.push(hypothesis);
if hypotheses.len() == n {
return hypotheses
.iter()
.map(|indices| collect_nodes_from_indices(indices, &self.nodes))
.collect();
}
} else {
for &lnode in &self.end_nodes[node_pos] {
let top_gx = top.borrow().gx;
let fx = self.nodes[lnode].backtrack_score + top_gx;
let gx = self.nodes[lnode].score + top_gx;
let hyp = Hypothesis::new(lnode, Some(Rc::clone(&top)), fx, gx);
agenda.push(hyp);
}
let k_max_agenda_size = 100_000;
let k_min_agenda_size = 512;
if agenda.len() > k_max_agenda_size {
let mut new_agenda = BinaryHeap::new();
let len = min(k_min_agenda_size, n * 10);
for _i in 0..len {
new_agenda.push(agenda.pop().unwrap());
}
agenda = new_agenda;
}
}
}
hypotheses
.iter()
.map(|indices| collect_nodes_from_indices(indices, &self.nodes))
.collect()
}
}
}
pub fn populate_marginal(&self, expected: &mut [f64]) -> f64 {
let len = self.sentence.len();
let num_nodes = self.nodes.len();
let mut alpha = vec![0.0; num_nodes];
let mut beta = vec![0.0; num_nodes];
for pos in 0..=len {
for &rid in &self.begin_nodes[pos] {
for &lid in &self.end_nodes[pos] {
alpha[rid] = log_sum_exp(
alpha[rid],
self.nodes[lid].score + alpha[lid],
lid == self.end_nodes[pos][0],
);
}
}
}
for pos in (0..=len).rev() {
for &lid in &self.end_nodes[pos] {
for &rid in &self.begin_nodes[pos] {
beta[lid] = log_sum_exp(
beta[lid],
self.nodes[rid].score + beta[rid],
rid == self.begin_nodes[pos][0],
);
}
}
}
let eos_idx = self.eos_idx;
let z = alpha[eos_idx];
for pos in 0..len {
for &node_idx in &self.begin_nodes[pos] {
let id = self.nodes[node_idx].token_id;
let score = self.nodes[node_idx].score;
let a = alpha[node_idx];
let b = beta[node_idx];
let total = a + score + b - z;
let update = total.exp();
expected[id as usize] += update;
}
}
z
}
}
impl<'a> Default for Lattice<'a> {
fn default() -> Self {
Self::new()
}
}
fn log_sum_exp(x: f64, y: f64, init_mode: bool) -> f64 {
if init_mode {
y
} else {
let (vmin, vmax) = if x > y { (y, x) } else { (x, y) };
let k_minus_log_epsilon = 50.0;
if vmax > vmin + k_minus_log_epsilon {
vmax
} else {
vmax + ((vmin - vmax).exp() + 1.0).ln()
}
}
}
type HypothesisRef = Rc<RefCell<Hypothesis>>;
type Agenda = BinaryHeap<Hypothesis>;
struct Hypothesis {
node_idx: usize,
next: Option<HypothesisRef>,
fx: f64,
gx: f64,
}
impl Hypothesis {
pub fn new(node_idx: usize, next: Option<HypothesisRef>, fx: f64, gx: f64) -> Self {
Self {
node_idx,
next,
fx,
gx,
}
}
}
impl PartialEq for Hypothesis {
fn eq(&self, other: &Self) -> bool {
self.fx == other.fx
}
}
impl Eq for Hypothesis {}
impl PartialOrd for Hypothesis {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Hypothesis {
fn cmp(&self, other: &Self) -> Ordering {
if self.fx < other.fx {
Ordering::Less
} else {
Ordering::Greater
}
}
}
#[derive(Default)]
pub struct VecPool {
pool: Vec<Vec<usize>>,
}
impl VecPool {
pub fn with_capacity(a: usize, b: usize) -> Self {
Self {
pool: vec![Vec::with_capacity(b); a],
}
}
pub fn get(&mut self) -> Vec<usize> {
self.pool.pop().unwrap_or_default()
}
pub fn put(&mut self, mut vec: Vec<usize>) {
vec.clear(); self.pool.push(vec);
}
}