use super::conv::{Conv1d, MAX_BLOCK};
use crate::error::Error;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(super) enum Activation {
Tanh,
Relu,
Sigmoid,
}
impl Activation {
pub(super) fn from_name(name: &str) -> Result<Self, Error> {
match name {
"Tanh" => Ok(Self::Tanh),
"ReLU" => Ok(Self::Relu),
"Sigmoid" => Ok(Self::Sigmoid),
other => Err(Error::UnsupportedActivation(other.to_string())),
}
}
#[inline]
fn apply(self, x: f32) -> f32 {
match self {
Self::Tanh => x.tanh(),
Self::Relu => x.max(0.0),
Self::Sigmoid => sigmoid(x),
}
}
}
#[inline]
fn sigmoid(x: f32) -> f32 {
1.0 / (1.0 + (-x).exp())
}
#[derive(Debug, Clone)]
pub(super) struct Layer {
conv: Conv1d,
mixin: Conv1d,
one_by_one: Conv1d,
activation: Activation,
gated: bool,
channels: usize,
block: Vec<f32>,
mix: Vec<f32>,
post: Vec<f32>,
block_blk: Vec<f32>,
mix_blk: Vec<f32>,
post_blk: Vec<f32>,
}
impl Layer {
#[allow(clippy::too_many_arguments)]
pub(super) fn new(
channels: usize,
condition_size: usize,
kernel: usize,
dilation: usize,
activation: Activation,
gated: bool,
conv_w: Vec<f32>,
conv_b: Vec<f32>,
mix_w: Vec<f32>,
one_w: Vec<f32>,
one_b: Vec<f32>,
) -> Self {
let mid = if gated { 2 * channels } else { channels };
Self {
conv: Conv1d::new(channels, mid, kernel, dilation, conv_w, Some(conv_b)),
mixin: Conv1d::new(condition_size, mid, 1, 1, mix_w, None),
one_by_one: Conv1d::new(channels, channels, 1, 1, one_w, Some(one_b)),
activation,
gated,
channels,
block: vec![0.0; mid],
mix: vec![0.0; mid],
post: vec![0.0; channels],
block_blk: vec![0.0; mid * MAX_BLOCK],
mix_blk: vec![0.0; mid * MAX_BLOCK],
post_blk: vec![0.0; channels * MAX_BLOCK],
}
}
pub(super) fn process_sample(
&mut self,
input: &[f32],
condition: &[f32],
head_accum: &mut [f32],
out: &mut [f32],
) {
self.conv.process_sample(input, &mut self.block);
self.mixin.process_sample(condition, &mut self.mix);
for (b, m) in self.block.iter_mut().zip(&self.mix) {
*b += *m;
}
if self.gated {
for c in 0..self.channels {
let a = self.activation.apply(self.block[c]);
let g = sigmoid(self.block[c + self.channels]);
self.post[c] = a * g;
}
} else {
for c in 0..self.channels {
self.post[c] = self.activation.apply(self.block[c]);
}
}
for (h, p) in head_accum.iter_mut().zip(&self.post) {
*h += *p;
}
self.one_by_one.process_sample(&self.post, out);
for (o, x) in out.iter_mut().zip(input) {
*o += *x;
}
}
pub(super) fn process_block(
&mut self,
input: &[f32],
condition: &[f32],
head_accum: &mut [f32],
out: &mut [f32],
n: usize,
) {
let mid = self.block.len();
let block = &mut self.block_blk[..mid * n];
let mix = &mut self.mix_blk[..mid * n];
let post = &mut self.post_blk[..self.channels * n];
self.conv.process_block(input, block, n);
self.mixin.process_block(condition, mix, n);
for (b, m) in block.iter_mut().zip(mix.iter()) {
*b += *m;
}
if self.gated {
for c in 0..self.channels {
let (vrow, grow) = (c * n, (c + self.channels) * n);
for t in 0..n {
let a = self.activation.apply(block[vrow + t]);
let g = sigmoid(block[grow + t]);
post[c * n + t] = a * g;
}
}
} else {
for c in 0..self.channels {
for t in 0..n {
post[c * n + t] = self.activation.apply(block[c * n + t]);
}
}
}
for (h, p) in head_accum.iter_mut().zip(post.iter()) {
*h += *p;
}
self.one_by_one.process_block(post, out, n);
for (o, x) in out.iter_mut().zip(input.iter()) {
*o += *x;
}
}
pub(super) fn reset(&mut self) {
self.conv.reset();
self.mixin.reset();
self.one_by_one.reset();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn relu_layer_residual_and_head_accumulate() {
let mut layer = Layer::new(
1,
1,
1,
1,
Activation::Relu,
false,
vec![2.0],
vec![0.5],
vec![1.0],
vec![3.0],
vec![0.1],
);
let mut head = vec![0.0];
let mut out = vec![0.0];
layer.process_sample(&[0.5], &[0.5], &mut head, &mut out);
assert_eq!(out, vec![6.6]); assert_eq!(head, vec![2.0]);
layer.process_sample(&[0.5], &[-3.0], &mut head, &mut out);
assert_eq!(out, vec![0.6]); assert_eq!(head, vec![2.0]); }
#[test]
fn gated_layer_multiplies_tanh_path_by_sigmoid_gate() {
let mut layer = Layer::new(
1,
1,
1,
1,
Activation::Relu,
true,
vec![2.0, 0.0],
vec![0.0, 0.0],
vec![0.0, 0.0],
vec![1.0],
vec![0.0],
);
let mut head = vec![0.0];
let mut out = vec![0.0];
layer.process_sample(&[1.0], &[0.0], &mut head, &mut out);
assert_eq!(head, vec![1.0]);
assert_eq!(out, vec![2.0]); }
#[test]
fn tanh_activation_matches_reference_value() {
let mut layer = Layer::new(
1,
1,
1,
1,
Activation::Tanh,
false,
vec![2.0],
vec![0.5],
vec![1.0],
vec![3.0],
vec![0.1],
);
let mut head = vec![0.0];
let mut out = vec![0.0];
layer.process_sample(&[0.5], &[0.5], &mut head, &mut out);
let post = 2.0_f32.tanh();
assert!((head[0] - post).abs() < 1e-6, "head={}", head[0]);
assert!((out[0] - (3.0 * post + 0.6)).abs() < 1e-6, "out={}", out[0]);
}
#[test]
fn process_block_equals_process_sample_loop() {
for gated in [false, true] {
let channels = 3usize;
let cond_sz = 2usize;
let kernel = 3usize;
let dilation = 4usize;
let mid = if gated { 2 * channels } else { channels };
let mk = |len: usize, salt: usize| -> Vec<f32> {
(0..len)
.map(|i| (((i * 31 + salt * 7) % 29) as f32 - 14.0) * 0.07)
.collect()
};
let conv_w = mk(mid * channels * kernel, 1);
let conv_b = mk(mid, 2);
let mix_w = mk(mid * cond_sz, 3);
let one_w = mk(channels * channels, 4);
let one_b = mk(channels, 5);
let total = 130usize;
let inp: Vec<Vec<f32>> = (0..total)
.map(|t| {
(0..channels)
.map(|c| ((t * 3 + c) as f32 * 0.21).sin())
.collect()
})
.collect();
let cond: Vec<Vec<f32>> = (0..total)
.map(|t| {
(0..cond_sz)
.map(|c| ((t * 5 + c) as f32 * 0.17).cos())
.collect()
})
.collect();
let seed: Vec<Vec<f32>> = (0..total)
.map(|t| (0..channels).map(|c| ((t + c) as f32) * 0.01).collect())
.collect();
let mk_layer = || {
Layer::new(
channels,
cond_sz,
kernel,
dilation,
Activation::Tanh,
gated,
conv_w.clone(),
conv_b.clone(),
mix_w.clone(),
one_w.clone(),
one_b.clone(),
)
};
let mut a = mk_layer();
let mut out_ref = vec![vec![0.0; channels]; total];
let mut head_ref = vec![vec![0.0; channels]; total];
for t in 0..total {
let mut head = seed[t].clone();
let mut out = vec![0.0; channels];
a.process_sample(&inp[t], &cond[t], &mut head, &mut out);
out_ref[t] = out;
head_ref[t] = head;
}
let mut b = mk_layer();
for (lo, len) in [(0usize, 70usize), (70, 60)] {
let mut bin = vec![0.0; channels * len];
let mut bcond = vec![0.0; cond_sz * len];
let mut bhead = vec![0.0; channels * len];
for lt in 0..len {
for c in 0..channels {
bin[c * len + lt] = inp[lo + lt][c];
bhead[c * len + lt] = seed[lo + lt][c];
}
for c in 0..cond_sz {
bcond[c * len + lt] = cond[lo + lt][c];
}
}
let mut bout = vec![0.0; channels * len];
b.process_block(&bin, &bcond, &mut bhead, &mut bout, len);
for lt in 0..len {
for c in 0..channels {
let go = bout[c * len + lt];
let gh = bhead[c * len + lt];
assert!(
(go - out_ref[lo + lt][c]).abs() < 1e-5,
"gated={gated} t{} c{c} out: got {go}, want {}",
lo + lt,
out_ref[lo + lt][c]
);
assert!(
(gh - head_ref[lo + lt][c]).abs() < 1e-5,
"gated={gated} t{} c{c} head: got {gh}, want {}",
lo + lt,
head_ref[lo + lt][c]
);
}
}
}
}
}
#[test]
fn unknown_activation_is_rejected() {
assert!(Activation::from_name("Swish").is_err());
assert_eq!(Activation::from_name("Tanh").unwrap(), Activation::Tanh);
}
}