use super::activations::sigmoid_approximated;
use super::vector_math::VectorMath;
use super::weights::WEIGHTS_SCALE;
pub(crate) const GRU_LAYER_MAX_UNITS: usize = 24;
const NUM_GRU_GATES: usize = 3;
#[derive(Debug)]
pub(crate) struct GatedRecurrentLayer {
input_size: usize,
output_size: usize,
bias: Vec<f32>,
weights: Vec<f32>,
recurrent_weights: Vec<f32>,
vector_math: VectorMath,
state: [f32; GRU_LAYER_MAX_UNITS],
}
impl GatedRecurrentLayer {
pub(crate) fn new(
input_size: usize,
output_size: usize,
bias: &[i8],
weights: &[i8],
recurrent_weights: &[i8],
vector_math: VectorMath,
) -> Self {
debug_assert!(output_size <= GRU_LAYER_MAX_UNITS);
debug_assert_eq!(bias.len(), NUM_GRU_GATES * output_size);
debug_assert_eq!(weights.len(), NUM_GRU_GATES * input_size * output_size);
debug_assert_eq!(
recurrent_weights.len(),
NUM_GRU_GATES * output_size * output_size
);
let preprocessed_bias = preprocess_gru_tensor(bias, output_size);
let preprocessed_weights = preprocess_gru_tensor(weights, output_size);
let preprocessed_recurrent = preprocess_gru_tensor(recurrent_weights, output_size);
let mut layer = Self {
input_size,
output_size,
bias: preprocessed_bias,
weights: preprocessed_weights,
recurrent_weights: preprocessed_recurrent,
vector_math,
state: [0.0; GRU_LAYER_MAX_UNITS],
};
layer.reset();
layer
}
pub(crate) fn input_size(&self) -> usize {
self.input_size
}
pub(crate) fn output(&self) -> &[f32] {
&self.state[..self.output_size]
}
pub(crate) fn size(&self) -> usize {
self.output_size
}
pub(crate) fn reset(&mut self) {
self.state.fill(0.0);
}
pub(crate) fn compute_output(&mut self, input: &[f32]) {
debug_assert_eq!(input.len(), self.input_size);
let stride_weights = self.input_size * self.output_size;
let stride_recurrent = self.output_size * self.output_size;
let mut update = [0.0_f32; GRU_LAYER_MAX_UNITS];
self.compute_update_reset_gate(
input,
0, stride_weights,
stride_recurrent,
&mut update,
);
let mut reset = [0.0_f32; GRU_LAYER_MAX_UNITS];
self.compute_update_reset_gate(
input,
1, stride_weights,
stride_recurrent,
&mut reset,
);
self.compute_state_gate(input, &update, &reset, stride_weights, stride_recurrent);
}
fn compute_update_reset_gate(
&self,
input: &[f32],
gate_index: usize,
stride_weights: usize,
stride_recurrent: usize,
gate: &mut [f32; GRU_LAYER_MAX_UNITS],
) {
let bias_offset = gate_index * self.output_size;
let w_offset = gate_index * stride_weights;
let r_offset = gate_index * stride_recurrent;
let state = &self.state[..self.output_size];
for (o, gate_val) in gate.iter_mut().enumerate().take(self.output_size) {
let mut x = self.bias[bias_offset + o];
x += self.vector_math.dot_product(
input,
&self.weights[w_offset + o * self.input_size..w_offset + (o + 1) * self.input_size],
);
x += self.vector_math.dot_product(
state,
&self.recurrent_weights
[r_offset + o * self.output_size..r_offset + (o + 1) * self.output_size],
);
*gate_val = sigmoid_approximated(x);
}
}
fn compute_state_gate(
&mut self,
input: &[f32],
update: &[f32; GRU_LAYER_MAX_UNITS],
reset: &[f32; GRU_LAYER_MAX_UNITS],
stride_weights: usize,
stride_recurrent: usize,
) {
let bias_offset = 2 * self.output_size;
let w_offset = 2 * stride_weights;
let r_offset = 2 * stride_recurrent;
let mut reset_x_state = [0.0_f32; GRU_LAYER_MAX_UNITS];
for o in 0..self.output_size {
reset_x_state[o] = self.state[o] * reset[o];
}
for (o, &u) in update.iter().enumerate().take(self.output_size) {
let mut x = self.bias[bias_offset + o];
x += self.vector_math.dot_product(
input,
&self.weights[w_offset + o * self.input_size..w_offset + (o + 1) * self.input_size],
);
x += self.vector_math.dot_product(
&reset_x_state[..self.output_size],
&self.recurrent_weights
[r_offset + o * self.output_size..r_offset + (o + 1) * self.output_size],
);
self.state[o] = u * self.state[o] + (1.0 - u) * x.max(0.0);
}
}
}
fn preprocess_gru_tensor(tensor_src: &[i8], output_size: usize) -> Vec<f32> {
let n = tensor_src.len() / (output_size * NUM_GRU_GATES);
debug_assert_eq!(tensor_src.len(), n * output_size * NUM_GRU_GATES);
let stride_src = NUM_GRU_GATES * output_size;
let stride_dst = n * output_size;
let mut tensor_dst = vec![0.0_f32; tensor_src.len()];
for g in 0..NUM_GRU_GATES {
for o in 0..output_size {
for i in 0..n {
tensor_dst[g * stride_dst + o * n + i] =
WEIGHTS_SCALE * tensor_src[i * stride_src + g * output_size + o] as f32;
}
}
}
tensor_dst
}
#[cfg(test)]
mod tests {
use super::*;
use sonora_simd::detect_backend;
const GRU_INPUT_SIZE: usize = 5;
const GRU_OUTPUT_SIZE: usize = 4;
const GRU_BIAS: [i8; 12] = [96, -99, -81, -114, 49, 119, -118, 68, -76, 91, 121, 125];
const GRU_WEIGHTS: [i8; 60] = [
124, 9, 1, 116, -66, -21, -118, -110, 104, 75, -23, -51, -72, -111, 47, 93, 77, -98, 41, -8, 40, -23, -43, -107, 9, -73, 30, -32, -2, 64, -26, 91, -48, -24, -28, -104, 74, -46, 116, 15, 32, 52, -126, -38, -121, 12, -16, 110, -95, 66, -103, -35, -38, 3, -126, -61, 28, 98, -117, -43, ];
const GRU_RECURRENT_WEIGHTS: [i8; 48] = [
-3, 87, 50, 51, -22, 27, -39, 62, 31, -83, -52, -48, -6, 83, -19, 104, 105, 48, 23, 68, 23, 40, 7, -120, 64, -62, 117, 85, 51, -43, 54, -105, 120, 56, -128, -107, 39, 50, -17, -47, -117, 14, 108, 12, -7, -72, 103, -87, ];
const GRU_INPUT_SEQUENCE: [f32; 20] = [
0.89395463, 0.93224651, 0.55788344, 0.32341808, 0.93355054, 0.13475326, 0.97370994,
0.14253306, 0.93710381, 0.76093364, 0.65780413, 0.41657975, 0.49403164, 0.46843281,
0.75138855, 0.24517593, 0.47657707, 0.57064998, 0.435184, 0.19319285,
];
const GRU_EXPECTED_OUTPUT_SEQUENCE: [f32; 16] = [
0.0239123, 0.5773077, 0.0, 0.0, 0.01282811, 0.64330572, 0.0, 0.04863098, 0.00781069,
0.75267816, 0.0, 0.02579715, 0.00471378, 0.59162533, 0.11087593, 0.01334511,
];
fn test_gated_recurrent_layer(mut gru: GatedRecurrentLayer) {
let input_sequence_length = GRU_INPUT_SEQUENCE.len() / gru.input_size();
let output_sequence_length = GRU_EXPECTED_OUTPUT_SEQUENCE.len() / gru.size();
assert_eq!(input_sequence_length, output_sequence_length);
gru.reset();
for i in 0..input_sequence_length {
let input_start = i * gru.input_size();
let input_end = input_start + gru.input_size();
gru.compute_output(&GRU_INPUT_SEQUENCE[input_start..input_end]);
let output_start = i * gru.size();
let expected = &GRU_EXPECTED_OUTPUT_SEQUENCE[output_start..output_start + gru.size()];
let actual = gru.output();
for (j, (&exp, &act)) in expected.iter().zip(actual.iter()).enumerate() {
assert!(
(exp - act).abs() < 3e-6,
"step {i}, output[{j}]: expected {exp}, got {act}"
);
}
}
}
#[test]
fn gated_recurrent_layer_output() {
let vector_math = VectorMath::new(detect_backend());
let gru = GatedRecurrentLayer::new(
GRU_INPUT_SIZE,
GRU_OUTPUT_SIZE,
&GRU_BIAS,
&GRU_WEIGHTS,
&GRU_RECURRENT_WEIGHTS,
vector_math,
);
test_gated_recurrent_layer(gru);
}
#[test]
fn gated_recurrent_layer_scalar() {
let vector_math = VectorMath::new(sonora_simd::SimdBackend::Scalar);
let gru = GatedRecurrentLayer::new(
GRU_INPUT_SIZE,
GRU_OUTPUT_SIZE,
&GRU_BIAS,
&GRU_WEIGHTS,
&GRU_RECURRENT_WEIGHTS,
vector_math,
);
test_gated_recurrent_layer(gru);
}
}