use super::activation::Activation;
use crate::model::GatingMode;
#[derive(Debug, Clone)]
pub(super) struct Gating {
mode: GatingMode,
primary: Activation,
secondary: Activation,
bottleneck: usize,
}
impl Gating {
pub(super) fn new(
mode: GatingMode,
primary: Activation,
secondary: Activation,
bottleneck: usize,
) -> Self {
Self {
mode,
primary,
secondary,
bottleneck,
}
}
pub(super) fn mode(&self) -> GatingMode {
self.mode
}
pub(super) fn input_rows(&self) -> usize {
match self.mode {
GatingMode::None => self.bottleneck,
GatingMode::Gated | GatingMode::Blended => 2 * self.bottleneck,
}
}
#[cfg(test)]
pub(super) fn output_rows(&self) -> usize {
self.bottleneck
}
pub(super) fn process_sample(&self, z: &[f32], out: &mut [f32]) {
let bn = self.bottleneck;
match self.mode {
GatingMode::None => {
for c in 0..bn {
out[c] = self.primary.apply(z[c]);
}
}
GatingMode::Gated => {
for c in 0..bn {
let v = self.primary.apply(z[c]);
let s = self.secondary.apply(z[c + bn]);
out[c] = v * s;
}
}
GatingMode::Blended => {
for c in 0..bn {
let alpha = self.secondary.apply(z[c + bn]);
let v = self.primary.apply(z[c]);
out[c] = alpha * v + (1.0 - alpha) * z[c];
}
}
}
}
pub(super) fn process_block(&self, z: &[f32], out: &mut [f32], n: usize) {
let bn = self.bottleneck;
match self.mode {
GatingMode::None => {
for c in 0..bn {
let base = c * n;
for t in 0..n {
out[base + t] = self.primary.apply(z[base + t]);
}
}
}
GatingMode::Gated => {
for c in 0..bn {
let (vrow, grow, orow) = (c * n, (c + bn) * n, c * n);
for t in 0..n {
let v = self.primary.apply(z[vrow + t]);
let s = self.secondary.apply(z[grow + t]);
out[orow + t] = v * s;
}
}
}
GatingMode::Blended => {
for c in 0..bn {
let (vrow, grow, orow) = (c * n, (c + bn) * n, c * n);
for t in 0..n {
let raw = z[vrow + t];
let alpha = self.secondary.apply(z[grow + t]);
let v = self.primary.apply(raw);
out[orow + t] = alpha * v + (1.0 - alpha) * raw;
}
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn row_counts_follow_mode() {
let g_none = Gating::new(GatingMode::None, Activation::Tanh, Activation::Sigmoid, 4);
assert_eq!(g_none.input_rows(), 4);
assert_eq!(g_none.output_rows(), 4);
let g_gated = Gating::new(GatingMode::Gated, Activation::Tanh, Activation::Sigmoid, 4);
assert_eq!(g_gated.input_rows(), 8);
assert_eq!(g_gated.output_rows(), 4);
let g_blend = Gating::new(
GatingMode::Blended,
Activation::Tanh,
Activation::Sigmoid,
3,
);
assert_eq!(g_blend.input_rows(), 6);
assert_eq!(g_blend.output_rows(), 3);
}
#[test]
fn process_sample_none_applies_primary_over_bn_rows() {
let g = Gating::new(GatingMode::None, Activation::Relu, Activation::Sigmoid, 2);
let mut out = [0.0_f32; 2];
g.process_sample(&[2.0, -3.0], &mut out);
assert_eq!(out, [2.0, 0.0]);
}
#[test]
fn process_sample_gated_multiplies_primary_by_secondary() {
let g = Gating::new(GatingMode::Gated, Activation::Relu, Activation::Sigmoid, 1);
let mut out = [0.0_f32];
g.process_sample(&[2.0, 0.0], &mut out);
assert_eq!(out, [1.0]);
}
#[test]
fn process_sample_gated_tanh_sigmoid_reproduces_a1() {
let g = Gating::new(GatingMode::Gated, Activation::Tanh, Activation::Sigmoid, 2);
let z = [0.5_f32, -0.5, 1.0, -1.0];
let mut out = [0.0_f32; 2];
g.process_sample(&z, &mut out);
let sig = |x: f32| 1.0_f32 / (1.0 + (-x).exp());
let want0 = 0.5_f32.tanh() * sig(1.0);
let want1 = (-0.5_f32).tanh() * sig(-1.0);
assert!(
(out[0] - want0).abs() < 1e-7,
"out0={} want={}",
out[0],
want0
);
assert!(
(out[1] - want1).abs() < 1e-7,
"out1={} want={}",
out[1],
want1
);
}
#[test]
fn process_sample_blended_uses_raw_pre_activation_for_one_minus_alpha() {
let g = Gating::new(
GatingMode::Blended,
Activation::Relu,
Activation::Sigmoid,
1,
);
let mut out = [0.0_f32];
g.process_sample(&[-2.0, 0.0], &mut out);
assert!((out[0] - (-1.0)).abs() < 1e-7, "out={}", out[0]);
}
#[test]
fn process_sample_blended_two_channels_tanh_sigmoid() {
let g = Gating::new(
GatingMode::Blended,
Activation::Tanh,
Activation::Sigmoid,
2,
);
let z = [0.5_f32, -1.0, 0.0, 2.0];
let mut out = [0.0_f32; 2];
g.process_sample(&z, &mut out);
let sig = |x: f32| 1.0_f32 / (1.0 + (-x).exp());
let a0 = sig(0.0); let a1 = sig(2.0);
let want0 = a0 * 0.5_f32.tanh() + (1.0 - a0) * 0.5;
let want1 = a1 * (-1.0_f32).tanh() - (1.0 - a1);
assert!(
(out[0] - want0).abs() < 1e-7,
"out0={} want={}",
out[0],
want0
);
assert!(
(out[1] - want1).abs() < 1e-7,
"out1={} want={}",
out[1],
want1
);
}
#[test]
fn process_block_direct_gated_two_channels() {
let g = Gating::new(GatingMode::Gated, Activation::Relu, Activation::Sigmoid, 2);
let z = [1.0_f32, 2.0, 3.0, -1.0, 0.0, 0.0, 0.0, 0.0];
let mut out = [0.0_f32; 4]; g.process_block(&z, &mut out, 2);
assert_eq!(out, [0.5, 1.0, 1.5, 0.0]);
}
#[test]
fn process_block_equals_process_sample_loop_all_modes() {
let modes = [GatingMode::None, GatingMode::Gated, GatingMode::Blended];
for mode in modes {
let bn = 3usize;
let g = Gating::new(mode, Activation::Tanh, Activation::Sigmoid, bn);
let in_rows = g.input_rows();
let n = 5usize;
let val = |r: usize, t: usize| (((r * 7 + t * 13) % 23) as f32 - 11.0) * 0.13;
let mut z = vec![0.0_f32; in_rows * n];
for r in 0..in_rows {
for t in 0..n {
z[r * n + t] = val(r, t);
}
}
let mut want = vec![0.0_f32; bn * n];
for t in 0..n {
let zc: Vec<f32> = (0..in_rows).map(|r| z[r * n + t]).collect();
let mut oc = vec![0.0_f32; bn];
g.process_sample(&zc, &mut oc);
for c in 0..bn {
want[c * n + t] = oc[c];
}
}
let mut got = vec![0.0_f32; bn * n];
g.process_block(&z, &mut got, n);
for (i, (a, b)) in got.iter().zip(&want).enumerate() {
assert!(
(a - b).abs() < 1e-6,
"mode={mode:?} idx{i}: block {a}, per-sample {b}"
);
}
}
}
}