use ordered_float::OrderedFloat;
use smallvec::SmallVec;
use std::cmp::Ordering;
use std::hash::{Hash, Hasher};
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub enum EditOp {
Copy(char),
Insert(char),
Delete(char),
Substitute {
from: char,
to: char,
},
Transpose {
a: char,
b: char,
},
}
impl EditOp {
#[inline]
pub fn default_cost(&self) -> f64 {
match self {
EditOp::Copy(_) => 0.0,
EditOp::Insert(_) => 1.0,
EditOp::Delete(_) => 1.0,
EditOp::Substitute { .. } => 1.0,
EditOp::Transpose { .. } => 1.0,
}
}
#[inline]
pub fn is_copy(&self) -> bool {
matches!(self, EditOp::Copy(_))
}
#[inline]
pub fn output_char(&self) -> Option<char> {
match self {
EditOp::Copy(c) => Some(*c),
EditOp::Insert(c) => Some(*c),
EditOp::Delete(_) => None,
EditOp::Substitute { to, .. } => Some(*to),
EditOp::Transpose { a: _, b } => Some(*b), }
}
#[inline]
pub fn input_char(&self) -> Option<char> {
match self {
EditOp::Copy(c) => Some(*c),
EditOp::Insert(_) => None,
EditOp::Delete(c) => Some(*c),
EditOp::Substitute { from, .. } => Some(*from),
EditOp::Transpose { a, .. } => Some(*a), }
}
}
impl std::fmt::Display for EditOp {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
EditOp::Copy(c) => write!(f, "={}", c),
EditOp::Insert(c) => write!(f, "+{}", c),
EditOp::Delete(c) => write!(f, "-{}", c),
EditOp::Substitute { from, to } => write!(f, "{}>{}", from, to),
EditOp::Transpose { a, b } => write!(f, "~{}{}", a, b),
}
}
}
pub type EditSequence = SmallVec<[EditOp; 8]>;
#[derive(Clone, Debug)]
pub struct EditWeight {
sequences: SmallVec<[EditSequence; 4]>,
cost: OrderedFloat<f64>,
}
impl EditWeight {
#[inline]
pub fn new(sequence: EditSequence, cost: f64) -> Self {
let mut sequences = SmallVec::new();
sequences.push(sequence);
EditWeight {
sequences,
cost: OrderedFloat(cost),
}
}
#[inline]
pub fn single(op: EditOp, cost: f64) -> Self {
let mut seq = EditSequence::new();
seq.push(op);
Self::new(seq, cost)
}
#[inline]
pub fn from_op(op: EditOp) -> Self {
Self::single(op, op.default_cost())
}
#[inline]
pub fn zero() -> Self {
EditWeight {
sequences: SmallVec::new(),
cost: OrderedFloat(f64::INFINITY),
}
}
#[inline]
pub fn one() -> Self {
Self::new(EditSequence::new(), 0.0)
}
#[inline]
pub fn cost(&self) -> f64 {
self.cost.into_inner()
}
#[inline]
pub fn num_alternatives(&self) -> usize {
self.sequences.len()
}
#[inline]
pub fn sequences(&self) -> impl Iterator<Item = impl Iterator<Item = &EditOp>> {
self.sequences.iter().map(|seq| seq.iter())
}
#[inline]
pub fn sequences_slice(&self) -> &[EditSequence] {
&self.sequences
}
#[inline]
pub fn is_zero(&self) -> bool {
self.cost.is_infinite() || self.sequences.is_empty()
}
#[inline]
pub fn is_one(&self) -> bool {
self.cost.into_inner() == 0.0 && self.sequences.len() == 1 && self.sequences[0].is_empty()
}
pub fn plus(&self, other: &Self) -> Self {
match self.cost.cmp(&other.cost) {
Ordering::Less => self.clone(),
Ordering::Greater => other.clone(),
Ordering::Equal => {
if self.is_zero() {
return other.clone();
}
if other.is_zero() {
return self.clone();
}
let mut merged = self.sequences.clone();
for seq in &other.sequences {
if !merged.contains(seq) {
merged.push(seq.clone());
}
}
EditWeight {
sequences: merged,
cost: self.cost,
}
}
}
}
pub fn times(&self, other: &Self) -> Self {
if self.is_zero() || other.is_zero() {
return Self::zero();
}
let mut sequences = SmallVec::new();
for seq1 in &self.sequences {
for seq2 in &other.sequences {
let mut combined = seq1.clone();
combined.extend(seq2.iter().cloned());
sequences.push(combined);
}
}
if sequences.len() > 100 {
sequences.truncate(100);
}
EditWeight {
sequences,
cost: OrderedFloat(self.cost.into_inner() + other.cost.into_inner()),
}
}
pub fn divide(&self, other: &Self) -> Option<Self> {
if other.is_zero() {
return None;
}
let new_cost = self.cost.into_inner() - other.cost.into_inner();
if new_cost.is_nan() {
return None;
}
Some(EditWeight {
sequences: self.sequences.clone(),
cost: OrderedFloat(new_cost),
})
}
pub fn star(&self) -> Option<Self> {
if self.is_zero() {
return Some(Self::one());
}
if self.cost.into_inner() >= 0.0 {
Some(Self::one())
} else {
None
}
}
pub fn approx_eq(&self, other: &Self, epsilon: f64) -> bool {
(self.cost.into_inner() - other.cost.into_inner()).abs() <= epsilon
}
pub fn natural_less(&self, other: &Self) -> Option<bool> {
Some(self.cost < other.cost)
}
pub fn to_bytes(&self) -> Vec<u8> {
self.cost.into_inner().to_le_bytes().to_vec()
}
pub fn prune(&mut self, max_alternatives: usize) {
if self.sequences.len() > max_alternatives {
self.sequences.truncate(max_alternatives);
}
}
pub fn deduplicate(&mut self) {
if self.sequences.len() <= 1 {
return;
}
self.sequences.sort();
self.sequences.dedup();
}
pub fn apply(&self, _input: &str) -> Option<String> {
let seq = self.sequences.first()?;
let mut output = String::new();
for op in seq {
match op {
EditOp::Copy(c) => output.push(*c),
EditOp::Insert(c) => output.push(*c),
EditOp::Delete(_) => {} EditOp::Substitute { to, .. } => output.push(*to),
EditOp::Transpose { a, b } => {
output.push(*b);
output.push(*a);
}
}
}
Some(output)
}
pub fn describe(&self) -> String {
match self.sequences.first() {
None => "unreachable".to_string(),
Some(seq) if seq.is_empty() => "identity".to_string(),
Some(seq) => seq
.iter()
.map(|op| op.to_string())
.collect::<Vec<_>>()
.join(" "),
}
}
pub fn operation_counts(&self) -> EditOpCounts {
let mut counts = EditOpCounts::default();
if let Some(seq) = self.sequences.first() {
for op in seq {
match op {
EditOp::Copy(_) => counts.copies += 1,
EditOp::Insert(_) => counts.insertions += 1,
EditOp::Delete(_) => counts.deletions += 1,
EditOp::Substitute { .. } => counts.substitutions += 1,
EditOp::Transpose { .. } => counts.transpositions += 1,
}
}
}
counts
}
pub fn quantize(&self, epsilon: f64) -> i64 {
let v = self.cost.into_inner();
if v.is_nan() {
i64::MIN
} else if v.is_infinite() {
if v > 0.0 {
i64::MAX
} else {
i64::MIN + 1
}
} else {
(v / epsilon).round() as i64
}
}
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub struct EditOpCounts {
pub copies: usize,
pub insertions: usize,
pub deletions: usize,
pub substitutions: usize,
pub transpositions: usize,
}
impl EditOpCounts {
#[inline]
pub fn edit_distance(&self) -> usize {
self.insertions + self.deletions + self.substitutions + self.transpositions
}
#[inline]
pub fn total(&self) -> usize {
self.copies + self.insertions + self.deletions + self.substitutions + self.transpositions
}
}
impl PartialEq for EditWeight {
fn eq(&self, other: &Self) -> bool {
self.cost == other.cost
}
}
impl Eq for EditWeight {}
impl Hash for EditWeight {
fn hash<H: Hasher>(&self, state: &mut H) {
self.cost.hash(state);
}
}
impl PartialOrd for EditWeight {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for EditWeight {
fn cmp(&self, other: &Self) -> Ordering {
self.cost.cmp(&other.cost)
}
}
impl Default for EditWeight {
#[inline]
fn default() -> Self {
Self::one()
}
}
impl std::ops::Add for EditWeight {
type Output = Self;
#[inline]
fn add(self, other: Self) -> Self {
self.plus(&other)
}
}
impl std::ops::Add<&EditWeight> for EditWeight {
type Output = Self;
#[inline]
fn add(self, other: &Self) -> Self {
self.plus(other)
}
}
impl std::ops::Mul for EditWeight {
type Output = Self;
#[inline]
fn mul(self, other: Self) -> Self {
self.times(&other)
}
}
impl std::ops::Mul<&EditWeight> for EditWeight {
type Output = Self;
#[inline]
fn mul(self, other: &Self) -> Self {
self.times(other)
}
}
impl std::ops::AddAssign for EditWeight {
#[inline]
fn add_assign(&mut self, other: Self) {
*self = self.plus(&other);
}
}
impl std::ops::MulAssign for EditWeight {
#[inline]
fn mul_assign(&mut self, other: Self) {
*self = self.times(&other);
}
}
#[cfg(feature = "serde")]
impl serde::Serialize for EditWeight {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::SerializeStruct;
let mut state = serializer.serialize_struct("EditWeight", 2)?;
state.serialize_field("cost", &self.cost.into_inner())?;
state.serialize_field("alternatives", &self.sequences.len())?;
state.end()
}
}
#[derive(Clone, Debug)]
pub struct EditWeightBuilder {
pub insert_cost: f64,
pub delete_cost: f64,
pub substitute_cost: f64,
pub transpose_cost: f64,
}
impl Default for EditWeightBuilder {
fn default() -> Self {
EditWeightBuilder {
insert_cost: 1.0,
delete_cost: 1.0,
substitute_cost: 1.0,
transpose_cost: 1.0,
}
}
}
impl EditWeightBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn insert_cost(mut self, cost: f64) -> Self {
self.insert_cost = cost;
self
}
pub fn delete_cost(mut self, cost: f64) -> Self {
self.delete_cost = cost;
self
}
pub fn substitute_cost(mut self, cost: f64) -> Self {
self.substitute_cost = cost;
self
}
pub fn transpose_cost(mut self, cost: f64) -> Self {
self.transpose_cost = cost;
self
}
pub fn weight_for(&self, op: EditOp) -> EditWeight {
let cost = match op {
EditOp::Copy(_) => 0.0,
EditOp::Insert(_) => self.insert_cost,
EditOp::Delete(_) => self.delete_cost,
EditOp::Substitute { .. } => self.substitute_cost,
EditOp::Transpose { .. } => self.transpose_cost,
};
EditWeight::single(op, cost)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_copy_weight(c: char, cost: f64) -> EditWeight {
EditWeight::single(EditOp::Copy(c), cost)
}
fn make_insert_weight(c: char, cost: f64) -> EditWeight {
EditWeight::single(EditOp::Insert(c), cost)
}
fn make_delete_weight(c: char, cost: f64) -> EditWeight {
EditWeight::single(EditOp::Delete(c), cost)
}
fn make_subst_weight(from: char, to: char, cost: f64) -> EditWeight {
EditWeight::single(EditOp::Substitute { from, to }, cost)
}
#[test]
fn test_basic_operations() {
let a = make_subst_weight('a', 'b', 1.0);
let b = make_insert_weight('c', 2.0);
let sum = a.plus(&b);
assert_eq!(sum.cost(), 1.0);
let prod = a.times(&b);
assert_eq!(prod.cost(), 3.0);
assert_eq!(prod.num_alternatives(), 1);
}
#[test]
fn test_identity() {
let a = make_subst_weight('a', 'b', 1.0);
let sum = a.plus(&EditWeight::zero());
assert!(a.approx_eq(&sum, 1e-10));
let prod = a.times(&EditWeight::one());
assert!(a.approx_eq(&prod, 1e-10));
assert_eq!(prod.num_alternatives(), a.num_alternatives());
}
#[test]
fn test_annihilation() {
let a = make_subst_weight('a', 'b', 1.0);
let prod = a.times(&EditWeight::zero());
assert!(prod.is_zero());
}
#[test]
fn test_sequence_merge() {
let a = make_subst_weight('a', 'b', 1.0);
let b = make_insert_weight('c', 1.0);
let sum = a.plus(&b);
assert_eq!(sum.cost(), 1.0);
assert_eq!(sum.num_alternatives(), 2); }
#[test]
fn test_sequence_concatenation() {
let copy_a = make_copy_weight('a', 0.0);
let subst_bc = make_subst_weight('b', 'c', 1.0);
let del_x = make_delete_weight('x', 1.0);
let chain = copy_a.times(&subst_bc).times(&del_x);
assert_eq!(chain.cost(), 2.0);
assert_eq!(chain.num_alternatives(), 1);
let seq: Vec<_> = chain
.sequences()
.next()
.expect("semiring/edit.rs: required value was None/Err")
.collect();
assert_eq!(seq.len(), 3);
assert!(matches!(seq[0], EditOp::Copy('a')));
assert!(matches!(seq[1], EditOp::Substitute { from: 'b', to: 'c' }));
assert!(matches!(seq[2], EditOp::Delete('x')));
}
#[test]
fn test_star() {
let one = EditWeight::one();
let star_one = one.star().expect("One star should converge");
assert!(star_one.is_one());
let positive = make_subst_weight('a', 'b', 1.0);
let star_pos = positive.star().expect("Positive cost star should converge");
assert!(star_pos.is_one());
let negative = EditWeight::new(EditSequence::new(), -1.0);
assert!(negative.star().is_none());
}
#[test]
fn test_division() {
let a = make_subst_weight('a', 'b', 5.0);
let b = make_insert_weight('c', 3.0);
let quotient = a.divide(&b).expect("Division should succeed");
assert_eq!(quotient.cost(), 2.0);
assert!(a.divide(&EditWeight::zero()).is_none());
}
#[test]
fn test_describe() {
let seq = make_subst_weight('a', 'b', 1.0)
.times(&make_insert_weight('c', 1.0))
.times(&make_delete_weight('x', 1.0));
let desc = seq.describe();
assert!(desc.contains("a>b"));
assert!(desc.contains("+c"));
assert!(desc.contains("-x"));
}
#[test]
fn test_operation_counts() {
let seq = make_copy_weight('a', 0.0)
.times(&make_subst_weight('b', 'c', 1.0))
.times(&make_insert_weight('d', 1.0))
.times(&make_delete_weight('e', 1.0));
let counts = seq.operation_counts();
assert_eq!(counts.copies, 1);
assert_eq!(counts.substitutions, 1);
assert_eq!(counts.insertions, 1);
assert_eq!(counts.deletions, 1);
assert_eq!(counts.transpositions, 0);
assert_eq!(counts.edit_distance(), 3);
}
#[test]
fn test_prune() {
let a = make_subst_weight('a', 'b', 1.0);
let b = make_subst_weight('c', 'd', 1.0);
let c = make_subst_weight('e', 'f', 1.0);
let merged = a.plus(&b).plus(&c);
assert_eq!(merged.num_alternatives(), 3);
let mut pruned = merged.clone();
pruned.prune(2);
assert_eq!(pruned.num_alternatives(), 2);
}
#[test]
fn test_builder() {
let builder = EditWeightBuilder::new()
.insert_cost(0.5)
.delete_cost(0.7)
.substitute_cost(0.9);
let ins = builder.weight_for(EditOp::Insert('a'));
assert_eq!(ins.cost(), 0.5);
let del = builder.weight_for(EditOp::Delete('b'));
assert_eq!(del.cost(), 0.7);
let sub = builder.weight_for(EditOp::Substitute { from: 'c', to: 'd' });
assert_eq!(sub.cost(), 0.9);
}
#[test]
fn test_semiring_axioms() {
let a = make_subst_weight('a', 'b', 2.0);
let b = make_insert_weight('c', 3.0);
let c = make_delete_weight('x', 1.0);
assert!(a.plus(&EditWeight::zero()).approx_eq(&a, 1e-10));
assert!(EditWeight::zero().plus(&a).approx_eq(&a, 1e-10));
assert!(a.times(&EditWeight::one()).approx_eq(&a, 1e-10));
assert!(EditWeight::one().times(&a).approx_eq(&a, 1e-10));
assert!(a.plus(&b).approx_eq(&b.plus(&a), 1e-10));
assert!(a.plus(&b).plus(&c).approx_eq(&a.plus(&b.plus(&c)), 1e-10));
assert!(a
.times(&b)
.times(&c)
.approx_eq(&a.times(&b.times(&c)), 1e-10));
assert!(EditWeight::zero().times(&a).is_zero());
assert!(a.times(&EditWeight::zero()).is_zero());
}
#[test]
fn test_spelling_correction_example() {
let step1 = EditWeight::single(EditOp::Copy('t'), 0.0);
let step2 = EditWeight::single(EditOp::Transpose { a: 'e', b: 'h' }, 1.0);
let correction = step1.times(&step2);
assert_eq!(correction.cost(), 1.0);
let desc = correction.describe();
assert!(desc.contains("=t"));
assert!(desc.contains("~eh"));
}
#[test]
fn test_operator_overloading() {
let a = make_subst_weight('a', 'b', 2.0);
let b = make_insert_weight('c', 3.0);
let sum = a.clone() + b.clone();
assert_eq!(sum.cost(), 2.0);
let prod = a.clone() * b.clone();
assert_eq!(prod.cost(), 5.0);
let sum_ref = a.clone() + &b;
assert_eq!(sum_ref.cost(), 2.0);
let prod_ref = a.clone() * &b;
assert_eq!(prod_ref.cost(), 5.0);
}
use proptest::prelude::*;
fn arb_edit_op() -> impl Strategy<Value = EditOp> {
prop_oneof![
any::<char>().prop_map(EditOp::Copy),
any::<char>().prop_map(EditOp::Insert),
any::<char>().prop_map(EditOp::Delete),
(any::<char>(), any::<char>()).prop_map(|(f, t)| EditOp::Substitute { from: f, to: t }),
(any::<char>(), any::<char>()).prop_map(|(a, b)| EditOp::Transpose { a, b }),
]
}
fn arb_edit_weight() -> impl Strategy<Value = EditWeight> {
(arb_edit_op(), 0.0f64..100.0).prop_map(|(op, cost)| EditWeight::single(op, cost))
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn proptest_plus_associative(
a in arb_edit_weight(),
b in arb_edit_weight(),
c in arb_edit_weight()
) {
let left = a.plus(&b).plus(&c);
let right = a.plus(&b.plus(&c));
prop_assert!((left.cost() - right.cost()).abs() < 1e-10);
}
#[test]
fn proptest_plus_commutative(
a in arb_edit_weight(),
b in arb_edit_weight()
) {
let ab = a.plus(&b);
let ba = b.plus(&a);
prop_assert!((ab.cost() - ba.cost()).abs() < 1e-10);
}
#[test]
fn proptest_plus_identity(a in arb_edit_weight()) {
let zero = EditWeight::zero();
let sum = a.plus(&zero);
prop_assert!(a.approx_eq(&sum, 1e-10));
}
#[test]
fn proptest_times_associative(
a in arb_edit_weight(),
b in arb_edit_weight(),
c in arb_edit_weight()
) {
let left = a.times(&b).times(&c);
let right = a.times(&b.times(&c));
prop_assert!((left.cost() - right.cost()).abs() < 1e-10);
}
#[test]
fn proptest_times_identity(a in arb_edit_weight()) {
let one = EditWeight::one();
prop_assert!(a.times(&one).approx_eq(&a, 1e-10));
prop_assert!(one.times(&a).approx_eq(&a, 1e-10));
}
#[test]
fn proptest_zero_annihilation(a in arb_edit_weight()) {
let zero = EditWeight::zero();
prop_assert!(a.times(&zero).is_zero());
prop_assert!(zero.times(&a).is_zero());
}
#[test]
fn proptest_left_distributivity(
a in arb_edit_weight(),
b in arb_edit_weight(),
c in arb_edit_weight()
) {
let left = a.times(&b.plus(&c));
let right = a.times(&b).plus(&a.times(&c));
prop_assert!((left.cost() - right.cost()).abs() < 1e-10);
}
#[test]
fn proptest_cost_non_negative(a in arb_edit_weight(), b in arb_edit_weight()) {
let sum = a.plus(&b);
let prod = a.times(&b);
prop_assert!(sum.cost() >= 0.0);
prop_assert!(prod.cost() >= 0.0);
}
#[test]
fn proptest_times_adds_costs(
op1 in arb_edit_op(),
cost1 in 0.0f64..100.0,
op2 in arb_edit_op(),
cost2 in 0.0f64..100.0
) {
let a = EditWeight::single(op1, cost1);
let b = EditWeight::single(op2, cost2);
let prod = a.times(&b);
prop_assert!((prod.cost() - (cost1 + cost2)).abs() < 1e-10);
}
#[test]
fn proptest_plus_takes_minimum_cost(
op1 in arb_edit_op(),
cost1 in 0.0f64..100.0,
op2 in arb_edit_op(),
cost2 in 0.0f64..100.0
) {
let a = EditWeight::single(op1, cost1);
let b = EditWeight::single(op2, cost2);
let sum = a.plus(&b);
prop_assert!((sum.cost() - cost1.min(cost2)).abs() < 1e-10);
}
}
}