#![allow(dead_code)]
#[derive(Debug, Clone, PartialEq)]
pub enum Hdf5Dtype {
Float32,
Float64,
Int32,
Int64,
Uint8,
}
impl Hdf5Dtype {
pub fn itemsize(&self) -> usize {
match self {
Hdf5Dtype::Float32 | Hdf5Dtype::Int32 => 4,
Hdf5Dtype::Float64 | Hdf5Dtype::Int64 => 8,
Hdf5Dtype::Uint8 => 1,
}
}
}
#[derive(Debug, Clone)]
pub struct Hdf5Dataset {
pub name: String,
pub dtype: Hdf5Dtype,
pub shape: Vec<usize>,
}
impl Hdf5Dataset {
pub fn numel(&self) -> usize {
self.shape.iter().product::<usize>().max(1)
}
pub fn byte_size(&self) -> usize {
self.numel() * self.dtype.itemsize()
}
}
#[derive(Debug, Clone)]
pub struct Hdf5Group {
pub name: String,
pub datasets: Vec<Hdf5Dataset>,
pub subgroups: Vec<Hdf5Group>,
pub attributes: Vec<(String, String)>,
}
impl Hdf5Group {
pub fn new(name: &str) -> Self {
Self {
name: name.to_string(),
datasets: vec![],
subgroups: vec![],
attributes: vec![],
}
}
pub fn add_dataset(&mut self, ds: Hdf5Dataset) {
self.datasets.push(ds);
}
pub fn add_subgroup(&mut self, sg: Hdf5Group) {
self.subgroups.push(sg);
}
pub fn add_attribute(&mut self, key: &str, val: &str) {
self.attributes.push((key.to_string(), val.to_string()));
}
pub fn total_byte_size(&self) -> usize {
let ds_bytes: usize = self.datasets.iter().map(|d| d.byte_size()).sum();
let sub_bytes: usize = self.subgroups.iter().map(|g| g.total_byte_size()).sum();
ds_bytes + sub_bytes
}
}
#[derive(Debug, Clone)]
pub struct Hdf5WeightsExport {
pub filename: String,
pub root: Hdf5Group,
pub hdf5_version: String,
}
impl Default for Hdf5WeightsExport {
fn default() -> Self {
Self {
filename: String::new(),
root: Hdf5Group::new("/"),
hdf5_version: "1.12.0".to_string(),
}
}
}
pub fn new_hdf5_weights_export(filename: &str) -> Hdf5WeightsExport {
Hdf5WeightsExport {
filename: filename.to_string(),
..Default::default()
}
}
pub fn hdf5_group_count(export: &Hdf5WeightsExport) -> usize {
export.root.subgroups.len()
}
pub fn hdf5_data_size(export: &Hdf5WeightsExport) -> usize {
export.root.total_byte_size() + 2048
}
pub fn validate_hdf5_weights(export: &Hdf5WeightsExport) -> bool {
!export.filename.is_empty() && !export.root.subgroups.is_empty()
}
pub fn hdf5_summary_json(export: &Hdf5WeightsExport) -> String {
format!(
"{{\"file\":\"{}\",\"hdf5_version\":\"{}\",\"root_groups\":{},\"data_bytes\":{}}}",
export.filename,
export.hdf5_version,
hdf5_group_count(export),
hdf5_data_size(export)
)
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_export() -> Hdf5WeightsExport {
let mut e = new_hdf5_weights_export("model.h5");
let mut layer_group = Hdf5Group::new("dense_1");
layer_group.add_dataset(Hdf5Dataset {
name: "kernel".into(),
dtype: Hdf5Dtype::Float32,
shape: vec![512, 256],
});
layer_group.add_dataset(Hdf5Dataset {
name: "bias".into(),
dtype: Hdf5Dtype::Float32,
shape: vec![256],
});
layer_group.add_attribute("keras_version", "2.13.0");
e.root.add_subgroup(layer_group);
e
}
#[test]
fn filename_stored() {
let e = new_hdf5_weights_export("weights.h5");
assert_eq!(e.filename, "weights.h5");
}
#[test]
fn group_count() {
let e = sample_export();
assert_eq!(hdf5_group_count(&e), 1);
}
#[test]
fn validate_complete() {
let e = sample_export();
assert!(validate_hdf5_weights(&e));
}
#[test]
fn validate_empty_false() {
let e = new_hdf5_weights_export("model.h5");
assert!(!validate_hdf5_weights(&e));
}
#[test]
fn data_size_positive() {
let e = sample_export();
assert!(hdf5_data_size(&e) > 0);
}
#[test]
fn dtype_itemsize_f64() {
assert_eq!(Hdf5Dtype::Float64.itemsize(), 8);
}
#[test]
fn dataset_byte_size() {
let ds = Hdf5Dataset {
name: "w".into(),
dtype: Hdf5Dtype::Float32,
shape: vec![4, 4],
};
assert_eq!(ds.byte_size(), 64);
}
#[test]
fn summary_json_has_file() {
let e = sample_export();
let json = hdf5_summary_json(&e);
assert!(json.contains("model.h5"));
}
#[test]
fn attribute_stored() {
let e = sample_export();
let g = &e.root.subgroups[0];
assert_eq!(g.attributes[0].0, "keras_version");
}
}