use super::char_class::{CharDefinition, InvokeTiming};
use super::dictionary::Dictionary;
use super::double_array::DoubleArray;
use super::id::WordIdentifier;
use std::collections::{HashSet, VecDeque};
pub const BOS_CONTEXT_ID: usize = 0;
pub const EOS_CONTEXT_ID: usize = 0;
const NODE_BOS: usize = 0;
#[derive(Debug)]
pub struct Lattice {
pub indices: Vec<Vec<(WordIdentifier, usize)>>,
pub dp: Vec<Vec<(i32, usize, usize)>>,
}
impl Lattice {
pub fn parse<D: Dictionary>(text: &str, da: &DoubleArray, dict: &D) -> Lattice {
let len = text.chars().count();
let mut indices: Vec<Vec<(WordIdentifier, usize)>> = vec![vec![]; len];
let mut open_indices = VecDeque::from(vec![0]);
let mut visited = HashSet::with_capacity(len);
let char_defs = text
.chars()
.map(|c| dict.classify_char(&c))
.collect::<Vec<&CharDefinition>>();
while let Some(index) = open_indices.pop_front() {
if visited.contains(&index) || index >= len {
continue;
}
visited.insert(index);
let c = text.chars().nth(index).unwrap();
let def = char_defs[index];
if let InvokeTiming::Always = def.timing {
let surface_form = dict.take_unknown_chars_seq(def, text, &index);
open_indices.push_back(index + surface_form.chars().count());
for (wid, _) in dict.get_unknown_morphemes_by_class(&def.class) {
indices[index].push((
WordIdentifier::Unknown(wid, surface_form.to_string()),
surface_form.chars().count(),
));
}
}
if let Ok((mut cursor, _)) = da.init(c) {
if let Ok(wid) = da.stop(cursor as usize) {
open_indices.push_back(index + 1);
for wid in dict.resolve_homonyms(&wid).unwrap().iter() {
indices[index].push((
WordIdentifier::Known(*wid, text.chars().skip(index).take(1).collect()),
1,
));
}
}
let mut j = index + 1;
while j < len {
let c = text.chars().nth(j).unwrap();
match da.transition(cursor as usize, c) {
Ok((next, _)) => {
if let Ok(wid) = da.stop(next as usize) {
open_indices.push_back(j + 1);
for wid in dict.resolve_homonyms(&wid).unwrap().iter() {
indices[index].push((
WordIdentifier::Known(
*wid,
text.chars().skip(index).take(j + 1 - index).collect(),
),
j + 1 - index,
));
}
}
cursor = next;
}
Err(_) => {
break;
}
}
j += 1;
}
}
if indices[index].is_empty() && matches!(def.timing, InvokeTiming::Fallback) {
let surface_form = dict.take_unknown_chars_seq(def, text, &index);
open_indices.push_back(index + surface_form.chars().count());
for (wid, _) in dict.get_unknown_morphemes_by_class(&def.class) {
indices[index].push((
WordIdentifier::Unknown(wid, surface_form.to_string()),
surface_form.chars().count(),
));
}
}
}
Lattice {
dp: get_dp_table(&indices, dict),
indices,
}
}
pub fn word_identifiers(&self) -> Vec<WordIdentifier> {
let mut wids = vec![];
for idx in self.indices.iter() {
for (wid, _) in idx.iter() {
wids.push(wid.clone())
}
}
wids
}
pub fn find_best_path(&self) -> Option<Vec<(usize, usize)>> {
let mut path = vec![];
let mut cursor = (self.dp.len() - 1, 0);
loop {
match self.dp[cursor.0].get(cursor.1) {
Some((_, i, j)) => {
if *i == NODE_BOS {
break;
}
path.insert(0, (*i, *j));
cursor = (*i, *j);
}
_ => return None,
}
}
Some(path)
}
pub fn find_best(&self) -> Option<Vec<WordIdentifier>> {
match self.find_best_path() {
Some(best_path) => {
let mut ids = vec![];
for (i, j) in best_path.iter() {
ids.push(self.indices[*i - 1][*j].0.clone());
}
Some(ids)
}
None => None,
}
}
}
fn get_dp_table<D: Dictionary>(
indices: &[Vec<(WordIdentifier, usize)>],
dict: &D,
) -> Vec<Vec<(i32, usize, usize)>> {
let len = indices.len();
let max_num_childs = indices.iter().map(|idx| idx.len()).max().unwrap();
let mut dp: Vec<Vec<(i32, usize, usize)>> =
vec![vec![(i32::MAX, 0, 0); max_num_childs]; len + 2];
if max_num_childs == 0 {
return dp;
}
dp[0][0] = (0, 0, 0);
for (i, (right_wid, _)) in indices[0].iter().enumerate() {
let right = dict.get(right_wid).unwrap();
let cost = dict
.transition_cost(&BOS_CONTEXT_ID, &right.right_context_id)
.unwrap()
+ right.cost;
dp[1][i] = (cost as i32, NODE_BOS, 0);
}
for (i, index) in indices.iter().enumerate() {
for (j, (left_wid, wlen)) in index.iter().enumerate() {
let before_cost = dp[i + 1][j].0;
let left = dict.get(left_wid).unwrap();
if i + wlen >= len {
let cost = (*dict
.transition_cost(&left.left_context_id, &EOS_CONTEXT_ID)
.unwrap() as i32)
+ (left.cost as i32)
+ before_cost;
if cost < dp[i + wlen + 1][0].0 {
dp[i + wlen + 1][0] = (cost, i + 1, j);
}
continue;
}
for (k, (right_wid, _)) in indices[i + wlen].iter().enumerate() {
let right = dict.get(right_wid).unwrap();
let cost = (*dict
.transition_cost(&left.left_context_id, &right.right_context_id)
.unwrap() as i32)
+ left.cost as i32
+ right.cost as i32
+ before_cost;
if cost < dp[i + 1 + wlen][k].0 {
dp[i + 1 + wlen][k] = (cost, i + 1, j);
}
}
}
}
dp
}