oxigaf_diffusion/
camera.rs1use candle_core::{Result, Tensor};
8use candle_nn as nn;
9use candle_nn::Module;
10
11#[derive(Debug)]
13pub struct CameraEmbedding {
14 linear1: nn::Linear,
15 linear2: nn::Linear,
16}
17
18impl CameraEmbedding {
19 pub fn new(vs: nn::VarBuilder, pose_dim: usize, embed_dim: usize) -> Result<Self> {
24 let linear1 = nn::linear(pose_dim, embed_dim, vs.pp("linear1"))?;
25 let linear2 = nn::linear(embed_dim, embed_dim, vs.pp("linear2"))?;
26 Ok(Self { linear1, linear2 })
27 }
28}
29
30impl Module for CameraEmbedding {
31 fn forward(&self, pose: &Tensor) -> Result<Tensor> {
32 let h = self.linear1.forward(pose)?.silu()?;
33 self.linear2.forward(&h)
34 }
35}
36
37pub fn timestep_embedding(timesteps: &Tensor, dim: usize) -> Result<Tensor> {
42 let half = dim / 2;
43 let device = timesteps.device();
44 let dtype = candle_core::DType::F32;
45
46 let exponent = (Tensor::arange(0u32, half as u32, device)?.to_dtype(dtype)?
48 * (-f64::ln(10_000.0) / half as f64))?;
49 let freqs = exponent.exp()?;
50
51 let ts = timesteps.to_dtype(dtype)?.unsqueeze(1)?; let args = ts.broadcast_mul(&freqs.unsqueeze(0)?)?; let cos = args.cos()?;
56 let sin = args.sin()?;
57 Tensor::cat(&[cos, sin], 1) }
59
60#[derive(Debug)]
62pub struct TimestepEmbedding {
63 linear1: nn::Linear,
64 linear2: nn::Linear,
65}
66
67impl TimestepEmbedding {
68 pub fn new(vs: nn::VarBuilder, in_dim: usize, out_dim: usize) -> Result<Self> {
69 let linear1 = nn::linear(in_dim, out_dim, vs.pp("linear_1"))?;
70 let linear2 = nn::linear(out_dim, out_dim, vs.pp("linear_2"))?;
71 Ok(Self { linear1, linear2 })
72 }
73}
74
75impl Module for TimestepEmbedding {
76 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
77 let h = self.linear1.forward(xs)?.silu()?;
78 self.linear2.forward(&h)
79 }
80}