use ndarray::{Array2, Array1, Axis};
use rand::{thread_rng, Rng};
use rayon::prelude::*;
pub struct OptimizedMlp {
w1: Array2<f32>,
b1: Array1<f32>,
w2: Array2<f32>,
b2: Array1<f32>,
vw1: Array2<f32>,
vb1: Array1<f32>,
vw2: Array2<f32>,
vb2: Array1<f32>,
mw1: Array2<f32>,
mb1: Array1<f32>,
mw2: Array2<f32>,
mb2: Array1<f32>,
sw1: Array2<f32>,
sb1: Array1<f32>,
sw2: Array2<f32>,
sb2: Array1<f32>,
t: f32,
output_dim: usize,
use_adam: bool,
}
impl OptimizedMlp {
pub fn new(input: usize, hidden: usize, output: usize) -> Self {
let mut rng = thread_rng();
let scale1 = (2.0 / input as f32).sqrt();
let scale2 = (2.0 / hidden as f32).sqrt();
let w1 = Array2::from_shape_fn((hidden, input), |_|
rng.gen::<f32>() * scale1 - scale1/2.0);
let b1 = Array1::zeros(hidden);
let w2 = Array2::from_shape_fn((output, hidden), |_|
rng.gen::<f32>() * scale2 - scale2/2.0);
let b2 = Array1::zeros(output);
let vw1 = Array2::zeros((hidden, input));
let vb1 = Array1::zeros(hidden);
let vw2 = Array2::zeros((output, hidden));
let vb2 = Array1::zeros(output);
let mw1 = Array2::zeros((hidden, input));
let mb1 = Array1::zeros(hidden);
let mw2 = Array2::zeros((output, hidden));
let mb2 = Array1::zeros(output);
let sw1 = Array2::zeros((hidden, input));
let sb1 = Array1::zeros(hidden);
let sw2 = Array2::zeros((output, hidden));
let sb2 = Array1::zeros(output);
Self {
w1, b1, w2, b2,
vw1, vb1, vw2, vb2,
mw1, mb1, mw2, mb2,
sw1, sb1, sw2, sb2,
t: 0.0,
output_dim: output,
use_adam: true,
}
}
pub fn forward(&self, x: &Array1<f32>) -> (Array1<f32>, Array1<f32>) {
let z1 = self.w1.dot(x) + &self.b1;
let h = z1.mapv(|v| v.max(0.0)); let output = self.w2.dot(&h) + &self.b2;
(output, h)
}
pub fn backward(&mut self, x: &Array1<f32>, y_true: f32, lr: f32) {
let (output, h) = self.forward(x);
let grad_output = if self.output_dim == 1 {
Array1::from_elem(1, output[0] - y_true)
} else {
let exp_out = output.mapv(|v| v.exp());
let sum_exp = exp_out.sum();
let softmax = &exp_out / sum_exp;
let class = if y_true < -0.25 { 0 }
else if y_true > 0.25 { 2 }
else { 1 };
let mut target = Array1::zeros(self.output_dim);
if class < self.output_dim {
target[class] = 1.0;
}
softmax - target
};
let grad_w2 = grad_output.clone().insert_axis(Axis(1)) * h.clone().insert_axis(Axis(0));
let grad_b2 = grad_output.clone();
let grad_h = self.w2.t().dot(&grad_output);
let z1 = self.w1.dot(x) + &self.b1;
let grad_z1 = grad_h * z1.mapv(|v| if v > 0.0 { 1.0 } else { 0.0 });
let grad_w1 = grad_z1.clone().insert_axis(Axis(1)) * x.clone().insert_axis(Axis(0));
let grad_b1 = grad_z1;
if self.use_adam {
self.adam_update(grad_w1, grad_b1, grad_w2.into_shape(self.w2.dim()).unwrap(), grad_b2, lr);
} else {
self.momentum_update(grad_w1, grad_b1, grad_w2.into_shape(self.w2.dim()).unwrap(), grad_b2, lr);
}
}
fn adam_update(&mut self, grad_w1: Array2<f32>, grad_b1: Array1<f32>,
grad_w2: Array2<f32>, grad_b2: Array1<f32>, lr: f32) {
let beta1 = 0.9;
let beta2 = 0.999;
let epsilon = 1e-8;
self.t += 1.0;
self.mw1 = &self.mw1 * beta1 + &grad_w1 * (1.0 - beta1);
self.mb1 = &self.mb1 * beta1 + &grad_b1 * (1.0 - beta1);
self.mw2 = &self.mw2 * beta1 + &grad_w2 * (1.0 - beta1);
self.mb2 = &self.mb2 * beta1 + &grad_b2 * (1.0 - beta1);
self.sw1 = &self.sw1 * beta2 + grad_w1.mapv(|x| x * x) * (1.0 - beta2);
self.sb1 = &self.sb1 * beta2 + grad_b1.mapv(|x| x * x) * (1.0 - beta2);
self.sw2 = &self.sw2 * beta2 + grad_w2.mapv(|x| x * x) * (1.0 - beta2);
self.sb2 = &self.sb2 * beta2 + grad_b2.mapv(|x| x * x) * (1.0 - beta2);
let bias_correction1 = 1.0 - beta1.powf(self.t);
let bias_correction2 = 1.0 - beta2.powf(self.t);
self.w1 = &self.w1 - lr * &self.mw1 / bias_correction1 / ((&self.sw1 / bias_correction2).mapv(f32::sqrt) + epsilon);
self.b1 = &self.b1 - lr * &self.mb1 / bias_correction1 / ((&self.sb1 / bias_correction2).mapv(f32::sqrt) + epsilon);
self.w2 = &self.w2 - lr * &self.mw2 / bias_correction1 / ((&self.sw2 / bias_correction2).mapv(f32::sqrt) + epsilon);
self.b2 = &self.b2 - lr * &self.mb2 / bias_correction1 / ((&self.sb2 / bias_correction2).mapv(f32::sqrt) + epsilon);
}
fn momentum_update(&mut self, grad_w1: Array2<f32>, grad_b1: Array1<f32>,
grad_w2: Array2<f32>, grad_b2: Array1<f32>, lr: f32) {
let momentum = 0.9;
self.vw1 = &self.vw1 * momentum - &grad_w1 * lr;
self.vb1 = &self.vb1 * momentum - &grad_b1 * lr;
self.vw2 = &self.vw2 * momentum - &grad_w2 * lr;
self.vb2 = &self.vb2 * momentum - &grad_b2 * lr;
self.w1 = &self.w1 + &self.vw1;
self.b1 = &self.b1 + &self.vb1;
self.w2 = &self.w2 + &self.vw2;
self.b2 = &self.b2 + &self.vb2;
}
pub fn train_regression(&mut self, x: &Vec<Vec<f32>>, y: &Vec<f32>, epochs: usize, lr: f32) {
for _ in 0..epochs {
let mut indices: Vec<usize> = (0..x.len()).collect();
let mut rng = thread_rng();
use rand::seq::SliceRandom;
indices.shuffle(&mut rng);
for &i in &indices {
let x_arr = Array1::from_vec(x[i].clone());
self.backward(&x_arr, y[i], lr);
}
}
}
pub fn train_batch(&mut self, x: &Vec<Vec<f32>>, y: &Vec<f32>, epochs: usize,
lr: f32, batch_size: usize) {
for _ in 0..epochs {
let mut indices: Vec<usize> = (0..x.len()).collect();
let mut rng = thread_rng();
use rand::seq::SliceRandom;
indices.shuffle(&mut rng);
for batch_start in (0..x.len()).step_by(batch_size) {
let batch_end = (batch_start + batch_size).min(x.len());
for i in batch_start..batch_end {
let idx = indices[i];
let x_arr = Array1::from_vec(x[idx].clone());
self.backward(&x_arr, y[idx], lr / batch_size as f32);
}
}
}
}
pub fn predict_reg(&self, x: &[Vec<f32>]) -> Vec<f32> {
x.par_iter().map(|xi| {
let (out, _) = self.forward(&Array1::from_vec(xi.clone()));
if self.output_dim == 1 {
out[0]
} else {
out[0]
}
}).collect()
}
pub fn predict_cls3(&self, x: &[Vec<f32>]) -> Vec<usize> {
x.par_iter().map(|xi| {
let (out, _) = self.forward(&Array1::from_vec(xi.clone()));
if self.output_dim >= 3 {
let exp_out = out.mapv(|v| v.exp());
let sum_exp = exp_out.sum();
let probs = exp_out / sum_exp;
let mut best = 0;
let mut best_val = probs[0];
for i in 1..3.min(probs.len()) {
if probs[i] > best_val {
best_val = probs[i];
best = i;
}
}
best
} else {
let val = out[0];
if val < -0.25 { 0 }
else if val > 0.25 { 2 }
else { 1 }
}
}).collect()
}
}