burn_tripo 0.1.0

TripoSG(-scribble) implemented in burn
Documentation
use burn::{prelude::*, tensor::Distribution};

#[cfg(feature = "import")]
use crate::model::triposg::scheduler::RectifiedFlowSchedulerConfig;
use crate::model::triposg::{
    dit::TripoSGDiT,
    image_encoder::{DinoImageProcessor, TripoSGImageEncoder},
    scheduler::RectifiedFlowScheduler,
    vae::TripoSGVae,
};
use crate::pipeline::geometry::{
    FlashExtractConfig, HierarchicalExtractConfig, flash_extract_geometry,
    hierarchical_extract_geometry,
};
use crate::pipeline::mesh::{DenseGrid, Mesh, grid_to_mesh, sdf_to_mesh_diff_dmc};
use crate::pipeline::triposg::TripoSGPipelineOutput;

#[derive(Debug)]
pub struct TripoSGScribblePipeline<B: Backend> {
    pub vae: TripoSGVae<B>,
    pub transformer: TripoSGDiT<B>,
    pub scheduler: RectifiedFlowScheduler,
    pub image_encoder: TripoSGImageEncoder<B>,
    pub image_processor: DinoImageProcessor,
}

#[derive(Debug)]
pub struct TripoSGScribbleMeshOutput<B: Backend> {
    pub latents: Tensor<B, 3>,
    pub grid: DenseGrid,
    pub mesh: Option<Mesh>,
}

impl<B: Backend> TripoSGScribblePipeline<B> {
    pub fn new(
        vae: TripoSGVae<B>,
        transformer: TripoSGDiT<B>,
        scheduler: RectifiedFlowScheduler,
        image_encoder: TripoSGImageEncoder<B>,
        image_processor: DinoImageProcessor,
    ) -> Self {
        Self {
            vae,
            transformer,
            scheduler,
            image_encoder,
            image_processor,
        }
    }

    pub fn encode_image(&self, image: Tensor<B, 4>) -> Tensor<B, 3> {
        let image = self.image_processor.preprocess(image);
        self.image_encoder.forward(image)
    }

    pub fn prepare_latents(
        &self,
        batch_size: usize,
        num_tokens: usize,
        num_channels: usize,
        device: &B::Device,
        latents: Option<Tensor<B, 3>>,
    ) -> Tensor<B, 3> {
        if let Some(latents) = latents {
            return latents;
        }
        Tensor::<B, 3>::random(
            [batch_size as i32, num_tokens as i32, num_channels as i32],
            Distribution::Normal(0.0, 1.0),
            device,
        )
    }

    #[allow(clippy::too_many_arguments)]
    pub fn sample(
        &mut self,
        image: Tensor<B, 4>,
        text_embeds: Tensor<B, 3>,
        num_inference_steps: usize,
        num_tokens: usize,
        guidance_scale: f32,
        query_coords: Option<Tensor<B, 3>>,
        latents: Option<Tensor<B, 3>>,
    ) -> TripoSGPipelineOutput<B> {
        let image_embeds = self.encode_image(image);
        self.sample_with_embeddings(
            text_embeds,
            image_embeds,
            num_inference_steps,
            num_tokens,
            guidance_scale,
            query_coords,
            latents,
        )
    }

    #[allow(clippy::too_many_arguments)]
    pub fn sample_with_embeddings(
        &mut self,
        text_embeds: Tensor<B, 3>,
        image_embeds: Tensor<B, 3>,
        num_inference_steps: usize,
        num_tokens: usize,
        guidance_scale: f32,
        query_coords: Option<Tensor<B, 3>>,
        latents: Option<Tensor<B, 3>>,
    ) -> TripoSGPipelineOutput<B> {
        let device = text_embeds.device();
        let batch_size = text_embeds.shape().dims::<3>()[0];
        let do_guidance = guidance_scale > 1.0;

        let text_embeds = if do_guidance {
            let zeros = Tensor::<B, 3>::zeros(text_embeds.shape(), &device);
            Tensor::cat(vec![zeros, text_embeds], 0)
        } else {
            text_embeds
        };

        let image_embeds = if do_guidance {
            let zeros = Tensor::<B, 3>::zeros(image_embeds.shape(), &device);
            Tensor::cat(vec![zeros, image_embeds], 0)
        } else {
            image_embeds
        };

        self.scheduler
            .set_timesteps(num_inference_steps, None, None, None)
            .expect("failed to set timesteps");

        let num_channels = self.transformer.config().in_channels;
        let mut latents =
            self.prepare_latents(batch_size, num_tokens, num_channels, &device, latents);

        let timesteps = self.scheduler.timesteps().to_vec();
        for &t in timesteps.iter() {
            let latent_model_input = if do_guidance {
                Tensor::cat(vec![latents.clone(), latents.clone()], 0)
            } else {
                latents.clone()
            };
            let model_batch = latent_model_input.shape().dims::<3>()[0];
            let timestep_values = vec![t; model_batch];
            let timestep = Tensor::<B, 1>::from_floats(timestep_values.as_slice(), &device);

            let mut noise_pred = self.transformer.forward(
                latent_model_input,
                timestep,
                text_embeds.clone(),
                Some(image_embeds.clone()),
                None,
            );

            if do_guidance {
                let half = batch_size;
                let noise_uncond =
                    noise_pred
                        .clone()
                        .slice([0..half, 0..num_tokens, 0..num_channels]);
                let noise_cond =
                    noise_pred.slice([half..(half * 2), 0..num_tokens, 0..num_channels]);
                noise_pred =
                    noise_uncond.clone() + (noise_cond - noise_uncond).mul_scalar(guidance_scale);
            }

            latents = self.scheduler.step(noise_pred, t, latents);
        }

        let decoded = query_coords.map(|coords| self.vae.decode(coords, latents.clone(), None));
        TripoSGPipelineOutput { latents, decoded }
    }

