use super::projection::{Projection, ProjectionId};
pub const EMBEDDING_DIM: usize = 384;
const SPLAT_BYTES: usize = EMBEDDING_DIM * 4;
#[derive(Debug, Clone)]
pub struct SplatProjection {
pub embedding: [f32; EMBEDDING_DIM],
}
impl Default for SplatProjection {
fn default() -> Self {
Self {
embedding: [0.0; EMBEDDING_DIM],
}
}
}
impl Projection for SplatProjection {
fn byte_size() -> usize {
SPLAT_BYTES
}
fn id() -> ProjectionId {
ProjectionId::Splat
}
fn read(buf: &[u8]) -> Self {
assert!(buf.len() >= SPLAT_BYTES, "SplatProjection: buffer too small");
let mut embedding = [0.0f32; EMBEDDING_DIM];
for i in 0..EMBEDDING_DIM {
let offset = i * 4;
embedding[i] = f32::from_le_bytes([
buf[offset],
buf[offset + 1],
buf[offset + 2],
buf[offset + 3],
]);
}
Self { embedding }
}
fn write(&self, buf: &mut [u8]) {
assert!(buf.len() >= SPLAT_BYTES, "SplatProjection: buffer too small");
for i in 0..EMBEDDING_DIM {
let bytes = self.embedding[i].to_le_bytes();
let offset = i * 4;
buf[offset..offset + 4].copy_from_slice(&bytes);
}
}
fn shape_hash_contribution(&self) -> u32 {
let mut hash = 0x811c_9dc5u32;
for &v in &self.embedding[..16.min(EMBEDDING_DIM)] {
let bits = v.to_bits();
for byte in bits.to_le_bytes() {
hash ^= byte as u32;
hash = hash.wrapping_mul(0x0100_0193);
}
}
hash
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_splat_byte_size() {
assert_eq!(SplatProjection::byte_size(), 1536);
}
#[test]
fn test_splat_roundtrip() {
let mut proj = SplatProjection::default();
proj.embedding[0] = 1.0;
proj.embedding[100] = -0.5;
proj.embedding[383] = 0.42;
let mut buf = vec![0u8; SplatProjection::byte_size()];
proj.write(&mut buf);
let restored = SplatProjection::read(&buf);
assert!((restored.embedding[0] - 1.0).abs() < 1e-6);
assert!((restored.embedding[100] - (-0.5)).abs() < 1e-6);
assert!((restored.embedding[383] - 0.42).abs() < 1e-6);
}
#[test]
fn test_splat_shape_hash_varies() {
let a = SplatProjection::default();
let mut b = SplatProjection::default();
b.embedding[0] = 1.0;
assert_ne!(a.shape_hash_contribution(), b.shape_hash_contribution());
}
}