use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use crate::{Activation, Error, Init, Layer, Mlp, Result};
#[derive(Debug, Clone, Copy)]
struct LayerSpec {
out_dim: usize,
activation: Activation,
}
#[derive(Debug, Clone)]
pub struct MlpBuilder {
input_dim: usize,
layers: Vec<LayerSpec>,
}
impl MlpBuilder {
pub fn new(input_dim: usize) -> Result<Self> {
if input_dim == 0 {
return Err(Error::InvalidConfig("input_dim must be > 0".to_owned()));
}
Ok(Self {
input_dim,
layers: Vec::new(),
})
}
pub fn from_sizes(sizes: &[usize], activations: &[Activation]) -> Result<Self> {
if sizes.len() < 2 {
return Err(Error::InvalidConfig(
"sizes must include input and output dims".to_owned(),
));
}
if sizes.contains(&0) {
return Err(Error::InvalidConfig(
"all layer sizes must be > 0".to_owned(),
));
}
if activations.len() != sizes.len() - 1 {
return Err(Error::InvalidConfig(format!(
"activations length {} does not match sizes.len() - 1 ({})",
activations.len(),
sizes.len() - 1
)));
}
let mut b = Self::new(sizes[0])?;
for (out_dim, &act) in sizes[1..].iter().zip(activations) {
b = b.add_layer(*out_dim, act)?;
}
Ok(b)
}
pub fn add_layer(mut self, out_dim: usize, activation: Activation) -> Result<Self> {
if out_dim == 0 {
return Err(Error::InvalidConfig("layer out_dim must be > 0".to_owned()));
}
activation.validate()?;
self.layers.push(LayerSpec {
out_dim,
activation,
});
Ok(self)
}
pub fn build_with_seed(self, seed: u64) -> Result<Mlp> {
let mut rng = StdRng::seed_from_u64(seed);
self.build_with_rng(&mut rng)
}
pub fn build_with_rng<R: Rng + ?Sized>(self, rng: &mut R) -> Result<Mlp> {
if self.layers.is_empty() {
return Err(Error::InvalidConfig(
"mlp must have at least one layer".to_owned(),
));
}
let mut layers = Vec::with_capacity(self.layers.len());
let mut in_dim = self.input_dim;
for spec in self.layers {
let init = default_init_for_activation(spec.activation);
let layer = Layer::new_with_rng(in_dim, spec.out_dim, init, spec.activation, rng)?;
layers.push(layer);
in_dim = spec.out_dim;
}
Ok(Mlp::from_layers(layers))
}
}
#[inline]
fn default_init_for_activation(act: Activation) -> Init {
match act {
Activation::Tanh | Activation::Sigmoid | Activation::Identity => Init::Xavier,
Activation::ReLU | Activation::LeakyReLU { .. } => Init::He,
}
}