burn_tripo 0.1.0

TripoSG(-scribble) implemented in burn
Documentation
use burn::{module::Param, nn, prelude::*, tensor::Distribution};

use super::hooks::HookRecorder;

pub fn record_tensor<B: Backend, const D: usize>(
    hook: &mut Option<&mut HookRecorder>,
    name: &str,
    tensor: &Tensor<B, D>,
) {
    if let Some(hook) = hook.as_deref_mut() {
        hook.record_tensor(name, tensor);
    }
}

#[derive(Debug, Clone)]
pub struct FrequencyPositionalEmbedding {
    pub num_freq: usize,
    pub include_pi: bool,
}

impl FrequencyPositionalEmbedding {
    pub fn embed_dim(&self, input_dim: usize) -> usize {
        input_dim + input_dim * self.num_freq * 2
    }

    pub fn forward<B: Backend>(&self, coords: Tensor<B, 3>) -> Tensor<B, 3> {
        let scale_pi = if self.include_pi {
            core::f32::consts::PI
        } else {
            1.0
        };
        let device = coords.device();
        let mut freq_values = Vec::with_capacity(self.num_freq);
        for freq in 0..self.num_freq {
            freq_values.push(scale_pi * 2_f32.powi(freq as i32));
        }
        let freqs = Tensor::<B, 1>::from_floats(freq_values.as_slice(), &device);

        let [b, n, c] = coords.shape().dims();
        let freqs = freqs
            .reshape([1, 1, 1, self.num_freq])
            .expand([b as i64, n as i64, c as i64, -1]);
        let scaled = coords.clone().unsqueeze_dim::<4>(3).mul(freqs);
        let scaled = scaled.reshape([b, n, c * self.num_freq]);

        let sin = scaled.clone().sin();
        let cos = scaled.cos();
        Tensor::cat(vec![coords, sin, cos], 2)
    }
}

#[derive(Module, Debug)]
pub struct RmsNorm<B: Backend> {
    pub gamma: Param<Tensor<B, 1>>,
    epsilon: f32,
}

impl<B: Backend> RmsNorm<B> {
    pub fn new(d_model: usize, epsilon: f32, device: &B::Device) -> Self {
        let gamma = nn::Initializer::Ones.init([d_model], device);
        Self { gamma, epsilon }
    }

    pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
        let variance = input.clone().powf_scalar(2.0).mean_dim(D - 1);
        let input_norm = input.mul(variance.add_scalar(self.epsilon).sqrt().recip());
        input_norm.mul(self.gamma.val().unsqueeze())
    }
}

#[derive(Module, Debug)]
pub struct CrossAttention<B: Backend> {
    pub to_q: nn::Linear<B>,
    pub to_k: nn::Linear<B>,
    pub to_v: nn::Linear<B>,
    pub to_out: nn::Linear<B>,
    pub norm_cross: Option<nn::LayerNorm<B>>,
    pub norm_q: Option<RmsNorm<B>>,
    pub norm_k: Option<RmsNorm<B>>,
    pub num_heads: usize,
    pub head_dim: usize,
    pub scale: f32,
    pub is_cross_attention: bool,
    pub use_triposg_split: bool,
}

impl<B: Backend> CrossAttention<B> {
    #[allow(clippy::too_many_arguments)]
    pub fn new(
        device: &B::Device,
        dim: usize,
        context_dim: usize,
        num_heads: usize,
        use_norm_cross: bool,
        qk_norm: bool,
        qkv_bias: bool,
        use_triposg_split: bool,
        is_cross_attention: bool,
    ) -> Self {
        let head_dim = dim / num_heads;
        let scale = (head_dim as f32).powf(-0.5);
        let to_q = nn::LinearConfig::new(dim, dim)
            .with_bias(qkv_bias)
            .init(device);
        let to_k = nn::LinearConfig::new(context_dim, dim)
            .with_bias(qkv_bias)
            .init(device);
        let to_v = nn::LinearConfig::new(context_dim, dim)
            .with_bias(qkv_bias)
            .init(device);
        let to_out = nn::LinearConfig::new(dim, dim).with_bias(true).init(device);

        let norm_cross = if use_norm_cross {
            nn::LayerNormConfig::new(context_dim).init(device).into()
        } else {
            None
        };
        let norm_q = if qk_norm {
            // Diffusers TripoSG attention uses qk_norm="rms_norm" with eps=1e-6.
            RmsNorm::new(head_dim, 1e-6, device).into()
        } else {
            None
        };
        let norm_k = if qk_norm {
            RmsNorm::new(head_dim, 1e-6, device).into()
        } else {
            None
        };

        Self {
            to_q,
            to_k,
            to_v,
            to_out,
            norm_cross,
            norm_q,
            norm_k,
            num_heads,
            head_dim,
            scale,
            is_cross_attention,
            use_triposg_split,
        }
    }