    pub fn decode_grid(
        &self,
        latents: Tensor<B, 3>,
        bounds: [f32; 6],
        resolution: usize,
        chunk_size: usize,
    ) -> Result<DenseGrid, Box<dyn std::error::Error>> {
        let resolution = resolution.max(2);
        let chunk_size = chunk_size.max(1);
        let values = super::triposg::decode_grid_values(
            &latents, &self.vae, bounds, resolution, chunk_size,
        )?;

        Ok(DenseGrid {
            values,
            size: [resolution, resolution, resolution],
            bounds,
        })
    }

    #[allow(clippy::too_many_arguments)]
    pub fn sample_mesh(
        &mut self,
        image: Tensor<B, 4>,
        text_embeds: Tensor<B, 3>,
        num_inference_steps: usize,
        num_tokens: usize,
        guidance_scale: f32,
        bounds: [f32; 6],
        resolution: usize,
        chunk_size: usize,
        latents: Option<Tensor<B, 3>>,
    ) -> Result<TripoSGScribbleMeshOutput<B>, Box<dyn std::error::Error>> {
        let output = self.sample(
            image,
            text_embeds,
            num_inference_steps,
            num_tokens,
            guidance_scale,
            None,
            latents,
        );
        let grid = self.decode_grid(output.latents.clone(), bounds, resolution, chunk_size)?;
        let mesh = grid_to_mesh(&grid, 0.0);
        Ok(TripoSGScribbleMeshOutput {
            latents: output.latents,
            grid,
            mesh,
        })
    }

    #[allow(clippy::too_many_arguments)]
    pub fn sample_mesh_hierarchical(
        &mut self,
        image: Tensor<B, 4>,
        text_embeds: Tensor<B, 3>,
        num_inference_steps: usize,
        num_tokens: usize,
        guidance_scale: f32,
        config: &HierarchicalExtractConfig,
        latents: Option<Tensor<B, 3>>,
    ) -> Result<TripoSGScribbleMeshOutput<B>, Box<dyn std::error::Error>> {
        let output = self.sample(
            image,
            text_embeds,
            num_inference_steps,
            num_tokens,
            guidance_scale,
            None,
            latents,
        );
        let grid = hierarchical_extract_geometry(output.latents.clone(), &self.vae, config)?;
        let mesh = grid_to_mesh(&grid, 0.0);
        Ok(TripoSGScribbleMeshOutput {
            latents: output.latents,
            grid,
            mesh,
        })
    }

    #[allow(clippy::too_many_arguments)]
    pub fn sample_mesh_flash(
        &mut self,
        image: Tensor<B, 4>,
        text_embeds: Tensor<B, 3>,
        num_inference_steps: usize,
        num_tokens: usize,
        guidance_scale: f32,
        config: &FlashExtractConfig,
        latents: Option<Tensor<B, 3>>,
    ) -> Result<TripoSGScribbleMeshOutput<B>, Box<dyn std::error::Error>> {
        let output = self.sample(
            image,
            text_embeds,
            num_inference_steps,
            num_tokens,
            guidance_scale,
            None,
            latents,
        );
        let grid = flash_extract_geometry(output.latents.clone(), &self.vae, config)?;
        let mesh = sdf_to_mesh_diff_dmc(&grid);
        Ok(TripoSGScribbleMeshOutput {
            latents: output.latents,
            grid,
            mesh,
        })
    }
}

