use std::collections::HashMap;
use crate::semiring::{LogWeight, Semiring};
#[cfg(test)]
use crate::wfst::Wfst;
use crate::wfst::{MutableWfst, StateId, VectorWfst};
pub type WordPieceId = u32;
pub type GraphemeId = u32;
#[derive(Clone, Debug)]
pub struct LexiconEntry {
pub word_piece: WordPieceId,
pub graphemes: Vec<GraphemeId>,
pub weight: f64,
}
impl LexiconEntry {
pub fn new(word_piece: WordPieceId, graphemes: Vec<GraphemeId>) -> Self {
Self {
word_piece,
graphemes,
weight: 0.0,
}
}
pub fn with_weight(word_piece: WordPieceId, graphemes: Vec<GraphemeId>, weight: f64) -> Self {
Self {
word_piece,
graphemes,
weight,
}
}
}
#[derive(Clone, Debug)]
pub struct LexiconConfig {
pub allow_multiple_decompositions: bool,
pub init_weight: f64,
pub word_boundary: Option<GraphemeId>,
}
impl Default for LexiconConfig {
fn default() -> Self {
Self {
allow_multiple_decompositions: true,
init_weight: 0.0,
word_boundary: None,
}
}
}
pub fn build_lexicon_transducer(
entries: &[LexiconEntry],
config: &LexiconConfig,
) -> VectorWfst<WordPieceId, LogWeight> {
let mut fst = VectorWfst::new();
let start = fst.add_state();
fst.set_start(start);
fst.set_final(start, LogWeight::one());
for entry in entries {
add_lexicon_entry(&mut fst, start, entry, config);
}
fst
}
fn add_lexicon_entry(
fst: &mut VectorWfst<WordPieceId, LogWeight>,
start: StateId,
entry: &LexiconEntry,
_config: &LexiconConfig,
) {
if entry.graphemes.is_empty() {
fst.add_arc(
start,
Some(entry.word_piece),
None,
start,
LogWeight::new(entry.weight),
);
return;
}
let num_graphemes = entry.graphemes.len();
let mut current = start;
let next = fst.add_state();
fst.add_arc(
current,
Some(entry.word_piece),
Some(entry.graphemes[0]),
next,
LogWeight::new(entry.weight),
);
current = next;
for i in 1..num_graphemes - 1 {
let next = fst.add_state();
fst.add_arc(
current,
None,
Some(entry.graphemes[i]),
next,
LogWeight::one(),
);
current = next;
}
if num_graphemes > 1 {
fst.add_arc(
current,
None,
Some(entry.graphemes[num_graphemes - 1]),
start,
LogWeight::one(),
);
} else {
fst.add_arc(current, None, None, start, LogWeight::one());
}
}
pub fn build_target_graph(graphemes: &[GraphemeId]) -> VectorWfst<GraphemeId, LogWeight> {
let mut fst = VectorWfst::new();
if graphemes.is_empty() {
let s = fst.add_state();
fst.set_start(s);
fst.set_final(s, LogWeight::one());
return fst;
}
let mut states = Vec::with_capacity(graphemes.len() + 1);
for _ in 0..=graphemes.len() {
states.push(fst.add_state());
}
fst.set_start(states[0]);
fst.set_final(states[graphemes.len()], LogWeight::one());
for (i, &grapheme) in graphemes.iter().enumerate() {
fst.add_arc(
states[i],
Some(grapheme),
Some(grapheme),
states[i + 1],
LogWeight::one(),
);
}
fst
}
#[derive(Clone, Debug)]
pub struct MarginalizationContext {
pub vocab_size: usize,
pub grapheme_vocab_size: usize,
pub initialized: bool,
}
impl MarginalizationContext {
pub fn new(vocab_size: usize, grapheme_vocab_size: usize) -> Self {
Self {
vocab_size,
grapheme_vocab_size,
initialized: false,
}
}
pub fn initialize(&mut self, _entries: &[LexiconEntry]) {
self.initialized = true;
}
}
pub fn marginalized_loss(
emissions: &VectorWfst<WordPieceId, LogWeight>,
_lexicon: &VectorWfst<WordPieceId, LogWeight>,
target: &[GraphemeId],
) -> f64 {
let _target_graph = build_target_graph(target);
compute_emission_score(emissions)
}
fn compute_emission_score(emissions: &VectorWfst<WordPieceId, LogWeight>) -> f64 {
use super::forward_score::forward_score;
use super::gradient::GradientWfst;
let grad_fst = GradientWfst::from_wfst(emissions);
let score = forward_score(&grad_fst);
score.value()
}
#[derive(Clone, Debug)]
pub struct MarginalizationResult {
pub loss: f64,
pub emission_gradients: Vec<f64>,
pub stats: MarginalizationStats,
}
#[derive(Clone, Debug, Default)]
pub struct MarginalizationStats {
pub num_decompositions: usize,
pub avg_decomposition_length: f64,
pub best_decomposition: Vec<WordPieceId>,
}
pub fn build_identity_lexicon(vocab_size: usize) -> Vec<LexiconEntry> {
(0..vocab_size as WordPieceId)
.map(|wp| LexiconEntry::new(wp, vec![wp]))
.collect()
}
pub fn build_character_lexicon(word_pieces: &HashMap<WordPieceId, String>) -> Vec<LexiconEntry> {
word_pieces
.iter()
.map(|(&wp_id, wp_str)| {
let graphemes: Vec<GraphemeId> = wp_str.chars().map(|c| c as GraphemeId).collect();
LexiconEntry::new(wp_id, graphemes)
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::wfst::NO_STATE;
#[test]
fn test_lexicon_entry_creation() {
let entry = LexiconEntry::new(1, vec![10, 11, 12]);
assert_eq!(entry.word_piece, 1);
assert_eq!(entry.graphemes, vec![10, 11, 12]);
assert_eq!(entry.weight, 0.0);
}
#[test]
fn test_lexicon_entry_with_weight() {
let entry = LexiconEntry::with_weight(2, vec![20, 21], -0.5);
assert_eq!(entry.word_piece, 2);
assert_eq!(entry.weight, -0.5);
}
#[test]
fn test_lexicon_config_default() {
let config = LexiconConfig::default();
assert!(config.allow_multiple_decompositions);
assert_eq!(config.init_weight, 0.0);
assert!(config.word_boundary.is_none());
}
#[test]
fn test_build_lexicon_transducer() {
let entries = vec![
LexiconEntry::new(1, vec![10, 11]),
LexiconEntry::new(2, vec![20]),
];
let config = LexiconConfig::default();
let fst = build_lexicon_transducer(&entries, &config);
assert!(fst.start() != NO_STATE);
assert!(fst.num_states() > 1);
}
#[test]
fn test_build_lexicon_empty_entry() {
let entries = vec![
LexiconEntry::new(1, vec![]), ];
let config = LexiconConfig::default();
let fst = build_lexicon_transducer(&entries, &config);
assert!(fst.start() != NO_STATE);
}
#[test]
fn test_build_target_graph() {
let graphemes = vec![10, 11, 12];
let fst = build_target_graph(&graphemes);
assert_eq!(fst.num_states(), 4); assert!(fst.start() != NO_STATE);
assert!(fst.is_final(3));
}
#[test]
fn test_build_target_graph_empty() {
let graphemes: Vec<GraphemeId> = vec![];
let fst = build_target_graph(&graphemes);
assert_eq!(fst.num_states(), 1);
assert!(fst.is_final(0));
}
#[test]
fn test_marginalization_context() {
let mut ctx = MarginalizationContext::new(100, 256);
assert!(!ctx.initialized);
let entries = vec![LexiconEntry::new(0, vec![0])];
ctx.initialize(&entries);
assert!(ctx.initialized);
}
#[test]
fn test_build_identity_lexicon() {
let lexicon = build_identity_lexicon(10);
assert_eq!(lexicon.len(), 10);
for (i, entry) in lexicon.iter().enumerate() {
assert_eq!(entry.word_piece, i as WordPieceId);
assert_eq!(entry.graphemes, vec![i as GraphemeId]);
}
}
#[test]
fn test_build_character_lexicon() {
let mut word_pieces = HashMap::new();
word_pieces.insert(0, "a".to_string());
word_pieces.insert(1, "bc".to_string());
let lexicon = build_character_lexicon(&word_pieces);
assert_eq!(lexicon.len(), 2);
let bc_entry = lexicon
.iter()
.find(|e| e.word_piece == 1)
.expect("differentiable/marginalization.rs: required value was None/Err");
assert_eq!(
bc_entry.graphemes,
vec!['b' as GraphemeId, 'c' as GraphemeId]
);
}
#[test]
fn test_marginalization_stats_default() {
let stats = MarginalizationStats::default();
assert_eq!(stats.num_decompositions, 0);
assert!(stats.best_decomposition.is_empty());
}
}