burn_trellis 0.1.0

Trellis2 integration utilities and runtime scaffolding for burn_synth
Documentation
use std::collections::{BTreeMap, BTreeSet};
use std::fmt::{Display, Formatter};
use std::path::Path;

use half::{bf16, f16};
use safetensors::tensor::{Dtype, SafeTensors, TensorView};

#[derive(Debug, Clone)]
pub struct HookTensor {
    pub shape: Vec<usize>,
    pub data: Vec<f32>,
}

#[derive(Debug, Clone, Default)]
pub struct HookSnapshot {
    pub tensors: BTreeMap<String, HookTensor>,
    pub metadata: BTreeMap<String, String>,
}

impl HookSnapshot {
    pub fn from_file(path: impl AsRef<Path>) -> Result<Self, HookDiffError> {
        let path = path.as_ref();
        let bytes = std::fs::read(path).map_err(HookDiffError::Io)?;
        let (_, header_metadata) =
            SafeTensors::read_metadata(&bytes).map_err(HookDiffError::SafeTensors)?;
        let safetensors = SafeTensors::deserialize(&bytes).map_err(HookDiffError::SafeTensors)?;

        let mut tensors = BTreeMap::new();
        for name in safetensors.names() {
            let view = safetensors
                .tensor(name)
                .map_err(HookDiffError::SafeTensors)?;
            let tensor = tensor_view_to_hook_tensor(&view)?;
            tensors.insert(name.to_string(), tensor);
        }

        let metadata = header_metadata
            .metadata()
            .as_ref()
            .map(|pairs| {
                pairs
                    .iter()
                    .map(|(key, value)| (key.clone(), value.clone()))
                    .collect::<BTreeMap<_, _>>()
            })
            .unwrap_or_default();

        Ok(Self { tensors, metadata })
    }
}

#[derive(Debug)]
pub enum HookDiffError {
    Io(std::io::Error),
    SafeTensors(safetensors::SafeTensorError),
    Decode(String),
}

impl Display for HookDiffError {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        match self {
            Self::Io(err) => write!(f, "io error: {err}"),
            Self::SafeTensors(err) => write!(f, "safetensors error: {err}"),
            Self::Decode(err) => write!(f, "decode error: {err}"),
        }
    }
}

impl std::error::Error for HookDiffError {}

#[derive(Debug, Clone, Copy, Default)]
pub struct MetricStats {
    pub mean_abs: f32,
    pub max_abs: f32,
    pub rmse: f32,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HookDiffStatus {
    Match,
    MissingInActual,
    ShapeMismatch,
}

#[derive(Debug, Clone)]
pub struct HookDiffEntry {
    pub key: String,
    pub status: HookDiffStatus,
    pub reference_shape: Vec<usize>,
    pub actual_shape: Option<Vec<usize>>,
    pub stats: Option<MetricStats>,
}

#[derive(Debug, Clone, Default)]
pub struct HookDiffReport {
    pub entries: Vec<HookDiffEntry>,
    pub extra_in_actual: Vec<String>,
}

pub fn compare_hook_snapshots(
    reference: &HookSnapshot,
    actual: &HookSnapshot,
    prefix: Option<&str>,
) -> HookDiffReport {
    let mut keys = BTreeSet::new();
    for key in reference.tensors.keys() {
        if prefix.is_none_or(|p| key.starts_with(p)) {
            keys.insert(key.clone());
        }
    }

    let mut entries = Vec::with_capacity(keys.len());
    for key in keys {
        let Some(reference_tensor) = reference.tensors.get(&key) else {
            continue;
        };

        let actual_tensor = actual.tensors.get(&key);
        match actual_tensor {
            None => entries.push(HookDiffEntry {
                key,
                status: HookDiffStatus::MissingInActual,
                reference_shape: reference_tensor.shape.clone(),
                actual_shape: None,
                stats: None,
            }),
            Some(actual_tensor) if actual_tensor.shape != reference_tensor.shape => {
                entries.push(HookDiffEntry {
                    key,
                    status: HookDiffStatus::ShapeMismatch,
                    reference_shape: reference_tensor.shape.clone(),
                    actual_shape: Some(actual_tensor.shape.clone()),
                    stats: None,
                });
            }
            Some(actual_tensor) => entries.push(HookDiffEntry {
                key,
                status: HookDiffStatus::Match,
                reference_shape: reference_tensor.shape.clone(),
                actual_shape: Some(actual_tensor.shape.clone()),
                stats: Some(compute_stats(&actual_tensor.data, &reference_tensor.data)),
            }),
        }
    }

    let mut extra_in_actual: Vec<String> = actual
        .tensors
        .keys()
        .filter(|key| !reference.tensors.contains_key(*key))
        .filter(|key| prefix.is_none_or(|p| key.starts_with(p)))
        .cloned()
        .collect();
    extra_in_actual.sort();

    HookDiffReport {
        entries,
        extra_in_actual,
    }
}

pub fn compute_stats(actual: &[f32], reference: &[f32]) -> MetricStats {
    let len = actual.len().min(reference.len());
    if len == 0 {
        return MetricStats::default();
    }

    let mut sum_abs = 0.0f32;
    let mut max_abs = 0.0f32;
    let mut sum_sq = 0.0f32;

    for i in 0..len {
        let diff = actual[i] - reference[i];
        let abs = diff.abs();
        sum_abs += abs;
        max_abs = max_abs.max(abs);
        sum_sq += diff * diff;
    }

    let n = len as f32;
    MetricStats {
        mean_abs: sum_abs / n,
        max_abs,
        rmse: (sum_sq / n).sqrt(),
    }
}

fn tensor_view_to_hook_tensor(view: &TensorView<'_>) -> Result<HookTensor, HookDiffError> {
    let data = decode_tensor_data(view)?;
    Ok(HookTensor {
        shape: view.shape().to_vec(),
        data,
    })
}

fn decode_tensor_data(view: &TensorView<'_>) -> Result<Vec<f32>, HookDiffError> {
    let shape = view.shape();
    let numel = shape
        .iter()
        .fold(1usize, |acc, &value| acc.saturating_mul(value));
    let bytes = view.data();
    let dtype = view.dtype();
    let item_size = dtype_size(dtype);
    let expected = numel.saturating_mul(item_size);

    if bytes.len() != expected {
        return Err(HookDiffError::Decode(format!(
            "tensor byte length mismatch for dtype {:?}: expected {}, got {}",
            dtype,
            expected,
            bytes.len()
        )));
    }

    let mut out = Vec::with_capacity(numel);
    for chunk in bytes.chunks_exact(item_size) {
        let value = match dtype {
            Dtype::BOOL => {
                if chunk[0] == 0 {
                    0.0
                } else {
                    1.0
                }
            }
            Dtype::U8 => chunk[0] as f32,
            Dtype::I8 => (chunk[0] as i8) as f32,
            Dtype::I16 => i16::from_le_bytes([chunk[0], chunk[1]]) as f32,
            Dtype::U16 => u16::from_le_bytes([chunk[0], chunk[1]]) as f32,
            Dtype::I32 => i32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]) as f32,
            Dtype::U32 => u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]) as f32,
            Dtype::I64 => i64::from_le_bytes([
                chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6], chunk[7],
            ]) as f32,
            Dtype::U64 => u64::from_le_bytes([
                chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6], chunk[7],
            ]) as f32,
            Dtype::F16 => {
                let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
                f16::from_bits(bits).to_f32()
            }
            Dtype::BF16 => {
                let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
                bf16::from_bits(bits).to_f32()
            }
            Dtype::F32 => f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]),
            Dtype::F64 => {
                let value = f64::from_le_bytes([
                    chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6], chunk[7],
                ]);
                value as f32
            }
            other => {
                return Err(HookDiffError::Decode(format!(
                    "unsupported safetensors dtype: {:?}",
                    other
                )));
            }
        };
        out.push(value);
    }
    Ok(out)
}

