#[cfg(feature = "alloc")]
use alloc::{vec, vec::Vec};
use serde::{Deserialize, Serialize};
use super::activation::Activation;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Layer {
pub weights: Vec<f32>,
pub biases: Vec<f32>,
pub input_dim: usize,
pub output_dim: usize,
pub activation: Activation,
}
impl Layer {
#[must_use]
pub fn new(input_dim: usize, output_dim: usize, activation: Activation) -> Self {
let scale = (2.0 / (input_dim + output_dim) as f32).sqrt();
let mut weights = Vec::with_capacity(input_dim * output_dim);
let mut biases = Vec::with_capacity(output_dim);
let mut seed = 12345u64;
for _ in 0..(input_dim * output_dim) {
seed = seed.wrapping_mul(1103515245).wrapping_add(12345);
let rand_val = (seed as f32 / u64::MAX as f32) * 2.0 - 1.0;
weights.push(rand_val * scale);
}
for _ in 0..output_dim {
biases.push(0.0);
}
Self {
weights,
biases,
input_dim,
output_dim,
activation,
}
}
#[must_use]
pub fn forward(&self, input: &[f32]) -> Vec<f32> {
let mut output = Vec::with_capacity(self.output_dim);
for i in 0..self.output_dim {
let mut sum = self.biases[i];
for j in 0..self.input_dim.min(input.len()) {
sum += self.weights[i * self.input_dim + j] * input[j];
}
output.push(self.activation.apply(sum));
}
output
}
#[must_use]
pub fn forward_with_cache(&self, input: &[f32]) -> LayerCache {
let mut pre_activation = Vec::with_capacity(self.output_dim);
let mut post_activation = Vec::with_capacity(self.output_dim);
for i in 0..self.output_dim {
let mut sum = self.biases[i];
for j in 0..self.input_dim.min(input.len()) {
sum += self.weights[i * self.input_dim + j] * input[j];
}
pre_activation.push(sum);
post_activation.push(self.activation.apply(sum));
}
LayerCache {
input: input.to_vec(),
pre_activation,
post_activation,
}
}
#[must_use]
pub fn get_weight(&self, output_idx: usize, input_idx: usize) -> f32 {
self.weights[output_idx * self.input_dim + input_idx]
}
pub fn set_weight(&mut self, output_idx: usize, input_idx: usize, value: f32) {
self.weights[output_idx * self.input_dim + input_idx] = value;
}
#[must_use]
pub fn from_weights(
input_dim: usize,
output_dim: usize,
activation: Activation,
weights: Vec<f32>,
biases: Vec<f32>,
) -> Self {
Self {
weights,
biases,
input_dim,
output_dim,
activation,
}
}
}
#[derive(Debug, Clone)]
pub struct LayerCache {
pub input: Vec<f32>,
pub pre_activation: Vec<f32>,
pub post_activation: Vec<f32>,
}
#[derive(Debug, Clone)]
pub struct LayerGradients {
pub weight_gradients: Vec<f32>,
pub bias_gradients: Vec<f32>,
}
impl LayerGradients {
#[must_use]
pub fn zeros(layer: &Layer) -> Self {
Self {
weight_gradients: vec![0.0; layer.weights.len()],
bias_gradients: vec![0.0; layer.biases.len()],
}
}
pub fn accumulate(&mut self, other: &Self) {
for (g, o) in self
.weight_gradients
.iter_mut()
.zip(&other.weight_gradients)
{
*g += o;
}
for (g, o) in self.bias_gradients.iter_mut().zip(&other.bias_gradients) {
*g += o;
}
}
pub fn scale(&mut self, factor: f32) {
for g in &mut self.weight_gradients {
*g *= factor;
}
for g in &mut self.bias_gradients {
*g *= factor;
}
}
}