burn_trellis 0.1.0

Trellis2 integration utilities and runtime scaffolding for burn_synth
Documentation
use std::path::PathBuf;
use std::time::{SystemTime, UNIX_EPOCH};

use burn_trellis::hook_diff::HookSnapshot;
use burn_trellis::mesh::load_obj_mesh;
use burn_trellis::pipeline::{Trellis2Pipeline, Trellis2PipelineConfig, TrellisRunOptions};
use image::{ImageBuffer, Rgba};

#[test]
fn native_pipeline_writes_mesh_and_preprocess_hook() {
    let unique = SystemTime::now()
        .duration_since(UNIX_EPOCH)
        .expect("clock drift")
        .as_nanos();
    let root = std::env::temp_dir().join(format!("burn_trellis_native_test_{unique}"));
    let weights_root = root.join("weights");
    let ckpts = weights_root.join("ckpts");
    std::fs::create_dir_all(&ckpts).expect("failed to create weights directory");

    std::fs::write(
        weights_root.join("pipeline.json"),
        r#"{
            "args": {
                "models": {
                    "shape": "ckpts/shape_model"
                }
            }
        }"#,
    )
    .expect("failed to write pipeline json");
    std::fs::write(ckpts.join("shape_model.json"), "{}").expect("failed to write model json");
    std::fs::write(ckpts.join("shape_model.bpk"), "fake").expect("failed to write model bpk");

    let input = root.join("input.png");
    let image = ImageBuffer::from_fn(8, 8, |x, y| {
        if (2..=5).contains(&x) && (1..=6).contains(&y) {
            Rgba([200u8, 120u8, 40u8, 255u8])
        } else {
            Rgba([0u8, 0u8, 0u8, 0u8])
        }
    });
    image.save(&input).expect("failed to save test image");

    let output = root.join("out.obj");
    let hook = root.join("hook.safetensors");

    let pipeline = Trellis2Pipeline::new(Trellis2PipelineConfig {
        weights_root,
        image_large_root: None,
    })
    .expect("pipeline should initialize");
    pipeline
        .validate_runtime()
        .expect("runtime assets should validate");

    pipeline
        .infer_mesh_to_obj(
            &input,
            &output,
            &TrellisRunOptions {
                hook_output: Some(hook.clone()),
                ..TrellisRunOptions::default()
            },
        )
        .expect("native pipeline should emit a mesh");
    assert!(output.exists(), "mesh output was not emitted");
    assert!(hook.exists(), "hook file was not emitted");

    let mesh = load_obj_mesh(&output).expect("failed to parse output OBJ");
    assert!(!mesh.vertices.is_empty(), "output mesh has no vertices");
    assert!(!mesh.faces.is_empty(), "output mesh has no faces");

    let snapshot = HookSnapshot::from_file(&hook).expect("failed to read hook");
    assert!(
        snapshot.tensors.contains_key("preprocess_image.output"),
        "preprocess hook tensor missing"
    );

    let _ = std::fs::remove_dir_all(root);
}

#[test]
fn runtime_validation_reports_missing_assets() {
    let unique = SystemTime::now()
        .duration_since(UNIX_EPOCH)
        .expect("clock drift")
        .as_nanos();
    let root = std::env::temp_dir().join(format!("burn_trellis_validate_test_{unique}"));
    let weights_root = root.join("weights");
    std::fs::create_dir_all(&weights_root).expect("failed to create test directory");

    std::fs::write(
        weights_root.join("pipeline.json"),
        r#"{
            "args": {
                "models": {
                    "shape": "ckpts/missing_shape"
                }
            }
        }"#,
    )
    .expect("failed to write pipeline json");

    let pipeline = Trellis2Pipeline::new(Trellis2PipelineConfig {
        weights_root: PathBuf::from(&weights_root),
        image_large_root: None,
    })
    .expect("pipeline should initialize");

    let err = pipeline
        .validate_runtime()
        .expect_err("validation should fail when model assets are missing");
    let message = err.to_string();
    assert!(message.contains("incomplete"));
    assert!(message.contains("missing_shape"));

    let _ = std::fs::remove_dir_all(root);
}