1pub mod attention;
41pub mod clip;
42pub mod ddim;
43pub mod ddpm;
44pub mod embeddings;
45pub mod euler_ancestral_discrete;
46pub mod resnet;
47pub mod schedulers;
48pub mod unet_2d;
49pub mod unet_2d_blocks;
50pub mod uni_pc;
51pub mod utils;
52pub mod vae;
53
54use std::sync::Arc;
55
56use candle::{DType, Device, Result};
57use candle_nn as nn;
58
59use self::schedulers::{Scheduler, SchedulerConfig};
60
61#[derive(Clone, Debug)]
62pub struct StableDiffusionConfig {
63 pub width: usize,
64 pub height: usize,
65 pub clip: clip::Config,
66 pub clip2: Option<clip::Config>,
67 autoencoder: vae::AutoEncoderKLConfig,
68 unet: unet_2d::UNet2DConditionModelConfig,
69 scheduler: Arc<dyn SchedulerConfig>,
70}
71
72impl StableDiffusionConfig {
73 pub fn v1_5(
74 sliced_attention_size: Option<usize>,
75 height: Option<usize>,
76 width: Option<usize>,
77 ) -> Self {
78 let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
79 out_channels,
80 use_cross_attn,
81 attention_head_dim,
82 };
83 let unet = unet_2d::UNet2DConditionModelConfig {
85 blocks: vec![
86 bc(320, Some(1), 8),
87 bc(640, Some(1), 8),
88 bc(1280, Some(1), 8),
89 bc(1280, None, 8),
90 ],
91 center_input_sample: false,
92 cross_attention_dim: 768,
93 downsample_padding: 1,
94 flip_sin_to_cos: true,
95 freq_shift: 0.,
96 layers_per_block: 2,
97 mid_block_scale_factor: 1.,
98 norm_eps: 1e-5,
99 norm_num_groups: 32,
100 sliced_attention_size,
101 use_linear_projection: false,
102 };
103 let autoencoder = vae::AutoEncoderKLConfig {
104 block_out_channels: vec![128, 256, 512, 512],
105 layers_per_block: 2,
106 latent_channels: 4,
107 norm_num_groups: 32,
108 use_quant_conv: true,
109 use_post_quant_conv: true,
110 };
111 let height = if let Some(height) = height {
112 assert_eq!(height % 8, 0, "height has to be divisible by 8");
113 height
114 } else {
115 512
116 };
117
118 let width = if let Some(width) = width {
119 assert_eq!(width % 8, 0, "width has to be divisible by 8");
120 width
121 } else {
122 512
123 };
124
125 let scheduler = Arc::new(ddim::DDIMSchedulerConfig {
126 prediction_type: schedulers::PredictionType::Epsilon,
127 ..Default::default()
128 });
129
130 StableDiffusionConfig {
131 width,
132 height,
133 clip: clip::Config::v1_5(),
134 clip2: None,
135 autoencoder,
136 scheduler,
137 unet,
138 }
139 }
140
141 fn v2_1_(
142 sliced_attention_size: Option<usize>,
143 height: Option<usize>,
144 width: Option<usize>,
145 prediction_type: schedulers::PredictionType,
146 ) -> Self {
147 let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
148 out_channels,
149 use_cross_attn,
150 attention_head_dim,
151 };
152 let unet = unet_2d::UNet2DConditionModelConfig {
154 blocks: vec![
155 bc(320, Some(1), 5),
156 bc(640, Some(1), 10),
157 bc(1280, Some(1), 20),
158 bc(1280, None, 20),
159 ],
160 center_input_sample: false,
161 cross_attention_dim: 1024,
162 downsample_padding: 1,
163 flip_sin_to_cos: true,
164 freq_shift: 0.,
165 layers_per_block: 2,
166 mid_block_scale_factor: 1.,
167 norm_eps: 1e-5,
168 norm_num_groups: 32,
169 sliced_attention_size,
170 use_linear_projection: true,
171 };
172 let autoencoder = vae::AutoEncoderKLConfig {
174 block_out_channels: vec![128, 256, 512, 512],
175 layers_per_block: 2,
176 latent_channels: 4,
177 norm_num_groups: 32,
178 use_quant_conv: true,
179 use_post_quant_conv: true,
180 };
181 let scheduler = Arc::new(ddim::DDIMSchedulerConfig {
182 prediction_type,
183 ..Default::default()
184 });
185
186 let height = if let Some(height) = height {
187 assert_eq!(height % 8, 0, "height has to be divisible by 8");
188 height
189 } else {
190 768
191 };
192
193 let width = if let Some(width) = width {
194 assert_eq!(width % 8, 0, "width has to be divisible by 8");
195 width
196 } else {
197 768
198 };
199
200 StableDiffusionConfig {
201 width,
202 height,
203 clip: clip::Config::v2_1(),
204 clip2: None,
205 autoencoder,
206 scheduler,
207 unet,
208 }
209 }
210
211 pub fn v2_1(
212 sliced_attention_size: Option<usize>,
213 height: Option<usize>,
214 width: Option<usize>,
215 ) -> Self {
216 Self::v2_1_(
218 sliced_attention_size,
219 height,
220 width,
221 schedulers::PredictionType::VPrediction,
222 )
223 }
224
225 fn sdxl_(
226 sliced_attention_size: Option<usize>,
227 height: Option<usize>,
228 width: Option<usize>,
229 prediction_type: schedulers::PredictionType,
230 ) -> Self {
231 let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
232 out_channels,
233 use_cross_attn,
234 attention_head_dim,
235 };
236 let unet = unet_2d::UNet2DConditionModelConfig {
238 blocks: vec![
239 bc(320, None, 5),
240 bc(640, Some(2), 10),
241 bc(1280, Some(10), 20),
242 ],
243 center_input_sample: false,
244 cross_attention_dim: 2048,
245 downsample_padding: 1,
246 flip_sin_to_cos: true,
247 freq_shift: 0.,
248 layers_per_block: 2,
249 mid_block_scale_factor: 1.,
250 norm_eps: 1e-5,
251 norm_num_groups: 32,
252 sliced_attention_size,
253 use_linear_projection: true,
254 };
255 let autoencoder = vae::AutoEncoderKLConfig {
257 block_out_channels: vec![128, 256, 512, 512],
258 layers_per_block: 2,
259 latent_channels: 4,
260 norm_num_groups: 32,
261 use_quant_conv: true,
262 use_post_quant_conv: true,
263 };
264 let scheduler = Arc::new(ddim::DDIMSchedulerConfig {
265 prediction_type,
266 ..Default::default()
267 });
268
269 let height = if let Some(height) = height {
270 assert_eq!(height % 8, 0, "height has to be divisible by 8");
271 height
272 } else {
273 1024
274 };
275
276 let width = if let Some(width) = width {
277 assert_eq!(width % 8, 0, "width has to be divisible by 8");
278 width
279 } else {
280 1024
281 };
282
283 StableDiffusionConfig {
284 width,
285 height,
286 clip: clip::Config::sdxl(),
287 clip2: Some(clip::Config::sdxl2()),
288 autoencoder,
289 scheduler,
290 unet,
291 }
292 }
293
294 fn sdxl_turbo_(
295 sliced_attention_size: Option<usize>,
296 height: Option<usize>,
297 width: Option<usize>,
298 prediction_type: schedulers::PredictionType,
299 ) -> Self {
300 let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
301 out_channels,
302 use_cross_attn,
303 attention_head_dim,
304 };
305 let unet = unet_2d::UNet2DConditionModelConfig {
307 blocks: vec![
308 bc(320, None, 5),
309 bc(640, Some(2), 10),
310 bc(1280, Some(10), 20),
311 ],
312 center_input_sample: false,
313 cross_attention_dim: 2048,
314 downsample_padding: 1,
315 flip_sin_to_cos: true,
316 freq_shift: 0.,
317 layers_per_block: 2,
318 mid_block_scale_factor: 1.,
319 norm_eps: 1e-5,
320 norm_num_groups: 32,
321 sliced_attention_size,
322 use_linear_projection: true,
323 };
324 let autoencoder = vae::AutoEncoderKLConfig {
326 block_out_channels: vec![128, 256, 512, 512],
327 layers_per_block: 2,
328 latent_channels: 4,
329 norm_num_groups: 32,
330 use_quant_conv: true,
331 use_post_quant_conv: true,
332 };
333 let scheduler = Arc::new(
334 euler_ancestral_discrete::EulerAncestralDiscreteSchedulerConfig {
335 prediction_type,
336 timestep_spacing: schedulers::TimestepSpacing::Trailing,
337 ..Default::default()
338 },
339 );
340
341 let height = if let Some(height) = height {
342 assert_eq!(height % 8, 0, "height has to be divisible by 8");
343 height
344 } else {
345 512
346 };
347
348 let width = if let Some(width) = width {
349 assert_eq!(width % 8, 0, "width has to be divisible by 8");
350 width
351 } else {
352 512
353 };
354
355 Self {
356 width,
357 height,
358 clip: clip::Config::sdxl(),
359 clip2: Some(clip::Config::sdxl2()),
360 autoencoder,
361 scheduler,
362 unet,
363 }
364 }
365
366 pub fn sdxl(
367 sliced_attention_size: Option<usize>,
368 height: Option<usize>,
369 width: Option<usize>,
370 ) -> Self {
371 Self::sdxl_(
372 sliced_attention_size,
373 height,
374 width,
375 schedulers::PredictionType::Epsilon,
377 )
378 }
379
380 pub fn sdxl_turbo(
381 sliced_attention_size: Option<usize>,
382 height: Option<usize>,
383 width: Option<usize>,
384 ) -> Self {
385 Self::sdxl_turbo_(
386 sliced_attention_size,
387 height,
388 width,
389 schedulers::PredictionType::Epsilon,
391 )
392 }
393
394 pub fn ssd1b(
395 sliced_attention_size: Option<usize>,
396 height: Option<usize>,
397 width: Option<usize>,
398 ) -> Self {
399 let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
400 out_channels,
401 use_cross_attn,
402 attention_head_dim,
403 };
404 let unet = unet_2d::UNet2DConditionModelConfig {
406 blocks: vec![
407 bc(320, None, 5),
408 bc(640, Some(2), 10),
409 bc(1280, Some(10), 20),
410 ],
411 center_input_sample: false,
412 cross_attention_dim: 2048,
413 downsample_padding: 1,
414 flip_sin_to_cos: true,
415 freq_shift: 0.,
416 layers_per_block: 2,
417 mid_block_scale_factor: 1.,
418 norm_eps: 1e-5,
419 norm_num_groups: 32,
420 sliced_attention_size,
421 use_linear_projection: true,
422 };
423 let autoencoder = vae::AutoEncoderKLConfig {
425 block_out_channels: vec![128, 256, 512, 512],
426 layers_per_block: 2,
427 latent_channels: 4,
428 norm_num_groups: 32,
429 use_quant_conv: true,
430 use_post_quant_conv: true,
431 };
432 let scheduler = Arc::new(ddim::DDIMSchedulerConfig {
433 ..Default::default()
434 });
435
436 let height = if let Some(height) = height {
437 assert_eq!(height % 8, 0, "height has to be divisible by 8");
438 height
439 } else {
440 1024
441 };
442
443 let width = if let Some(width) = width {
444 assert_eq!(width % 8, 0, "width has to be divisible by 8");
445 width
446 } else {
447 1024
448 };
449
450 Self {
451 width,
452 height,
453 clip: clip::Config::ssd1b(),
454 clip2: Some(clip::Config::ssd1b2()),
455 autoencoder,
456 scheduler,
457 unet,
458 }
459 }
460
461 pub fn build_vae<P: AsRef<std::path::Path>>(
462 &self,
463 vae_weights: P,
464 device: &Device,
465 dtype: DType,
466 ) -> Result<vae::AutoEncoderKL> {
467 let vs_ae =
468 unsafe { nn::VarBuilder::from_mmaped_safetensors(&[vae_weights], dtype, device)? };
469 let autoencoder = vae::AutoEncoderKL::new(vs_ae, 3, 3, self.autoencoder.clone())?;
471 Ok(autoencoder)
472 }
473
474 pub fn build_unet<P: AsRef<std::path::Path>>(
475 &self,
476 unet_weights: P,
477 device: &Device,
478 in_channels: usize,
479 use_flash_attn: bool,
480 dtype: DType,
481 ) -> Result<unet_2d::UNet2DConditionModel> {
482 let vs_unet =
483 unsafe { nn::VarBuilder::from_mmaped_safetensors(&[unet_weights], dtype, device)? };
484 let unet = unet_2d::UNet2DConditionModel::new(
485 vs_unet,
486 in_channels,
487 4,
488 use_flash_attn,
489 self.unet.clone(),
490 )?;
491 Ok(unet)
492 }
493
494 pub fn build_unet_sharded<P: AsRef<std::path::Path>>(
495 &self,
496 unet_weight_files: &[P],
497 device: &Device,
498 in_channels: usize,
499 use_flash_attn: bool,
500 dtype: DType,
501 ) -> Result<unet_2d::UNet2DConditionModel> {
502 let vs_unet =
503 unsafe { nn::VarBuilder::from_mmaped_safetensors(unet_weight_files, dtype, device)? };
504 unet_2d::UNet2DConditionModel::new(
505 vs_unet,
506 in_channels,
507 4,
508 use_flash_attn,
509 self.unet.clone(),
510 )
511 }
512
513 pub fn build_scheduler(&self, n_steps: usize) -> Result<Box<dyn Scheduler>> {
514 self.scheduler.build(n_steps)
515 }
516}
517
518pub fn build_clip_transformer<P: AsRef<std::path::Path>>(
519 clip: &clip::Config,
520 clip_weights: P,
521 device: &Device,
522 dtype: DType,
523) -> Result<clip::ClipTextTransformer> {
524 let vs = unsafe { nn::VarBuilder::from_mmaped_safetensors(&[clip_weights], dtype, device)? };
525 let text_model = clip::ClipTextTransformer::new(vs, clip)?;
526 Ok(text_model)
527}