use std::hash::Hash;
use super::{MultiTapeLabel, MultiTapeWfst};
use crate::semiring::Semiring;
use crate::wfst::{MutableWfst, VectorWfst, WeightedTransition};
#[derive(Debug, Clone)]
pub struct ProjectSource<T, const N: usize> {
source: T,
tape: usize,
}
impl<T, const N: usize> ProjectSource<T, N> {
pub fn new(source: T, tape: usize) -> Self {
assert!(tape < N, "Tape index {} out of range (max {})", tape, N - 1);
Self { source, tape }
}
pub fn source(&self) -> &T {
&self.source
}
pub fn tape(&self) -> usize {
self.tape
}
}
#[derive(Debug, Clone)]
pub struct ProjectedWfst<L, W: Semiring> {
wfst: VectorWfst<L, W>,
}
impl<L: Clone + Eq + Hash + Send + Sync, W: Semiring + Clone> ProjectedWfst<L, W> {
pub fn wfst(&self) -> &VectorWfst<L, W> {
&self.wfst
}
pub fn into_wfst(self) -> VectorWfst<L, W> {
self.wfst
}
}
pub fn project<L, W, T, const N: usize>(source: &T, tape: usize) -> ProjectedWfst<L, W>
where
L: Clone + Eq + Hash + Send + Sync,
W: Semiring + Clone,
T: MultiTapeWfst<L, W, N>,
{
assert!(tape < N, "Tape index {} out of range (max {})", tape, N - 1);
let mut wfst = VectorWfst::new();
for state in source.states() {
let new_state = wfst.add_state();
assert_eq!(state, new_state);
if source.is_final(state) {
wfst.set_final(state, source.final_weight(state));
}
}
wfst.set_start(source.start());
for state in source.states() {
for trans in source.transitions(state) {
let label = trans.tape_label(tape).cloned();
wfst.add_transition(WeightedTransition::new(
trans.from,
label.clone(),
label,
trans.to,
trans.weight.clone(),
));
}
}
ProjectedWfst { wfst }
}
pub fn project_tapes<L, W, T, const N: usize, const M: usize>(
source: &T,
tapes: [usize; M],
) -> crate::multitape::VectorMultiTapeWfst<L, W, M>
where
L: Clone + Eq + Hash + Send + Sync,
W: Semiring + Clone,
T: MultiTapeWfst<L, W, N>,
{
use crate::multitape::MultiTapeWfstBuilder;
for &tape in &tapes {
assert!(tape < N, "Tape index {} out of range (max {})", tape, N - 1);
}
let mut builder = MultiTapeWfstBuilder::<L, W, M>::new();
for state in source.states() {
let new_state = builder.add_state();
assert_eq!(state, new_state);
if source.is_final(state) {
builder.set_final(state, source.final_weight(state));
}
}
builder.set_start(source.start());
for state in source.states() {
for trans in source.transitions(state) {
let new_labels: [Option<L>; M] =
std::array::from_fn(|i| trans.tape_label(tapes[i]).cloned());
builder.add_transition(
trans.from,
trans.to,
MultiTapeLabel::new(new_labels),
trans.weight.clone(),
);
}
}
builder.build()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::multitape::MultiTapeWfstBuilder;
use crate::semiring::TropicalWeight;
use crate::wfst::Wfst;
fn make_test_mt() -> crate::multitape::VectorMultiTapeWfst<char, TropicalWeight, 3> {
let mut builder = MultiTapeWfstBuilder::<char, TropicalWeight, 3>::new();
let s0 = builder.add_state();
let s1 = builder.add_state();
let s2 = builder.add_final_state(TropicalWeight::one());
builder.set_start(s0);
builder.add_transition(
s0,
s1,
MultiTapeLabel::from_values(['a', 'x', '1']),
TropicalWeight::one(),
);
builder.add_transition(
s1,
s2,
MultiTapeLabel::from_values(['b', 'y', '2']),
TropicalWeight::one(),
);
builder.build()
}
#[test]
fn test_project_to_tape_0() {
let mt = make_test_mt();
let projected = project(&mt, 0);
let wfst = projected.wfst();
assert_eq!(wfst.num_states(), 3);
let transitions = wfst.transitions(0);
assert_eq!(transitions.len(), 1);
assert_eq!(transitions[0].input, Some('a'));
}
#[test]
fn test_project_to_tape_1() {
let mt = make_test_mt();
let projected = project(&mt, 1);
let wfst = projected.wfst();
let transitions = wfst.transitions(0);
assert_eq!(transitions.len(), 1);
assert_eq!(transitions[0].input, Some('x'));
}
#[test]
fn test_project_to_tape_2() {
let mt = make_test_mt();
let projected = project(&mt, 2);
let wfst = projected.wfst();
let transitions = wfst.transitions(0);
assert_eq!(transitions.len(), 1);
assert_eq!(transitions[0].input, Some('1'));
}
#[test]
fn test_project_preserves_finals() {
let mt = make_test_mt();
let projected = project(&mt, 0);
let wfst = projected.wfst();
assert!(!wfst.is_final(0));
assert!(!wfst.is_final(1));
assert!(wfst.is_final(2));
}
#[test]
fn test_project_preserves_start() {
let mt = make_test_mt();
let projected = project(&mt, 0);
assert_eq!(projected.wfst().start(), 0);
}
#[test]
fn test_project_epsilon_tape() {
let mut builder = MultiTapeWfstBuilder::<char, TropicalWeight, 2>::new();
let s0 = builder.add_state();
let s1 = builder.add_final_state(TropicalWeight::one());
builder.set_start(s0);
builder.add_transition(
s0,
s1,
MultiTapeLabel::single(0, 'a'),
TropicalWeight::one(),
);
let mt = builder.build();
let projected = project(&mt, 1);
let wfst = projected.wfst();
let transitions = wfst.transitions(0);
assert_eq!(transitions.len(), 1);
assert_eq!(transitions[0].input, None); }
#[test]
fn test_project_tapes() {
let mt = make_test_mt();
let projected: crate::multitape::VectorMultiTapeWfst<char, TropicalWeight, 2> =
project_tapes(&mt, [0, 2]);
assert_eq!(projected.num_states(), 3);
assert_eq!(projected.num_tapes(), 2);
use crate::multitape::MultiTapeWfst;
let trans = &projected.transitions(0)[0];
assert_eq!(trans.tape_label(0), Some(&'a'));
assert_eq!(trans.tape_label(1), Some(&'1'));
}
#[test]
#[should_panic(expected = "out of range")]
fn test_project_invalid_tape() {
let mt = make_test_mt();
let _ = project(&mt, 5); }
}