use super::conv::Conv1d;
use super::layer::Layer;
#[derive(Debug, Clone)]
pub(super) struct LayerArray {
rechannel: Conv1d,
layers: Vec<Layer>,
head_rechannel: Conv1d,
channels: usize,
head_size: usize,
cur: Vec<f32>,
nxt: Vec<f32>,
head_accum: Vec<f32>,
}
impl LayerArray {
pub(super) fn new(
input_size: usize,
channels: usize,
head_size: usize,
rechannel_w: Vec<f32>,
layers: Vec<Layer>,
head_w: Vec<f32>,
head_b: Option<Vec<f32>>,
) -> Self {
Self {
rechannel: Conv1d::new(input_size, channels, 1, 1, rechannel_w, None),
layers,
head_rechannel: Conv1d::new(channels, head_size, 1, 1, head_w, head_b),
channels,
head_size,
cur: vec![0.0; channels],
nxt: vec![0.0; channels],
head_accum: vec![0.0; channels],
}
}
pub(super) fn channels(&self) -> usize {
self.channels
}
pub(super) fn head_size(&self) -> usize {
self.head_size
}
pub(super) fn process_sample(
&mut self,
input: &[f32],
condition: &[f32],
head_in: &[f32],
head_out: &mut [f32],
array_out: &mut [f32],
) {
self.rechannel.process_sample(input, &mut self.cur);
self.head_accum.copy_from_slice(head_in);
for i in 0..self.layers.len() {
self.layers[i].process_sample(
&self.cur,
condition,
&mut self.head_accum,
&mut self.nxt,
);
std::mem::swap(&mut self.cur, &mut self.nxt);
}
self.head_rechannel
.process_sample(&self.head_accum, head_out);
array_out.copy_from_slice(&self.cur);
}
pub(super) fn reset(&mut self) {
self.rechannel.reset();
self.head_rechannel.reset();
for layer in &mut self.layers {
layer.reset();
}
}
}
#[cfg(test)]
mod tests {
use super::super::layer::{Activation, Layer};
use super::*;
fn relu_layer(
channels: usize,
conv_w: Vec<f32>,
conv_b: Vec<f32>,
mix_w: Vec<f32>,
one_w: Vec<f32>,
one_b: Vec<f32>,
) -> Layer {
Layer::new(
channels,
1,
1,
1,
Activation::Relu,
false,
conv_w,
conv_b,
mix_w,
one_w,
one_b,
)
}
#[test]
fn single_layer_array_rechannels_and_projects_head() {
let layer = relu_layer(1, vec![2.0], vec![0.5], vec![1.0], vec![3.0], vec![0.1]);
let mut array = LayerArray::new(1, 1, 1, vec![1.0], vec![layer], vec![0.5], None);
let mut head_out = vec![0.0];
let mut array_out = vec![0.0];
array.process_sample(&[0.5], &[0.5], &[0.0], &mut head_out, &mut array_out);
assert_eq!(array_out, vec![6.6]);
assert_eq!(head_out, vec![1.0]);
}
#[test]
fn head_accumulates_across_stacked_layers() {
let l0 = relu_layer(1, vec![1.0], vec![0.0], vec![0.0], vec![1.0], vec![0.0]);
let l1 = relu_layer(1, vec![1.0], vec![0.0], vec![0.0], vec![1.0], vec![0.0]);
let mut array = LayerArray::new(1, 1, 1, vec![1.0], vec![l0, l1], vec![1.0], None);
let mut head_out = vec![0.0];
let mut array_out = vec![0.0];
array.process_sample(&[2.0], &[0.0], &[0.0], &mut head_out, &mut array_out);
assert_eq!(array_out, vec![8.0]);
assert_eq!(head_out, vec![6.0]); }
#[test]
fn incoming_head_is_carried_and_head_bias_applies() {
let layer = relu_layer(1, vec![1.0], vec![0.0], vec![0.0], vec![1.0], vec![0.0]);
let mut array =
LayerArray::new(1, 1, 1, vec![1.0], vec![layer], vec![2.0], Some(vec![1.0]));
let mut head_out = vec![0.0];
let mut array_out = vec![0.0];
array.process_sample(&[-5.0], &[0.0], &[10.0], &mut head_out, &mut array_out);
assert_eq!(head_out, vec![21.0]);
assert_eq!(array_out, vec![-5.0]);
}
#[test]
fn channels_and_head_size_reported() {
let layer = relu_layer(1, vec![1.0], vec![0.0], vec![0.0], vec![1.0], vec![0.0]);
let array = LayerArray::new(1, 1, 2, vec![1.0], vec![layer], vec![1.0, 1.0], None);
assert_eq!(array.channels(), 1);
assert_eq!(array.head_size(), 2);
}
}