use crate::Error;
use crate::domain::DomainId;
use crate::object::{Shape, Tensor};
#[derive(Debug, Clone)]
pub struct Parameter {
pub data: Tensor<f32>,
pub m: Tensor<f32>,
pub v: Tensor<f32>,
pub step: u32,
}
impl Parameter {
pub fn zeros(shape: Shape, domain: DomainId) -> Self {
let n = numel(&shape);
Self {
data: Tensor::dense_cpu(domain.clone(), shape.clone(), vec![0.0f32; n]),
m: Tensor::dense_cpu(domain.clone(), shape.clone(), vec![0.0f32; n]),
v: Tensor::dense_cpu(domain.clone(), shape.clone(), vec![0.0f32; n]),
step: 0,
}
}
pub fn uniform(shape: Shape, lo: f32, hi: f32, seed: u32, domain: DomainId) -> Self {
let n = numel(&shape);
let mut data = Vec::with_capacity(n);
let mut state: u32 = seed.wrapping_add(1);
for _ in 0..n {
state ^= state << 13;
state ^= state >> 17;
state ^= state << 5;
let frac = state as f32 / u32::MAX as f32;
data.push(lo + (hi - lo) * frac);
}
Self {
data: Tensor::dense_cpu(domain.clone(), shape.clone(), data),
m: Tensor::dense_cpu(domain.clone(), shape.clone(), vec![0.0f32; n]),
v: Tensor::dense_cpu(domain.clone(), shape.clone(), vec![0.0f32; n]),
step: 0,
}
}
pub fn from_tensor(t: Tensor<f32>) -> Self {
let n = t.data.len();
let shape = t.meta.shape.clone();
let domain = t.meta.domain.clone();
Self {
data: t,
m: Tensor::dense_cpu(domain.clone(), shape.clone(), vec![0.0f32; n]),
v: Tensor::dense_cpu(domain.clone(), shape.clone(), vec![0.0f32; n]),
step: 0,
}
}
pub fn numel(&self) -> usize {
self.data.data.len()
}
pub fn adamw_step(
&mut self,
grad: &Tensor<f32>,
lr: f32,
beta1: f32,
beta2: f32,
eps: f32,
weight_decay: f32,
) -> Result<(), Error> {
if grad.data.len() != self.data.data.len() {
return Err(Error::shape(format!(
"adamw_step grad length {} != parameter length {}",
grad.data.len(),
self.data.data.len()
)));
}
if self.step == u32::MAX {
return Err(Error::backend("adamw_step parameter step counter overflow"));
}
self.step += 1;
let t = self.step as f32;
let bc1 = 1.0 - beta1.powf(t);
let bc2 = 1.0 - beta2.powf(t);
for i in 0..self.data.data.len() {
let g = grad.data[i] + weight_decay * self.data.data[i];
self.m.data[i] = beta1 * self.m.data[i] + (1.0 - beta1) * g;
self.v.data[i] = beta2 * self.v.data[i] + (1.0 - beta2) * g * g;
let m_hat = self.m.data[i] / bc1;
let v_hat = self.v.data[i] / bc2;
self.data.data[i] -= lr * m_hat / (v_hat.sqrt() + eps);
}
Ok(())
}
pub fn sgd_step(&mut self, grad: &Tensor<f32>, lr: f32) -> Result<(), Error> {
if grad.data.len() != self.data.data.len() {
return Err(Error::shape(format!(
"sgd_step grad length {} != parameter length {}",
grad.data.len(),
self.data.data.len()
)));
}
for (theta, g) in self.data.data.iter_mut().zip(grad.data.iter()) {
*theta -= lr * g;
}
Ok(())
}
pub fn sgd_momentum_step(
&mut self,
grad: &Tensor<f32>,
lr: f32,
momentum: f32,
) -> Result<(), Error> {
if grad.data.len() != self.data.data.len() {
return Err(Error::shape(format!(
"sgd_momentum_step grad length {} != parameter length {}",
grad.data.len(),
self.data.data.len()
)));
}
for ((theta, g), momentum_buf) in self
.data
.data
.iter_mut()
.zip(grad.data.iter())
.zip(self.m.data.iter_mut())
{
*momentum_buf = momentum * (*momentum_buf) + *g;
*theta -= lr * *momentum_buf;
}
Ok(())
}
}
fn numel(shape: &Shape) -> usize {
let mut n = 1usize;
for d in &shape.dims {
match d {
crate::object::Dim::Static(v) => n *= v,
_ => {
return 1;
}
}
}
n
}