use super::layout::ProjectionLayout;
use super::projection::{Projection, ProjectionId};
use super::splat::SplatProjection;
use super::kuramoto::KuramotoProjection;
use super::expert::ExpertProjection;
use super::graph::GraphProjection;
use super::thermal::ThermalProjection;
fn fnv1a(data: &[u8]) -> u32 {
let mut hash = 0x811c_9dc5u32;
for &byte in data {
hash ^= byte as u32;
hash = hash.wrapping_mul(0x0100_0193);
}
hash
}
#[derive(Debug, Clone)]
pub struct ComputeAtom {
pub buffer: Vec<u8>,
pub layout: ProjectionLayout,
pub shape_hash: u32,
}
impl ComputeAtom {
pub fn new(layout: ProjectionLayout) -> Self {
let buffer = vec![0u8; layout.stride];
let mut atom = Self {
buffer,
layout,
shape_hash: 0,
};
atom.recompute_hash();
atom
}
pub fn create_n(layout: &ProjectionLayout, n: usize) -> Vec<Self> {
(0..n).map(|_| Self::new(layout.clone())).collect()
}
pub fn read_projection<P: Projection>(&self) -> Option<P> {
let offset = self.layout.offset_of(P::id())?;
Some(P::read(&self.buffer[offset..]))
}
pub fn write_projection<P: Projection>(&mut self, proj: &P) -> bool {
if let Some(offset) = self.layout.offset_of(P::id()) {
proj.write(&mut self.buffer[offset..]);
self.recompute_hash();
true
} else {
false
}
}
pub fn projection_bytes(&self, id: ProjectionId) -> Option<&[u8]> {
let offset = self.layout.offset_of(id)?;
let size = match id {
ProjectionId::Splat => SplatProjection::byte_size(),
ProjectionId::Kuramoto => KuramotoProjection::byte_size(),
ProjectionId::Expert => ExpertProjection::byte_size(),
ProjectionId::Graph => GraphProjection::byte_size(),
ProjectionId::Thermal => ThermalProjection::byte_size(),
};
Some(&self.buffer[offset..offset + size])
}
fn recompute_hash(&mut self) {
self.shape_hash = fnv1a(&self.buffer);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_atom_full_layout() {
let atom = ComputeAtom::new(ProjectionLayout::full());
assert_eq!(atom.buffer.len(), 1612);
assert!(atom.read_projection::<SplatProjection>().is_some());
assert!(atom.read_projection::<KuramotoProjection>().is_some());
assert!(atom.read_projection::<ExpertProjection>().is_some());
assert!(atom.read_projection::<GraphProjection>().is_some());
assert!(atom.read_projection::<ThermalProjection>().is_some());
}
#[test]
fn test_atom_minimal_layout() {
let atom = ComputeAtom::new(ProjectionLayout::minimal());
assert!(atom.read_projection::<SplatProjection>().is_some());
assert!(atom.read_projection::<KuramotoProjection>().is_some());
assert!(atom.read_projection::<ExpertProjection>().is_none());
assert!(atom.read_projection::<GraphProjection>().is_none());
}
#[test]
fn test_atom_write_read_roundtrip() {
let mut atom = ComputeAtom::new(ProjectionLayout::full());
let k = KuramotoProjection {
theta: 1.5,
omega: 0.3,
coupling: 2.0,
};
assert!(atom.write_projection(&k));
let restored = atom.read_projection::<KuramotoProjection>().unwrap();
assert!((restored.theta - 1.5).abs() < 1e-6);
assert!((restored.omega - 0.3).abs() < 1e-6);
}
#[test]
fn test_atom_shape_hash_changes_on_write() {
let mut atom = ComputeAtom::new(ProjectionLayout::full());
let hash_before = atom.shape_hash;
let k = KuramotoProjection {
theta: 3.14,
omega: 1.0,
coupling: 1.0,
};
atom.write_projection(&k);
assert_ne!(atom.shape_hash, hash_before);
}
#[test]
fn test_create_n() {
let layout = ProjectionLayout::minimal();
let atoms = ComputeAtom::create_n(&layout, 100);
assert_eq!(atoms.len(), 100);
assert_eq!(atoms[0].buffer.len(), layout.stride);
}
}