use glam::*;
use half::f16;
pub trait GaussianShConfig {
const FEATURE: &'static str;
type Field: bytemuck::Pod + bytemuck::Zeroable;
fn from_sh(sh: &[Vec3; 15]) -> Self::Field;
fn to_sh(field: &Self::Field) -> [Vec3; 15];
}
pub struct GaussianShSingleConfig;
impl GaussianShConfig for GaussianShSingleConfig {
const FEATURE: &'static str = "sh_single";
type Field = [Vec3; 15];
fn from_sh(sh: &[Vec3; 15]) -> Self::Field {
*sh
}
fn to_sh(field: &Self::Field) -> [Vec3; 15] {
*field
}
}
pub struct GaussianShHalfConfig;
impl GaussianShConfig for GaussianShHalfConfig {
const FEATURE: &'static str = "sh_half";
type Field = [f16; 3 * 15 + 1];
fn from_sh(sh: &[Vec3; 15]) -> Self::Field {
sh.iter()
.flat_map(|sh| sh.to_array())
.map(f16::from_f32)
.chain(std::iter::once(f16::from_f32(0.0)))
.collect::<Vec<_>>()
.try_into()
.expect("SH half")
}
fn to_sh(field: &Self::Field) -> [Vec3; 15] {
field
.chunks_exact(3)
.map(|chunk| {
Vec3::new(
f16::to_f32(chunk[0]),
f16::to_f32(chunk[1]),
f16::to_f32(chunk[2]),
)
})
.collect::<Vec<_>>()
.try_into()
.expect("SH half")
}
}
pub struct GaussianShNorm8Config;
impl GaussianShConfig for GaussianShNorm8Config {
const FEATURE: &'static str = "sh_norm8";
type Field = [i8; 3 * 15 + 3];
fn from_sh(sh: &[Vec3; 15]) -> Self::Field {
sh.iter()
.flat_map(|sh| sh.to_array())
.map(|v| (v * 127.0).clamp(-127.0, 127.0) as i8)
.chain(std::iter::repeat_n(0, 3))
.collect::<Vec<_>>()
.try_into()
.expect("SH norm8")
}
fn to_sh(field: &Self::Field) -> [Vec3; 15] {
field
.chunks_exact(3)
.take(15)
.map(|chunk| {
Vec3::new(
((chunk[0] as f32) / 127.0).max(-1.0),
((chunk[1] as f32) / 127.0).max(-1.0),
((chunk[2] as f32) / 127.0).max(-1.0),
)
})
.collect::<Vec<_>>()
.try_into()
.expect("SH norm8")
}
}
pub struct GaussianShNoneConfig;
impl GaussianShConfig for GaussianShNoneConfig {
const FEATURE: &'static str = "sh_none";
type Field = ();
fn from_sh(_sh: &[Vec3; 15]) -> Self::Field {}
fn to_sh(_field: &Self::Field) -> [Vec3; 15] {
panic!("Cannot convert from SH None configuration")
}
}
pub trait GaussianCov3dConfig {
const FEATURE: &'static str;
type Field: bytemuck::Pod + bytemuck::Zeroable;
fn from_rot_scale(rot: Quat, scale: Vec3) -> Self::Field;
fn to_rot_scale(field: &Self::Field) -> (Quat, Vec3);
}
pub struct GaussianCov3dRotScaleConfig;
impl GaussianCov3dConfig for GaussianCov3dRotScaleConfig {
const FEATURE: &'static str = "cov3d_rot_scale";
type Field = [f32; 7];
fn from_rot_scale(rot: Quat, scale: Vec3) -> Self::Field {
[rot.x, rot.y, rot.z, rot.w, scale.x, scale.y, scale.z]
}
fn to_rot_scale(field: &Self::Field) -> (Quat, Vec3) {
(
Quat::from_xyzw(field[0], field[1], field[2], field[3]),
Vec3::new(field[4], field[5], field[6]),
)
}
}
pub struct GaussianCov3dSingleConfig;
impl GaussianCov3dConfig for GaussianCov3dSingleConfig {
const FEATURE: &'static str = "cov3d_single";
type Field = [f32; 6];
fn from_rot_scale(rot: Quat, scale: Vec3) -> Self::Field {
let r = Mat3::from_quat(rot);
let s = Mat3::from_diagonal(scale);
let m = r * s;
let sigma = m * m.transpose();
[
sigma.x_axis.x,
sigma.x_axis.y,
sigma.x_axis.z,
sigma.y_axis.y,
sigma.y_axis.z,
sigma.z_axis.z,
]
}
fn to_rot_scale(_field: &Self::Field) -> (Quat, Vec3) {
panic!("Cannot convert from Cov3d Single configuration")
}
}
pub struct GaussianCov3dHalfConfig;
impl GaussianCov3dConfig for GaussianCov3dHalfConfig {
const FEATURE: &'static str = "cov3d_half";
type Field = [f16; 6];
fn from_rot_scale(rot: Quat, scale: Vec3) -> Self::Field {
GaussianCov3dSingleConfig::from_rot_scale(rot, scale).map(f16::from_f32)
}
fn to_rot_scale(_field: &Self::Field) -> (Quat, Vec3) {
panic!("Cannot convert from Cov3d Half configuration")
}
}