Skip to main content

oxigaf_diffusion/
pipeline.rs

1//! Full multi-view diffusion pipeline.
2//!
3//! Orchestrates the CLIP encoder, U-Net, VAE, and DDIM scheduler to
4//! generate multi-view images from a single reference photo and camera poses.
5
6use std::path::Path;
7
8use candle_core::{DType, Device, Tensor};
9use candle_nn as nn;
10
11use crate::clip::{build_clip_encoder, ClipImageEncoder};
12use crate::config::DiffusionConfig;
13use crate::scheduler::{DdimScheduler, PredictionType};
14use crate::unet::MultiViewUNet;
15use crate::upsampler::LatentUpsampler;
16use crate::vae::Vae;
17use crate::DiffusionError;
18
19/// Output of the multi-view diffusion pipeline.
20#[derive(Debug)]
21pub struct MultiViewOutput {
22    /// Generated images, one per view, as `(3, H, W)` tensors in `[0, 1]`.
23    pub images: Vec<Tensor>,
24    /// Width of each generated image.
25    pub width: u32,
26    /// Height of each generated image.
27    pub height: u32,
28}
29
30/// The full multi-view diffusion pipeline.
31pub struct MultiViewDiffusionPipeline {
32    unet: MultiViewUNet,
33    vae: Vae,
34    clip_encoder: ClipImageEncoder,
35    scheduler: DdimScheduler,
36    upsampler: Option<LatentUpsampler>,
37    config: DiffusionConfig,
38    device: Device,
39}
40
41impl MultiViewDiffusionPipeline {
42    /// Load a pipeline from a directory of safetensors files.
43    ///
44    /// Expected files:
45    /// - `unet/diffusion_pytorch_model.safetensors`
46    /// - `vae/diffusion_pytorch_model.safetensors`
47    /// - `image_encoder/model.safetensors`
48    /// - `upsampler/diffusion_pytorch_model.safetensors` (optional, for SdX2 mode)
49    pub fn load(
50        config: DiffusionConfig,
51        weights_dir: &Path,
52        device: &Device,
53    ) -> std::result::Result<Self, DiffusionError> {
54        let dtype = DType::F32;
55
56        // Load U-Net weights
57        let unet_path = weights_dir.join("unet/diffusion_pytorch_model.safetensors");
58        let unet_data = std::fs::read(&unet_path)
59            .map_err(|e| DiffusionError::ModelLoad(format!("Failed to read U-Net weights: {e}")))?;
60        let unet_vb = nn::VarBuilder::from_buffered_safetensors(unet_data, dtype, device)
61            .map_err(|e| DiffusionError::ModelLoad(format!("U-Net VarBuilder: {e}")))?;
62        let unet = MultiViewUNet::new(unet_vb, &config)
63            .map_err(|e| DiffusionError::ModelLoad(format!("U-Net build: {e}")))?;
64
65        // Load VAE weights
66        let vae_path = weights_dir.join("vae/diffusion_pytorch_model.safetensors");
67        let vae_data = std::fs::read(&vae_path)
68            .map_err(|e| DiffusionError::ModelLoad(format!("Failed to read VAE weights: {e}")))?;
69        let vae_vb = nn::VarBuilder::from_buffered_safetensors(vae_data, dtype, device)
70            .map_err(|e| DiffusionError::ModelLoad(format!("VAE VarBuilder: {e}")))?;
71        let vae = Vae::new(vae_vb, config.latent_channels, config.vae_scale_factor)
72            .map_err(|e| DiffusionError::ModelLoad(format!("VAE build: {e}")))?;
73
74        // Load CLIP image encoder weights
75        let clip_path = weights_dir.join("image_encoder/model.safetensors");
76        let clip_data = std::fs::read(&clip_path)
77            .map_err(|e| DiffusionError::ModelLoad(format!("Failed to read CLIP weights: {e}")))?;
78        let clip_vb = nn::VarBuilder::from_buffered_safetensors(clip_data, dtype, device)
79            .map_err(|e| DiffusionError::ModelLoad(format!("CLIP VarBuilder: {e}")))?;
80        let clip_encoder = build_clip_encoder(clip_vb, &config)
81            .map_err(|e| DiffusionError::ModelLoad(format!("CLIP build: {e}")))?;
82
83        let scheduler = DdimScheduler::new(1000, PredictionType::VPrediction);
84
85        // Load upsampler if configured
86        let upsampler = if let Some(mode) = config.upsampler_mode {
87            let upsampler_path = weights_dir.join("upsampler");
88            Some(LatentUpsampler::load(mode, &upsampler_path, device)?)
89        } else {
90            None
91        };
92
93        Ok(Self {
94            unet,
95            vae,
96            clip_encoder,
97            scheduler,
98            upsampler,
99            config,
100            device: device.clone(),
101        })
102    }
103
104    /// Generate multi-view images from a reference image and camera poses.
105    ///
106    /// - `reference_image`: `(1, 3, 224, 224)` normalised image for CLIP.
107    /// - `normal_map_latents`: `(num_views, latent_channels, h, w)` encoded normal maps.
108    /// - `camera_poses`: `(num_views, pose_dim)` flattened extrinsics per view.
109    /// - `seed`: RNG seed for reproducibility.
110    ///
111    /// # Classifier-Free Guidance (CFG)
112    ///
113    /// This pipeline implements CFG for IP-Adapter conditioning:
114    /// - **Conditional pass**: Uses IP tokens from CLIP-encoded reference image
115    /// - **Unconditional pass**: Skips IP tokens (no reference conditioning)
116    /// - **Formula**: `pred = uncond + guidance_scale * (cond - uncond)`
117    ///
118    /// The `guidance_scale` parameter (from config) controls the strength of
119    /// conditioning. Typical values:
120    /// - `1.0` = no guidance (unconditional generation)
121    /// - `3.0-7.5` = balanced (default: 3.0 for GAF)
122    /// - `>10.0` = strong conditioning (may oversaturate)
123    ///
124    /// # Errors
125    ///
126    /// Returns `DiffusionError::Inference` if guidance_scale < 1.0 or if any
127    /// tensor operation fails during generation.
128    pub fn generate(
129        &mut self,
130        reference_image: &Tensor,
131        normal_map_latents: &Tensor,
132        camera_poses: &Tensor,
133        _seed: u64,
134    ) -> std::result::Result<MultiViewOutput, DiffusionError> {
135        let num_views = self.config.num_views;
136        let latent_size = self.config.latent_size;
137        let latent_ch = self.config.latent_channels;
138
139        // Validate guidance_scale
140        if self.config.guidance_scale < 1.0 {
141            return Err(DiffusionError::Inference(format!(
142                "guidance_scale must be >= 1.0, got {}",
143                self.config.guidance_scale
144            )));
145        }
146
147        // 1. Encode reference image with CLIP for IP-Adapter conditioning
148        let ip_tokens = self
149            .clip_encoder
150            .forward(reference_image)
151            .map_err(|e| DiffusionError::Inference(format!("CLIP encode: {e}")))?;
152        // Expand to all views: (1, seq, dim) -> (V, seq, dim)
153        let ip_tokens = ip_tokens
154            .repeat(&[num_views, 1, 1])
155            .map_err(|e| DiffusionError::Inference(format!("IP token expand: {e}")))?;
156
157        // 2. Prepare null text embedding (GAF doesn't use text conditioning)
158        let null_context = Tensor::zeros(
159            (num_views, 77, self.config.cross_attention_dim),
160            DType::F32,
161            &self.device,
162        )
163        .map_err(|e| DiffusionError::Inference(format!("null context: {e}")))?;
164
165        // 3. Prepare initial noise
166        let latent_shape = (num_views, latent_ch, latent_size, latent_size);
167        let mut latents = Tensor::randn(0f32, 1f32, latent_shape, &self.device)
168            .map_err(|e| DiffusionError::Inference(format!("noise init: {e}")))?;
169
170        // 4. Set scheduler timesteps
171        self.scheduler
172            .set_timesteps(self.config.num_inference_steps);
173        let timesteps = self.scheduler.timesteps().to_vec();
174
175        // 5. Denoising loop with Classifier-Free Guidance (CFG)
176        // We use separate forward passes for conditional and unconditional to
177        // simplify implementation and avoid tensor concatenation issues with
178        // IP-Adapter attention (which needs different shapes for cond/uncond).
179        for &t in &timesteps {
180            // Concatenate noise latents with normal-map latents
181            let model_input = Tensor::cat(&[&latents, normal_map_latents], 1)
182                .map_err(|e| DiffusionError::Inference(format!("concat: {e}")))?;
183
184            // Forward pass 1: Conditional (with IP-Adapter tokens)
185            // This provides identity-preserving conditioning from the reference image
186            let noise_pred_cond = self.unet.forward(
187                &model_input,
188                t,
189                Some(&null_context),
190                Some(camera_poses),
191                Some(&ip_tokens),
192            )?;
193
194            // Forward pass 2: Unconditional (without IP-Adapter tokens)
195            // This provides the baseline without reference conditioning
196            let noise_pred_uncond = self.unet.forward(
197                &model_input,
198                t,
199                Some(&null_context),
200                Some(camera_poses),
201                None, // Skip IP tokens for unconditional
202            )?;
203
204            // Apply CFG formula: pred = uncond + scale * (cond - uncond)
205            // This interpolates between unconditional and conditional predictions
206            let diff = (&noise_pred_cond - &noise_pred_uncond)
207                .map_err(|e| DiffusionError::Inference(format!("CFG diff: {e}")))?;
208            let noise_pred = (&noise_pred_uncond + (diff * self.config.guidance_scale))
209                .map_err(|e| DiffusionError::Inference(format!("CFG combine: {e}")))?;
210
211            // Scheduler step
212            latents = self
213                .scheduler
214                .step(&noise_pred, t, &latents)
215                .map_err(|e| DiffusionError::Inference(format!("scheduler step: {e}")))?;
216        }
217
218        // 6. Upsample latents if configured (32×32 → 64×64)
219        if let Some(ref mut upsampler) = self.upsampler {
220            latents = upsampler
221                .upsample(&latents, self.config.upsampler_steps)
222                .map_err(|e| DiffusionError::Inference(format!("Upsampler: {e}")))?;
223        }
224
225        // 8. Decode latents with VAE
226        let decoded = self
227            .vae
228            .decode(&latents)
229            .map_err(|e| DiffusionError::Inference(format!("VAE decode: {e}")))?;
230
231        // 9. Post-process: clamp to [0, 1]
232        let images = ((decoded + 1.0)
233            .map_err(|e| DiffusionError::Inference(format!("post +1: {e}")))?
234            * 0.5)
235            .map_err(|e| DiffusionError::Inference(format!("post *0.5: {e}")))?
236            .clamp(0.0, 1.0)
237            .map_err(|e| DiffusionError::Inference(format!("clamp: {e}")))?;
238
239        // Split into per-view tensors
240        let mut view_images = Vec::with_capacity(num_views);
241        for i in 0..num_views {
242            let img = images
243                .narrow(0, i, 1)
244                .and_then(|t| t.squeeze(0))
245                .map_err(|e| DiffusionError::Inference(format!("split view {i}: {e}")))?;
246            view_images.push(img);
247        }
248
249        // Calculate output size based on whether upsampling was used
250        let size = if self.upsampler.is_some() {
251            self.config.image_size as u32 * 2 // 512×512 with upsampling
252        } else {
253            self.config.image_size as u32 // 256×256 without upsampling
254        };
255        Ok(MultiViewOutput {
256            images: view_images,
257            width: size,
258            height: size,
259        })
260    }
261}