use super::camera::Camera;
#[allow(non_snake_case)]
pub fn backward_projection(
positions: &[[f32; 3]],
log_scales: &[[f32; 3]],
rotations: &[[f32; 4]],
camera: &Camera,
radii: &[u32],
dL_dconics: &[[f32; 3]],
dL_dmeans2d: &[[f32; 2]],
) -> (Vec<[f32; 3]>, Vec<[f32; 3]>, Vec<[f32; 4]>) {
let n = positions.len();
let mut dL_dpos = vec![[0.0f32; 3]; n];
let mut dL_dscale = vec![[0.0f32; 3]; n];
let mut dL_drot = vec![[0.0f32; 4]; n];
let view = &camera.view_matrix;
let w = [
[view[0], view[4], view[8]],
[view[1], view[5], view[9]],
[view[2], view[6], view[10]],
];
let fx = camera.intrinsics.fx;
let fy = camera.intrinsics.fy;
for i in 0..n {
if radii[i] == 0 {
continue;
}
let pos = positions[i];
let s = [
log_scales[i][0].exp(),
log_scales[i][1].exp(),
log_scales[i][2].exp(),
];
let q = normalize4(rotations[i]);
let pos_cam = mat4x4_transform(view, pos);
let depth = -pos_cam[2];
if depth <= 0.0 {
continue;
}
let z2 = depth * depth;
let r = quat_to_mat3(q);
let rs = [
[r[0][0] * s[0], r[0][1] * s[1], r[0][2] * s[2]],
[r[1][0] * s[0], r[1][1] * s[1], r[1][2] * s[2]],
[r[2][0] * s[0], r[2][1] * s[1], r[2][2] * s[2]],
];
let cov3d = mat3_mul_transpose(rs, rs);
let j00 = fx / depth;
let j11 = fy / depth;
let j02 = -fx * pos_cam[0] / z2;
let j12 = -fy * pos_cam[1] / z2;
let t0 = [
j00 * w[0][0] + j02 * w[2][0],
j00 * w[0][1] + j02 * w[2][1],
j00 * w[0][2] + j02 * w[2][2],
];
let t1 = [
j11 * w[1][0] + j12 * w[2][0],
j11 * w[1][1] + j12 * w[2][1],
j11 * w[1][2] + j12 * w[2][2],
];
let v0 = mat3_vec(cov3d, t0); let v1 = mat3_vec(cov3d, t1);
let a = dot3(t0, v0) + 0.3; let b = dot3(t0, v1); let c = dot3(t1, v1) + 0.3;
let det = a * c - b * b;
if det <= 0.0 {
continue;
}
let det2_inv = 1.0 / (det * det);
let dL_cx = dL_dconics[i][0]; let dL_cy = dL_dconics[i][1]; let dL_cz = dL_dconics[i][2];
let dL_da = det2_inv
* (-c * c * dL_cx + b * c * dL_cy - b * b * dL_cz);
let dL_db = det2_inv
* (2.0 * b * c * dL_cx - (a * c + b * b) * dL_cy + 2.0 * a * b * dL_cz);
let dL_dc_val = det2_inv
* (-b * b * dL_cx + a * b * dL_cy - a * a * dL_cz);
let mut dL_dcov3d = [[0.0f32; 3]; 3];
for p in 0..3 {
for qq in 0..3 {
dL_dcov3d[p][qq] = dL_da * t0[p] * t0[qq]
+ dL_db * t0[p] * t1[qq]
+ dL_dc_val * t1[p] * t1[qq];
}
}
let mut dL_drs = [[0.0f32; 3]; 3];
for row in 0..3 {
for col in 0..3 {
for k in 0..3 {
dL_drs[row][col] +=
(dL_dcov3d[row][k] + dL_dcov3d[k][row]) * rs[k][col];
}
}
}
for j_idx in 0..3 {
let mut ds = 0.0;
for k in 0..3 {
ds += r[k][j_idx] * dL_drs[k][j_idx];
}
dL_dscale[i][j_idx] = ds * s[j_idx];
}
let mut dL_dr = [[0.0f32; 3]; 3];
for row in 0..3 {
for col in 0..3 {
dL_dr[row][col] = dL_drs[row][col] * s[col];
}
}
dL_drot[i] = dR_dquat(q, dL_dr);
let mut dL_dt0 = [0.0f32; 3];
let mut dL_dt1 = [0.0f32; 3];
for m in 0..3 {
dL_dt0[m] = 2.0 * dL_da * v0[m] + dL_db * v1[m];
dL_dt1[m] = dL_db * v0[m] + 2.0 * dL_dc_val * v1[m];
}
let dL_dj00 = dot3(dL_dt0, w[0]);
let dL_dj11 = dot3(dL_dt1, w[1]);
let dL_dj02 = dot3(dL_dt0, w[2]);
let dL_dj12 = dot3(dL_dt1, w[2]);
let mut dL_dpos_cam = [0.0f32; 3];
dL_dpos_cam[0] += dL_dj02 * (-fx / z2);
dL_dpos_cam[1] += dL_dj12 * (-fy / z2);
dL_dpos_cam[2] += dL_dj00 * (fx / z2)
+ dL_dj11 * (fy / z2)
+ dL_dj02 * (2.0 * j02 / depth)
+ dL_dj12 * (2.0 * j12 / depth);
let dL_dm = dL_dmeans2d[i];
dL_dpos_cam[0] += dL_dm[0] * (fx / depth);
dL_dpos_cam[1] += dL_dm[1] * (-fy / depth);
dL_dpos_cam[2] += dL_dm[0] * (fx * pos_cam[0] / z2)
+ dL_dm[1] * (-fy * pos_cam[1] / z2);
for j_idx in 0..3 {
for k in 0..3 {
dL_dpos[i][j_idx] += w[k][j_idx] * dL_dpos_cam[k];
}
}
}
(dL_dpos, dL_dscale, dL_drot)
}
fn normalize4(q: [f32; 4]) -> [f32; 4] {
let len = (q[0] * q[0] + q[1] * q[1] + q[2] * q[2] + q[3] * q[3]).sqrt();
if len < 1e-10 {
return [1.0, 0.0, 0.0, 0.0];
}
[q[0] / len, q[1] / len, q[2] / len, q[3] / len]
}
fn mat4x4_transform(m: &[f32; 16], p: [f32; 3]) -> [f32; 3] {
[
m[0] * p[0] + m[4] * p[1] + m[8] * p[2] + m[12],
m[1] * p[0] + m[5] * p[1] + m[9] * p[2] + m[13],
m[2] * p[0] + m[6] * p[1] + m[10] * p[2] + m[14],
]
}
fn quat_to_mat3(q: [f32; 4]) -> [[f32; 3]; 3] {
let (w, x, y, z) = (q[0], q[1], q[2], q[3]);
let x2 = x + x;
let y2 = y + y;
let z2 = z + z;
let xx = x * x2;
let xy = x * y2;
let xz = x * z2;
let yy = y * y2;
let yz = y * z2;
let zz = z * z2;
let wx = w * x2;
let wy = w * y2;
let wz = w * z2;
[
[1.0 - yy - zz, xy + wz, xz - wy],
[xy - wz, 1.0 - xx - zz, yz + wx],
[xz + wy, yz - wx, 1.0 - xx - yy],
]
}
#[allow(non_snake_case)]
fn dR_dquat(q: [f32; 4], dL_dR: [[f32; 3]; 3]) -> [f32; 4] {
let (w, x, y, z) = (q[0], q[1], q[2], q[3]);
let m = dL_dR;
let dL_dw = 2.0
* (m[0][1] * z + m[0][2] * (-y) + m[1][0] * (-z) + m[1][2] * x + m[2][0] * y
+ m[2][1] * (-x));
let dL_dx = 2.0
* (m[0][1] * y + m[0][2] * z + m[1][0] * y + m[1][1] * (-2.0 * x) + m[1][2] * w
+ m[2][0] * z
+ m[2][1] * (-w)
+ m[2][2] * (-2.0 * x));
let dL_dy = 2.0
* (m[0][0] * (-2.0 * y) + m[0][1] * x + m[0][2] * (-w) + m[1][0] * x + m[1][2] * z
+ m[2][0] * w
+ m[2][1] * z
+ m[2][2] * (-2.0 * y));
let dL_dz = 2.0
* (m[0][0] * (-2.0 * z) + m[0][1] * w + m[0][2] * x + m[1][0] * (-w)
+ m[1][1] * (-2.0 * z)
+ m[1][2] * y
+ m[2][0] * x
+ m[2][1] * y);
[dL_dw, dL_dx, dL_dy, dL_dz]
}
fn dot3(a: [f32; 3], b: [f32; 3]) -> f32 {
a[0] * b[0] + a[1] * b[1] + a[2] * b[2]
}
fn mat3_vec(m: [[f32; 3]; 3], v: [f32; 3]) -> [f32; 3] {
[
m[0][0] * v[0] + m[0][1] * v[1] + m[0][2] * v[2],
m[1][0] * v[0] + m[1][1] * v[1] + m[1][2] * v[2],
m[2][0] * v[0] + m[2][1] * v[1] + m[2][2] * v[2],
]
}
fn mat3_mul_transpose(a: [[f32; 3]; 3], b: [[f32; 3]; 3]) -> [[f32; 3]; 3] {
let mut c = [[0.0f32; 3]; 3];
for i in 0..3 {
for j in 0..3 {
for k in 0..3 {
c[i][j] += a[i][k] * b[j][k];
}
}
}
c
}