burn_tripo 0.1.1

TripoSG(-scribble) implemented in burn
Documentation
use burn::prelude::*;
use burn::tensor::FloatDType;

#[derive(Config, Debug)]
pub struct RectifiedFlowSchedulerConfig {
    pub num_train_timesteps: usize,
    pub shift: f32,
    pub use_dynamic_shifting: bool,
}

impl RectifiedFlowSchedulerConfig {
    pub fn midi_3d() -> Self {
        Self {
            num_train_timesteps: 1000,
            shift: 2.0,
            use_dynamic_shifting: false,
        }
    }

    pub fn init(&self) -> RectifiedFlowScheduler {
        RectifiedFlowScheduler::new(self.clone())
    }

    #[cfg(feature = "import")]
    pub fn from_config_bytes(bytes: &[u8]) -> Result<Self, Box<dyn std::error::Error>> {
        let config: RectifiedFlowSchedulerConfigFile = serde_json::from_slice(bytes)?;
        Ok(Self {
            num_train_timesteps: config.num_train_timesteps.unwrap_or(1000),
            shift: config.shift.unwrap_or(1.0),
            use_dynamic_shifting: config.use_dynamic_shifting.unwrap_or(false),
        })
    }

    #[cfg(feature = "import")]
    pub fn from_config_file(
        path: impl AsRef<std::path::Path>,
    ) -> Result<Self, Box<dyn std::error::Error>> {
        let bytes = std::fs::read(path)?;
        Self::from_config_bytes(&bytes)
    }
}

#[derive(Debug, Clone)]
pub struct RectifiedFlowScheduler {
    pub config: RectifiedFlowSchedulerConfig,
    timesteps: Vec<f32>,
    sigmas: Vec<f32>,
    step_index: Option<usize>,
    begin_index: Option<usize>,
}

impl RectifiedFlowScheduler {
    pub fn new(config: RectifiedFlowSchedulerConfig) -> Self {
        let (timesteps, sigmas) = build_training_schedule(&config);
        Self {
            config,
            timesteps,
            sigmas,
            step_index: None,
            begin_index: None,
        }
    }

    pub fn timesteps(&self) -> &[f32] {
        &self.timesteps
    }

    pub fn sigmas(&self) -> &[f32] {
        &self.sigmas
    }

    pub fn step_index(&self) -> Option<usize> {
        self.step_index
    }

    pub fn begin_index(&self) -> Option<usize> {
        self.begin_index
    }

    pub fn set_begin_index(&mut self, begin_index: usize) {
        self.begin_index = Some(begin_index);
    }

    pub fn set_timesteps(
        &mut self,
        num_inference_steps: usize,
        timesteps: Option<Vec<f32>>,
        sigmas: Option<Vec<f32>>,
        mu: Option<f32>,
    ) -> Result<(), String> {
        if timesteps.is_some() && sigmas.is_some() {
            return Err("Only one of timesteps or sigmas can be passed".to_string());
        }

        let mut sigmas = if let Some(custom_sigmas) = sigmas {
            custom_sigmas
        } else if let Some(custom_timesteps) = timesteps {
            custom_timesteps
                .iter()
                .map(|t| self.t_to_sigma(*t))
                .collect()
        } else {
            build_inference_sigmas(self.config.num_train_timesteps, num_inference_steps)
        };

        if self.config.use_dynamic_shifting {
            let mu = mu.ok_or_else(|| {
                "mu must be provided when use_dynamic_shifting is enabled".to_string()
            })?;
            sigmas = sigmas
                .into_iter()
                .map(|sigma| self.time_shift_dynamic(mu, 1.0, sigma))
                .collect();
        } else {
            sigmas = sigmas
                .into_iter()
                .map(|sigma| self.time_shift(sigma))
                .collect();
        }

        self.timesteps = sigmas
            .iter()
            .map(|sigma| sigma * self.config.num_train_timesteps as f32)
            .collect();

        self.sigmas = sigmas;
        self.sigmas.push(0.0);

        self.step_index = None;
        self.begin_index = None;
        Ok(())
    }

