use std::fmt::{self, Debug};
use std::hash::Hash;
use super::{StackAction, StackSymbol};
use crate::semiring::Semiring;
use crate::wfst::StateId;
#[derive(Clone, PartialEq)]
pub struct PdaTransition<L, W: Semiring> {
pub from: StateId,
pub input: Option<L>,
pub stack_top: StackSymbol,
pub stack_action: StackAction,
pub to: StateId,
pub weight: W,
}
impl<L, W: Semiring> PdaTransition<L, W> {
pub fn new(
from: StateId,
input: Option<L>,
stack_top: StackSymbol,
stack_action: StackAction,
to: StateId,
weight: W,
) -> Self {
Self {
from,
input,
stack_top,
stack_action,
to,
weight,
}
}
pub fn epsilon(
from: StateId,
stack_top: StackSymbol,
stack_action: StackAction,
to: StateId,
weight: W,
) -> Self {
Self::new(from, None, stack_top, stack_action, to, weight)
}
pub fn is_epsilon(&self) -> bool {
self.input.is_none()
}
pub fn matches(&self, input: Option<&L>, stack_top: StackSymbol) -> bool
where
L: PartialEq,
{
if self.stack_top != stack_top {
return false;
}
match (&self.input, input) {
(None, _) => true, (Some(a), Some(b)) => a == b,
(Some(_), None) => false,
}
}
pub fn source(&self) -> StateId {
self.from
}
pub fn target(&self) -> StateId {
self.to
}
pub fn net_stack_change(&self) -> i32 {
self.stack_action.net_change()
}
}
impl<L: Clone, W: Semiring + Clone> PdaTransition<L, W> {
pub fn map_input<F, M>(&self, f: F) -> PdaTransition<M, W>
where
F: FnOnce(&L) -> M,
{
PdaTransition {
from: self.from,
input: self.input.as_ref().map(f),
stack_top: self.stack_top,
stack_action: self.stack_action.clone(),
to: self.to,
weight: self.weight.clone(),
}
}
}
impl<L: Debug, W: Semiring + Debug> Debug for PdaTransition<L, W> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "PdaTransition {{ {} --", self.from)?;
match &self.input {
Some(i) => write!(f, "{:?}", i)?,
None => write!(f, "ε")?,
}
write!(
f,
"/{} {}-- {} (w={:?}) }}",
self.stack_top, self.stack_action, self.to, self.weight
)
}
}
impl<L: Eq + Hash, W: Semiring + PartialEq> Eq for PdaTransition<L, W> {}
impl<L: Hash, W: Semiring + Hash> Hash for PdaTransition<L, W> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.from.hash(state);
self.input.hash(state);
self.stack_top.hash(state);
self.stack_action.hash(state);
self.to.hash(state);
self.weight.hash(state);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::semiring::TropicalWeight;
#[test]
fn test_transition_creation() {
let trans: PdaTransition<char, TropicalWeight> = PdaTransition::new(
0,
Some('a'),
StackSymbol::BOTTOM,
StackAction::Push(vec![StackSymbol::BOTTOM, StackSymbol::new(1)]),
1,
TropicalWeight::one(),
);
assert_eq!(trans.from, 0);
assert_eq!(trans.to, 1);
assert_eq!(trans.input, Some('a'));
assert!(!trans.is_epsilon());
}
#[test]
fn test_epsilon_transition() {
let trans: PdaTransition<char, TropicalWeight> = PdaTransition::epsilon(
0,
StackSymbol::BOTTOM,
StackAction::Noop,
1,
TropicalWeight::one(),
);
assert!(trans.is_epsilon());
assert_eq!(trans.input, None);
}
#[test]
fn test_matches() {
let trans: PdaTransition<char, TropicalWeight> = PdaTransition::new(
0,
Some('a'),
StackSymbol::new(1),
StackAction::Pop,
1,
TropicalWeight::one(),
);
assert!(trans.matches(Some(&'a'), StackSymbol::new(1)));
assert!(!trans.matches(Some(&'b'), StackSymbol::new(1)));
assert!(!trans.matches(Some(&'a'), StackSymbol::new(2)));
assert!(!trans.matches(None, StackSymbol::new(1)));
}
#[test]
fn test_epsilon_matches() {
let trans: PdaTransition<char, TropicalWeight> = PdaTransition::epsilon(
0,
StackSymbol::new(1),
StackAction::Pop,
1,
TropicalWeight::one(),
);
assert!(trans.matches(Some(&'a'), StackSymbol::new(1)));
assert!(trans.matches(Some(&'b'), StackSymbol::new(1)));
assert!(trans.matches(None, StackSymbol::new(1)));
assert!(!trans.matches(Some(&'a'), StackSymbol::new(2)));
}
#[test]
fn test_source_target() {
let trans: PdaTransition<char, TropicalWeight> = PdaTransition::new(
5,
Some('x'),
StackSymbol::BOTTOM,
StackAction::Noop,
10,
TropicalWeight::one(),
);
assert_eq!(trans.source(), 5);
assert_eq!(trans.target(), 10);
}
#[test]
fn test_net_stack_change() {
let trans1: PdaTransition<char, TropicalWeight> = PdaTransition::new(
0,
Some('a'),
StackSymbol::BOTTOM,
StackAction::Pop,
1,
TropicalWeight::one(),
);
assert_eq!(trans1.net_stack_change(), -1);
let trans2: PdaTransition<char, TropicalWeight> = PdaTransition::new(
0,
Some('a'),
StackSymbol::BOTTOM,
StackAction::Push(vec![StackSymbol::new(1), StackSymbol::new(2)]),
1,
TropicalWeight::one(),
);
assert_eq!(trans2.net_stack_change(), 1);
}
#[test]
fn test_map_input() {
let trans: PdaTransition<i32, TropicalWeight> = PdaTransition::new(
0,
Some(5),
StackSymbol::BOTTOM,
StackAction::Noop,
1,
TropicalWeight::one(),
);
let mapped = trans.map_input(|&x| x * 2);
assert_eq!(mapped.input, Some(10));
}
#[test]
fn test_debug_format() {
let trans: PdaTransition<char, TropicalWeight> = PdaTransition::new(
0,
Some('a'),
StackSymbol::BOTTOM,
StackAction::Pop,
1,
TropicalWeight::new(1.0),
);
let debug = format!("{:?}", trans);
assert!(debug.contains("PdaTransition"));
assert!(debug.contains("'a'"));
}
}