use std::collections::VecDeque;
use rustc_hash::FxHashMap;
use smallvec::SmallVec;
use super::traits::{CachePolicy, LazyWfst, Wfst};
use super::{StateId, WeightedTransition};
use crate::semiring::Semiring;
#[derive(Clone, Debug, Default)]
pub enum LazyState<L, W: Semiring> {
#[default]
Pending,
Computed {
is_final: bool,
final_weight: W,
transitions: SmallVec<[WeightedTransition<L, W>; 4]>,
},
}
impl<L, W: Semiring> LazyState<L, W> {
pub fn non_final(transitions: SmallVec<[WeightedTransition<L, W>; 4]>) -> Self {
LazyState::Computed {
is_final: false,
final_weight: W::zero(),
transitions,
}
}
pub fn final_state(weight: W, transitions: SmallVec<[WeightedTransition<L, W>; 4]>) -> Self {
LazyState::Computed {
is_final: true,
final_weight: weight,
transitions,
}
}
#[inline]
pub fn is_computed(&self) -> bool {
matches!(self, LazyState::Computed { .. })
}
#[inline]
pub fn transitions(&self) -> Option<&[WeightedTransition<L, W>]> {
match self {
LazyState::Computed { transitions, .. } => Some(transitions.as_slice()),
LazyState::Pending => None,
}
}
}
pub trait StateSource<L, W: Semiring>: Clone + Send + Sync {
fn compute_state(&self, state: StateId) -> LazyState<L, W>;
fn start(&self) -> StateId;
fn num_states_hint(&self) -> Option<usize> {
None
}
}
#[derive(Debug)]
pub struct LazyWfstWrapper<S, L, W>
where
S: StateSource<L, W>,
W: Semiring,
{
source: S,
cache: FxHashMap<StateId, LazyState<L, W>>,
access_order: VecDeque<StateId>,
policy: CachePolicy,
computed_count: u32,
start: StateId,
}
impl<S, L, W> Clone for LazyWfstWrapper<S, L, W>
where
S: StateSource<L, W> + Clone,
L: Clone,
W: Semiring,
{
fn clone(&self) -> Self {
Self {
source: self.source.clone(),
cache: self.cache.clone(),
access_order: self.access_order.clone(),
policy: self.policy,
computed_count: self.computed_count,
start: self.start,
}
}
}
impl<S, L, W> LazyWfstWrapper<S, L, W>
where
S: StateSource<L, W>,
L: Clone + Send + Sync,
W: Semiring,
{
pub fn new(source: S) -> Self {
let start = source.start();
let initial_capacity = source.num_states_hint().unwrap_or(16);
Self {
source,
cache: FxHashMap::with_capacity_and_hasher(initial_capacity, Default::default()),
access_order: VecDeque::with_capacity(initial_capacity),
policy: CachePolicy::default(),
computed_count: 0,
start,
}
}
pub fn with_cache_policy(source: S, policy: CachePolicy) -> Self {
let mut wrapper = Self::new(source);
wrapper.policy = policy;
wrapper
}
fn ensure_computed(&mut self, state: StateId) -> &LazyState<L, W> {
if !self.cache.contains_key(&state) {
let computed = self.source.compute_state(state);
self.insert_cached(state, computed);
} else if matches!(self.policy, CachePolicy::Lru { .. }) {
self.touch_lru(state);
}
self.cache.get(&state).expect("State should be cached")
}
fn insert_cached(&mut self, state: StateId, computed: LazyState<L, W>) {
match self.policy {
CachePolicy::NoCache => {
self.computed_count += 1;
}
CachePolicy::CacheAll => {
self.cache.insert(state, computed);
self.computed_count += 1;
}
CachePolicy::Lru { max_states } => {
while self.cache.len() >= max_states {
if let Some(evict) = self.access_order.pop_front() {
self.cache.remove(&evict);
} else {
break;
}
}
self.cache.insert(state, computed);
self.access_order.push_back(state);
self.computed_count += 1;
}
}
}
fn touch_lru(&mut self, state: StateId) {
if let Some(pos) = self.access_order.iter().position(|&s| s == state) {
self.access_order.remove(pos);
self.access_order.push_back(state);
}
}
pub fn source(&self) -> &S {
&self.source
}
pub fn source_mut(&mut self) -> &mut S {
&mut self.source
}
pub fn into_source(self) -> S {
self.source
}
}
impl<S, L, W> Wfst<L, W> for LazyWfstWrapper<S, L, W>
where
S: StateSource<L, W>,
L: Clone + Send + Sync,
W: Semiring,
{
fn start(&self) -> StateId {
self.start
}
fn is_final(&self, state: StateId) -> bool {
self.cache
.get(&state)
.map(|s| matches!(s, LazyState::Computed { is_final: true, .. }))
.unwrap_or(false)
}
fn final_weight(&self, state: StateId) -> W {
self.cache
.get(&state)
.map(|s| match s {
LazyState::Computed { final_weight, .. } => *final_weight,
LazyState::Pending => W::zero(),
})
.unwrap_or_else(W::zero)
}
fn transitions(&self, state: StateId) -> &[WeightedTransition<L, W>] {
self.cache
.get(&state)
.and_then(|s| s.transitions())
.unwrap_or(&[])
}
fn num_states(&self) -> usize {
self.source.num_states_hint().unwrap_or(0)
}
}
impl<S, L, W> LazyWfst<L, W> for LazyWfstWrapper<S, L, W>
where
S: StateSource<L, W>,
L: Clone + Send + Sync,
W: Semiring,
{
fn is_expanded(&self, state: StateId) -> bool {
self.cache
.get(&state)
.map(|s| s.is_computed())
.unwrap_or(false)
}
fn expand(&mut self, state: StateId) {
if !self.is_expanded(state) {
let computed = self.source.compute_state(state);
self.insert_cached(state, computed);
}
}
fn transitions_lazy(&mut self, state: StateId) -> &[WeightedTransition<L, W>] {
self.ensure_computed(state);
self.transitions(state)
}
fn cache_policy(&self) -> CachePolicy {
self.policy
}
fn set_cache_policy(&mut self, policy: CachePolicy) {
self.policy = policy;
}
fn computed_states(&self) -> usize {
self.computed_count as usize
}
fn clear_cache(&mut self) {
self.cache.clear();
self.access_order.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::semiring::TropicalWeight;
#[derive(Clone)]
struct LinearChainSource {
num_states: usize,
}
impl StateSource<char, TropicalWeight> for LinearChainSource {
fn compute_state(&self, state: StateId) -> LazyState<char, TropicalWeight> {
let state_idx = state as usize;
if state_idx >= self.num_states {
return LazyState::Pending;
}
let is_final = state_idx == self.num_states - 1;
let mut transitions = SmallVec::new();
if state_idx < self.num_states - 1 {
transitions.push(WeightedTransition::new(
state,
Some('a'),
Some('a'),
state + 1,
TropicalWeight::new(1.0),
));
}
if is_final {
LazyState::final_state(TropicalWeight::one(), transitions)
} else {
LazyState::non_final(transitions)
}
}
fn start(&self) -> StateId {
0
}
fn num_states_hint(&self) -> Option<usize> {
Some(self.num_states)
}
}
#[test]
fn test_lazy_wrapper_basic() {
let source = LinearChainSource { num_states: 5 };
let mut lazy = LazyWfstWrapper::new(source);
assert_eq!(lazy.start(), 0);
assert_eq!(lazy.computed_states(), 0);
let transitions = lazy.transitions_lazy(0);
assert_eq!(transitions.len(), 1);
assert_eq!(lazy.computed_states(), 1);
let transitions = lazy.transitions_lazy(1);
assert_eq!(transitions.len(), 1);
assert_eq!(lazy.computed_states(), 2);
lazy.expand(4);
assert!(lazy.is_expanded(4));
assert_eq!(lazy.computed_states(), 3);
}
#[test]
fn test_lru_eviction() {
let source = LinearChainSource { num_states: 10 };
let mut lazy =
LazyWfstWrapper::with_cache_policy(source, CachePolicy::Lru { max_states: 3 });
for i in 0..5 {
lazy.expand(i);
}
assert_eq!(lazy.cache.len(), 3);
assert!(lazy.is_expanded(4));
assert!(lazy.is_expanded(3));
assert!(lazy.is_expanded(2));
assert!(!lazy.is_expanded(0));
assert!(!lazy.is_expanded(1));
}
#[test]
fn test_clear_cache() {
let source = LinearChainSource { num_states: 5 };
let mut lazy = LazyWfstWrapper::new(source);
lazy.expand(0);
lazy.expand(1);
lazy.expand(2);
assert_eq!(lazy.cache.len(), 3);
assert_eq!(lazy.computed_states(), 3);
lazy.clear_cache();
assert_eq!(lazy.cache.len(), 0);
assert_eq!(lazy.computed_states(), 3);
}
}