use crate::optim::*;
use std::collections::HashMap;
use wasm_bindgen::prelude::*;
#[wasm_bindgen]
pub struct WasmSGD {
learning_rate: f32,
momentum: f32,
weight_decay: f32,
nesterov: bool,
momentum_buffers: HashMap<String, Vec<f32>>,
}
#[wasm_bindgen]
impl WasmSGD {
#[wasm_bindgen(constructor)]
pub fn new(learning_rate: f32) -> Self {
Self {
learning_rate,
momentum: 0.0,
weight_decay: 0.0,
nesterov: false,
momentum_buffers: HashMap::new(),
}
}
#[wasm_bindgen]
pub fn with_momentum(learning_rate: f32, momentum: f32) -> WasmSGD {
Self {
learning_rate,
momentum,
weight_decay: 0.0,
nesterov: false,
momentum_buffers: HashMap::new(),
}
}
#[wasm_bindgen]
pub fn step(&mut self, param_id: &str, parameters: Vec<f32>, gradients: Vec<f32>) -> Vec<f32> {
if parameters.len() != gradients.len() {
return parameters; }
let params_f64: Vec<f64> = parameters.iter().map(|&x| x as f64).collect();
let grads_f64: Vec<f64> = gradients.iter().map(|&x| x as f64).collect();
let mut updated_params = params_f64;
for (i, grad) in grads_f64.iter().enumerate() {
updated_params[i] -= self.learning_rate as f64 * grad;
}
updated_params.iter().map(|&x| x as f32).collect()
}
#[wasm_bindgen]
pub fn get_learning_rate(&self) -> f32 {
self.learning_rate
}
#[wasm_bindgen]
pub fn set_learning_rate(&mut self, lr: f32) {
self.learning_rate = lr;
}
}
#[wasm_bindgen]
pub struct SGDWasm {
learning_rate: f64,
momentum: f64,
dampening: f64,
weight_decay: f64,
nesterov: bool,
velocity: HashMap<String, Vec<f64>>,
}
#[wasm_bindgen]
impl SGDWasm {
#[wasm_bindgen(constructor)]
pub fn new(
learning_rate: f64,
momentum: f64,
dampening: f64,
weight_decay: f64,
nesterov: bool,
) -> SGDWasm {
SGDWasm {
learning_rate,
momentum,
dampening,
weight_decay,
nesterov,
velocity: HashMap::new(),
}
}
#[wasm_bindgen]
pub fn step(&mut self, param_name: &str, params: &mut [f64], gradients: &[f64]) {
if params.len() != gradients.len() {
return;
}
let velocity = self
.velocity
.entry(param_name.to_string())
.or_insert_with(|| vec![0.0; params.len()]);
for i in 0..params.len() {
let mut grad = gradients[i];
if self.weight_decay != 0.0 {
grad += self.weight_decay * params[i];
}
if self.momentum != 0.0 {
velocity[i] = self.momentum * velocity[i] + (1.0 - self.dampening) * grad;
if self.nesterov {
grad = grad + self.momentum * velocity[i];
} else {
grad = velocity[i];
}
}
params[i] -= self.learning_rate * grad;
}
}
#[wasm_bindgen]
pub fn get_learning_rate(&self) -> f64 {
self.learning_rate
}
#[wasm_bindgen]
pub fn set_learning_rate(&mut self, lr: f64) {
self.learning_rate = lr;
}
#[wasm_bindgen]
pub fn reset_state(&mut self) {
self.velocity.clear();
}
}
#[wasm_bindgen]
pub struct AdamWasm {
learning_rate: f64,
beta1: f64,
beta2: f64,
epsilon: f64,
weight_decay: f64,
step_count: u64,
m: HashMap<String, Vec<f64>>, v: HashMap<String, Vec<f64>>, }
#[wasm_bindgen]
impl AdamWasm {
#[wasm_bindgen(constructor)]
pub fn new(
learning_rate: f64,
beta1: f64,
beta2: f64,
epsilon: f64,
weight_decay: f64,
) -> AdamWasm {
AdamWasm {
learning_rate,
beta1,
beta2,
epsilon,
weight_decay,
step_count: 0,
m: HashMap::new(),
v: HashMap::new(),
}
}
#[wasm_bindgen]
pub fn step(&mut self, param_name: &str, params: &mut [f64], gradients: &[f64]) {
if params.len() != gradients.len() {
return;
}
self.step_count += 1;
let m = self
.m
.entry(param_name.to_string())
.or_insert_with(|| vec![0.0; params.len()]);
let v = self
.v
.entry(param_name.to_string())
.or_insert_with(|| vec![0.0; params.len()]);
for i in 0..params.len() {
let mut grad = gradients[i];
if self.weight_decay != 0.0 {
grad += self.weight_decay * params[i];
}
m[i] = self.beta1 * m[i] + (1.0 - self.beta1) * grad;
v[i] = self.beta2 * v[i] + (1.0 - self.beta2) * grad * grad;
let m_hat = m[i] / (1.0 - self.beta1.powi(self.step_count as i32));
let v_hat = v[i] / (1.0 - self.beta2.powi(self.step_count as i32));
params[i] -= self.learning_rate * m_hat / (v_hat.sqrt() + self.epsilon);
}
}
#[wasm_bindgen]
pub fn get_learning_rate(&self) -> f64 {
self.learning_rate
}
#[wasm_bindgen]
pub fn set_learning_rate(&mut self, lr: f64) {
self.learning_rate = lr;
}
#[wasm_bindgen]
pub fn get_step_count(&self) -> u64 {
self.step_count
}
#[wasm_bindgen]
pub fn reset_state(&mut self) {
self.step_count = 0;
self.m.clear();
self.v.clear();
}
}
#[wasm_bindgen]
pub struct RMSpropWasm {
learning_rate: f64,
alpha: f64,
epsilon: f64,
weight_decay: f64,
momentum: f64,
v: HashMap<String, Vec<f64>>, momentum_buffer: HashMap<String, Vec<f64>>,
}
#[wasm_bindgen]
impl RMSpropWasm {
#[wasm_bindgen(constructor)]
pub fn new(
learning_rate: f64,
alpha: f64,
epsilon: f64,
weight_decay: f64,
momentum: f64,
) -> RMSpropWasm {
RMSpropWasm {
learning_rate,
alpha,
epsilon,
weight_decay,
momentum,
v: HashMap::new(),
momentum_buffer: HashMap::new(),
}
}
#[wasm_bindgen]
pub fn step(&mut self, param_name: &str, params: &mut [f64], gradients: &[f64]) {
if params.len() != gradients.len() {
return;
}
let v = self
.v
.entry(param_name.to_string())
.or_insert_with(|| vec![0.0; params.len()]);
let momentum_buffer = self
.momentum_buffer
.entry(param_name.to_string())
.or_insert_with(|| vec![0.0; params.len()]);
for i in 0..params.len() {
let mut grad = gradients[i];
if self.weight_decay != 0.0 {
grad += self.weight_decay * params[i];
}
v[i] = self.alpha * v[i] + (1.0 - self.alpha) * grad * grad;
let update = if self.momentum > 0.0 {
momentum_buffer[i] =
self.momentum * momentum_buffer[i] + grad / (v[i].sqrt() + self.epsilon);
momentum_buffer[i]
} else {
grad / (v[i].sqrt() + self.epsilon)
};
params[i] -= self.learning_rate * update;
}
}
#[wasm_bindgen]
pub fn get_learning_rate(&self) -> f64 {
self.learning_rate
}
#[wasm_bindgen]
pub fn set_learning_rate(&mut self, lr: f64) {
self.learning_rate = lr;
}
#[wasm_bindgen]
pub fn reset_state(&mut self) {
self.v.clear();
self.momentum_buffer.clear();
}
}
#[wasm_bindgen]
pub struct AdaGradWasm {
learning_rate: f64,
epsilon: f64,
weight_decay: f64,
sum_sq_gradients: HashMap<String, Vec<f64>>,
}
#[wasm_bindgen]
impl AdaGradWasm {
#[wasm_bindgen(constructor)]
pub fn new(learning_rate: f64, epsilon: f64, weight_decay: f64) -> AdaGradWasm {
AdaGradWasm {
learning_rate,
epsilon,
weight_decay,
sum_sq_gradients: HashMap::new(),
}
}
#[wasm_bindgen]
pub fn step(&mut self, param_name: &str, params: &mut [f64], gradients: &[f64]) {
if params.len() != gradients.len() {
return;
}
let sum_sq = self
.sum_sq_gradients
.entry(param_name.to_string())
.or_insert_with(|| vec![0.0; params.len()]);
for i in 0..params.len() {
let mut grad = gradients[i];
if self.weight_decay != 0.0 {
grad += self.weight_decay * params[i];
}
sum_sq[i] += grad * grad;
params[i] -= self.learning_rate * grad / (sum_sq[i].sqrt() + self.epsilon);
}
}
#[wasm_bindgen]
pub fn get_learning_rate(&self) -> f64 {
self.learning_rate
}
#[wasm_bindgen]
pub fn set_learning_rate(&mut self, lr: f64) {
self.learning_rate = lr;
}
#[wasm_bindgen]
pub fn reset_state(&mut self) {
self.sum_sq_gradients.clear();
}
}
#[wasm_bindgen]
pub fn learning_rate_schedule_wasm(
initial_lr: f64,
step: u64,
decay_rate: f64,
decay_steps: u64,
) -> f64 {
initial_lr * decay_rate.powi((step / decay_steps) as i32)
}
#[wasm_bindgen]
pub fn cosine_annealing_wasm(initial_lr: f64, current_step: u64, total_steps: u64) -> f64 {
let min_lr = initial_lr * 0.01;
min_lr
+ (initial_lr - min_lr)
* 0.5
* (1.0 + ((current_step as f64 * std::f64::consts::PI) / total_steps as f64).cos())
}