Skip to main content

candle_transformers/models/stable_diffusion/
mod.rs

1//! Stable Diffusion
2//!
3//! Stable Diffusion is a latent text-to-image diffusion model capable of
4//! generating photo-realistic images given any text input.
5//!
6//! - 💻 [Original Repository](https://github.com/CompVis/stable-diffusion)
7//! - 🤗 [Hugging Face](https://huggingface.co/runwayml/stable-diffusion-v1-5)
8//! - The default scheduler for the v1.5, v2.1 and XL 1.0 version is the Denoising Diffusion Implicit Model scheduler (DDIM). The original paper and some code can be found in the [associated repo](https://github.com/ermongroup/ddim). The default scheduler for the XL Turbo version is the Euler Ancestral scheduler.
9//!
10//!
11//! # Example
12//!
13//! <div align=center>
14//!   <img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg" alt="rusty robot holding a candle" width=320>
15//! </div>
16//!
17//! _"A rusty robot holding a fire torch in its hand."_ Generated by Stable Diffusion XL using Rust and [candle](https://github.com/huggingface/candle).
18//!
19//! ```bash
20//! # example running with cuda
21//! # see the candle-examples/examples/stable-diffusion for all options
22//! cargo run --example stable-diffusion --release --features=cuda,cudnn \
23//!     -- --prompt "a cosmonaut on a horse (hd, realistic, high-def)"
24//!
25//! # with sd-turbo
26//! cargo run --example stable-diffusion --release --features=cuda,cudnn \
27//!     -- --prompt "a cosmonaut on a horse (hd, realistic, high-def)" \
28//!     --sd-version turbo
29//!
30//! # with flash attention.
31//! # feature flag: `--features flash-attn`
32//! # cli flag: `--use-flash-attn`.
33//! # flash-attention-v2 is only compatible with Ampere, Ada, \
34//! # or Hopper GPUs (e.g., A100/H100, RTX 3090/4090).
35//! cargo run --example stable-diffusion --release --features=cuda,cudnn \
36//!     -- --prompt "a cosmonaut on a horse (hd, realistic, high-def)" \
37//!     --use-flash-attn
38//! ```
39
40pub mod attention;
41pub mod clip;
42pub mod ddim;
43pub mod ddpm;
44pub mod embeddings;
45pub mod euler_ancestral_discrete;
46pub mod resnet;
47pub mod schedulers;
48pub mod unet_2d;
49pub mod unet_2d_blocks;
50pub mod uni_pc;
51pub mod utils;
52pub mod vae;
53
54use std::sync::Arc;
55
56use candle::{DType, Device, Result};
57use candle_nn as nn;
58
59use self::schedulers::{Scheduler, SchedulerConfig};
60
61#[derive(Clone, Debug)]
62pub struct StableDiffusionConfig {
63    pub width: usize,
64    pub height: usize,
65    pub clip: clip::Config,
66    pub clip2: Option<clip::Config>,
67    autoencoder: vae::AutoEncoderKLConfig,
68    unet: unet_2d::UNet2DConditionModelConfig,
69    scheduler: Arc<dyn SchedulerConfig>,
70}
71
72impl StableDiffusionConfig {
73    pub fn v1_5(
74        sliced_attention_size: Option<usize>,
75        height: Option<usize>,
76        width: Option<usize>,
77    ) -> Self {
78        let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
79            out_channels,
80            use_cross_attn,
81            attention_head_dim,
82        };
83        // https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/unet/config.json
84        let unet = unet_2d::UNet2DConditionModelConfig {
85            blocks: vec![
86                bc(320, Some(1), 8),
87                bc(640, Some(1), 8),
88                bc(1280, Some(1), 8),
89                bc(1280, None, 8),
90            ],
91            center_input_sample: false,
92            cross_attention_dim: 768,
93            downsample_padding: 1,
94            flip_sin_to_cos: true,
95            freq_shift: 0.,
96            layers_per_block: 2,
97            mid_block_scale_factor: 1.,
98            norm_eps: 1e-5,
99            norm_num_groups: 32,
100            sliced_attention_size,
101            use_linear_projection: false,
102        };
103        let autoencoder = vae::AutoEncoderKLConfig {
104            block_out_channels: vec![128, 256, 512, 512],
105            layers_per_block: 2,
106            latent_channels: 4,
107            norm_num_groups: 32,
108            use_quant_conv: true,
109            use_post_quant_conv: true,
110        };
111        let height = if let Some(height) = height {
112            assert_eq!(height % 8, 0, "height has to be divisible by 8");
113            height
114        } else {
115            512
116        };
117
118        let width = if let Some(width) = width {
119            assert_eq!(width % 8, 0, "width has to be divisible by 8");
120            width
121        } else {
122            512
123        };
124
125        let scheduler = Arc::new(ddim::DDIMSchedulerConfig {
126            prediction_type: schedulers::PredictionType::Epsilon,
127            ..Default::default()
128        });
129
130        StableDiffusionConfig {
131            width,
132            height,
133            clip: clip::Config::v1_5(),
134            clip2: None,
135            autoencoder,
136            scheduler,
137            unet,
138        }
139    }
140
141    fn v2_1_(
142        sliced_attention_size: Option<usize>,
143        height: Option<usize>,
144        width: Option<usize>,
145        prediction_type: schedulers::PredictionType,
146    ) -> Self {
147        let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
148            out_channels,
149            use_cross_attn,
150            attention_head_dim,
151        };
152        // https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/unet/config.json
153        let unet = unet_2d::UNet2DConditionModelConfig {
154            blocks: vec![
155                bc(320, Some(1), 5),
156                bc(640, Some(1), 10),
157                bc(1280, Some(1), 20),
158                bc(1280, None, 20),
159            ],
160            center_input_sample: false,
161            cross_attention_dim: 1024,
162            downsample_padding: 1,
163            flip_sin_to_cos: true,
164            freq_shift: 0.,
165            layers_per_block: 2,
166            mid_block_scale_factor: 1.,
167            norm_eps: 1e-5,
168            norm_num_groups: 32,
169            sliced_attention_size,
170            use_linear_projection: true,
171        };
172        // https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/vae/config.json
173        let autoencoder = vae::AutoEncoderKLConfig {
174            block_out_channels: vec![128, 256, 512, 512],
175            layers_per_block: 2,
176            latent_channels: 4,
177            norm_num_groups: 32,
178            use_quant_conv: true,
179            use_post_quant_conv: true,
180        };
181        let scheduler = Arc::new(ddim::DDIMSchedulerConfig {
182            prediction_type,
183            ..Default::default()
184        });
185
186        let height = if let Some(height) = height {
187            assert_eq!(height % 8, 0, "height has to be divisible by 8");
188            height
189        } else {
190            768
191        };
192
193        let width = if let Some(width) = width {
194            assert_eq!(width % 8, 0, "width has to be divisible by 8");
195            width
196        } else {
197            768
198        };
199
200        StableDiffusionConfig {
201            width,
202            height,
203            clip: clip::Config::v2_1(),
204            clip2: None,
205            autoencoder,
206            scheduler,
207            unet,
208        }
209    }
210
211    pub fn v2_1(
212        sliced_attention_size: Option<usize>,
213        height: Option<usize>,
214        width: Option<usize>,
215    ) -> Self {
216        // https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/scheduler/scheduler_config.json
217        Self::v2_1_(
218            sliced_attention_size,
219            height,
220            width,
221            schedulers::PredictionType::VPrediction,
222        )
223    }
224
225    fn sdxl_(
226        sliced_attention_size: Option<usize>,
227        height: Option<usize>,
228        width: Option<usize>,
229        prediction_type: schedulers::PredictionType,
230    ) -> Self {
231        let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
232            out_channels,
233            use_cross_attn,
234            attention_head_dim,
235        };
236        // https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/unet/config.json
237        let unet = unet_2d::UNet2DConditionModelConfig {
238            blocks: vec![
239                bc(320, None, 5),
240                bc(640, Some(2), 10),
241                bc(1280, Some(10), 20),
242            ],
243            center_input_sample: false,
244            cross_attention_dim: 2048,
245            downsample_padding: 1,
246            flip_sin_to_cos: true,
247            freq_shift: 0.,
248            layers_per_block: 2,
249            mid_block_scale_factor: 1.,
250            norm_eps: 1e-5,
251            norm_num_groups: 32,
252            sliced_attention_size,
253            use_linear_projection: true,
254        };
255        // https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/vae/config.json
256        let autoencoder = vae::AutoEncoderKLConfig {
257            block_out_channels: vec![128, 256, 512, 512],
258            layers_per_block: 2,
259            latent_channels: 4,
260            norm_num_groups: 32,
261            use_quant_conv: true,
262            use_post_quant_conv: true,
263        };
264        let scheduler = Arc::new(ddim::DDIMSchedulerConfig {
265            prediction_type,
266            ..Default::default()
267        });
268
269        let height = if let Some(height) = height {
270            assert_eq!(height % 8, 0, "height has to be divisible by 8");
271            height
272        } else {
273            1024
274        };
275
276        let width = if let Some(width) = width {
277            assert_eq!(width % 8, 0, "width has to be divisible by 8");
278            width
279        } else {
280            1024
281        };
282
283        StableDiffusionConfig {
284            width,
285            height,
286            clip: clip::Config::sdxl(),
287            clip2: Some(clip::Config::sdxl2()),
288            autoencoder,
289            scheduler,
290            unet,
291        }
292    }
293
294    fn sdxl_turbo_(
295        sliced_attention_size: Option<usize>,
296        height: Option<usize>,
297        width: Option<usize>,
298        prediction_type: schedulers::PredictionType,
299    ) -> Self {
300        let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
301            out_channels,
302            use_cross_attn,
303            attention_head_dim,
304        };
305        // https://huggingface.co/stabilityai/sdxl-turbo/blob/main/unet/config.json
306        let unet = unet_2d::UNet2DConditionModelConfig {
307            blocks: vec![
308                bc(320, None, 5),
309                bc(640, Some(2), 10),
310                bc(1280, Some(10), 20),
311            ],
312            center_input_sample: false,
313            cross_attention_dim: 2048,
314            downsample_padding: 1,
315            flip_sin_to_cos: true,
316            freq_shift: 0.,
317            layers_per_block: 2,
318            mid_block_scale_factor: 1.,
319            norm_eps: 1e-5,
320            norm_num_groups: 32,
321            sliced_attention_size,
322            use_linear_projection: true,
323        };
324        // https://huggingface.co/stabilityai/sdxl-turbo/blob/main/vae/config.json
325        let autoencoder = vae::AutoEncoderKLConfig {
326            block_out_channels: vec![128, 256, 512, 512],
327            layers_per_block: 2,
328            latent_channels: 4,
329            norm_num_groups: 32,
330            use_quant_conv: true,
331            use_post_quant_conv: true,
332        };
333        let scheduler = Arc::new(
334            euler_ancestral_discrete::EulerAncestralDiscreteSchedulerConfig {
335                prediction_type,
336                timestep_spacing: schedulers::TimestepSpacing::Trailing,
337                ..Default::default()
338            },
339        );
340
341        let height = if let Some(height) = height {
342            assert_eq!(height % 8, 0, "height has to be divisible by 8");
343            height
344        } else {
345            512
346        };
347
348        let width = if let Some(width) = width {
349            assert_eq!(width % 8, 0, "width has to be divisible by 8");
350            width
351        } else {
352            512
353        };
354
355        Self {
356            width,
357            height,
358            clip: clip::Config::sdxl(),
359            clip2: Some(clip::Config::sdxl2()),
360            autoencoder,
361            scheduler,
362            unet,
363        }
364    }
365
366    pub fn sdxl(
367        sliced_attention_size: Option<usize>,
368        height: Option<usize>,
369        width: Option<usize>,
370    ) -> Self {
371        Self::sdxl_(
372            sliced_attention_size,
373            height,
374            width,
375            // https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/scheduler/scheduler_config.json
376            schedulers::PredictionType::Epsilon,
377        )
378    }
379
380    pub fn sdxl_turbo(
381        sliced_attention_size: Option<usize>,
382        height: Option<usize>,
383        width: Option<usize>,
384    ) -> Self {
385        Self::sdxl_turbo_(
386            sliced_attention_size,
387            height,
388            width,
389            // https://huggingface.co/stabilityai/sdxl-turbo/blob/main/scheduler/scheduler_config.json
390            schedulers::PredictionType::Epsilon,
391        )
392    }
393
394    pub fn ssd1b(
395        sliced_attention_size: Option<usize>,
396        height: Option<usize>,
397        width: Option<usize>,
398    ) -> Self {
399        let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
400            out_channels,
401            use_cross_attn,
402            attention_head_dim,
403        };
404        // https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/unet/config.json
405        let unet = unet_2d::UNet2DConditionModelConfig {
406            blocks: vec![
407                bc(320, None, 5),
408                bc(640, Some(2), 10),
409                bc(1280, Some(10), 20),
410            ],
411            center_input_sample: false,
412            cross_attention_dim: 2048,
413            downsample_padding: 1,
414            flip_sin_to_cos: true,
415            freq_shift: 0.,
416            layers_per_block: 2,
417            mid_block_scale_factor: 1.,
418            norm_eps: 1e-5,
419            norm_num_groups: 32,
420            sliced_attention_size,
421            use_linear_projection: true,
422        };
423        // https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/vae/config.json
424        let autoencoder = vae::AutoEncoderKLConfig {
425            block_out_channels: vec![128, 256, 512, 512],
426            layers_per_block: 2,
427            latent_channels: 4,
428            norm_num_groups: 32,
429            use_quant_conv: true,
430            use_post_quant_conv: true,
431        };
432        let scheduler = Arc::new(ddim::DDIMSchedulerConfig {
433            ..Default::default()
434        });
435
436        let height = if let Some(height) = height {
437            assert_eq!(height % 8, 0, "height has to be divisible by 8");
438            height
439        } else {
440            1024
441        };
442
443        let width = if let Some(width) = width {
444            assert_eq!(width % 8, 0, "width has to be divisible by 8");
445            width
446        } else {
447            1024
448        };
449
450        Self {
451            width,
452            height,
453            clip: clip::Config::ssd1b(),
454            clip2: Some(clip::Config::ssd1b2()),
455            autoencoder,
456            scheduler,
457            unet,
458        }
459    }
460
461    pub fn build_vae<P: AsRef<std::path::Path>>(
462        &self,
463        vae_weights: P,
464        device: &Device,
465        dtype: DType,
466    ) -> Result<vae::AutoEncoderKL> {
467        let vs_ae =
468            unsafe { nn::VarBuilder::from_mmaped_safetensors(&[vae_weights], dtype, device)? };
469        // https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/vae/config.json
470        let autoencoder = vae::AutoEncoderKL::new(vs_ae, 3, 3, self.autoencoder.clone())?;
471        Ok(autoencoder)
472    }
473
474    pub fn build_unet<P: AsRef<std::path::Path>>(
475        &self,
476        unet_weights: P,
477        device: &Device,
478        in_channels: usize,
479        use_flash_attn: bool,
480        dtype: DType,
481    ) -> Result<unet_2d::UNet2DConditionModel> {
482        let vs_unet =
483            unsafe { nn::VarBuilder::from_mmaped_safetensors(&[unet_weights], dtype, device)? };
484        let unet = unet_2d::UNet2DConditionModel::new(
485            vs_unet,
486            in_channels,
487            4,
488            use_flash_attn,
489            self.unet.clone(),
490        )?;
491        Ok(unet)
492    }
493
494    pub fn build_unet_sharded<P: AsRef<std::path::Path>>(
495        &self,
496        unet_weight_files: &[P],
497        device: &Device,
498        in_channels: usize,
499        use_flash_attn: bool,
500        dtype: DType,
501    ) -> Result<unet_2d::UNet2DConditionModel> {
502        let vs_unet =
503            unsafe { nn::VarBuilder::from_mmaped_safetensors(unet_weight_files, dtype, device)? };
504        unet_2d::UNet2DConditionModel::new(
505            vs_unet,
506            in_channels,
507            4,
508            use_flash_attn,
509            self.unet.clone(),
510        )
511    }
512
513    pub fn build_scheduler(&self, n_steps: usize) -> Result<Box<dyn Scheduler>> {
514        self.scheduler.build(n_steps)
515    }
516}
517
518pub fn build_clip_transformer<P: AsRef<std::path::Path>>(
519    clip: &clip::Config,
520    clip_weights: P,
521    device: &Device,
522    dtype: DType,
523) -> Result<clip::ClipTextTransformer> {
524    let vs = unsafe { nn::VarBuilder::from_mmaped_safetensors(&[clip_weights], dtype, device)? };
525    let text_model = clip::ClipTextTransformer::new(vs, clip)?;
526    Ok(text_model)
527}