use super::conv::{Conv1d, MAX_BLOCK};
use super::layer::Layer;
#[derive(Debug, Clone)]
pub(super) struct LayerArray {
rechannel: Conv1d,
layers: Vec<Layer>,
head_rechannel: Conv1d,
channels: usize,
head_size: usize,
cur: Vec<f32>,
nxt: Vec<f32>,
head_accum: Vec<f32>,
cur_blk: Vec<f32>,
nxt_blk: Vec<f32>,
head_accum_blk: Vec<f32>,
}
impl LayerArray {
pub(super) fn new(
input_size: usize,
channels: usize,
head_size: usize,
rechannel_w: Vec<f32>,
layers: Vec<Layer>,
head_w: Vec<f32>,
head_b: Option<Vec<f32>>,
) -> Self {
Self {
rechannel: Conv1d::new(input_size, channels, 1, 1, rechannel_w, None),
layers,
head_rechannel: Conv1d::new(channels, head_size, 1, 1, head_w, head_b),
channels,
head_size,
cur: vec![0.0; channels],
nxt: vec![0.0; channels],
head_accum: vec![0.0; channels],
cur_blk: vec![0.0; channels * MAX_BLOCK],
nxt_blk: vec![0.0; channels * MAX_BLOCK],
head_accum_blk: vec![0.0; channels * MAX_BLOCK],
}
}
pub(super) fn channels(&self) -> usize {
self.channels
}
pub(super) fn head_size(&self) -> usize {
self.head_size
}
pub(super) fn process_sample(
&mut self,
input: &[f32],
condition: &[f32],
head_in: &[f32],
head_out: &mut [f32],
array_out: &mut [f32],
) {
self.rechannel.process_sample(input, &mut self.cur);
self.head_accum.copy_from_slice(head_in);
for i in 0..self.layers.len() {
self.layers[i].process_sample(
&self.cur,
condition,
&mut self.head_accum,
&mut self.nxt,
);
std::mem::swap(&mut self.cur, &mut self.nxt);
}
self.head_rechannel
.process_sample(&self.head_accum, head_out);
array_out.copy_from_slice(&self.cur);
}
pub(super) fn process_block(
&mut self,
input: &[f32],
condition: &[f32],
head_in: &[f32],
head_out: &mut [f32],
array_out: &mut [f32],
n: usize,
) {
let ch = self.channels;
self.rechannel
.process_block(input, &mut self.cur_blk[..ch * n], n);
self.head_accum_blk[..ch * n].copy_from_slice(head_in);
for i in 0..self.layers.len() {
let (cur, nxt) = (&self.cur_blk[..ch * n], &mut self.nxt_blk[..ch * n]);
self.layers[i].process_block(
cur,
condition,
&mut self.head_accum_blk[..ch * n],
nxt,
n,
);
std::mem::swap(&mut self.cur_blk, &mut self.nxt_blk);
}
self.head_rechannel
.process_block(&self.head_accum_blk[..ch * n], head_out, n);
array_out.copy_from_slice(&self.cur_blk[..ch * n]);
}
pub(super) fn reset(&mut self) {
self.rechannel.reset();
self.head_rechannel.reset();
for layer in &mut self.layers {
layer.reset();
}
}
}
#[cfg(test)]
mod tests {
use super::super::layer::{Activation, Layer};
use super::*;
fn relu_layer(
channels: usize,
conv_w: Vec<f32>,
conv_b: Vec<f32>,
mix_w: Vec<f32>,
one_w: Vec<f32>,
one_b: Vec<f32>,
) -> Layer {
Layer::new(
channels,
1,
1,
1,
Activation::Relu,
false,
conv_w,
conv_b,
mix_w,
one_w,
one_b,
)
}
#[test]
fn process_block_equals_process_sample_loop() {
let input_size = 1usize;
let channels = 3usize;
let head_size = 2usize;
let cond_sz = 1usize;
let kernel = 3usize;
let mk = |len: usize, salt: usize| -> Vec<f32> {
(0..len)
.map(|i| (((i * 13 + salt * 5) % 19) as f32 - 9.0) * 0.06)
.collect()
};
let tanh_layer = |dilation: usize, salt: usize| {
Layer::new(
channels,
cond_sz,
kernel,
dilation,
Activation::Tanh,
false,
mk(channels * channels * kernel, salt),
mk(channels, salt + 1),
mk(channels * cond_sz, salt + 2),
mk(channels * channels, salt + 3),
mk(channels, salt + 4),
)
};
let mk_array = || {
LayerArray::new(
input_size,
channels,
head_size,
mk(channels * input_size, 20),
vec![tanh_layer(1, 30), tanh_layer(2, 40), tanh_layer(4, 50)],
mk(head_size * channels, 60),
Some(mk(head_size, 70)),
)
};
let total = 120usize;
let inp: Vec<Vec<f32>> = (0..total)
.map(|t| {
(0..input_size)
.map(|c| ((t + c) as f32 * 0.23).sin())
.collect()
})
.collect();
let cond: Vec<Vec<f32>> = (0..total)
.map(|t| {
(0..cond_sz)
.map(|c| ((t + c) as f32 * 0.19).cos())
.collect()
})
.collect();
let head_in: Vec<Vec<f32>> = (0..total)
.map(|t| {
(0..channels)
.map(|c| ((t * 2 + c) as f32) * 0.013)
.collect()
})
.collect();
let mut a = mk_array();
let mut head_ref = vec![vec![0.0; head_size]; total];
let mut out_ref = vec![vec![0.0; channels]; total];
for t in 0..total {
let mut ho = vec![0.0; head_size];
let mut ao = vec![0.0; channels];
a.process_sample(&inp[t], &cond[t], &head_in[t], &mut ho, &mut ao);
head_ref[t] = ho;
out_ref[t] = ao;
}
let mut b = mk_array();
let mut lo = 0usize;
for &len in &[33usize, 1, 86] {
let mut bin = vec![0.0; input_size * len];
let mut bcond = vec![0.0; cond_sz * len];
let mut bhead = vec![0.0; channels * len];
for lt in 0..len {
for c in 0..input_size {
bin[c * len + lt] = inp[lo + lt][c];
}
for c in 0..cond_sz {
bcond[c * len + lt] = cond[lo + lt][c];
}
for c in 0..channels {
bhead[c * len + lt] = head_in[lo + lt][c];
}
}
let mut bho = vec![0.0; head_size * len];
let mut bao = vec![0.0; channels * len];
b.process_block(&bin, &bcond, &bhead, &mut bho, &mut bao, len);
for lt in 0..len {
for c in 0..head_size {
let g = bho[c * len + lt];
assert!(
(g - head_ref[lo + lt][c]).abs() < 1e-5,
"t{} head c{c}: got {g}, want {}",
lo + lt,
head_ref[lo + lt][c]
);
}
for c in 0..channels {
let g = bao[c * len + lt];
assert!(
(g - out_ref[lo + lt][c]).abs() < 1e-5,
"t{} out c{c}: got {g}, want {}",
lo + lt,
out_ref[lo + lt][c]
);
}
}
lo += len;
}
}
#[test]
fn single_layer_array_rechannels_and_projects_head() {
let layer = relu_layer(1, vec![2.0], vec![0.5], vec![1.0], vec![3.0], vec![0.1]);
let mut array = LayerArray::new(1, 1, 1, vec![1.0], vec![layer], vec![0.5], None);
let mut head_out = vec![0.0];
let mut array_out = vec![0.0];
array.process_sample(&[0.5], &[0.5], &[0.0], &mut head_out, &mut array_out);
assert_eq!(array_out, vec![6.6]);
assert_eq!(head_out, vec![1.0]);
}
#[test]
fn head_accumulates_across_stacked_layers() {
let l0 = relu_layer(1, vec![1.0], vec![0.0], vec![0.0], vec![1.0], vec![0.0]);
let l1 = relu_layer(1, vec![1.0], vec![0.0], vec![0.0], vec![1.0], vec![0.0]);
let mut array = LayerArray::new(1, 1, 1, vec![1.0], vec![l0, l1], vec![1.0], None);
let mut head_out = vec![0.0];
let mut array_out = vec![0.0];
array.process_sample(&[2.0], &[0.0], &[0.0], &mut head_out, &mut array_out);
assert_eq!(array_out, vec![8.0]);
assert_eq!(head_out, vec![6.0]); }
#[test]
fn incoming_head_is_carried_and_head_bias_applies() {
let layer = relu_layer(1, vec![1.0], vec![0.0], vec![0.0], vec![1.0], vec![0.0]);
let mut array =
LayerArray::new(1, 1, 1, vec![1.0], vec![layer], vec![2.0], Some(vec![1.0]));
let mut head_out = vec![0.0];
let mut array_out = vec![0.0];
array.process_sample(&[-5.0], &[0.0], &[10.0], &mut head_out, &mut array_out);
assert_eq!(head_out, vec![21.0]);
assert_eq!(array_out, vec![-5.0]);
}
#[test]
fn channels_and_head_size_reported() {
let layer = relu_layer(1, vec![1.0], vec![0.0], vec![0.0], vec![1.0], vec![0.0]);
let array = LayerArray::new(1, 1, 2, vec![1.0], vec![layer], vec![1.0, 1.0], None);
assert_eq!(array.channels(), 1);
assert_eq!(array.head_size(), 2);
}
}