use super::holo_model::{HoloConfig, HoloModel};
use std::io::Write;
use std::path::Path;
const MAGIC: &[u8; 4] = b"HOLO";
const VERSION: u32 = 1;
pub fn save_holo(model: &HoloModel, path: &Path) -> Result<(), std::io::Error> {
let mut f = std::fs::File::create(path)?;
f.write_all(MAGIC)?;
f.write_all(&VERSION.to_le_bytes())?;
let config_json = config_to_json(&model.config);
let config_bytes = config_json.as_bytes();
f.write_all(&(config_bytes.len() as u32).to_le_bytes())?;
f.write_all(config_bytes)?;
f.write_all(&(model.anchors.len() as u32).to_le_bytes())?;
for anchor in &model.anchors {
for &v in anchor {
f.write_all(&v.to_le_bytes())?;
}
}
write_params(&mut f, &model.triplane.params_flat())?;
write_params(&mut f, &model.canonical.mlp.params_flat())?;
write_params(&mut f, &model.audio_encoder.params_flat())?;
write_params(&mut f, &model.cross_attention.params_flat())?;
write_params(&mut f, &model.deformation.params_flat())?;
Ok(())
}
pub fn load_holo(path: &Path) -> Result<HoloModel, HoloFormatError> {
let data = std::fs::read(path)?;
let mut offset = 0;
if &data[offset..offset + 4] != MAGIC {
return Err(HoloFormatError::BadMagic);
}
offset += 4;
let version = u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap());
if version != VERSION {
return Err(HoloFormatError::UnsupportedVersion(version));
}
offset += 4;
let config_len = u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()) as usize;
offset += 4;
let config_json = std::str::from_utf8(&data[offset..offset + config_len])
.map_err(|_| HoloFormatError::BadConfig)?;
let config = json_to_config(config_json).ok_or(HoloFormatError::BadConfig)?;
offset += config_len;
let num_anchors = u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()) as usize;
offset += 4;
let mut anchors = Vec::with_capacity(num_anchors);
for _ in 0..num_anchors {
let x = f32::from_le_bytes(data[offset..offset + 4].try_into().unwrap());
let y = f32::from_le_bytes(data[offset + 4..offset + 8].try_into().unwrap());
let z = f32::from_le_bytes(data[offset + 8..offset + 12].try_into().unwrap());
anchors.push([x, y, z]);
offset += 12;
}
let mut model = HoloModel::new(config);
model.anchors = anchors;
let tp = read_params(&data, &mut offset)?;
model.triplane.set_params_flat(&tp);
let cp = read_params(&data, &mut offset)?;
model.canonical.mlp.set_params_flat(&cp);
let ae = read_params(&data, &mut offset)?;
model.audio_encoder.set_params_flat(&ae);
let ca = read_params(&data, &mut offset)?;
model.cross_attention.set_params_flat(&ca);
let dp = read_params(&data, &mut offset)?;
model.deformation.set_params_flat(&dp);
Ok(model)
}
fn write_params(f: &mut std::fs::File, params: &[f32]) -> Result<(), std::io::Error> {
f.write_all(&(params.len() as u32).to_le_bytes())?;
for &v in params {
f.write_all(&v.to_le_bytes())?;
}
Ok(())
}
fn read_params(data: &[u8], offset: &mut usize) -> Result<Vec<f32>, HoloFormatError> {
let len = u32::from_le_bytes(data[*offset..*offset + 4].try_into().unwrap()) as usize;
*offset += 4;
let params: Vec<f32> = (0..len)
.map(|i| {
let o = *offset + i * 4;
f32::from_le_bytes(data[o..o + 4].try_into().unwrap())
})
.collect();
*offset += len * 4;
Ok(params)
}
fn config_to_json(config: &HoloConfig) -> String {
format!(
r#"{{"triplane_res":{},"feature_dim":{},"num_anchors":{},"mel_dim":{},"audio_window":{},"audio_latent_dim":{},"attn_head_dim":{}}}"#,
config.triplane_res,
config.feature_dim,
config.num_anchors,
config.mel_dim,
config.audio_window,
config.audio_latent_dim,
config.attn_head_dim,
)
}
fn json_to_config(json: &str) -> Option<HoloConfig> {
let get_u32 = |key: &str| -> Option<u32> {
let pattern = format!("\"{}\":", key);
let pos = json.find(&pattern)?;
let rest = &json[pos + pattern.len()..];
let end = rest.find(|c: char| c == ',' || c == '}')?;
rest[..end].trim().parse().ok()
};
Some(HoloConfig {
triplane_res: get_u32("triplane_res")?,
feature_dim: get_u32("feature_dim")? as usize,
num_anchors: get_u32("num_anchors")? as usize,
mel_dim: get_u32("mel_dim")? as usize,
audio_window: get_u32("audio_window")? as usize,
audio_latent_dim: get_u32("audio_latent_dim")? as usize,
attn_head_dim: get_u32("attn_head_dim")? as usize,
})
}
#[derive(Debug)]
pub enum HoloFormatError {
Io(std::io::Error),
BadMagic,
UnsupportedVersion(u32),
BadConfig,
}
impl From<std::io::Error> for HoloFormatError {
fn from(e: std::io::Error) -> Self {
HoloFormatError::Io(e)
}
}
impl std::fmt::Display for HoloFormatError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
HoloFormatError::Io(e) => write!(f, "IO error: {}", e),
HoloFormatError::BadMagic => write!(f, "not a .holo file"),
HoloFormatError::UnsupportedVersion(v) => write!(f, "unsupported version: {}", v),
HoloFormatError::BadConfig => write!(f, "bad config in .holo file"),
}
}
}
impl std::error::Error for HoloFormatError {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_save_load_roundtrip() {
let config = HoloConfig {
num_anchors: 50,
triplane_res: 4,
feature_dim: 8,
..Default::default()
};
let model = HoloModel::new(config);
let tmp = std::env::temp_dir().join("test_model.holo");
save_holo(&model, &tmp).unwrap();
let loaded = load_holo(&tmp).unwrap();
assert_eq!(loaded.anchors.len(), model.anchors.len());
assert_eq!(loaded.config.triplane_res, 4);
assert_eq!(loaded.config.feature_dim, 8);
assert_eq!(loaded.config.num_anchors, 50);
assert_eq!(
loaded.triplane.params_flat(),
model.triplane.params_flat()
);
std::fs::remove_file(&tmp).ok();
}
}