pub struct AdamParam {
pub lr: f32,
m: Vec<f32>,
v: Vec<f32>,
t: u32,
beta1: f32,
beta2: f32,
eps: f32,
pub max_grad: f32,
}
impl AdamParam {
pub fn new(n: usize, lr: f32) -> Self {
Self {
lr,
m: vec![0.0; n],
v: vec![0.0; n],
t: 0,
beta1: 0.9,
beta2: 0.999,
eps: 1e-15,
max_grad: 0.0,
}
}
pub fn with_clip(n: usize, lr: f32, max_grad: f32) -> Self {
let mut adam = Self::new(n, lr);
adam.max_grad = max_grad;
adam
}
pub fn step(&mut self, params: &mut [f32], grads: &[f32]) {
self.t += 1;
let bc1 = 1.0 - self.beta1.powi(self.t as i32);
let bc2 = 1.0 - self.beta2.powi(self.t as i32);
for i in 0..params.len() {
let mut g = grads[i];
if g.is_nan() || g.is_infinite() {
continue;
}
if self.max_grad > 0.0 {
g = g.clamp(-self.max_grad, self.max_grad);
}
self.m[i] = self.beta1 * self.m[i] + (1.0 - self.beta1) * g;
self.v[i] = self.beta2 * self.v[i] + (1.0 - self.beta2) * g * g;
let m_hat = self.m[i] / bc1;
let v_hat = self.v[i] / bc2;
params[i] -= self.lr * m_hat / (v_hat.sqrt() + self.eps);
}
}
pub fn grow(&mut self, new_len: usize) {
self.m.resize(new_len, 0.0);
self.v.resize(new_len, 0.0);
}
}
pub struct TrainConfig {
pub lr_position: f32,
pub lr_scale: f32,
pub lr_rotation: f32,
pub lr_opacity: f32,
pub lr_sh: f32,
pub lambda_l1: f32,
pub iterations: u32,
pub densify_from: u32,
pub densify_until: u32,
pub densify_interval: u32,
pub opacity_reset_logit: f32,
}
impl Default for TrainConfig {
fn default() -> Self {
Self {
lr_position: 1.6e-4,
lr_scale: 5e-3,
lr_rotation: 1e-3,
lr_opacity: 5e-2,
lr_sh: 2.5e-3,
lambda_l1: 0.8,
iterations: 7000,
densify_from: 500,
densify_until: 5000,
densify_interval: 100,
opacity_reset_logit: -2.0, }
}
}
pub fn l1_loss_grad(rendered: &[f32], target: &[f32]) -> (f32, Vec<f32>) {
let n = rendered.len() as f32;
let mut loss = 0.0f32;
let mut grad = Vec::with_capacity(rendered.len());
for (r, t) in rendered.iter().zip(target.iter()) {
let diff = r - t;
loss += diff.abs();
grad.push(diff.signum() / n);
}
(loss / n, grad)
}
pub fn dssim_loss_grad(
rendered: &[f32],
target: &[f32],
width: u32,
height: u32,
) -> (f32, Vec<f32>) {
let c1: f32 = 0.01 * 0.01; let c2: f32 = 0.03 * 0.03;
let w = width as usize;
let h = height as usize;
let patch = 5i32;
let mut grad = vec![0.0f32; rendered.len()];
let mut total_ssim = 0.0f32;
let mut count = 0u32;
for py in 0..h {
for px in 0..w {
let y0 = (py as i32 - patch).max(0) as usize;
let y1 = ((py as i32 + patch + 1) as usize).min(h);
let x0 = (px as i32 - patch).max(0) as usize;
let x1 = ((px as i32 + patch + 1) as usize).min(w);
let area = ((y1 - y0) * (x1 - x0)) as f32;
for c in 0..3 {
let mut mu_r = 0.0f32;
let mut mu_t = 0.0f32;
let mut sig_rr = 0.0f32;
let mut sig_tt = 0.0f32;
let mut sig_rt = 0.0f32;
for y in y0..y1 {
for x in x0..x1 {
let idx = (y * w + x) * 3 + c;
mu_r += rendered[idx];
mu_t += target[idx];
}
}
mu_r /= area;
mu_t /= area;
for y in y0..y1 {
for x in x0..x1 {
let idx = (y * w + x) * 3 + c;
let dr = rendered[idx] - mu_r;
let dt = target[idx] - mu_t;
sig_rr += dr * dr;
sig_tt += dt * dt;
sig_rt += dr * dt;
}
}
sig_rr /= area;
sig_tt /= area;
sig_rt /= area;
let num = (2.0 * mu_r * mu_t + c1) * (2.0 * sig_rt + c2);
let den = (mu_r * mu_r + mu_t * mu_t + c1) * (sig_rr + sig_tt + c2);
let ssim = num / den;
total_ssim += ssim;
count += 1;
let idx_center = (py * w + px) * 3 + c;
let d_ssim_d_mu_r = {
let d_num = 2.0 * mu_t * (2.0 * sig_rt + c2);
let d_den = 2.0 * mu_r * (sig_rr + sig_tt + c2);
(d_num * den - num * d_den) / (den * den)
};
grad[idx_center] += -d_ssim_d_mu_r / area;
}
}
}
let dssim = (1.0 - total_ssim / count as f32) / 2.0;
let scale = 1.0 / (count as f32 / 3.0);
for g in grad.iter_mut() {
*g *= scale;
}
(dssim, grad)
}
pub fn combined_loss_grad(
rendered: &[f32],
target: &[f32],
width: u32,
height: u32,
lambda_l1: f32,
) -> (f32, Vec<f32>) {
let (l1, g_l1) = l1_loss_grad(rendered, target);
let (dssim, g_dssim) = dssim_loss_grad(rendered, target, width, height);
let loss = lambda_l1 * l1 + (1.0 - lambda_l1) * dssim;
let grad: Vec<f32> = g_l1
.iter()
.zip(g_dssim.iter())
.map(|(a, b)| lambda_l1 * a + (1.0 - lambda_l1) * b)
.collect();
(loss, grad)
}