burn_tripo 0.1.1

TripoSG(-scribble) implemented in burn
Documentation
use burn::prelude::Backend;

use crate::model::triposg::load_policy::{BpkPrecisionPreference, BurnpackLoadPolicy};
use crate::pipeline::mesh::Mesh;

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum DinoBackendChoice {
    Auto,
    Cpu,
    Gpu,
}

impl DinoBackendChoice {
    pub fn resolve<B: Backend>(self) -> Self {
        resolve_dino_backend::<B>(self)
    }
}

pub fn resolve_dino_backend<B: Backend>(requested: DinoBackendChoice) -> DinoBackendChoice {
    match requested {
        DinoBackendChoice::Auto => {
            if cfg!(target_arch = "wasm32") {
                if is_gpu_backend::<B>() {
                    DinoBackendChoice::Gpu
                } else {
                    DinoBackendChoice::Cpu
                }
            } else if is_gpu_backend::<B>() {
                DinoBackendChoice::Gpu
            } else {
                DinoBackendChoice::Cpu
            }
        }
        other => other,
    }
}

pub fn should_use_cpu_dino_backend<B: Backend>(requested: DinoBackendChoice) -> bool {
    matches!(resolve_dino_backend::<B>(requested), DinoBackendChoice::Cpu)
}

pub fn is_wgpu_backend<B: Backend>() -> bool {
    std::any::type_name::<B>()
        .to_ascii_lowercase()
        .contains("wgpu")
}

pub fn is_cuda_backend<B: Backend>() -> bool {
    std::any::type_name::<B>()
        .to_ascii_lowercase()
        .contains("cuda")
}

pub fn is_gpu_backend<B: Backend>() -> bool {
    is_wgpu_backend::<B>() || is_cuda_backend::<B>()
}

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct TripoSGRuntimeParityProfile {
    pub strict_dino_preprocess: bool,
    pub strict_rmbg_interp: bool,
    pub max_image_dim: Option<usize>,
    pub burnpack_policy: BurnpackLoadPolicy,
}

pub fn triposg_runtime_profile(
    fallback_max_image_dim: Option<usize>,
) -> TripoSGRuntimeParityProfile {
    let max_image_dim = fallback_max_image_dim.filter(|value| *value > 0 && *value != usize::MAX);
    let burnpack_policy =
        BurnpackLoadPolicy::default().with_precision(BpkPrecisionPreference::PreferF16);

    TripoSGRuntimeParityProfile {
        strict_dino_preprocess: true,
        strict_rmbg_interp: true,
        max_image_dim,
        burnpack_policy,
    }
}

pub fn should_prefer_f16_triposg_weights(parity: TripoSGRuntimeParityProfile) -> bool {
    parity.burnpack_policy.precision.prefer_f16()
}

pub fn decimate_tripo_mesh(mesh: &Mesh, target_faces: usize) -> Result<Mesh, String> {
    if target_faces == 0 || mesh.faces.len() <= target_faces {
        return Ok(mesh.clone());
    }
    if mesh.faces.is_empty() || mesh.vertices.is_empty() {
        return Ok(mesh.clone());
    }

    let mut indices = Vec::with_capacity(mesh.faces.len() * 3);
    for face in &mesh.faces {
        indices.push(face[0]);
        indices.push(face[1]);
        indices.push(face[2]);
    }
    let target_index_count = (target_faces.saturating_mul(3)).min(indices.len());
    if target_index_count < 3 {
        return Err("target face count too small for decimation".to_string());
    }

    let vertices_bytes = meshopt::typed_to_bytes(mesh.vertices.as_slice());
    let adapter =
        meshopt::VertexDataAdapter::new(vertices_bytes, std::mem::size_of::<[f32; 3]>(), 0)
            .map_err(|err| format!("meshopt vertex adapter: {err}"))?;

    // Start from low-error simplification and only relax if we can't reach target count.
    // This avoids aggressive topology collapse at moderate face budgets (e.g. 10k).
    let mut result_error = 0.0f32;
    let mut simplified = Vec::<u32>::new();
    for error_limit in [0.02f32, 0.05, 0.1, 0.25, 0.5, 1.0] {
        let mut stage_error = 0.0f32;
        let candidate = meshopt::simplify(
            &indices,
            &adapter,
            target_index_count,
            error_limit,
            meshopt::SimplifyOptions::None,
            Some(&mut stage_error),
        );
        if candidate.len() < 3 {
            continue;
        }
        result_error = stage_error;
        simplified = candidate;
        if simplified.len() <= target_index_count {
            break;
        }
    }
    if simplified.len() > target_index_count {
        simplified = meshopt::simplify_sloppy(
            &indices,
            &adapter,
            target_index_count,
            result_error.max(0.25),
            None,
        );
    }
    if simplified.len() < 3 {
        return Err("meshopt simplification produced empty mesh".to_string());
    }

    let (vertex_count, remap) =
        meshopt::generate_vertex_remap(mesh.vertices.as_slice(), Some(&simplified));
    let vertices = meshopt::remap_vertex_buffer(mesh.vertices.as_slice(), vertex_count, &remap);
    let indices = meshopt::remap_index_buffer(Some(&simplified), vertex_count, &remap);
    if indices.len() < 3 {
        return Err("meshopt remap produced empty mesh".to_string());
    }

    let faces = indices
        .chunks_exact(3)
        .map(|chunk| [chunk[0], chunk[1], chunk[2]])
        .collect::<Vec<[u32; 3]>>();
    Ok(Mesh { vertices, faces })
}

