use crate::semiring::{LogWeight, Semiring};
use crate::wfst::{MutableWfst, StateId, VectorWfst, Wfst};
pub type TokenId = u32;
pub const BLANK_TOKEN: TokenId = 0;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum TokenGraphType {
Standard,
Spike,
DurationLimited {
max_duration: usize,
},
EquallySpaced {
blank_count: usize,
},
}
#[derive(Clone, Debug)]
pub struct TokenGraphConfig {
pub graph_type: TokenGraphType,
pub include_blank: bool,
pub blank_id: TokenId,
pub init_weight: f64,
}
impl Default for TokenGraphConfig {
fn default() -> Self {
Self {
graph_type: TokenGraphType::Standard,
include_blank: true,
blank_id: BLANK_TOKEN,
init_weight: 0.0,
}
}
}
pub fn build_token_graph(
token: TokenId,
config: &TokenGraphConfig,
) -> VectorWfst<TokenId, LogWeight> {
match config.graph_type {
TokenGraphType::Standard => build_standard_token_graph(token, config),
TokenGraphType::Spike => build_spike_token_graph(token, config),
TokenGraphType::DurationLimited { max_duration } => {
build_duration_limited_token_graph(token, max_duration, config)
}
TokenGraphType::EquallySpaced { blank_count } => {
build_equally_spaced_token_graph(token, blank_count, config)
}
}
}
fn build_standard_token_graph(
token: TokenId,
config: &TokenGraphConfig,
) -> VectorWfst<TokenId, LogWeight> {
let mut fst = VectorWfst::new();
let s0 = fst.add_state();
let s1 = fst.add_state();
fst.set_start(s0);
fst.set_final(s1, LogWeight::one());
fst.add_arc(
s0,
Some(token),
Some(token),
s1,
LogWeight::new(config.init_weight),
);
fst.add_arc(
s1,
Some(token),
None,
s1,
LogWeight::new(config.init_weight),
);
if config.include_blank {
fst.add_arc(
s0,
Some(config.blank_id),
None,
s0,
LogWeight::new(config.init_weight),
);
fst.add_arc(
s1,
Some(config.blank_id),
None,
s1,
LogWeight::new(config.init_weight),
);
}
fst
}
fn build_spike_token_graph(
token: TokenId,
config: &TokenGraphConfig,
) -> VectorWfst<TokenId, LogWeight> {
let mut fst = VectorWfst::new();
let s0 = fst.add_state();
let s1 = fst.add_state();
fst.set_start(s0);
fst.set_final(s1, LogWeight::one());
fst.add_arc(
s0,
Some(token),
Some(token),
s1,
LogWeight::new(config.init_weight),
);
if config.include_blank {
fst.add_arc(
s0,
Some(config.blank_id),
None,
s0,
LogWeight::new(config.init_weight),
);
fst.add_arc(
s1,
Some(config.blank_id),
None,
s1,
LogWeight::new(config.init_weight),
);
}
fst
}
fn build_duration_limited_token_graph(
token: TokenId,
max_duration: usize,
config: &TokenGraphConfig,
) -> VectorWfst<TokenId, LogWeight> {
let mut fst = VectorWfst::new();
let mut states = Vec::with_capacity(max_duration + 1);
for _ in 0..=max_duration {
states.push(fst.add_state());
}
fst.set_start(states[0]);
fst.set_final(states[max_duration], LogWeight::one());
fst.add_arc(
states[0],
Some(token),
Some(token),
states[1],
LogWeight::new(config.init_weight),
);
for i in 1..max_duration {
fst.add_arc(
states[i],
Some(token),
None,
states[i + 1],
LogWeight::new(config.init_weight),
);
fst.set_final(states[i], LogWeight::one());
}
if config.include_blank {
fst.add_arc(
states[0],
Some(config.blank_id),
None,
states[0],
LogWeight::new(config.init_weight),
);
fst.add_arc(
states[max_duration],
Some(config.blank_id),
None,
states[max_duration],
LogWeight::new(config.init_weight),
);
}
fst
}
fn build_equally_spaced_token_graph(
token: TokenId,
blank_count: usize,
config: &TokenGraphConfig,
) -> VectorWfst<TokenId, LogWeight> {
let mut fst = VectorWfst::new();
let num_states = blank_count + 2;
let mut states = Vec::with_capacity(num_states);
for _ in 0..num_states {
states.push(fst.add_state());
}
fst.set_start(states[0]);
fst.set_final(states[num_states - 1], LogWeight::one());
fst.add_arc(
states[0],
Some(token),
Some(token),
states[1],
LogWeight::new(config.init_weight),
);
for i in 1..=blank_count {
fst.add_arc(
states[i],
Some(config.blank_id),
None,
states[i + 1],
LogWeight::new(config.init_weight),
);
}
fst.set_final(states[1], LogWeight::one());
fst
}
pub fn build_vocabulary_graph(
vocab_size: usize,
config: &TokenGraphConfig,
) -> VectorWfst<TokenId, LogWeight> {
let mut fst = VectorWfst::new();
let start = fst.add_state();
fst.set_start(start);
fst.set_final(start, LogWeight::one());
let start_token = if config.include_blank { 1 } else { 0 };
for token_id in start_token..(start_token + vocab_size as TokenId) {
let token_graph = build_token_graph(token_id, config);
let state_offset = fst.num_states() as StateId;
for _ in 0..token_graph.num_states() {
fst.add_state();
}
for s in 0..token_graph.num_states() as StateId {
for arc in token_graph.transitions(s) {
fst.add_arc(
s + state_offset,
arc.input.clone(),
arc.output.clone(),
arc.to + state_offset,
arc.weight,
);
}
}
let token_start = token_graph.start() + state_offset;
fst.add_arc(start, None, None, token_start, LogWeight::one());
for s in 0..token_graph.num_states() as StateId {
if token_graph.is_final(s) {
let mapped_state = s + state_offset;
fst.add_arc(mapped_state, None, None, start, token_graph.final_weight(s));
}
}
}
fst
}
pub fn build_blank_graph(config: &TokenGraphConfig) -> VectorWfst<TokenId, LogWeight> {
let mut fst = VectorWfst::new();
let s0 = fst.add_state();
fst.set_start(s0);
fst.set_final(s0, LogWeight::one());
fst.add_arc(
s0,
Some(config.blank_id),
None,
s0,
LogWeight::new(config.init_weight),
);
fst
}
#[derive(Clone, Debug, Default)]
pub struct TokenGraphStats {
pub num_states: usize,
pub num_arcs: usize,
pub graph_type: Option<TokenGraphType>,
}
impl TokenGraphStats {
pub fn from_wfst<L: Clone + Send + Sync>(fst: &VectorWfst<L, LogWeight>) -> Self {
let num_states = fst.num_states();
let num_arcs: usize = (0..num_states as StateId)
.map(|s| fst.transitions(s).len())
.sum();
Self {
num_states,
num_arcs,
graph_type: None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::wfst::NO_STATE;
#[test]
fn test_token_graph_config_default() {
let config = TokenGraphConfig::default();
assert_eq!(config.graph_type, TokenGraphType::Standard);
assert!(config.include_blank);
assert_eq!(config.blank_id, BLANK_TOKEN);
}
#[test]
fn test_standard_token_graph() {
let config = TokenGraphConfig::default();
let graph = build_token_graph(1, &config);
assert!(graph.start() != NO_STATE);
assert!(graph.num_states() >= 2);
let stats = TokenGraphStats::from_wfst(&graph);
assert!(stats.num_arcs >= 2);
}
#[test]
fn test_spike_token_graph() {
let config = TokenGraphConfig {
graph_type: TokenGraphType::Spike,
..Default::default()
};
let graph = build_token_graph(1, &config);
assert_eq!(graph.num_states(), 2);
}
#[test]
fn test_duration_limited_token_graph() {
let config = TokenGraphConfig {
graph_type: TokenGraphType::DurationLimited { max_duration: 3 },
include_blank: false,
..Default::default()
};
let graph = build_token_graph(1, &config);
assert_eq!(graph.num_states(), 4);
}
#[test]
fn test_equally_spaced_token_graph() {
let config = TokenGraphConfig {
graph_type: TokenGraphType::EquallySpaced { blank_count: 2 },
..Default::default()
};
let graph = build_token_graph(1, &config);
assert_eq!(graph.num_states(), 4);
}
#[test]
fn test_vocabulary_graph() {
let config = TokenGraphConfig {
graph_type: TokenGraphType::Spike,
include_blank: true,
blank_id: 0,
init_weight: 0.0,
};
let graph = build_vocabulary_graph(3, &config);
assert!(graph.num_states() > 1);
assert!(graph.start() != NO_STATE);
}
#[test]
fn test_blank_graph() {
let config = TokenGraphConfig::default();
let graph = build_blank_graph(&config);
assert_eq!(graph.num_states(), 1);
assert!(graph.is_final(0));
}
#[test]
fn test_token_graph_stats() {
let config = TokenGraphConfig::default();
let graph = build_token_graph(1, &config);
let stats = TokenGraphStats::from_wfst(&graph);
assert!(stats.num_states > 0);
assert!(stats.num_arcs > 0);
}
#[test]
fn test_duration_limited_all_states_reachable() {
let config = TokenGraphConfig {
graph_type: TokenGraphType::DurationLimited { max_duration: 2 },
include_blank: false,
..Default::default()
};
let graph = build_token_graph(1, &config);
assert!(graph.is_final(1)); assert!(graph.is_final(2)); }
#[test]
fn test_equally_spaced_requires_blanks() {
let config = TokenGraphConfig {
graph_type: TokenGraphType::EquallySpaced { blank_count: 2 },
include_blank: true,
blank_id: 0,
..Default::default()
};
let graph = build_token_graph(5, &config);
let stats = TokenGraphStats::from_wfst(&graph);
assert!(stats.num_arcs >= 3); }
}