burn_trellis 0.1.0

Trellis2 integration utilities and runtime scaffolding for burn_synth
Documentation
use std::path::PathBuf;

use burn_trellis::hook_diff::{HookSnapshot, compute_stats};
use burn_trellis::preprocess::{PreprocessConfig, preprocess_image_path};

const MEAN_ABS_TOLERANCE: f32 = 0.0;
const MAX_ABS_TOLERANCE: f32 = 0.0;
const RMSE_TOLERANCE: f32 = 0.0;

#[test]
fn trellis2_preprocess_matches_reference_hook() -> Result<(), Box<dyn std::error::Error>> {
    let root = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
    let input_image = root.join("assets/hooks/trellis2_preprocess_input.png");
    let reference_hook = root.join("assets/hooks/trellis2_preprocess_reference.safetensors");
    if !input_image.exists() || !reference_hook.exists() {
        eprintln!(
            "Skipping preprocess hook correctness test: missing input or reference hook files."
        );
        return Ok(());
    }

    let output = preprocess_image_path(&input_image, PreprocessConfig::default())?;
    let output_shape = vec![output.height as usize, output.width as usize, 3usize];
    let output_data: Vec<f32> = output.rgb.iter().map(|&v| v as f32).collect();

    let reference = HookSnapshot::from_file(&reference_hook)?;
    let reference_tensor = reference
        .tensors
        .get("preprocess_image.output")
        .ok_or("missing preprocess_image.output tensor in reference hook")?;

    if reference_tensor.shape != output_shape {
        return Err(format!(
            "shape mismatch: reference {:?} vs burn {:?}",
            reference_tensor.shape, output_shape
        )
        .into());
    }
    if reference_tensor.data.len() != output_data.len() {
        return Err(format!(
            "data length mismatch: reference {} vs burn {}",
            reference_tensor.data.len(),
            output_data.len()
        )
        .into());
    }

    let stats = compute_stats(&output_data, &reference_tensor.data);
    eprintln!(
        "preprocess_image.output diff: mean_abs={:.6e}, max_abs={:.6e}, rmse={:.6e}",
        stats.mean_abs, stats.max_abs, stats.rmse
    );
    if stats.mean_abs > MEAN_ABS_TOLERANCE
        || stats.max_abs > MAX_ABS_TOLERANCE
        || stats.rmse > RMSE_TOLERANCE
    {
        return Err(format!(
            "preprocess hook mismatch: mean_abs={:.6e}, max_abs={:.6e}, rmse={:.6e}",
            stats.mean_abs, stats.max_abs, stats.rmse
        )
        .into());
    }
    Ok(())
}