use super::forward_score::forward_score;
use super::gradient::{backward, GradientAccumulator, GradientWfst};
use crate::semiring::{LogWeight, Semiring};
use crate::wfst::{MutableWfst, StateId, VectorWfst, Wfst};
#[derive(Clone, Debug)]
pub struct WfstKernel<L: Clone + Send + Sync> {
pub fst: VectorWfst<L, LogWeight>,
pub kernel_size: usize,
}
impl<L: Clone + Send + Sync + Default + Eq + std::hash::Hash> WfstKernel<L> {
pub fn new(vocab_size: usize, kernel_size: usize, init_weight: f64) -> Self {
let mut fst = VectorWfst::new();
let mut states = Vec::with_capacity(kernel_size + 1);
for _ in 0..=kernel_size {
states.push(fst.add_state());
}
fst.set_start(states[0]);
fst.set_final(states[kernel_size], LogWeight::one());
for pos in 0..kernel_size {
for _label_idx in 0..vocab_size {
fst.add_arc(
states[pos],
None, None,
states[pos + 1],
LogWeight::new(init_weight),
);
}
}
Self { fst, kernel_size }
}
pub fn from_wfst(fst: VectorWfst<L, LogWeight>, kernel_size: usize) -> Self {
Self { fst, kernel_size }
}
}
#[derive(Clone, Debug)]
pub struct ReceptiveField<L: Clone + Send + Sync> {
pub fst: VectorWfst<L, LogWeight>,
pub start_pos: usize,
pub size: usize,
}
impl<L: Clone + Send + Sync + Default + Eq + std::hash::Hash> ReceptiveField<L> {
pub fn from_hidden_states(hidden_states: &[(L, f64)], start_pos: usize) -> Self {
let size = hidden_states.len();
let mut fst = VectorWfst::new();
let mut states = Vec::with_capacity(size + 1);
for _ in 0..=size {
states.push(fst.add_state());
}
fst.set_start(states[0]);
fst.set_final(states[size], LogWeight::one());
for (i, (label, weight)) in hidden_states.iter().enumerate() {
fst.add_arc(
states[i],
Some(label.clone()),
Some(label.clone()),
states[i + 1],
LogWeight::new(*weight),
);
}
Self {
fst,
start_pos,
size,
}
}
pub fn from_weights(weights: &[f64], start_pos: usize) -> Self
where
L: Default,
{
let size = weights.len();
let mut fst = VectorWfst::new();
let mut states = Vec::with_capacity(size + 1);
for _ in 0..=size {
states.push(fst.add_state());
}
fst.set_start(states[0]);
fst.set_final(states[size], LogWeight::one());
for (i, &weight) in weights.iter().enumerate() {
fst.add_arc(states[i], None, None, states[i + 1], LogWeight::new(weight));
}
Self {
fst,
start_pos,
size,
}
}
}
#[derive(Clone, Debug)]
pub struct WfstConvConfig {
pub input_channels: usize,
pub output_channels: usize,
pub kernel_size: usize,
pub stride: usize,
pub padding: PaddingMode,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum PaddingMode {
Valid,
Same,
Custom(usize),
}
impl Default for WfstConvConfig {
fn default() -> Self {
Self {
input_channels: 256,
output_channels: 256,
kernel_size: 3,
stride: 1,
padding: PaddingMode::Same,
}
}
}
#[derive(Clone, Debug)]
pub struct WfstConvLayer<L: Clone + Send + Sync> {
pub kernels: Vec<WfstKernel<L>>,
pub config: WfstConvConfig,
}
impl<L: Clone + Send + Sync + Default + Eq + std::hash::Hash> WfstConvLayer<L> {
pub fn new(config: WfstConvConfig) -> Self {
let kernels = (0..config.output_channels)
.map(|_| WfstKernel::new(config.input_channels, config.kernel_size, 0.0))
.collect();
Self { kernels, config }
}
pub fn from_kernels(kernels: Vec<WfstKernel<L>>, config: WfstConvConfig) -> Self {
Self { kernels, config }
}
pub fn output_length(&self, input_length: usize) -> usize {
let padding = match self.config.padding {
PaddingMode::Valid => 0,
PaddingMode::Same => self.config.kernel_size / 2,
PaddingMode::Custom(p) => p,
};
let padded_length = input_length + 2 * padding;
if padded_length < self.config.kernel_size {
return 0;
}
(padded_length - self.config.kernel_size) / self.config.stride + 1
}
pub fn num_parameters(&self) -> usize {
self.kernels.iter().map(|k| count_arcs(&k.fst)).sum()
}
}
#[derive(Clone, Debug)]
pub struct WfstConvOutput {
pub features: Vec<Vec<f64>>,
pub gradient_wfsts: Vec<Vec<GradientWfst<u32>>>,
}
pub fn wfst_conv_forward<L: Clone + Send + Sync + Default + Eq + std::hash::Hash>(
layer: &WfstConvLayer<L>,
input: &[Vec<f64>],
) -> Vec<Vec<f64>> {
let input_length = input.len();
let output_length = layer.output_length(input_length);
let num_kernels = layer.kernels.len();
let mut output = vec![vec![0.0; output_length]; num_kernels];
if output_length == 0 {
return output;
}
let padding = match layer.config.padding {
PaddingMode::Valid => 0,
PaddingMode::Same => layer.config.kernel_size / 2,
PaddingMode::Custom(p) => p,
};
for kernel_idx in 0..num_kernels {
let kernel = &layer.kernels[kernel_idx];
for out_pos in 0..output_length {
let in_start = out_pos * layer.config.stride;
let mut rf_weights = Vec::with_capacity(layer.config.kernel_size);
for k in 0..layer.config.kernel_size {
let in_pos = in_start + k;
if in_pos < padding || in_pos >= padding + input_length {
rf_weights.push(0.0);
} else {
let actual_pos = in_pos - padding;
let weight: f64 = input[actual_pos].iter().sum();
rf_weights.push(weight);
}
}
let rf: ReceptiveField<L> = ReceptiveField::from_weights(&rf_weights, in_start);
let score = compute_receptive_field_score(&rf, kernel);
output[kernel_idx][out_pos] = score;
}
}
output
}
pub fn wfst_conv_forward_with_gradients<L: Clone + Send + Sync + Default + Eq + std::hash::Hash>(
layer: &WfstConvLayer<L>,
input: &[Vec<f64>],
) -> WfstConvOutput {
let input_length = input.len();
let output_length = layer.output_length(input_length);
let num_kernels = layer.kernels.len();
let mut features = vec![vec![0.0; output_length]; num_kernels];
let mut gradient_wfsts: Vec<Vec<GradientWfst<u32>>> = (0..num_kernels)
.map(|_| Vec::with_capacity(output_length))
.collect();
if output_length == 0 {
return WfstConvOutput {
features,
gradient_wfsts,
};
}
let padding = match layer.config.padding {
PaddingMode::Valid => 0,
PaddingMode::Same => layer.config.kernel_size / 2,
PaddingMode::Custom(p) => p,
};
for kernel_idx in 0..num_kernels {
let kernel = &layer.kernels[kernel_idx];
for out_pos in 0..output_length {
let in_start = out_pos * layer.config.stride;
let mut rf_weights = Vec::with_capacity(layer.config.kernel_size);
for k in 0..layer.config.kernel_size {
let in_pos = in_start + k;
if in_pos < padding || in_pos >= padding + input_length {
rf_weights.push(0.0);
} else {
let actual_pos = in_pos - padding;
let weight: f64 = input[actual_pos].iter().sum();
rf_weights.push(weight);
}
}
let rf: ReceptiveField<L> = ReceptiveField::from_weights(&rf_weights, in_start);
let grad_fst = GradientWfst::from_wfst(&rf.fst);
let score = forward_score(&grad_fst);
features[kernel_idx][out_pos] = score.value();
gradient_wfsts[kernel_idx].push(GradientWfst::from_wfst(&u32_view(&rf.fst, kernel)));
}
}
WfstConvOutput {
features,
gradient_wfsts,
}
}
fn u32_view<L: Clone + Send + Sync>(
rf: &crate::wfst::VectorWfst<L, crate::semiring::LogWeight>,
_kernel: &WfstKernel<L>,
) -> crate::wfst::VectorWfst<u32, crate::semiring::LogWeight> {
use crate::wfst::{MutableWfst, StateId, Wfst};
let mut out: crate::wfst::VectorWfst<u32, crate::semiring::LogWeight> =
crate::wfst::VectorWfst::new();
for _ in 0..rf.num_states() {
out.add_state();
}
if rf.start() != crate::wfst::NO_STATE {
out.set_start(rf.start());
}
for state in 0..rf.num_states() as StateId {
if rf.is_final(state) {
out.set_final(state, rf.final_weight(state));
}
for (idx, arc) in rf.transitions(state).iter().enumerate() {
out.add_arc(
arc.from,
Some(idx as u32),
Some(idx as u32),
arc.to,
arc.weight,
);
}
}
out
}
fn compute_receptive_field_score<L: Clone + Send + Sync>(
rf: &ReceptiveField<L>,
_kernel: &WfstKernel<L>,
) -> f64 {
let grad_fst = GradientWfst::from_wfst(&rf.fst);
let score = forward_score(&grad_fst);
score.value()
}
pub fn wfst_conv_backward<L: Clone + Send + Sync + Default + Eq + std::hash::Hash>(
layer: &WfstConvLayer<L>,
input: &[Vec<f64>],
output_grad: &[Vec<f64>],
) -> (Vec<Vec<f64>>, Vec<GradientAccumulator>) {
let input_length = input.len();
let input_channels = if input.is_empty() { 0 } else { input[0].len() };
let output_length = layer.output_length(input_length);
let num_kernels = layer.kernels.len();
let mut input_grad = vec![vec![0.0; input_channels]; input_length];
let mut kernel_grads: Vec<GradientAccumulator> = layer
.kernels
.iter()
.map(|k| GradientAccumulator::with_capacity(count_arcs(&k.fst)))
.collect();
if output_length == 0 {
return (input_grad, kernel_grads);
}
let padding = match layer.config.padding {
PaddingMode::Valid => 0,
PaddingMode::Same => layer.config.kernel_size / 2,
PaddingMode::Custom(p) => p,
};
for kernel_idx in 0..num_kernels {
let _kernel = &layer.kernels[kernel_idx];
for out_pos in 0..output_length {
let in_start = out_pos * layer.config.stride;
let out_grad = output_grad[kernel_idx][out_pos];
let mut rf_weights = Vec::with_capacity(layer.config.kernel_size);
for k in 0..layer.config.kernel_size {
let in_pos = in_start + k;
if in_pos < padding || in_pos >= padding + input_length {
rf_weights.push(0.0);
} else {
let actual_pos = in_pos - padding;
let weight: f64 = input[actual_pos].iter().sum();
rf_weights.push(weight);
}
}
let rf: ReceptiveField<L> = ReceptiveField::from_weights(&rf_weights, in_start);
let grad_fst = GradientWfst::from_wfst(&rf.fst);
let _ = forward_score(&grad_fst);
let rf_grads = backward(&grad_fst);
for arc_grad in &rf_grads.arc_gradients {
let k = arc_grad.arc.from as usize;
let in_pos = in_start + k;
if in_pos >= padding && in_pos < padding + input_length {
let actual_pos = in_pos - padding;
for c in 0..input_channels {
input_grad[actual_pos][c] +=
out_grad * arc_grad.gradient / input_channels as f64;
}
}
}
kernel_grads[kernel_idx].merge(&rf_grads);
}
}
(input_grad, kernel_grads)
}
fn count_arcs<L: Clone + Send + Sync, W: Semiring>(fst: &VectorWfst<L, W>) -> usize {
let mut count = 0;
for s in 0..fst.num_states() as StateId {
count += fst.transitions(s).len();
}
count
}
#[derive(Clone, Debug, Default)]
pub struct WfstConvStats {
pub num_kernels: usize,
pub kernel_size: usize,
pub num_parameters: usize,
pub equiv_traditional_params: usize,
pub reduction_ratio: f64,
}
impl<L: Clone + Send + Sync + Default + Eq + std::hash::Hash> WfstConvLayer<L> {
pub fn stats(&self) -> WfstConvStats {
let num_parameters = self.num_parameters();
let equiv_traditional =
self.config.input_channels * self.config.output_channels * self.config.kernel_size;
WfstConvStats {
num_kernels: self.kernels.len(),
kernel_size: self.config.kernel_size,
num_parameters,
equiv_traditional_params: equiv_traditional,
reduction_ratio: if num_parameters > 0 {
equiv_traditional as f64 / num_parameters as f64
} else {
0.0
},
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_wfst_kernel_creation() {
let kernel = WfstKernel::<u32>::new(10, 3, 0.0);
assert_eq!(kernel.kernel_size, 3);
assert!(kernel.fst.num_states() > 0);
}
#[test]
fn test_receptive_field_from_weights() {
let weights = vec![1.0, 2.0, 3.0];
let rf = ReceptiveField::<u32>::from_weights(&weights, 0);
assert_eq!(rf.size, 3);
assert_eq!(rf.start_pos, 0);
assert_eq!(rf.fst.num_states(), 4); }
#[test]
fn test_wfst_conv_config_default() {
let config = WfstConvConfig::default();
assert_eq!(config.kernel_size, 3);
assert_eq!(config.stride, 1);
}
#[test]
fn test_wfst_conv_layer_creation() {
let config = WfstConvConfig {
input_channels: 10,
output_channels: 5,
kernel_size: 3,
stride: 1,
padding: PaddingMode::Valid,
};
let layer = WfstConvLayer::<u32>::new(config);
assert_eq!(layer.kernels.len(), 5);
}
#[test]
fn test_output_length_valid_padding() {
let config = WfstConvConfig {
input_channels: 10,
output_channels: 5,
kernel_size: 3,
stride: 1,
padding: PaddingMode::Valid,
};
let layer = WfstConvLayer::<u32>::new(config);
assert_eq!(layer.output_length(10), 8); }
#[test]
fn test_output_length_same_padding() {
let config = WfstConvConfig {
input_channels: 10,
output_channels: 5,
kernel_size: 3,
stride: 1,
padding: PaddingMode::Same,
};
let layer = WfstConvLayer::<u32>::new(config);
assert_eq!(layer.output_length(10), 10);
}
#[test]
fn test_wfst_conv_forward() {
let config = WfstConvConfig {
input_channels: 2,
output_channels: 2,
kernel_size: 2,
stride: 1,
padding: PaddingMode::Valid,
};
let layer = WfstConvLayer::<u32>::new(config);
let input = vec![vec![1.0, 0.5], vec![0.5, 1.0], vec![1.0, 0.5]];
let output = wfst_conv_forward(&layer, &input);
assert_eq!(output.len(), 2); assert_eq!(output[0].len(), 2); }
#[test]
fn test_wfst_conv_stats() {
let config = WfstConvConfig {
input_channels: 256,
output_channels: 256,
kernel_size: 3,
stride: 1,
padding: PaddingMode::Same,
};
let layer = WfstConvLayer::<u32>::new(config);
let stats = layer.stats();
assert_eq!(stats.num_kernels, 256);
assert_eq!(stats.kernel_size, 3);
assert_eq!(stats.equiv_traditional_params, 196608);
}
#[test]
fn test_padding_mode_custom() {
let config = WfstConvConfig {
input_channels: 10,
output_channels: 5,
kernel_size: 3,
stride: 1,
padding: PaddingMode::Custom(2),
};
let layer = WfstConvLayer::<u32>::new(config);
assert_eq!(layer.output_length(10), 12);
}
#[test]
fn test_stride_greater_than_one() {
let config = WfstConvConfig {
input_channels: 10,
output_channels: 5,
kernel_size: 3,
stride: 2,
padding: PaddingMode::Valid,
};
let layer = WfstConvLayer::<u32>::new(config);
assert_eq!(layer.output_length(10), 4);
}
}