oxigaf_diffusion/
pipeline.rs1use std::path::Path;
7
8use candle_core::{DType, Device, Tensor};
9use candle_nn as nn;
10
11use crate::clip::{build_clip_encoder, ClipImageEncoder};
12use crate::config::DiffusionConfig;
13use crate::scheduler::{DdimScheduler, PredictionType};
14use crate::unet::MultiViewUNet;
15use crate::upsampler::LatentUpsampler;
16use crate::vae::Vae;
17use crate::DiffusionError;
18
19#[derive(Debug)]
21pub struct MultiViewOutput {
22 pub images: Vec<Tensor>,
24 pub width: u32,
26 pub height: u32,
28}
29
30pub struct MultiViewDiffusionPipeline {
32 unet: MultiViewUNet,
33 vae: Vae,
34 clip_encoder: ClipImageEncoder,
35 scheduler: DdimScheduler,
36 upsampler: Option<LatentUpsampler>,
37 config: DiffusionConfig,
38 device: Device,
39}
40
41impl MultiViewDiffusionPipeline {
42 pub fn load(
50 config: DiffusionConfig,
51 weights_dir: &Path,
52 device: &Device,
53 ) -> std::result::Result<Self, DiffusionError> {
54 let dtype = DType::F32;
55
56 let unet_path = weights_dir.join("unet/diffusion_pytorch_model.safetensors");
58 let unet_data = std::fs::read(&unet_path)
59 .map_err(|e| DiffusionError::ModelLoad(format!("Failed to read U-Net weights: {e}")))?;
60 let unet_vb = nn::VarBuilder::from_buffered_safetensors(unet_data, dtype, device)
61 .map_err(|e| DiffusionError::ModelLoad(format!("U-Net VarBuilder: {e}")))?;
62 let unet = MultiViewUNet::new(unet_vb, &config)
63 .map_err(|e| DiffusionError::ModelLoad(format!("U-Net build: {e}")))?;
64
65 let vae_path = weights_dir.join("vae/diffusion_pytorch_model.safetensors");
67 let vae_data = std::fs::read(&vae_path)
68 .map_err(|e| DiffusionError::ModelLoad(format!("Failed to read VAE weights: {e}")))?;
69 let vae_vb = nn::VarBuilder::from_buffered_safetensors(vae_data, dtype, device)
70 .map_err(|e| DiffusionError::ModelLoad(format!("VAE VarBuilder: {e}")))?;
71 let vae = Vae::new(vae_vb, config.latent_channels, config.vae_scale_factor)
72 .map_err(|e| DiffusionError::ModelLoad(format!("VAE build: {e}")))?;
73
74 let clip_path = weights_dir.join("image_encoder/model.safetensors");
76 let clip_data = std::fs::read(&clip_path)
77 .map_err(|e| DiffusionError::ModelLoad(format!("Failed to read CLIP weights: {e}")))?;
78 let clip_vb = nn::VarBuilder::from_buffered_safetensors(clip_data, dtype, device)
79 .map_err(|e| DiffusionError::ModelLoad(format!("CLIP VarBuilder: {e}")))?;
80 let clip_encoder = build_clip_encoder(clip_vb, &config)
81 .map_err(|e| DiffusionError::ModelLoad(format!("CLIP build: {e}")))?;
82
83 let scheduler = DdimScheduler::new(1000, PredictionType::VPrediction);
84
85 let upsampler = if let Some(mode) = config.upsampler_mode {
87 let upsampler_path = weights_dir.join("upsampler");
88 Some(LatentUpsampler::load(mode, &upsampler_path, device)?)
89 } else {
90 None
91 };
92
93 Ok(Self {
94 unet,
95 vae,
96 clip_encoder,
97 scheduler,
98 upsampler,
99 config,
100 device: device.clone(),
101 })
102 }
103
104 pub fn generate(
129 &mut self,
130 reference_image: &Tensor,
131 normal_map_latents: &Tensor,
132 camera_poses: &Tensor,
133 _seed: u64,
134 ) -> std::result::Result<MultiViewOutput, DiffusionError> {
135 let num_views = self.config.num_views;
136 let latent_size = self.config.latent_size;
137 let latent_ch = self.config.latent_channels;
138
139 if self.config.guidance_scale < 1.0 {
141 return Err(DiffusionError::Inference(format!(
142 "guidance_scale must be >= 1.0, got {}",
143 self.config.guidance_scale
144 )));
145 }
146
147 let ip_tokens = self
149 .clip_encoder
150 .forward(reference_image)
151 .map_err(|e| DiffusionError::Inference(format!("CLIP encode: {e}")))?;
152 let ip_tokens = ip_tokens
154 .repeat(&[num_views, 1, 1])
155 .map_err(|e| DiffusionError::Inference(format!("IP token expand: {e}")))?;
156
157 let null_context = Tensor::zeros(
159 (num_views, 77, self.config.cross_attention_dim),
160 DType::F32,
161 &self.device,
162 )
163 .map_err(|e| DiffusionError::Inference(format!("null context: {e}")))?;
164
165 let latent_shape = (num_views, latent_ch, latent_size, latent_size);
167 let mut latents = Tensor::randn(0f32, 1f32, latent_shape, &self.device)
168 .map_err(|e| DiffusionError::Inference(format!("noise init: {e}")))?;
169
170 self.scheduler
172 .set_timesteps(self.config.num_inference_steps);
173 let timesteps = self.scheduler.timesteps().to_vec();
174
175 for &t in ×teps {
180 let model_input = Tensor::cat(&[&latents, normal_map_latents], 1)
182 .map_err(|e| DiffusionError::Inference(format!("concat: {e}")))?;
183
184 let noise_pred_cond = self.unet.forward(
187 &model_input,
188 t,
189 Some(&null_context),
190 Some(camera_poses),
191 Some(&ip_tokens),
192 )?;
193
194 let noise_pred_uncond = self.unet.forward(
197 &model_input,
198 t,
199 Some(&null_context),
200 Some(camera_poses),
201 None, )?;
203
204 let diff = (&noise_pred_cond - &noise_pred_uncond)
207 .map_err(|e| DiffusionError::Inference(format!("CFG diff: {e}")))?;
208 let noise_pred = (&noise_pred_uncond + (diff * self.config.guidance_scale))
209 .map_err(|e| DiffusionError::Inference(format!("CFG combine: {e}")))?;
210
211 latents = self
213 .scheduler
214 .step(&noise_pred, t, &latents)
215 .map_err(|e| DiffusionError::Inference(format!("scheduler step: {e}")))?;
216 }
217
218 if let Some(ref mut upsampler) = self.upsampler {
220 latents = upsampler
221 .upsample(&latents, self.config.upsampler_steps)
222 .map_err(|e| DiffusionError::Inference(format!("Upsampler: {e}")))?;
223 }
224
225 let decoded = self
227 .vae
228 .decode(&latents)
229 .map_err(|e| DiffusionError::Inference(format!("VAE decode: {e}")))?;
230
231 let images = ((decoded + 1.0)
233 .map_err(|e| DiffusionError::Inference(format!("post +1: {e}")))?
234 * 0.5)
235 .map_err(|e| DiffusionError::Inference(format!("post *0.5: {e}")))?
236 .clamp(0.0, 1.0)
237 .map_err(|e| DiffusionError::Inference(format!("clamp: {e}")))?;
238
239 let mut view_images = Vec::with_capacity(num_views);
241 for i in 0..num_views {
242 let img = images
243 .narrow(0, i, 1)
244 .and_then(|t| t.squeeze(0))
245 .map_err(|e| DiffusionError::Inference(format!("split view {i}: {e}")))?;
246 view_images.push(img);
247 }
248
249 let size = if self.upsampler.is_some() {
251 self.config.image_size as u32 * 2 } else {
253 self.config.image_size as u32 };
255 Ok(MultiViewOutput {
256 images: view_images,
257 width: size,
258 height: size,
259 })
260 }
261}