use std::collections::VecDeque;
use rustc_hash::FxHashMap;
use smallvec::SmallVec;
use crate::semiring::{LogWeight, Semiring};
use crate::wfst::StateId;
#[derive(Clone, Debug)]
pub struct Token {
pub base_state: StateId,
pub grammar_state: StateId,
pub forward_prob: LogWeight,
pub prev_token: Option<TokenId>,
pub prev_arc: Option<ArcId>,
}
pub type TokenId = u32;
pub type ArcId = u32;
#[derive(Clone, Debug)]
pub struct GroupLink {
pub source_group: TokenGroupId,
pub target_group: TokenGroupId,
pub weight: LogWeight,
pub is_word_arc: bool,
}
pub type TokenGroupId = u32;
#[derive(Clone, Debug)]
pub struct TokenGroup {
pub base_state: StateId,
pub best_forward_prob: LogWeight,
pub expanded: bool,
tokens: SmallVec<[Token; 4]>,
preceding_links: SmallVec<[GroupLink; 4]>,
succeeding_links: SmallVec<[GroupLink; 4]>,
pub frame: u32,
}
impl TokenGroup {
pub fn new(base_state: StateId, frame: u32) -> Self {
Self {
base_state,
best_forward_prob: LogWeight::zero(),
expanded: false,
tokens: SmallVec::new(),
preceding_links: SmallVec::new(),
succeeding_links: SmallVec::new(),
frame,
}
}
pub fn with_token(base_state: StateId, token: Token, frame: u32) -> Self {
let forward_prob = token.forward_prob.clone();
Self {
base_state,
best_forward_prob: forward_prob,
expanded: true,
tokens: SmallVec::from_elem(token, 1),
preceding_links: SmallVec::new(),
succeeding_links: SmallVec::new(),
frame,
}
}
pub fn add_token(&mut self, token: Token) {
self.best_forward_prob = self.best_forward_prob.plus(&token.forward_prob);
self.tokens.push(token);
}
pub fn add_preceding_link(&mut self, link: GroupLink) {
let incoming_prob = link.weight.clone();
self.best_forward_prob = self.best_forward_prob.plus(&incoming_prob);
self.preceding_links.push(link);
}
pub fn add_succeeding_link(&mut self, link: GroupLink) {
self.succeeding_links.push(link);
}
pub fn num_tokens(&self) -> usize {
self.tokens.len()
}
pub fn tokens(&self) -> &[Token] {
&self.tokens
}
pub fn tokens_mut(&mut self) -> &mut SmallVec<[Token; 4]> {
&mut self.tokens
}
pub fn preceding_links(&self) -> &[GroupLink] {
&self.preceding_links
}
pub fn succeeding_links(&self) -> &[GroupLink] {
&self.succeeding_links
}
pub fn is_empty(&self) -> bool {
self.tokens.is_empty() && self.preceding_links.is_empty()
}
}
#[derive(Debug)]
pub struct TokenGroupPool {
groups: Vec<TokenGroup>,
current_frame_map: FxHashMap<StateId, TokenGroupId>,
current_frame: u32,
}
impl TokenGroupPool {
pub fn new() -> Self {
Self {
groups: Vec::new(),
current_frame_map: FxHashMap::default(),
current_frame: 0,
}
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
groups: Vec::with_capacity(capacity),
current_frame_map: FxHashMap::with_capacity_and_hasher(capacity, Default::default()),
current_frame: 0,
}
}
pub fn advance_frame(&mut self) {
self.current_frame += 1;
self.current_frame_map.clear();
}
pub fn get_or_create(&mut self, base_state: StateId) -> TokenGroupId {
if let Some(&group_id) = self.current_frame_map.get(&base_state) {
return group_id;
}
let group_id = self.groups.len() as TokenGroupId;
self.groups
.push(TokenGroup::new(base_state, self.current_frame));
self.current_frame_map.insert(base_state, group_id);
group_id
}
pub fn get(&self, id: TokenGroupId) -> Option<&TokenGroup> {
self.groups.get(id as usize)
}
pub fn get_mut(&mut self, id: TokenGroupId) -> Option<&mut TokenGroup> {
self.groups.get_mut(id as usize)
}
pub fn len(&self) -> usize {
self.groups.len()
}
pub fn is_empty(&self) -> bool {
self.groups.is_empty()
}
pub fn current_frame(&self) -> u32 {
self.current_frame
}
pub fn clear(&mut self) {
self.groups.clear();
self.current_frame_map.clear();
self.current_frame = 0;
}
pub fn current_frame_groups(&self) -> impl Iterator<Item = (TokenGroupId, &TokenGroup)> {
self.current_frame_map
.iter()
.filter_map(|(&_base_state, &group_id)| {
self.groups.get(group_id as usize).map(|g| (group_id, g))
})
}
}
impl Default for TokenGroupPool {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct BucketQueue<T> {
buckets: Vec<VecDeque<T>>,
min_bucket: usize,
len: usize,
scale: f64,
offset: f64,
}
impl<T> BucketQueue<T> {
pub fn new(num_buckets: usize, min_weight: f64, max_weight: f64) -> Self {
let range = max_weight - min_weight;
let scale = if range > 0.0 {
(num_buckets - 1) as f64 / range
} else {
1.0
};
Self {
buckets: (0..num_buckets).map(|_| VecDeque::new()).collect(),
min_bucket: num_buckets, len: 0,
scale,
offset: min_weight,
}
}
pub fn default_for_log_probs(num_buckets: usize) -> Self {
Self::new(num_buckets, 0.0, 100.0)
}
pub fn insert(&mut self, weight: f64, item: T) {
let bucket = self.weight_to_bucket(weight);
self.buckets[bucket].push_back(item);
self.len += 1;
if bucket < self.min_bucket {
self.min_bucket = bucket;
}
}
pub fn pop(&mut self) -> Option<T> {
if self.len == 0 {
return None;
}
while self.min_bucket < self.buckets.len() {
if let Some(item) = self.buckets[self.min_bucket].pop_front() {
self.len -= 1;
return Some(item);
}
self.min_bucket += 1;
}
None
}
pub fn peek(&self) -> Option<&T> {
if self.len == 0 {
return None;
}
for bucket in self.min_bucket..self.buckets.len() {
if let Some(item) = self.buckets[bucket].front() {
return Some(item);
}
}
None
}
pub fn len(&self) -> usize {
self.len
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn clear(&mut self) {
for bucket in &mut self.buckets {
bucket.clear();
}
self.len = 0;
self.min_bucket = self.buckets.len();
}
fn weight_to_bucket(&self, weight: f64) -> usize {
let normalized = (weight - self.offset) * self.scale;
let bucket = normalized.round() as isize;
bucket.clamp(0, (self.buckets.len() - 1) as isize) as usize
}
pub fn histogram(&self) -> Vec<usize> {
self.buckets.iter().map(|b| b.len()).collect()
}
pub fn prune_beyond(&mut self, max_bucket: usize) -> usize {
let mut pruned = 0;
for bucket_idx in (max_bucket + 1)..self.buckets.len() {
pruned += self.buckets[bucket_idx].len();
self.buckets[bucket_idx].clear();
}
self.len -= pruned;
pruned
}
}
impl<T> Default for BucketQueue<T> {
fn default() -> Self {
Self::default_for_log_probs(100)
}
}
#[derive(Clone, Debug)]
pub struct TokenGroupConfig {
pub max_tokens_per_group: usize,
pub max_groups: usize,
pub num_buckets: usize,
pub lazy_evaluation: bool,
}
impl Default for TokenGroupConfig {
fn default() -> Self {
Self {
max_tokens_per_group: 32,
max_groups: 10000,
num_buckets: 100,
lazy_evaluation: true,
}
}
}
#[derive(Clone, Debug, Default)]
pub struct TokenGroupStats {
pub tokens_processed: usize,
pub groups_created: usize,
pub expansions: usize,
pub ops_saved: usize,
pub avg_tokens_per_group: f64,
}
#[derive(Clone, Debug)]
pub struct GroupedFrame {
pub active_groups: Vec<TokenGroupId>,
pub best_forward_prob: LogWeight,
pub needs_expansion: bool,
}
impl GroupedFrame {
pub fn new() -> Self {
Self {
active_groups: Vec::new(),
best_forward_prob: LogWeight::zero(),
needs_expansion: false,
}
}
}
impl Default for GroupedFrame {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct TokenGroupManager {
config: TokenGroupConfig,
pool: TokenGroupPool,
queue: BucketQueue<TokenGroupId>,
stats: TokenGroupStats,
}
impl TokenGroupManager {
pub fn new(config: TokenGroupConfig) -> Self {
let num_buckets = config.num_buckets;
Self {
config,
pool: TokenGroupPool::new(),
queue: BucketQueue::default_for_log_probs(num_buckets),
stats: TokenGroupStats::default(),
}
}
pub fn default_config() -> Self {
Self::new(TokenGroupConfig::default())
}
pub fn process_token(&mut self, token: Token, is_word_arc: bool) -> TokenGroupId {
let group_id = self.pool.get_or_create(token.base_state);
let group = self.pool.get_mut(group_id).expect("just created");
self.stats.tokens_processed += 1;
if is_word_arc || group.expanded || !self.config.lazy_evaluation {
group.expanded = true;
group.add_token(token);
} else {
group.best_forward_prob = group.best_forward_prob.plus(&token.forward_prob);
}
let weight = group.best_forward_prob.value();
self.queue.insert(weight, group_id);
group_id
}
pub fn add_link(
&mut self,
source_group: TokenGroupId,
target_group: TokenGroupId,
weight: LogWeight,
is_word_arc: bool,
) {
let link = GroupLink {
source_group,
target_group,
weight: weight.clone(),
is_word_arc,
};
if let Some(target) = self.pool.get_mut(target_group) {
target.add_preceding_link(link.clone());
}
if let Some(source) = self.pool.get_mut(source_group) {
source.add_succeeding_link(link);
}
self.stats.ops_saved += 1;
}
pub fn expand_group(&mut self, group_id: TokenGroupId) {
let group = match self.pool.get_mut(group_id) {
Some(g) => g,
None => return,
};
if group.expanded {
return;
}
group.expanded = true;
self.stats.expansions += 1;
}
pub fn advance_frame(&mut self) -> GroupedFrame {
let frame = GroupedFrame {
active_groups: self.pool.current_frame_map.values().copied().collect(),
best_forward_prob: self.compute_best_forward_prob(),
needs_expansion: false,
};
self.pool.advance_frame();
self.queue.clear();
frame
}
fn compute_best_forward_prob(&self) -> LogWeight {
let mut best = LogWeight::zero();
for (_id, group) in self.pool.current_frame_groups() {
best = best.plus(&group.best_forward_prob);
}
best
}
pub fn prune(&mut self, threshold: f64) -> usize {
let max_bucket = ((threshold - self.queue.offset) * self.queue.scale).round() as usize;
let max_bucket = max_bucket.min(self.config.num_buckets - 1);
self.queue.prune_beyond(max_bucket)
}
pub fn stats(&self) -> &TokenGroupStats {
&self.stats
}
pub fn group_mut(&mut self, id: TokenGroupId) -> Option<&mut TokenGroup> {
self.pool.get_mut(id)
}
pub fn group(&self, id: TokenGroupId) -> Option<&TokenGroup> {
self.pool.get(id)
}
pub fn num_groups(&self) -> usize {
self.pool.len()
}
pub fn clear(&mut self) {
self.pool.clear();
self.queue.clear();
self.stats = TokenGroupStats::default();
}
pub fn finalize_stats(&mut self) {
if self.stats.groups_created > 0 {
self.stats.avg_tokens_per_group =
self.stats.tokens_processed as f64 / self.stats.groups_created as f64;
}
}
}
impl Default for TokenGroupManager {
fn default() -> Self {
Self::default_config()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_token_group_creation() {
let group = TokenGroup::new(0, 0);
assert_eq!(group.base_state, 0);
assert!(group.best_forward_prob.is_zero());
assert!(!group.expanded);
assert!(group.is_empty());
}
#[test]
fn test_token_group_with_token() {
let token = Token {
base_state: 0,
grammar_state: 1,
forward_prob: LogWeight::new(1.0),
prev_token: None,
prev_arc: None,
};
let group = TokenGroup::with_token(0, token, 0);
assert!(group.expanded);
assert_eq!(group.num_tokens(), 1);
assert!(group
.best_forward_prob
.approx_eq(&LogWeight::new(1.0), 0.001));
}
#[test]
fn test_token_group_add_token() {
let mut group = TokenGroup::new(0, 0);
group.expanded = true;
let token = Token {
base_state: 0,
grammar_state: 1,
forward_prob: LogWeight::new(1.0),
prev_token: None,
prev_arc: None,
};
group.add_token(token);
assert_eq!(group.num_tokens(), 1);
let token2 = Token {
base_state: 0,
grammar_state: 2,
forward_prob: LogWeight::new(1.0),
prev_token: None,
prev_arc: None,
};
group.add_token(token2);
assert_eq!(group.num_tokens(), 2);
let expected = -(2.0 * (-1.0_f64).exp()).ln();
assert!(
group
.best_forward_prob
.approx_eq(&LogWeight::new(expected), 0.01),
"Expected {:?}, got {:?}",
expected,
group.best_forward_prob
);
}
#[test]
fn test_token_group_pool() {
let mut pool = TokenGroupPool::new();
let id1 = pool.get_or_create(0);
let id2 = pool.get_or_create(1);
let id3 = pool.get_or_create(0);
assert_eq!(id1, id3);
assert_ne!(id1, id2);
assert_eq!(pool.len(), 2);
}
#[test]
fn test_token_group_pool_advance_frame() {
let mut pool = TokenGroupPool::new();
let id1 = pool.get_or_create(0);
assert_eq!(pool.current_frame(), 0);
pool.advance_frame();
assert_eq!(pool.current_frame(), 1);
let id2 = pool.get_or_create(0);
assert_ne!(id1, id2);
}
#[test]
fn test_bucket_queue_basic() {
let mut queue: BucketQueue<u32> = BucketQueue::new(10, 0.0, 10.0);
queue.insert(5.0, 1);
queue.insert(2.0, 2);
queue.insert(8.0, 3);
assert_eq!(queue.len(), 3);
assert_eq!(queue.pop(), Some(2)); assert_eq!(queue.pop(), Some(1)); assert_eq!(queue.pop(), Some(3)); assert_eq!(queue.pop(), None);
}
#[test]
fn test_bucket_queue_same_bucket() {
let mut queue: BucketQueue<u32> = BucketQueue::new(10, 0.0, 10.0);
queue.insert(2.0, 1);
queue.insert(2.1, 2);
queue.insert(2.2, 3);
assert_eq!(queue.len(), 3);
assert_eq!(queue.pop(), Some(1));
assert_eq!(queue.pop(), Some(2));
assert_eq!(queue.pop(), Some(3));
}
#[test]
fn test_bucket_queue_prune() {
let mut queue: BucketQueue<u32> = BucketQueue::new(10, 0.0, 10.0);
queue.insert(1.0, 1);
queue.insert(5.0, 2);
queue.insert(9.0, 3);
let pruned = queue.prune_beyond(5);
assert_eq!(pruned, 1); assert_eq!(queue.len(), 2);
}
#[test]
fn test_bucket_queue_histogram() {
let mut queue: BucketQueue<u32> = BucketQueue::new(5, 0.0, 4.0);
queue.insert(0.0, 1);
queue.insert(0.0, 2);
queue.insert(2.0, 3);
queue.insert(4.0, 4);
let hist = queue.histogram();
assert_eq!(hist[0], 2); assert_eq!(hist[2], 1); assert_eq!(hist[4], 1); }
#[test]
fn test_token_group_manager_basic() {
let mut manager = TokenGroupManager::default_config();
let token = Token {
base_state: 0,
grammar_state: 1,
forward_prob: LogWeight::new(1.0),
prev_token: None,
prev_arc: None,
};
let group_id = manager.process_token(token, false);
let group = manager.group(group_id).expect("group exists");
assert!(group
.best_forward_prob
.approx_eq(&LogWeight::new(1.0), 0.001));
}
#[test]
fn test_token_group_manager_word_arc() {
let config = TokenGroupConfig {
lazy_evaluation: true,
..Default::default()
};
let mut manager = TokenGroupManager::new(config);
let token1 = Token {
base_state: 0,
grammar_state: 1,
forward_prob: LogWeight::new(1.0),
prev_token: None,
prev_arc: None,
};
let id1 = manager.process_token(token1, false);
let group1 = manager.group(id1).expect("lazy group exists");
assert!(!group1.expanded, "non-word arc should be lazy");
let token2 = Token {
base_state: 1,
grammar_state: 2,
forward_prob: LogWeight::new(2.0),
prev_token: None,
prev_arc: None,
};
let id2 = manager.process_token(token2, true);
let group2 = manager.group(id2).expect("group exists");
assert!(group2.expanded);
}
#[test]
fn test_token_group_manager_stats() {
let mut manager = TokenGroupManager::default_config();
for i in 0..5 {
let token = Token {
base_state: i % 2, grammar_state: i,
forward_prob: LogWeight::new(1.0),
prev_token: None,
prev_arc: None,
};
manager.process_token(token, i == 4); }
assert_eq!(manager.stats().tokens_processed, 5);
}
#[test]
fn test_grouped_frame() {
let mut manager = TokenGroupManager::default_config();
for i in 0..3 {
let token = Token {
base_state: i,
grammar_state: i,
forward_prob: LogWeight::new(1.0),
prev_token: None,
prev_arc: None,
};
manager.process_token(token, true);
}
let frame = manager.advance_frame();
assert_eq!(frame.active_groups.len(), 3);
}
#[test]
fn test_group_link() {
let mut manager = TokenGroupManager::default_config();
let token1 = Token {
base_state: 0,
grammar_state: 0,
forward_prob: LogWeight::new(1.0),
prev_token: None,
prev_arc: None,
};
let id1 = manager.process_token(token1, true);
manager.advance_frame();
let token2 = Token {
base_state: 1,
grammar_state: 1,
forward_prob: LogWeight::new(2.0),
prev_token: None,
prev_arc: None,
};
let id2 = manager.process_token(token2, false);
manager.add_link(id1, id2, LogWeight::new(0.5), false);
let group2 = manager.group(id2).expect("group exists");
assert_eq!(group2.preceding_links().len(), 1);
assert_eq!(manager.stats().ops_saved, 1);
}
}