    pub fn forward(
        &self,
        x: Tensor<B, 3>,
        context: Tensor<B, 3>,
        mut hook: Option<&mut HookRecorder>,
        hook_prefix: &str,
    ) -> Tensor<B, 3> {
        let [b, n, c] = x.shape().dims();

        let context = if let Some(norm_cross) = &self.norm_cross {
            norm_cross.forward(context)
        } else {
            context
        };

        let q = self.to_q.forward(x);
        let k = self.to_k.forward(context.clone());
        let v = self.to_v.forward(context);

        record_tensor(&mut hook, &format!("{hook_prefix}.q"), &q);
        record_tensor(&mut hook, &format!("{hook_prefix}.k"), &k);
        record_tensor(&mut hook, &format!("{hook_prefix}.v"), &v);

        let context_len = k.shape().dims::<3>()[1];
        let is_cross_attention = self.is_cross_attention || context_len != n;

        let (q, k, v) = if self.use_triposg_split {
            if is_cross_attention {
                let m = context_len;
                let q = q
                    .reshape([b, n, self.num_heads, self.head_dim])
                    .permute([0, 2, 1, 3]);
                let kv =
                    Tensor::cat(vec![k, v], 2).reshape([b, m, self.num_heads, self.head_dim * 2]);
                let k = kv
                    .clone()
                    .slice([0..b, 0..m, 0..self.num_heads, 0..self.head_dim])
                    .permute([0, 2, 1, 3]);
                let v = kv
                    .slice([
                        0..b,
                        0..m,
                        0..self.num_heads,
                        self.head_dim..(self.head_dim * 2),
                    ])
                    .permute([0, 2, 1, 3]);
                (q, k, v)
            } else {
                let qkv = Tensor::cat(vec![q, k, v], 2).reshape([
                    b,
                    n,
                    self.num_heads,
                    self.head_dim * 3,
                ]);
                let q = qkv
                    .clone()
                    .slice([0..b, 0..n, 0..self.num_heads, 0..self.head_dim])
                    .permute([0, 2, 1, 3]);
                let k = qkv
                    .clone()
                    .slice([
                        0..b,
                        0..n,
                        0..self.num_heads,
                        self.head_dim..(self.head_dim * 2),
                    ])
                    .permute([0, 2, 1, 3]);
                let v = qkv
                    .slice([
                        0..b,
                        0..n,
                        0..self.num_heads,
                        (self.head_dim * 2)..(self.head_dim * 3),
                    ])
                    .permute([0, 2, 1, 3]);
                (q, k, v)
            }
        } else {
            let q = q
                .reshape([b, n, self.num_heads, self.head_dim])
                .permute([0, 2, 1, 3]);
            let k = k
                .reshape([b, context_len, self.num_heads, self.head_dim])
                .permute([0, 2, 1, 3]);
            let v = v
                .reshape([b, context_len, self.num_heads, self.head_dim])
                .permute([0, 2, 1, 3]);
            (q, k, v)
        };

        let q = if let Some(norm_q) = &self.norm_q {
            norm_q.forward(q)
        } else {
            q
        };
        let k = if let Some(norm_k) = &self.norm_k {
            norm_k.forward(k)
        } else {
            k
        };

        let attn_logits = q.matmul(k.swap_dims(2, 3)).mul_scalar(self.scale);
        let attn_max = attn_logits.clone().max_dim(3);
        let attn_exp = attn_logits.sub(attn_max).exp();
        let attn_denom = attn_exp.clone().sum_dim(3);
        let attn = attn_exp.div(attn_denom);
        record_tensor(&mut hook, &format!("{hook_prefix}.attn"), &attn);

        let out = attn.matmul(v).permute([0, 2, 1, 3]).reshape([b, n, c]);
        let out = self.to_out.forward(out);
        record_tensor(&mut hook, &format!("{hook_prefix}.out"), &out);
        out
    }
}

#[derive(Module, Debug)]
pub struct FeedForward<B: Backend> {
    pub proj: nn::Linear<B>,
    pub out: nn::Linear<B>,
    activation: nn::Gelu,
    dropout: nn::Dropout,
}

impl<B: Backend> FeedForward<B> {
    pub fn new(device: &B::Device, dim: usize, hidden_dim: usize) -> Self {
        let proj = nn::LinearConfig::new(dim, hidden_dim)
            .with_bias(true)
            .init(device);
        let out = nn::LinearConfig::new(hidden_dim, dim)
            .with_bias(true)
            .init(device);
        let activation = nn::Gelu::new();
        let dropout = nn::DropoutConfig::new(0.0).init();

        Self {
            proj,
            out,
            activation,
            dropout,
        }
    }

    pub fn forward(
        &self,
        x: Tensor<B, 3>,
        mut hook: Option<&mut HookRecorder>,
        hook_prefix: &str,
    ) -> Tensor<B, 3> {
        let x = self.proj.forward(x);
        let x = self.activation.forward(x);
        let x = self.dropout.forward(x);
        let x = self.out.forward(x);
        let x = self.dropout.forward(x);
        record_tensor(&mut hook, hook_prefix, &x);
        x
    }
}

#[derive(Debug, Clone)]
pub struct DiagonalGaussianDistribution<B: Backend> {
    pub mean: Tensor<B, 3>,
    pub logvar: Tensor<B, 3>,
}

impl<B: Backend> DiagonalGaussianDistribution<B> {
    pub fn new(mean: Tensor<B, 3>, logvar: Tensor<B, 3>) -> Self {
        Self { mean, logvar }
    }

    pub fn sample(&self) -> Tensor<B, 3> {
        let std = self.logvar.clone().mul_scalar(0.5).exp();
        let noise =
            Tensor::<B, 3>::random(std.shape(), Distribution::Normal(0.0, 1.0), &std.device());
        self.mean.clone() + std * noise
    }

    pub fn mode(&self) -> Tensor<B, 3> {
        self.mean.clone()
    }
}