use super::conv::Conv1d;
use crate::error::Error;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(super) enum Activation {
Tanh,
Relu,
Sigmoid,
}
impl Activation {
pub(super) fn from_name(name: &str) -> Result<Self, Error> {
match name {
"Tanh" => Ok(Self::Tanh),
"ReLU" => Ok(Self::Relu),
"Sigmoid" => Ok(Self::Sigmoid),
other => Err(Error::UnsupportedActivation(other.to_string())),
}
}
#[inline]
fn apply(self, x: f32) -> f32 {
match self {
Self::Tanh => x.tanh(),
Self::Relu => x.max(0.0),
Self::Sigmoid => sigmoid(x),
}
}
}
#[inline]
fn sigmoid(x: f32) -> f32 {
1.0 / (1.0 + (-x).exp())
}
#[derive(Debug, Clone)]
pub(super) struct Layer {
conv: Conv1d,
mixin: Conv1d,
one_by_one: Conv1d,
activation: Activation,
gated: bool,
channels: usize,
block: Vec<f32>,
mix: Vec<f32>,
post: Vec<f32>,
}
impl Layer {
#[allow(clippy::too_many_arguments)]
pub(super) fn new(
channels: usize,
condition_size: usize,
kernel: usize,
dilation: usize,
activation: Activation,
gated: bool,
conv_w: Vec<f32>,
conv_b: Vec<f32>,
mix_w: Vec<f32>,
one_w: Vec<f32>,
one_b: Vec<f32>,
) -> Self {
let mid = if gated { 2 * channels } else { channels };
Self {
conv: Conv1d::new(channels, mid, kernel, dilation, conv_w, Some(conv_b)),
mixin: Conv1d::new(condition_size, mid, 1, 1, mix_w, None),
one_by_one: Conv1d::new(channels, channels, 1, 1, one_w, Some(one_b)),
activation,
gated,
channels,
block: vec![0.0; mid],
mix: vec![0.0; mid],
post: vec![0.0; channels],
}
}
pub(super) fn process_sample(
&mut self,
input: &[f32],
condition: &[f32],
head_accum: &mut [f32],
out: &mut [f32],
) {
self.conv.process_sample(input, &mut self.block);
self.mixin.process_sample(condition, &mut self.mix);
for (b, m) in self.block.iter_mut().zip(&self.mix) {
*b += *m;
}
if self.gated {
for c in 0..self.channels {
let a = self.activation.apply(self.block[c]);
let g = sigmoid(self.block[c + self.channels]);
self.post[c] = a * g;
}
} else {
for c in 0..self.channels {
self.post[c] = self.activation.apply(self.block[c]);
}
}
for (h, p) in head_accum.iter_mut().zip(&self.post) {
*h += *p;
}
self.one_by_one.process_sample(&self.post, out);
for (o, x) in out.iter_mut().zip(input) {
*o += *x;
}
}
pub(super) fn reset(&mut self) {
self.conv.reset();
self.mixin.reset();
self.one_by_one.reset();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn relu_layer_residual_and_head_accumulate() {
let mut layer = Layer::new(
1,
1,
1,
1,
Activation::Relu,
false,
vec![2.0],
vec![0.5],
vec![1.0],
vec![3.0],
vec![0.1],
);
let mut head = vec![0.0];
let mut out = vec![0.0];
layer.process_sample(&[0.5], &[0.5], &mut head, &mut out);
assert_eq!(out, vec![6.6]); assert_eq!(head, vec![2.0]);
layer.process_sample(&[0.5], &[-3.0], &mut head, &mut out);
assert_eq!(out, vec![0.6]); assert_eq!(head, vec![2.0]); }
#[test]
fn gated_layer_multiplies_tanh_path_by_sigmoid_gate() {
let mut layer = Layer::new(
1,
1,
1,
1,
Activation::Relu,
true,
vec![2.0, 0.0],
vec![0.0, 0.0],
vec![0.0, 0.0],
vec![1.0],
vec![0.0],
);
let mut head = vec![0.0];
let mut out = vec![0.0];
layer.process_sample(&[1.0], &[0.0], &mut head, &mut out);
assert_eq!(head, vec![1.0]);
assert_eq!(out, vec![2.0]); }
#[test]
fn tanh_activation_matches_reference_value() {
let mut layer = Layer::new(
1,
1,
1,
1,
Activation::Tanh,
false,
vec![2.0],
vec![0.5],
vec![1.0],
vec![3.0],
vec![0.1],
);
let mut head = vec![0.0];
let mut out = vec![0.0];
layer.process_sample(&[0.5], &[0.5], &mut head, &mut out);
let post = 2.0_f32.tanh();
assert!((head[0] - post).abs() < 1e-6, "head={}", head[0]);
assert!((out[0] - (3.0 * post + 0.6)).abs() < 1e-6, "out={}", out[0]);
}
#[test]
fn unknown_activation_is_rejected() {
assert!(Activation::from_name("Swish").is_err());
assert_eq!(Activation::from_name("Tanh").unwrap(), Activation::Tanh);
}
}