fn dtype_size(dtype: Dtype) -> usize {
    match dtype {
        Dtype::BOOL | Dtype::U8 | Dtype::I8 => 1,
        Dtype::I16 | Dtype::U16 | Dtype::F16 | Dtype::BF16 => 2,
        Dtype::I32 | Dtype::U32 | Dtype::F32 => 4,
        Dtype::I64 | Dtype::U64 | Dtype::F64 => 8,
        _ => 0,
    }
}

#[cfg(test)]
mod tests {
    use super::{HookDiffStatus, HookSnapshot, compare_hook_snapshots, compute_stats};

    #[test]
    fn compute_stats_reports_expected_values() {
        let stats = compute_stats(&[1.0, 2.0, 3.0], &[1.5, 1.0, 3.0]);
        assert!((stats.mean_abs - 0.5).abs() < 1e-6);
        assert!((stats.max_abs - 1.0).abs() < 1e-6);
        assert!((stats.rmse - (1.25f32 / 3.0f32).sqrt()).abs() < 1e-6);
    }

    #[test]
    fn compare_reports_missing_and_shape_mismatch() {
        let mut reference = HookSnapshot::default();
        reference.tensors.insert(
            "a".to_string(),
            super::HookTensor {
                shape: vec![2],
                data: vec![0.0, 1.0],
            },
        );
        reference.tensors.insert(
            "b".to_string(),
            super::HookTensor {
                shape: vec![1, 2],
                data: vec![0.0, 1.0],
            },
        );

        let mut actual = HookSnapshot::default();
        actual.tensors.insert(
            "a".to_string(),
            super::HookTensor {
                shape: vec![2],
                data: vec![0.0, 2.0],
            },
        );
        actual.tensors.insert(
            "b".to_string(),
            super::HookTensor {
                shape: vec![2, 1],
                data: vec![0.0, 1.0],
            },
        );
        actual.tensors.insert(
            "extra".to_string(),
            super::HookTensor {
                shape: vec![1],
                data: vec![0.0],
            },
        );

        let report = compare_hook_snapshots(&reference, &actual, None);
        assert_eq!(report.entries.len(), 2);
        assert_eq!(report.extra_in_actual, vec!["extra".to_string()]);
        assert_eq!(report.entries[0].key, "a");
        assert_eq!(report.entries[0].status, HookDiffStatus::Match);
        assert!(report.entries[0].stats.is_some());
        assert_eq!(report.entries[1].key, "b");
        assert_eq!(report.entries[1].status, HookDiffStatus::ShapeMismatch);
        assert!(report.entries[1].stats.is_none());
    }
}