use std::hash::Hash;
use super::{MultiTapeLabel, MultiTapeTransition};
use crate::semiring::Semiring;
use crate::wfst::StateId;
pub trait MultiTapeWfst<L, W, const N: usize>: Clone + Send + Sync
where
L: Clone + Eq + Hash + Send + Sync,
W: Semiring,
{
fn start(&self) -> StateId;
fn is_final(&self, state: StateId) -> bool;
fn final_weight(&self, state: StateId) -> W;
fn transitions(&self, state: StateId) -> &[MultiTapeTransition<L, W, N>];
fn num_states(&self) -> usize;
fn num_transitions(&self) -> usize;
fn states(&self) -> impl Iterator<Item = StateId>;
fn final_states(&self) -> impl Iterator<Item = StateId>;
fn is_empty(&self) -> bool {
self.num_states() == 0
}
fn num_tapes(&self) -> usize {
N
}
}
pub trait MultiTapeWfstOps<L, W, const N: usize>: MultiTapeWfst<L, W, N>
where
L: Clone + Eq + Hash + Send + Sync,
W: Semiring + Clone,
{
fn transitions_matching_tape(
&self,
state: StateId,
tape: usize,
label: &L,
) -> Vec<&MultiTapeTransition<L, W, N>> {
self.transitions(state)
.iter()
.filter(|t| t.tape_label(tape) == Some(label))
.collect()
}
fn epsilon_transitions(&self, state: StateId) -> Vec<&MultiTapeTransition<L, W, N>> {
self.transitions(state)
.iter()
.filter(|t| t.is_epsilon())
.collect()
}
fn tape_epsilon_transitions(
&self,
state: StateId,
tape: usize,
) -> Vec<&MultiTapeTransition<L, W, N>> {
self.transitions(state)
.iter()
.filter(|t| t.is_tape_epsilon(tape))
.collect()
}
fn tape_non_epsilon_transitions(
&self,
state: StateId,
tape: usize,
) -> Vec<&MultiTapeTransition<L, W, N>> {
self.transitions(state)
.iter()
.filter(|t| !t.is_tape_epsilon(tape))
.collect()
}
fn tape_alphabet(&self, tape: usize) -> Vec<L> {
let mut labels = std::collections::HashSet::new();
for state in self.states() {
for trans in self.transitions(state) {
if let Some(label) = trans.tape_label(tape) {
labels.insert(label.clone());
}
}
}
labels.into_iter().collect()
}
fn has_epsilon_transitions(&self) -> bool {
self.states()
.any(|s| self.transitions(s).iter().any(|t| t.is_epsilon()))
}
fn tape_has_epsilon(&self, tape: usize) -> bool {
self.states()
.any(|s| self.transitions(s).iter().any(|t| t.is_tape_epsilon(tape)))
}
fn count_transitions(&self) -> usize {
self.states().map(|s| self.transitions(s).len()).sum()
}
fn accepts(&self, input: &[MultiTapeLabel<L, N>]) -> bool
where
L: PartialEq,
{
self.transduce(input).is_some()
}
fn transduce(&self, input: &[MultiTapeLabel<L, N>]) -> Option<W>
where
L: PartialEq,
{
fn transduce_from<L, W, const N: usize, T>(
wfst: &T,
state: StateId,
input: &[MultiTapeLabel<L, N>],
) -> Option<W>
where
L: Clone + Eq + Hash + Send + Sync + PartialEq,
W: Semiring + Clone,
T: MultiTapeWfst<L, W, N>,
{
if input.is_empty() {
if wfst.is_final(state) {
return Some(wfst.final_weight(state));
}
for trans in wfst.transitions(state) {
if trans.is_epsilon() {
if let Some(w) = transduce_from(wfst, trans.to, input) {
return Some(trans.weight.clone().times(&w));
}
}
}
return None;
}
let label = &input[0];
let rest = &input[1..];
for trans in wfst.transitions(state) {
if trans.labels == *label {
if let Some(w) = transduce_from(wfst, trans.to, rest) {
return Some(trans.weight.clone().times(&w));
}
}
}
for trans in wfst.transitions(state) {
if trans.is_epsilon() {
if let Some(w) = transduce_from(wfst, trans.to, input) {
return Some(trans.weight.clone().times(&w));
}
}
}
None
}
transduce_from(self, self.start(), input)
}
}
impl<T, L, W, const N: usize> MultiTapeWfstOps<L, W, N> for T
where
T: MultiTapeWfst<L, W, N>,
L: Clone + Eq + Hash + Send + Sync,
W: Semiring + Clone,
{
}
#[cfg(test)]
mod tests {
use super::*;
use crate::multitape::VectorMultiTapeWfst;
use crate::semiring::TropicalWeight;
fn make_simple_mt() -> VectorMultiTapeWfst<char, TropicalWeight, 2> {
use crate::multitape::MultiTapeWfstBuilder;
let mut builder = MultiTapeWfstBuilder::<char, TropicalWeight, 2>::new();
let s0 = builder.add_state();
let s1 = builder.add_state();
builder.set_start(s0);
builder.set_final(s1, TropicalWeight::one());
builder.add_transition(
s0,
s1,
MultiTapeLabel::from_values(['a', 'x']),
TropicalWeight::one(),
);
builder.build()
}
#[test]
fn test_basic_ops() {
let mt = make_simple_mt();
assert_eq!(mt.num_tapes(), 2);
assert!(!mt.is_empty());
}
#[test]
fn test_tape_alphabet() {
let mt = make_simple_mt();
let alphabet0 = mt.tape_alphabet(0);
let alphabet1 = mt.tape_alphabet(1);
assert!(alphabet0.contains(&'a'));
assert!(alphabet1.contains(&'x'));
}
#[test]
fn test_transduce() {
let mt = make_simple_mt();
let input = vec![MultiTapeLabel::from_values(['a', 'x'])];
assert!(mt.accepts(&input));
let input2 = vec![MultiTapeLabel::from_values(['b', 'y'])];
assert!(!mt.accepts(&input2));
}
}