Skip to main content

oxigaf_diffusion/
camera.rs

1//! Camera-pose conditioning MLP.
2//!
3//! Embeds the flattened 4×3 extrinsics matrix (12 floats) into the same
4//! dimension as the U-Net time embedding so it can be added to the timestep
5//! conditioning signal.
6
7use candle_core::{Result, Tensor};
8use candle_nn as nn;
9use candle_nn::Module;
10
11/// MLP that lifts a flat camera-pose vector to the time-embedding dimension.
12#[derive(Debug)]
13pub struct CameraEmbedding {
14    linear1: nn::Linear,
15    linear2: nn::Linear,
16}
17
18impl CameraEmbedding {
19    /// Create a new camera embedding MLP.
20    ///
21    /// - `pose_dim`: input dimension (typically 12 for a flattened 4×3 matrix).
22    /// - `embed_dim`: output dimension (should match the time-embedding dim).
23    pub fn new(vs: nn::VarBuilder, pose_dim: usize, embed_dim: usize) -> Result<Self> {
24        let linear1 = nn::linear(pose_dim, embed_dim, vs.pp("linear1"))?;
25        let linear2 = nn::linear(embed_dim, embed_dim, vs.pp("linear2"))?;
26        Ok(Self { linear1, linear2 })
27    }
28}
29
30impl Module for CameraEmbedding {
31    fn forward(&self, pose: &Tensor) -> Result<Tensor> {
32        let h = self.linear1.forward(pose)?.silu()?;
33        self.linear2.forward(&h)
34    }
35}
36
37/// Build sinusoidal timestep embeddings (same as Stable Diffusion).
38///
39/// - `timesteps`: 1-D tensor of shape `(B,)` containing integer timesteps.
40/// - `dim`: embedding dimension (must be even).
41pub fn timestep_embedding(timesteps: &Tensor, dim: usize) -> Result<Tensor> {
42    let half = dim / 2;
43    let device = timesteps.device();
44    let dtype = candle_core::DType::F32;
45
46    // freq = exp(-ln(10000) * i / half)  for i in 0..half
47    let exponent = (Tensor::arange(0u32, half as u32, device)?.to_dtype(dtype)?
48        * (-f64::ln(10_000.0) / half as f64))?;
49    let freqs = exponent.exp()?;
50
51    // timesteps might be float already; ensure f32
52    let ts = timesteps.to_dtype(dtype)?.unsqueeze(1)?; // (B, 1)
53    let args = ts.broadcast_mul(&freqs.unsqueeze(0)?)?; // (B, half)
54
55    let cos = args.cos()?;
56    let sin = args.sin()?;
57    Tensor::cat(&[cos, sin], 1) // (B, dim)
58}
59
60/// Timestep-embedding MLP (projects sinusoidal embeddings to a wider space).
61#[derive(Debug)]
62pub struct TimestepEmbedding {
63    linear1: nn::Linear,
64    linear2: nn::Linear,
65}
66
67impl TimestepEmbedding {
68    pub fn new(vs: nn::VarBuilder, in_dim: usize, out_dim: usize) -> Result<Self> {
69        let linear1 = nn::linear(in_dim, out_dim, vs.pp("linear_1"))?;
70        let linear2 = nn::linear(out_dim, out_dim, vs.pp("linear_2"))?;
71        Ok(Self { linear1, linear2 })
72    }
73}
74
75impl Module for TimestepEmbedding {
76    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
77        let h = self.linear1.forward(xs)?.silu()?;
78        self.linear2.forward(&h)
79    }
80}