#[cfg(feature = "import")]
impl<B: Backend> TripoSGScribblePipeline<B> {
    pub fn from_pretrained(
        weights_root: impl AsRef<std::path::Path>,
        device: &B::Device,
    ) -> Result<Self, Box<dyn std::error::Error>> {
        use crate::model::triposg::dit::import::{
            load_triposg_dit, load_triposg_dit_from_safetensors,
        };
        use crate::model::triposg::image_encoder::import::{
            load_dinov2_processor, load_triposg_dinov2, load_triposg_dinov2_from_safetensors,
        };
        use crate::model::triposg::vae::import::{
            load_triposg_vae, load_triposg_vae_from_safetensors,
        };

        let root = weights_root.as_ref();
        let vae_path = root.join("vae/diffusion_pytorch_model.safetensors");
        let dit_path = root.join("transformer/diffusion_pytorch_model.safetensors");
        let scheduler_path = root.join("scheduler/scheduler_config.json");
        let dino_path = root.join("image_encoder_dinov2/model.safetensors");
        let use_safetensors = false;

        let vae_config_path = root.join("vae/config.json");
        let vae_config =
            crate::model::triposg::vae::TripoSGVaeConfig::from_config_file(vae_config_path)
                .unwrap_or_else(|_| crate::model::triposg::vae::TripoSGVaeConfig::midi_3d());
        let vae = if use_safetensors {
            load_triposg_vae_from_safetensors(&vae_config, device, &vae_path)?
        } else {
            load_triposg_vae(&vae_config, device, &vae_path)?
        };

        let dit_config_path = root.join("transformer/config.json");
        let mut dit_configs = Vec::new();
        if let Ok(config) =
            crate::model::triposg::dit::TripoSGDiTConfig::from_config_file(&dit_config_path)
        {
            dit_configs.push(config);
        } else if dit_path.exists() {
            dit_configs.push(crate::model::triposg::dit::TripoSGDiTConfig::triposg_pretrained());
        } else {
            dit_configs.push(crate::model::triposg::dit::TripoSGDiTConfig::midi_3d());
        }

        push_unique_dit_config(
            &mut dit_configs,
            crate::model::triposg::dit::TripoSGDiTConfig::triposg_pretrained(),
        );
        push_unique_dit_config(
            &mut dit_configs,
            crate::model::triposg::dit::TripoSGDiTConfig::midi_3d(),
        );

        let mut load_errors = Vec::new();
        let mut dit = None;
        for config in dit_configs {
            let loaded = if use_safetensors {
                load_triposg_dit_from_safetensors(&config, device, &dit_path)
            } else {
                load_triposg_dit(&config, device, &dit_path)
            };
            match loaded {
                Ok(model) => {
                    dit = Some(model);
                    break;
                }
                Err(err) => {
                    load_errors.push(format!(
                        "cross_attention_dim={} cross_attention_2_dim={:?}: {err}",
                        config.cross_attention_dim, config.cross_attention_2_dim
                    ));
                }
            }
        }
        let Some(dit) = dit else {
            return Err(format!(
                "failed to load TripoSG DiT with all known configs:\n{}",
                load_errors.join("\n")
            )
            .into());
        };

        let scheduler_config = RectifiedFlowSchedulerConfig::from_config_file(scheduler_path)
            .unwrap_or_else(|_| RectifiedFlowSchedulerConfig::midi_3d());
        let scheduler = RectifiedFlowScheduler::new(scheduler_config);

        let image_encoder = if use_safetensors {
            load_triposg_dinov2_from_safetensors(device, &dino_path)?
        } else {
            load_triposg_dinov2(device, &dino_path)?
        };
        let mut image_processor = load_dinov2_processor(root)?;
        image_processor.set_strict_preprocess(true);

        Ok(Self::new(
            vae,
            dit,
            scheduler,
            image_encoder,
            image_processor,
        ))
    }
}

#[cfg(feature = "import")]
fn push_unique_dit_config(
    configs: &mut Vec<crate::model::triposg::dit::TripoSGDiTConfig>,
    candidate: crate::model::triposg::dit::TripoSGDiTConfig,
) {
    if !configs.iter().any(|existing| {
        existing.in_channels == candidate.in_channels
            && existing.width == candidate.width
            && existing.num_layers == candidate.num_layers
            && existing.num_attention_heads == candidate.num_attention_heads
            && existing.cross_attention_dim == candidate.cross_attention_dim
            && existing.cross_attention_2_dim == candidate.cross_attention_2_dim
    }) {
        configs.push(candidate);
    }
}