    pub fn step<B: Backend>(
        &mut self,
        model_output: Tensor<B, 3>,
        timestep: f32,
        sample: Tensor<B, 3>,
    ) -> Tensor<B, 3> {
        if self.step_index.is_none() {
            self.init_step_index(timestep);
        }
        let step_index = self.step_index.unwrap_or(0);
        let sigma = *self.sigmas.get(step_index).unwrap_or(&0.0);
        let sigma_next = *self.sigmas.get(step_index + 1).unwrap_or(&0.0);

        let output_dtype: FloatDType = model_output.dtype().into();
        let needs_cast = output_dtype != FloatDType::F32;
        let sample_f32 = if needs_cast {
            sample.cast(FloatDType::F32)
        } else {
            sample
        };
        let model_f32 = if needs_cast {
            model_output.cast(FloatDType::F32)
        } else {
            model_output
        };
        let prev_sample = sample_f32.add(model_f32.mul_scalar(sigma - sigma_next));
        let prev_sample = if needs_cast {
            prev_sample.cast(output_dtype)
        } else {
            prev_sample
        };
        let max_step_index = self.sigmas.len().saturating_sub(1);
        let next_step_index = (step_index + 1).min(max_step_index);
        self.step_index = Some(next_step_index);
        prev_sample
    }

    pub fn scale_noise<B: Backend>(
        &self,
        original_samples: Tensor<B, 3>,
        noise: Tensor<B, 3>,
        timesteps: Tensor<B, 1>,
    ) -> Tensor<B, 3> {
        let sigmas = timesteps
            .div_scalar(self.config.num_train_timesteps as f32)
            .unsqueeze_dim::<2>(1)
            .unsqueeze_dim::<3>(2);
        let scaled = original_samples
            .clone()
            .sub(sigmas.clone() * original_samples);
        scaled.add(sigmas * noise)
    }

    pub fn scale_model_input<B: Backend>(&self, latents: Tensor<B, 3>) -> Tensor<B, 3> {
        latents
    }

    fn init_step_index(&mut self, timestep: f32) {
        if let Some(begin_index) = self.begin_index {
            self.step_index = Some(begin_index);
            return;
        }
        self.step_index = self.index_for_timestep(timestep);
    }

    fn index_for_timestep(&self, timestep: f32) -> Option<usize> {
        let indices: Vec<usize> = self
            .timesteps
            .iter()
            .enumerate()
            .filter_map(|(idx, value)| if *value == timestep { Some(idx) } else { None })
            .collect();
        if indices.is_empty() {
            return None;
        }
        if indices.len() > 1 {
            indices.get(1).copied()
        } else {
            indices.first().copied()
        }
    }

    fn time_shift(&self, t: f32) -> f32 {
        self.config.shift * t / (1.0 + (self.config.shift - 1.0) * t)
    }

    fn time_shift_dynamic(&self, mu: f32, sigma: f32, t: f32) -> f32 {
        let exp_mu = mu.exp();
        exp_mu / (exp_mu + (1.0 / t - 1.0).powf(sigma))
    }

    fn t_to_sigma(&self, timestep: f32) -> f32 {
        timestep / self.config.num_train_timesteps as f32
    }
}

#[cfg(feature = "import")]
#[derive(serde::Deserialize)]
struct RectifiedFlowSchedulerConfigFile {
    num_train_timesteps: Option<usize>,
    shift: Option<f32>,
    use_dynamic_shifting: Option<bool>,
}

fn build_training_schedule(config: &RectifiedFlowSchedulerConfig) -> (Vec<f32>, Vec<f32>) {
    let num_steps = config.num_train_timesteps.max(1);
    let mut sigmas = Vec::with_capacity(num_steps);
    let mut timesteps = Vec::with_capacity(num_steps);
    for i in 0..num_steps {
        let t = (1.0 - i as f32 / num_steps as f32) * num_steps as f32;
        let sigma = t / num_steps as f32;
        let sigma = if config.use_dynamic_shifting {
            sigma
        } else {
            config.shift * sigma / (1.0 + (config.shift - 1.0) * sigma)
        };
        timesteps.push(sigma * num_steps as f32);
        sigmas.push(sigma);
    }
    (timesteps, sigmas)
}

fn build_inference_sigmas(num_train_timesteps: usize, num_inference_steps: usize) -> Vec<f32> {
    let steps = num_inference_steps.max(1);
    let mut sigmas = Vec::with_capacity(steps);
    for i in 0..steps {
        let t = (1.0 - i as f32 / steps as f32) * num_train_timesteps as f32;
        sigmas.push(t / num_train_timesteps as f32);
    }
    sigmas
}