#[cfg(test)]
mod tests {
    use burn::backend::NdArray;

    use super::{
        DinoBackendChoice, decimate_tripo_mesh, resolve_dino_backend,
        should_prefer_f16_triposg_weights, triposg_runtime_profile,
    };

    #[test]
    fn auto_dino_backend_resolves_to_cpu_on_ndarray() {
        assert_eq!(
            resolve_dino_backend::<NdArray<f32>>(DinoBackendChoice::Auto),
            DinoBackendChoice::Cpu
        );
    }

    #[test]
    fn parity_profile_has_strict_defaults() {
        let profile = triposg_runtime_profile(Some(777));
        assert!(profile.strict_dino_preprocess);
        assert!(profile.strict_rmbg_interp);
        assert_eq!(profile.max_image_dim, Some(777));
        assert!(profile.burnpack_policy.precision.prefer_f16());
    }

    #[test]
    fn default_profile_prefers_f16_weights() {
        let profile = triposg_runtime_profile(None);
        assert!(should_prefer_f16_triposg_weights(profile));
    }

    #[test]
    fn explicit_f16_profile_is_respected() {
        use crate::model::triposg::load_policy::BpkPrecisionPreference;

        let mut profile = triposg_runtime_profile(None);
        profile.burnpack_policy = profile
            .burnpack_policy
            .with_precision(BpkPrecisionPreference::PreferF16);
        assert!(should_prefer_f16_triposg_weights(profile));
    }

    #[cfg(not(target_arch = "wasm32"))]
    #[test]
    fn auto_dino_backend_uses_gpu_on_wgpu() {
        type WgpuBackend = burn_wgpu::Wgpu<f32, i32, u32>;
        assert_eq!(
            resolve_dino_backend::<WgpuBackend>(DinoBackendChoice::Auto),
            DinoBackendChoice::Gpu
        );
    }

    #[test]
    fn decimation_reduces_faces_and_preserves_index_bounds() {
        let n = 24usize;
        let mut vertices = Vec::with_capacity((n + 1) * (n + 1));
        let mut faces = Vec::with_capacity(n * n * 2);
        for y in 0..=n {
            for x in 0..=n {
                vertices.push([x as f32, y as f32, 0.0]);
            }
        }
        for y in 0..n {
            for x in 0..n {
                let i0 = (y * (n + 1) + x) as u32;
                let i1 = i0 + 1;
                let i2 = i0 + (n + 1) as u32;
                let i3 = i2 + 1;
                faces.push([i0, i1, i3]);
                faces.push([i0, i3, i2]);
            }
        }
        let original_faces = faces.len();
        let mesh = crate::pipeline::mesh::Mesh { vertices, faces };
        let decimated = decimate_tripo_mesh(&mesh, 200).expect("decimation should succeed");
        assert!(decimated.faces.len() <= 200);
        assert!(!decimated.faces.is_empty());
        assert!(decimated.faces.len() < original_faces);
        let vertex_count = decimated.vertices.len() as u32;
        for face in &decimated.faces {
            assert!(face[0] < vertex_count);
            assert!(face[1] < vertex_count);
            assert!(face[2] < vertex_count);
        }
    }
}