use core::num::NonZeroU32;
use alloc::vec::Vec;
use bincode::{
de::Decoder,
enc::Encoder,
error::{DecodeError, EncodeError},
Decode, Encode,
};
use hashbrown::HashMap;
use crate::errors::{Result, RucrfError};
use crate::feature::{self, FeatureProvider};
use crate::lattice::{Edge, Lattice};
use crate::utils::FromU32;
pub trait Model {
fn search_best_path(&self, lattice: &Lattice) -> (Vec<Edge>, f64);
}
pub struct RawModel {
weights: Vec<f64>,
unigram_weight_indices: Vec<Option<NonZeroU32>>,
bigram_weight_indices: Vec<HashMap<u32, u32>>,
provider: FeatureProvider,
}
impl Decode for RawModel {
fn decode<D: Decoder>(decoder: &mut D) -> Result<Self, DecodeError> {
let weights = Decode::decode(decoder)?;
let unigram_weight_indices: Vec<Option<NonZeroU32>> = Decode::decode(decoder)?;
let bigram_weight_indices: Vec<Vec<(u32, u32)>> = Decode::decode(decoder)?;
let provider: FeatureProvider = Decode::decode(decoder)?;
Ok(Self {
weights,
unigram_weight_indices,
bigram_weight_indices: bigram_weight_indices
.into_iter()
.map(|v| v.into_iter().collect())
.collect(),
provider,
})
}
}
bincode::impl_borrow_decode!(RawModel);
impl Encode for RawModel {
fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
let bigram_weight_indices: Vec<Vec<(u32, u32)>> = self
.bigram_weight_indices
.iter()
.map(|v| v.iter().map(|(&k, &v)| (k, v)).collect())
.collect();
Encode::encode(&self.weights, encoder)?;
Encode::encode(&self.unigram_weight_indices, encoder)?;
Encode::encode(&bigram_weight_indices, encoder)?;
Encode::encode(&self.provider, encoder)?;
Ok(())
}
}
impl RawModel {
#[cfg(feature = "train")]
pub(crate) fn new(
weights: Vec<f64>,
unigram_weight_indices: Vec<Option<NonZeroU32>>,
bigram_weight_indices: Vec<HashMap<u32, u32>>,
provider: FeatureProvider,
) -> Self {
Self {
weights,
unigram_weight_indices,
bigram_weight_indices,
provider,
}
}
pub fn feature_provider(&mut self) -> &mut FeatureProvider {
&mut self.provider
}
#[allow(clippy::missing_panics_doc)]
pub fn merge(&self) -> Result<MergedModel> {
let mut left_conn_ids = HashMap::new();
let mut right_conn_ids = HashMap::new();
let mut left_conn_to_right_feats = vec![];
let mut right_conn_to_left_feats = vec![];
let mut new_feature_sets = vec![];
for feature_set in &self.provider.feature_sets {
let mut weight = 0.0;
for fid in feature_set.unigram() {
let fid = usize::from_u32(fid.get() - 1);
if let Some(widx) = self.unigram_weight_indices.get(fid).copied().flatten() {
weight += self.weights[usize::from_u32(widx.get() - 1)];
}
}
let left_id = {
let new_id = u32::try_from(left_conn_to_right_feats.len() + 1)
.map_err(|_| RucrfError::model_scale("connection ID too large"))?;
*left_conn_ids
.raw_entry_mut()
.from_key(feature_set.bigram_right())
.or_insert_with(|| {
let features = feature_set.bigram_right().to_vec();
left_conn_to_right_feats.push(features.clone());
(features, NonZeroU32::new(new_id).unwrap())
})
.1
};
let right_id = {
let new_id = u32::try_from(right_conn_to_left_feats.len() + 1)
.map_err(|_| RucrfError::model_scale("connection ID too large"))?;
*right_conn_ids
.raw_entry_mut()
.from_key(feature_set.bigram_left())
.or_insert_with(|| {
let features = feature_set.bigram_left().to_vec();
right_conn_to_left_feats.push(features.clone());
(features, NonZeroU32::new(new_id).unwrap())
})
.1
};
new_feature_sets.push(MergedFeatureSet {
weight,
left_id,
right_id,
});
}
let mut matrix = vec![];
let mut m = HashMap::new();
for (i, left_ids) in left_conn_to_right_feats.iter().enumerate() {
let mut weight = 0.0;
for fid in left_ids.iter().flatten() {
if let Some(&widx) = self.bigram_weight_indices[0].get(&fid.get()) {
weight += self.weights[usize::from_u32(widx)];
}
}
if weight.abs() >= f64::EPSILON {
m.insert(
u32::try_from(i + 1)
.map_err(|_| RucrfError::model_scale("connection ID too large"))?,
weight,
);
}
}
matrix.push(m);
for right_ids in &right_conn_to_left_feats {
let mut m = HashMap::new();
let mut weight = 0.0;
for fid in right_ids.iter().flatten() {
let right_id = usize::from_u32(fid.get());
if let Some(&widx) = self
.bigram_weight_indices
.get(right_id)
.and_then(|hm| hm.get(&0))
{
weight += self.weights[usize::from_u32(widx)];
}
}
if weight.abs() >= f64::EPSILON {
m.insert(0, weight);
}
for (i, left_ids) in left_conn_to_right_feats.iter().enumerate() {
let mut weight = 0.0;
for (right_id, left_id) in right_ids.iter().zip(left_ids) {
if let (Some(right_id), Some(left_id)) = (right_id, left_id) {
let right_id = usize::from_u32(right_id.get());
let left_id = left_id.get();
if let Some(&widx) = self
.bigram_weight_indices
.get(right_id)
.and_then(|hm| hm.get(&left_id))
{
weight += self.weights[usize::from_u32(widx)];
}
}
}
if weight.abs() >= f64::EPSILON {
m.insert(
u32::try_from(i + 1)
.map_err(|_| RucrfError::model_scale("connection ID too large"))?,
weight,
);
}
}
matrix.push(m);
}
Ok(MergedModel {
feature_sets: new_feature_sets,
matrix,
left_conn_to_right_feats,
right_conn_to_left_feats,
})
}
#[must_use]
pub fn unigram_weight_indices(&self) -> &[Option<NonZeroU32>] {
&self.unigram_weight_indices
}
#[must_use]
pub fn bigram_weight_indices(&self) -> &[HashMap<u32, u32>] {
&self.bigram_weight_indices
}
#[must_use]
pub fn weights(&self) -> &[f64] {
&self.weights
}
}
impl Model for RawModel {
#[must_use]
fn search_best_path(&self, lattice: &Lattice) -> (Vec<Edge>, f64) {
let mut best_scores = vec![vec![]; lattice.nodes().len()];
best_scores[lattice.nodes().len() - 1].push((0, 0, None, 0.0));
for (i, node) in lattice.nodes().iter().enumerate() {
for edge in node.edges() {
let mut score = 0.0;
if let Some(feature_set) = self.provider.get_feature_set(edge.label) {
for &fid in feature_set.unigram() {
let fid = usize::from_u32(fid.get() - 1);
if let Some(widx) = self.unigram_weight_indices[fid] {
score += self.weights[usize::from_u32(widx.get() - 1)];
}
}
}
best_scores[i].push((edge.target(), 0, Some(edge.label), score));
}
}
for i in (0..lattice.nodes().len() - 1).rev() {
for j in 0..best_scores[i].len() {
let (k, _, curr_label, _) = best_scores[i][j];
let mut best_score = f64::NEG_INFINITY;
let mut best_idx = 0;
for (p, &(_, _, next_label, mut score)) in best_scores[k].iter().enumerate() {
feature::apply_bigram(
curr_label,
next_label,
&self.provider,
&self.bigram_weight_indices,
|widx| {
score += self.weights[usize::from_u32(widx)];
},
);
if score > best_score {
best_score = score;
best_idx = p;
}
}
best_scores[i][j].1 = best_idx;
best_scores[i][j].3 += best_score;
}
}
let mut best_score = f64::NEG_INFINITY;
let mut idx = 0;
for (p, &(_, _, next_label, mut score)) in best_scores[0].iter().enumerate() {
feature::apply_bigram(
None,
next_label,
&self.provider,
&self.bigram_weight_indices,
|widx| {
score += self.weights[usize::from_u32(widx)];
},
);
if score > best_score {
best_score = score;
idx = p;
}
}
let mut pos = 0;
let mut best_path = vec![];
while pos < lattice.nodes().len() - 1 {
let edge = &lattice.nodes()[pos].edges()[idx];
idx = best_scores[pos][idx].1;
pos = edge.target();
best_path.push(Edge::new(pos, edge.label()));
}
(best_path, best_score)
}
}
#[derive(Clone, Copy, Debug, Decode, Encode)]
pub struct MergedFeatureSet {
pub weight: f64,
pub left_id: NonZeroU32,
pub right_id: NonZeroU32,
}
pub struct MergedModel {
pub feature_sets: Vec<MergedFeatureSet>,
pub matrix: Vec<HashMap<u32, f64>>,
pub left_conn_to_right_feats: Vec<Vec<Option<NonZeroU32>>>,
pub right_conn_to_left_feats: Vec<Vec<Option<NonZeroU32>>>,
}
impl Decode for MergedModel {
fn decode<D: Decoder>(decoder: &mut D) -> Result<Self, DecodeError> {
let feature_sets = Decode::decode(decoder)?;
let matrix: Vec<Vec<(u32, f64)>> = Decode::decode(decoder)?;
let left_conn_to_right_feats = Decode::decode(decoder)?;
let right_conn_to_left_feats = Decode::decode(decoder)?;
Ok(Self {
feature_sets,
matrix: matrix
.into_iter()
.map(|x| x.into_iter().collect())
.collect(),
left_conn_to_right_feats,
right_conn_to_left_feats,
})
}
}
bincode::impl_borrow_decode!(MergedModel);
impl Encode for MergedModel {
fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
let matrix: Vec<Vec<(u32, f64)>> = self
.matrix
.clone()
.into_iter()
.map(|x| x.into_iter().collect())
.collect();
Encode::encode(&self.feature_sets, encoder)?;
Encode::encode(&matrix, encoder)?;
Encode::encode(&self.left_conn_to_right_feats, encoder)?;
Encode::encode(&self.right_conn_to_left_feats, encoder)?;
Ok(())
}
}
impl Model for MergedModel {
#[must_use]
fn search_best_path(&self, lattice: &Lattice) -> (Vec<Edge>, f64) {
let mut best_scores = vec![vec![]; lattice.nodes().len()];
best_scores[lattice.nodes().len() - 1].push((0, 0, None, 0.0));
for (i, node) in lattice.nodes().iter().enumerate() {
for edge in node.edges() {
let label = usize::from_u32(edge.label.get() - 1);
let score = self.feature_sets.get(label).map_or(0.0, |s| s.weight);
best_scores[i].push((edge.target(), 0, Some(edge.label), score));
}
}
for i in (0..lattice.nodes().len() - 1).rev() {
for j in 0..best_scores[i].len() {
let (k, _, curr_label, _) = best_scores[i][j];
let mut best_score = f64::NEG_INFINITY;
let mut best_idx = 0;
let curr_id = curr_label.map_or(Some(0), |label| {
self.feature_sets
.get(usize::from_u32(label.get() - 1))
.map(|s| s.right_id.get())
});
for (p, &(_, _, next_label, mut score)) in best_scores[k].iter().enumerate() {
let next_id = next_label.map_or(Some(0), |label| {
self.feature_sets
.get(usize::from_u32(label.get() - 1))
.map(|s| s.left_id.get())
});
if let (Some(curr_id), Some(next_id)) = (curr_id, next_id) {
score += self
.matrix
.get(usize::from_u32(curr_id))
.and_then(|hm| hm.get(&next_id))
.unwrap_or(&0.0);
}
if score > best_score {
best_score = score;
best_idx = p;
}
}
best_scores[i][j].1 = best_idx;
best_scores[i][j].3 += best_score;
}
}
let mut best_score = f64::NEG_INFINITY;
let mut idx = 0;
for (p, &(_, _, next_label, mut score)) in best_scores[0].iter().enumerate() {
let next_id = next_label.map_or(Some(0), |label| {
self.feature_sets
.get(usize::from_u32(label.get() - 1))
.map(|s| s.right_id.get())
});
if let Some(next_id) = next_id {
score += self
.matrix
.get(0)
.and_then(|hm| hm.get(&next_id))
.unwrap_or(&0.0);
}
if score > best_score {
best_score = score;
idx = p;
}
}
let mut pos = 0;
let mut best_path = vec![];
while pos < lattice.nodes().len() - 1 {
let edge = &lattice.nodes()[pos].edges()[idx];
idx = best_scores[pos][idx].1;
pos = edge.target();
best_path.push(Edge::new(pos, edge.label()));
}
(best_path, best_score)
}
}
#[cfg(test)]
mod tests {
use super::*;
use core::num::NonZeroU32;
use crate::lattice::Edge;
use crate::test_utils::{self, hashmap};
#[test]
fn test_search_best_path() {
let model = RawModel {
weights: vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 13.0, 24.0, 5.0, 26.0, 27.0, 6.0,
],
unigram_weight_indices: vec![
NonZeroU32::new(2),
NonZeroU32::new(4),
NonZeroU32::new(6),
NonZeroU32::new(8),
],
bigram_weight_indices: vec![
hashmap![0 => 28, 1 => 0, 2 => 2, 3 => 4, 4 => 6],
hashmap![0 => 8, 1 => 9, 2 => 10, 3 => 11, 4 => 12],
hashmap![0 => 13, 1 => 14, 2 => 15, 3 => 16, 4 => 17],
hashmap![0 => 18, 1 => 19, 2 => 20, 3 => 21, 4 => 22],
hashmap![0 => 23, 1 => 24, 2 => 25, 3 => 26, 4 => 27],
],
provider: test_utils::generate_test_feature_provider(),
};
let lattice = test_utils::generate_test_lattice();
let (path, score) = model.search_best_path(&lattice);
assert_eq!(
vec![
Edge::new(1, NonZeroU32::new(1).unwrap()),
Edge::new(2, NonZeroU32::new(2).unwrap()),
Edge::new(3, NonZeroU32::new(6).unwrap()),
Edge::new(4, NonZeroU32::new(7).unwrap()),
Edge::new(5, NonZeroU32::new(4).unwrap()),
],
path,
);
assert!((194.0 - score).abs() < f64::EPSILON);
}
#[test]
fn test_hashed_search_best_path() {
let model = RawModel {
weights: vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 13.0, 24.0, 5.0, 26.0, 27.0, 6.0,
],
unigram_weight_indices: vec![
NonZeroU32::new(2),
NonZeroU32::new(4),
NonZeroU32::new(6),
NonZeroU32::new(8),
],
bigram_weight_indices: vec![
hashmap![0 => 28, 1 => 0, 2 => 2, 3 => 4, 4 => 6],
hashmap![0 => 8, 1 => 9, 2 => 10, 3 => 11, 4 => 12],
hashmap![0 => 13, 1 => 14, 2 => 15, 3 => 16, 4 => 17],
hashmap![0 => 18, 1 => 19, 2 => 20, 3 => 21, 4 => 22],
hashmap![0 => 23, 1 => 24, 2 => 25, 3 => 26, 4 => 27],
],
provider: test_utils::generate_test_feature_provider(),
};
let merged_model = model.merge().unwrap();
let lattice = test_utils::generate_test_lattice();
let (path, score) = merged_model.search_best_path(&lattice);
assert_eq!(
vec![
Edge::new(1, NonZeroU32::new(1).unwrap()),
Edge::new(2, NonZeroU32::new(2).unwrap()),
Edge::new(3, NonZeroU32::new(6).unwrap()),
Edge::new(4, NonZeroU32::new(7).unwrap()),
Edge::new(5, NonZeroU32::new(4).unwrap()),
],
path,
);
assert!((194.0 - score).abs() < f64::EPSILON);
}
}