use super::*;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum InitializationMethod {
Xavier,
}
impl Network {
pub fn new(
input_neuron_amount: usize,
output_neuron_amount: usize,
other_layer_amounts: Option<Vec<usize>>,
layer_activation_functions: Option<Vec<Option<ActivationFunction>>>,
) -> Network {
let mut net = Network {
layers: vec![],
connections: vec![],
shape: vec![]
};
net.layers.push(Layer::new(
input_neuron_amount,
ActivationFunction::Identity,
));
net.shape.push(input_neuron_amount);
let acts: Vec<ActivationFunction> = layer_activation_functions
.unwrap_or_else(|| {
vec![
Some(ActivationFunction::Sigmoid);
other_layer_amounts.clone().unwrap_or_default().len() + 1
]
})
.iter()
.map(|act| act.unwrap_or(ActivationFunction::Sigmoid))
.collect();
if let Some(other_layer_amounts) = other_layer_amounts {
for (idx, amount) in other_layer_amounts.iter().enumerate() {
net.layers.push(Layer::new(*amount, acts[idx]));
net.shape.push(*amount);
}
}
net.layers
.push(Layer::new(output_neuron_amount, *acts.last().unwrap()));
net.shape.push(output_neuron_amount);
net
}
pub fn all_connect(&mut self) {
for (i, layer) in self.layers.clone().iter().enumerate() {
if i == self.layers.len() - 1 {
break;
}
for j in 0..layer.neuron_amt {
for k in 0..self.layers[i + 1].clone().neuron_amt {
self.add_connection((i, j), (i + 1, k), 0.0);
}
}
}
}
pub fn initialize_weights(&mut self, method: InitializationMethod) {
let mut rng = rand::thread_rng();
match method {
InitializationMethod::Xavier => self.xavier_init(&mut rng),
}
}
fn xavier_init(&mut self, rng: &mut impl Rng) {
for (i, layer) in self.layers.clone().iter().enumerate() {
let fan_in = self.shape[i] as f64;
let fan_out = if i == self.shape.len() - 1 {
self.shape[i - 1]
} else {
self.shape[i + 1]
} as f64;
let normal = Normal::new(0.0, (2.0 / (fan_in + fan_out)).sqrt()).unwrap();
self.layers[i].bias = Array1::from_shape_fn(layer.neuron_amt, |_| normal.sample(rng));
}
for conn in self.connections.iter_mut() {
let fan_in = self.shape[conn.in_neuron_id.0] as f64;
let fan_out = self.shape[conn.out_neuron_id.0] as f64;
let std_dev = (2.0 / (fan_in + fan_out)).sqrt();
let normal = Normal::new(0.0, std_dev).unwrap();
conn.weight = normal.sample(rng);
}
}
}