pub struct Mlp {
pub weights: Vec<Vec<f32>>,
pub biases: Vec<Vec<f32>>,
pub dims: Vec<(usize, usize)>,
}
impl Mlp {
pub fn new(layer_sizes: &[usize]) -> Self {
let mut weights = Vec::new();
let mut biases = Vec::new();
let mut dims = Vec::new();
let mut seed = 42u64;
let mut rand = || -> f32 {
seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
let bits = ((seed >> 33) as u32) as f32 / u32::MAX as f32;
(bits - 0.5) * 2.0
};
for i in 0..layer_sizes.len() - 1 {
let in_dim = layer_sizes[i];
let out_dim = layer_sizes[i + 1];
dims.push((in_dim, out_dim));
let scale = (2.0 / in_dim as f32).sqrt();
let w: Vec<f32> = (0..out_dim * in_dim).map(|_| rand() * scale).collect();
let b: Vec<f32> = vec![0.0; out_dim];
weights.push(w);
biases.push(b);
}
Self {
weights,
biases,
dims,
}
}
pub fn forward(&self, input: &[f32]) -> Vec<f32> {
let mut x = input.to_vec();
for (i, ((in_dim, out_dim), (w, b))) in
self.dims.iter().zip(self.weights.iter().zip(self.biases.iter())).enumerate()
{
let mut y = vec![0.0f32; *out_dim];
for j in 0..*out_dim {
let mut sum = b[j];
for k in 0..*in_dim {
sum += w[j * in_dim + k] * x[k];
}
if i < self.dims.len() - 1 {
sum = sum.max(0.0);
}
y[j] = sum;
}
x = y;
}
x
}
pub fn num_params(&self) -> usize {
self.weights.iter().map(|w| w.len()).sum::<usize>()
+ self.biases.iter().map(|b| b.len()).sum::<usize>()
}
pub fn params_flat(&self) -> Vec<f32> {
let mut out = Vec::with_capacity(self.num_params());
for (w, b) in self.weights.iter().zip(self.biases.iter()) {
out.extend_from_slice(w);
out.extend_from_slice(b);
}
out
}
pub fn set_params_flat(&mut self, params: &[f32]) {
let mut offset = 0;
for (w, b) in self.weights.iter_mut().zip(self.biases.iter_mut()) {
let wlen = w.len();
w.copy_from_slice(¶ms[offset..offset + wlen]);
offset += wlen;
let blen = b.len();
b.copy_from_slice(¶ms[offset..offset + blen]);
offset += blen;
}
}
}
pub struct CanonicalPredictor {
pub mlp: Mlp,
}
pub const CANONICAL_OUTPUT_DIM: usize = 14;
impl CanonicalPredictor {
pub fn new(feature_dim: usize) -> Self {
let mut mlp = Mlp::new(&[feature_dim, 64, 64, CANONICAL_OUTPUT_DIM]);
let last = mlp.biases.last_mut().unwrap();
last[3] = -3.0;
last[4] = -3.0;
last[5] = -3.0;
last[6] = 1.0; last[10] = 2.0;
Self { mlp }
}
pub fn predict(
&self,
features: &[f32],
feature_dim: usize,
anchor_positions: &[[f32; 3]],
) -> PredictedGaussians {
let n = anchor_positions.len();
let mut positions = Vec::with_capacity(n);
let mut scales = Vec::with_capacity(n);
let mut rotations = Vec::with_capacity(n);
let mut opacities = Vec::with_capacity(n);
let mut sh_coeffs = Vec::with_capacity(n * 3);
for i in 0..n {
let feat = &features[i * feature_dim..(i + 1) * feature_dim];
let out = self.mlp.forward(feat);
positions.push([
anchor_positions[i][0] + out[0] * 0.1,
anchor_positions[i][1] + out[1] * 0.1,
anchor_positions[i][2] + out[2] * 0.1,
]);
scales.push([out[3], out[4], out[5]]);
let qw = out[6];
let qx = out[7];
let qy = out[8];
let qz = out[9];
let qnorm = (qw * qw + qx * qx + qy * qy + qz * qz).sqrt().max(1e-8);
rotations.push([qw / qnorm, qx / qnorm, qy / qnorm, qz / qnorm]);
opacities.push(out[10]);
sh_coeffs.push(out[11]);
sh_coeffs.push(out[12]);
sh_coeffs.push(out[13]);
}
PredictedGaussians {
positions,
scales,
rotations,
opacities,
sh_coeffs,
}
}
}
pub struct PredictedGaussians {
pub positions: Vec<[f32; 3]>,
pub scales: Vec<[f32; 3]>,
pub rotations: Vec<[f32; 4]>,
pub opacities: Vec<f32>,
pub sh_coeffs: Vec<f32>,
}
impl PredictedGaussians {
pub fn to_cloud(&self) -> crate::splatting::GaussianCloud {
crate::splatting::GaussianCloud {
count: self.positions.len(),
positions: self.positions.clone(),
scales: self.scales.clone(),
rotations: self.rotations.clone(),
opacities: self.opacities.clone(),
sh_coeffs: self.sh_coeffs.clone(),
sh_degree: 0,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mlp_forward() {
let mlp = Mlp::new(&[4, 8, 3]);
let input = vec![1.0, 0.5, -0.5, 0.0];
let output = mlp.forward(&input);
assert_eq!(output.len(), 3);
}
#[test]
fn test_canonical_predict() {
let pred = CanonicalPredictor::new(32);
let features = vec![0.1f32; 32 * 3]; let anchors = vec![[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]];
let result = pred.predict(&features, 32, &anchors);
assert_eq!(result.positions.len(), 3);
assert_eq!(result.sh_coeffs.len(), 9);
}
}