use std::collections::HashMap;
use crate::semiring::Semiring;
use crate::wfst::{MutableWfst, StateId, VectorWfst, WeightedTransition, Wfst};
pub type PhoneId = u32;
pub type FrameIndex = usize;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum DysfluencyPattern {
SoundRepetition,
SyllableRepetition,
WordRepetition,
Block,
Prolongation,
Interjection,
}
impl DysfluencyPattern {
pub fn all() -> &'static [DysfluencyPattern] {
&[
DysfluencyPattern::SoundRepetition,
DysfluencyPattern::SyllableRepetition,
DysfluencyPattern::WordRepetition,
DysfluencyPattern::Block,
DysfluencyPattern::Prolongation,
DysfluencyPattern::Interjection,
]
}
pub fn name(&self) -> &'static str {
match self {
DysfluencyPattern::SoundRepetition => "sound_repetition",
DysfluencyPattern::SyllableRepetition => "syllable_repetition",
DysfluencyPattern::WordRepetition => "word_repetition",
DysfluencyPattern::Block => "block",
DysfluencyPattern::Prolongation => "prolongation",
DysfluencyPattern::Interjection => "interjection",
}
}
}
#[derive(Debug, Clone)]
pub struct DysfluencySpan {
pub pattern: DysfluencyPattern,
pub start_frame: FrameIndex,
pub end_frame: FrameIndex,
pub phones: Vec<PhoneId>,
pub score: f64,
pub repetition_count: Option<usize>,
}
impl DysfluencySpan {
pub fn duration(&self) -> usize {
self.end_frame.saturating_sub(self.start_frame)
}
}
#[derive(Debug, Clone)]
pub struct DysfluencyConfig {
pub min_repetitions: usize,
pub max_repetitions: usize,
pub min_prolongation_frames: usize,
pub detection_penalty: f64,
pub block_cost: f64,
pub interjection_phones: Vec<PhoneId>,
pub silence_phone: PhoneId,
}
impl Default for DysfluencyConfig {
fn default() -> Self {
Self {
min_repetitions: 2,
max_repetitions: 5,
min_prolongation_frames: 3,
detection_penalty: 1.0,
block_cost: 2.0,
interjection_phones: vec![],
silence_phone: 0,
}
}
}
#[derive(Debug)]
pub struct DysfluencyDetector<W: Semiring> {
patterns: HashMap<DysfluencyPattern, VectorWfst<PhoneId, W>>,
config: DysfluencyConfig,
vocab_size: usize,
}
impl<W: Semiring + From<f64> + Clone> DysfluencyDetector<W> {
pub fn new(vocab_size: usize, config: DysfluencyConfig) -> Self {
let mut detector = Self {
patterns: HashMap::new(),
config,
vocab_size,
};
detector.build_patterns();
detector
}
pub fn with_vocab_size(vocab_size: usize) -> Self {
Self::new(vocab_size, DysfluencyConfig::default())
}
fn build_patterns(&mut self) {
self.patterns.insert(
DysfluencyPattern::SoundRepetition,
self.build_sound_repetition_pattern(),
);
self.patterns.insert(
DysfluencyPattern::Prolongation,
self.build_prolongation_pattern(),
);
self.patterns
.insert(DysfluencyPattern::Block, self.build_block_pattern());
self.patterns.insert(
DysfluencyPattern::Interjection,
self.build_interjection_pattern(),
);
}
fn build_sound_repetition_pattern(&self) -> VectorWfst<PhoneId, W> {
let mut fst: VectorWfst<PhoneId, W> = VectorWfst::new();
fst.add_states(3);
fst.set_start(0);
fst.set_final(2, W::one());
let penalty = W::from(self.config.detection_penalty);
for phone in 1..self.vocab_size as PhoneId {
if phone == self.config.silence_phone {
continue;
}
fst.add_transition(WeightedTransition {
from: 0,
input: Some(phone),
output: Some(phone),
to: 1,
weight: W::one(),
});
fst.add_transition(WeightedTransition {
from: 1,
input: Some(phone),
output: Some(phone),
to: 2,
weight: penalty.clone(),
});
fst.add_transition(WeightedTransition {
from: 2,
input: Some(phone),
output: Some(phone),
to: 2,
weight: W::one(),
});
}
fst.add_transition(WeightedTransition {
from: 2,
input: None, output: None,
to: 0,
weight: W::one(),
});
fst
}
fn build_prolongation_pattern(&self) -> VectorWfst<PhoneId, W> {
let mut fst: VectorWfst<PhoneId, W> = VectorWfst::new();
let num_states = self.config.min_prolongation_frames + 2;
fst.add_states(num_states);
fst.set_start(0);
fst.set_final((num_states - 1) as StateId, W::one());
let penalty = W::from(self.config.detection_penalty);
for phone in 1..self.vocab_size as PhoneId {
if phone == self.config.silence_phone {
continue;
}
for i in 0..self.config.min_prolongation_frames {
let from_state = i as StateId;
let to_state = (i + 1) as StateId;
fst.add_transition(WeightedTransition {
from: from_state,
input: Some(phone),
output: Some(phone),
to: to_state,
weight: W::one(),
});
}
let penultimate = self.config.min_prolongation_frames as StateId;
let final_state = (num_states - 1) as StateId;
fst.add_transition(WeightedTransition {
from: penultimate,
input: Some(phone),
output: Some(phone),
to: final_state,
weight: penalty.clone(),
});
fst.add_transition(WeightedTransition {
from: final_state,
input: Some(phone),
output: Some(phone),
to: final_state,
weight: W::one(),
});
}
fst
}
fn build_block_pattern(&self) -> VectorWfst<PhoneId, W> {
let mut fst: VectorWfst<PhoneId, W> = VectorWfst::new();
fst.add_states(3);
fst.set_start(0);
fst.set_final(2, W::one());
let block_cost = W::from(self.config.block_cost);
let silence = self.config.silence_phone;
for phone in 1..self.vocab_size as PhoneId {
if phone == silence {
continue;
}
fst.add_transition(WeightedTransition {
from: 0,
input: Some(phone),
output: Some(phone),
to: 0,
weight: W::one(),
});
}
fst.add_transition(WeightedTransition {
from: 0,
input: Some(silence),
output: Some(silence),
to: 1,
weight: W::one(),
});
fst.add_transition(WeightedTransition {
from: 1,
input: Some(silence),
output: Some(silence),
to: 1,
weight: W::one(),
});
for phone in 1..self.vocab_size as PhoneId {
if phone == silence {
continue;
}
fst.add_transition(WeightedTransition {
from: 1,
input: Some(phone),
output: Some(phone),
to: 2,
weight: block_cost.clone(),
});
}
for phone in 1..self.vocab_size as PhoneId {
fst.add_transition(WeightedTransition {
from: 2,
input: Some(phone),
output: Some(phone),
to: 2,
weight: W::one(),
});
}
fst
}
fn build_interjection_pattern(&self) -> VectorWfst<PhoneId, W> {
let mut fst: VectorWfst<PhoneId, W> = VectorWfst::new();
fst.add_states(2);
fst.set_start(0);
fst.set_final(1, W::one());
let penalty = W::from(self.config.detection_penalty);
for &phone in &self.config.interjection_phones {
fst.add_transition(WeightedTransition {
from: 0,
input: Some(phone),
output: Some(phone),
to: 1,
weight: penalty.clone(),
});
fst.add_transition(WeightedTransition {
from: 1,
input: Some(phone),
output: Some(phone),
to: 1,
weight: W::one(),
});
}
fst
}
pub fn detect(&self, lattice: &VectorWfst<PhoneId, W>) -> Vec<DysfluencySpan>
where
W: Into<f64>,
{
let mut spans = Vec::new();
for (&pattern_type, pattern_fst) in &self.patterns {
let detected = self.detect_pattern(lattice, pattern_fst, pattern_type);
spans.extend(detected);
}
spans.sort_by_key(|s| s.start_frame);
spans
}
fn detect_pattern(
&self,
lattice: &VectorWfst<PhoneId, W>,
pattern: &VectorWfst<PhoneId, W>,
pattern_type: DysfluencyPattern,
) -> Vec<DysfluencySpan>
where
W: Into<f64>,
{
let mut spans = Vec::new();
let start_state = lattice.start();
self.scan_for_pattern(
lattice,
pattern,
start_state,
pattern.start(),
0,
Vec::new(),
&mut spans,
pattern_type,
);
spans
}
#[allow(clippy::too_many_arguments)]
fn scan_for_pattern(
&self,
lattice: &VectorWfst<PhoneId, W>,
pattern: &VectorWfst<PhoneId, W>,
lattice_state: StateId,
pattern_state: StateId,
frame: FrameIndex,
phones: Vec<PhoneId>,
spans: &mut Vec<DysfluencySpan>,
pattern_type: DysfluencyPattern,
) where
W: Into<f64>,
{
if pattern.is_final(pattern_state) && !phones.is_empty() {
let score: f64 = pattern.final_weight(pattern_state).into();
spans.push(DysfluencySpan {
pattern: pattern_type,
start_frame: frame.saturating_sub(phones.len()),
end_frame: frame,
phones: phones.clone(),
score,
repetition_count: self.count_repetitions(&phones),
});
}
if frame > 1000 {
return;
}
for lat_tr in lattice.transitions(lattice_state) {
for pat_tr in pattern.transitions(pattern_state) {
let labels_match = match (lat_tr.input, pat_tr.input) {
(Some(l1), Some(l2)) => l1 == l2,
(None, None) => true, _ => false,
};
if labels_match {
let mut new_phones = phones.clone();
if let Some(phone) = lat_tr.input {
new_phones.push(phone);
}
if new_phones.len() <= 20 {
self.scan_for_pattern(
lattice,
pattern,
lat_tr.to,
pat_tr.to,
frame + 1,
new_phones,
spans,
pattern_type,
);
}
}
}
for pat_tr in pattern.transitions(pattern_state) {
if pat_tr.input.is_none() {
self.scan_for_pattern(
lattice,
pattern,
lattice_state,
pat_tr.to,
frame,
phones.clone(),
spans,
pattern_type,
);
}
}
}
}
fn count_repetitions(&self, phones: &[PhoneId]) -> Option<usize> {
if phones.is_empty() {
return None;
}
let first = phones[0];
let mut count = 0;
for &p in phones {
if p == first {
count += 1;
} else {
break;
}
}
if count >= 2 {
Some(count)
} else {
None
}
}
pub fn get_pattern(&self, pattern: DysfluencyPattern) -> Option<&VectorWfst<PhoneId, W>> {
self.patterns.get(&pattern)
}
pub fn config(&self) -> &DysfluencyConfig {
&self.config
}
}
impl<W: Semiring + From<f64> + Clone> Default for DysfluencyDetector<W> {
fn default() -> Self {
Self::with_vocab_size(100) }
}
#[derive(Debug)]
pub struct WordRepetitionBuilder<W: Semiring> {
words: Vec<u32>,
min_reps: usize,
max_reps: usize,
_phantom: std::marker::PhantomData<W>,
}
impl<W: Semiring + From<f64> + Clone> WordRepetitionBuilder<W> {
pub fn new() -> Self {
Self {
words: Vec::new(),
min_reps: 2,
max_reps: 5,
_phantom: std::marker::PhantomData,
}
}
pub fn add_word(mut self, word_id: u32) -> Self {
self.words.push(word_id);
self
}
pub fn min_repetitions(mut self, min: usize) -> Self {
self.min_reps = min;
self
}
pub fn max_repetitions(mut self, max: usize) -> Self {
self.max_reps = max;
self
}
pub fn build(self) -> VectorWfst<u32, W> {
let mut fst: VectorWfst<u32, W> = VectorWfst::new();
let num_states = self.max_reps + 1;
fst.add_states(num_states);
fst.set_start(0);
for i in self.min_reps..=self.max_reps {
fst.set_final(i as StateId, W::one());
}
for &word in &self.words {
for i in 0..self.max_reps {
fst.add_transition(WeightedTransition {
from: i as StateId,
input: Some(word),
output: Some(word),
to: (i + 1) as StateId,
weight: W::one(),
});
}
}
fst
}
}
impl<W: Semiring + From<f64> + Clone> Default for WordRepetitionBuilder<W> {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct SyllableRepetitionBuilder<W: Semiring> {
syllables: Vec<Vec<PhoneId>>,
min_reps: usize,
_phantom: std::marker::PhantomData<W>,
}
impl<W: Semiring + From<f64> + Clone> SyllableRepetitionBuilder<W> {
pub fn new() -> Self {
Self {
syllables: Vec::new(),
min_reps: 2,
_phantom: std::marker::PhantomData,
}
}
pub fn add_syllable(mut self, phones: Vec<PhoneId>) -> Self {
self.syllables.push(phones);
self
}
pub fn min_repetitions(mut self, min: usize) -> Self {
self.min_reps = min;
self
}
pub fn build(self) -> VectorWfst<PhoneId, W> {
let mut fst: VectorWfst<PhoneId, W> = VectorWfst::new();
let mut state_counter: StateId = 0;
fst.add_state(); fst.set_start(0);
state_counter += 1;
for syllable in &self.syllables {
if syllable.is_empty() {
continue;
}
let first_start = state_counter;
for &phone in syllable {
fst.add_state();
fst.add_transition(WeightedTransition {
from: if state_counter == first_start {
0
} else {
state_counter - 1
},
input: Some(phone),
output: Some(phone),
to: state_counter,
weight: W::one(),
});
state_counter += 1;
}
let repeat_start = state_counter - 1;
for (i, &phone) in syllable.iter().enumerate() {
fst.add_state();
let from = if i == 0 {
repeat_start
} else {
state_counter - 1
};
fst.add_transition(WeightedTransition {
from,
input: Some(phone),
output: Some(phone),
to: state_counter,
weight: W::one(),
});
state_counter += 1;
}
fst.set_final(state_counter - 1, W::one());
}
fst
}
}
impl<W: Semiring + From<f64> + Clone> Default for SyllableRepetitionBuilder<W> {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::semiring::TropicalWeight;
#[test]
fn test_dysfluency_pattern_all() {
let patterns = DysfluencyPattern::all();
assert_eq!(patterns.len(), 6);
}
#[test]
fn test_dysfluency_detector_creation() {
let detector = DysfluencyDetector::<TropicalWeight>::with_vocab_size(50);
assert!(detector
.get_pattern(DysfluencyPattern::SoundRepetition)
.is_some());
assert!(detector
.get_pattern(DysfluencyPattern::Prolongation)
.is_some());
assert!(detector.get_pattern(DysfluencyPattern::Block).is_some());
}
#[test]
fn test_sound_repetition_pattern() {
let detector = DysfluencyDetector::<TropicalWeight>::with_vocab_size(10);
let pattern = detector
.get_pattern(DysfluencyPattern::SoundRepetition)
.expect("asr/dysfluency.rs: required value was None/Err");
assert!(pattern.num_states() >= 3);
assert_eq!(pattern.start(), 0);
}
#[test]
fn test_prolongation_pattern() {
let config = DysfluencyConfig {
min_prolongation_frames: 3,
..Default::default()
};
let detector = DysfluencyDetector::<TropicalWeight>::new(10, config);
let pattern = detector
.get_pattern(DysfluencyPattern::Prolongation)
.expect("asr/dysfluency.rs: required value was None/Err");
assert!(pattern.num_states() >= 4);
}
#[test]
fn test_word_repetition_builder() {
let fst: VectorWfst<u32, TropicalWeight> = WordRepetitionBuilder::new()
.add_word(1)
.add_word(2)
.min_repetitions(2)
.max_repetitions(3)
.build();
assert_eq!(fst.num_states(), 4);
assert!(fst.is_final(2));
assert!(fst.is_final(3));
}
#[test]
fn test_syllable_repetition_builder() {
let fst: VectorWfst<PhoneId, TropicalWeight> = SyllableRepetitionBuilder::new()
.add_syllable(vec![1, 2]) .min_repetitions(2)
.build();
assert!(fst.num_states() > 0);
assert_eq!(fst.start(), 0);
}
#[test]
fn test_dysfluency_span() {
let span = DysfluencySpan {
pattern: DysfluencyPattern::SoundRepetition,
start_frame: 10,
end_frame: 15,
phones: vec![1, 1, 1],
score: 0.5,
repetition_count: Some(3),
};
assert_eq!(span.duration(), 5);
assert_eq!(span.repetition_count, Some(3));
}
#[test]
fn test_config_default() {
let config = DysfluencyConfig::default();
assert_eq!(config.min_repetitions, 2);
assert_eq!(config.max_repetitions, 5);
assert_eq!(config.min_prolongation_frames, 3);
}
#[test]
fn test_detect_empty_lattice() {
let detector = DysfluencyDetector::<TropicalWeight>::with_vocab_size(10);
let lattice: VectorWfst<PhoneId, TropicalWeight> = VectorWfst::new();
let spans = detector.detect(&lattice);
assert!(spans.is_empty());
}
}