1use crate::cake::{Context, Forwarder};
2use crate::models::sd::clip::Clip;
3use crate::models::sd::safe_scheduler::SafeScheduler;
4use crate::models::sd::sd_shardable::SDShardable;
5use crate::models::sd::unet::UNet;
6use crate::models::sd::vae::VAE;
7use crate::models::{Generator, ImageGenerator};
8use crate::{ImageGenerationArgs, SDArgs, StableDiffusionVersion};
9use anyhow::{Error as E, Result};
10use async_trait::async_trait;
11use candle_core::{DType, Device, IndexOp, Tensor, D};
12use candle_transformers::models::stable_diffusion::StableDiffusionConfig;
13use hf_hub::api::sync::ApiBuilder;
14use hf_hub::Cache;
15use image::{ImageBuffer, Rgb};
16use log::{debug, info};
17use tokenizers::Tokenizer;
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum ModelFile {
21 Tokenizer,
22 Tokenizer2,
23 Clip,
24 Clip2,
25 Unet,
26 Vae,
27}
28
29impl ModelFile {
30 pub fn get(
31 &self,
32 filename: Option<String>,
33 version: StableDiffusionVersion,
34 use_f16: bool,
35 cache_dir: String,
36 ) -> Result<std::path::PathBuf> {
37 match filename {
38 Some(filename) => Ok(std::path::PathBuf::from(filename)),
39 None => {
40 let (repo, path) = match self {
41 Self::Tokenizer => {
42 let tokenizer_repo = match version {
43 StableDiffusionVersion::V1_5 | StableDiffusionVersion::V2_1 => {
44 "openai/clip-vit-base-patch32"
45 }
46 StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo => {
47 "openai/clip-vit-large-patch14"
50 }
51 };
52 (tokenizer_repo, "tokenizer.json")
53 }
54 Self::Tokenizer2 => {
55 ("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", "tokenizer.json")
56 }
57 Self::Clip => (version.repo(), version.clip_file(use_f16)),
58 Self::Clip2 => (version.repo(), version.clip2_file(use_f16)),
59 Self::Unet => (version.repo(), version.unet_file(use_f16)),
60 Self::Vae => {
61 if matches!(
64 version,
65 StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo,
66 ) && use_f16
67 {
68 (
69 "madebyollin/sdxl-vae-fp16-fix",
70 "diffusion_pytorch_model.safetensors",
71 )
72 } else {
73 (version.repo(), version.vae_file(use_f16))
74 }
75 }
76 };
77 let mut cache_path = std::path::PathBuf::from(cache_dir.as_str());
78 cache_path.push("hub");
79
80 debug!("Model cache dir: {:?}", cache_path);
81
82 let cache = Cache::new(cache_path);
83 let api = ApiBuilder::from_cache(cache).build()?;
84
85 let filename = api.model(repo.to_string()).get(path)?;
86 Ok(filename)
87 }
88 }
89 }
90
91 pub(crate) fn name(&self) -> &'static str {
92 match *self {
93 ModelFile::Tokenizer => "tokenizer",
94 ModelFile::Tokenizer2 => "tokenizer_2",
95 ModelFile::Clip => "clip",
96 ModelFile::Clip2 => "clip2",
97 ModelFile::Unet => "unet",
98 ModelFile::Vae => "vae",
99 }
100 }
101}
102
103pub struct SD {
104 tokenizer: Tokenizer,
105 pad_id: u32,
106 tokenizer_2: Option<Tokenizer>,
107 pad_id_2: Option<u32>,
108 text_model: Box<dyn Forwarder>,
109 text_model_2: Option<Box<dyn Forwarder>>,
110 vae: Box<dyn Forwarder>,
111 unet: Box<dyn Forwarder>,
112 sd_version: StableDiffusionVersion,
113 sd_config: StableDiffusionConfig,
114 context: Context,
115}
116
117#[async_trait]
118impl Generator for SD {
119 type Shardable = SDShardable;
120 const MODEL_NAME: &'static str = "stable-diffusion";
121
122 async fn load(context: &mut Context) -> Result<Option<Box<Self>>> {
123 let SDArgs {
124 tokenizer,
125 tokenizer_2,
126 sd_version,
127 use_f16,
128 width,
129 height,
130 sliced_attention_size,
131 clip,
132 clip2,
133 vae,
134 unet,
135 use_flash_attention,
136 ..
137 } = &context.args.sd_args;
138
139 let sd_config = match *sd_version {
140 StableDiffusionVersion::V1_5 => {
141 StableDiffusionConfig::v1_5(*sliced_attention_size, *height, *width)
142 }
143 StableDiffusionVersion::V2_1 => {
144 StableDiffusionConfig::v2_1(*sliced_attention_size, *height, *width)
145 }
146 StableDiffusionVersion::Xl => {
147 StableDiffusionConfig::sdxl(*sliced_attention_size, *height, *width)
148 }
149 StableDiffusionVersion::Turbo => {
150 StableDiffusionConfig::sdxl_turbo(*sliced_attention_size, *height, *width)
151 }
152 };
153
154 info!("Loading the Tokenizer...");
156
157 let tokenizer_file = ModelFile::Tokenizer;
158 let tokenizer = tokenizer_file.get(
159 tokenizer.clone(),
160 *sd_version,
161 *use_f16,
162 context.args.model.clone(),
163 )?;
164 let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
165
166 let pad_id = match &sd_config.clip.pad_with {
167 Some(padding) => *tokenizer.get_vocab(true).get(padding.as_str()).unwrap(),
168 None => *tokenizer.get_vocab(true).get("<|endoftext|>").unwrap(),
169 };
170
171 info!("Tokenizer loaded!");
172
173 let mut tokenizer_2_option: Option<Tokenizer> = None;
176 let mut pad_id_2: Option<u32> = None;
177
178 if let StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo = sd_version {
179 info!("Loading the Tokenizer 2...");
180
181 let tokenizer_2_file = ModelFile::Tokenizer2;
182 let tokenizer_2 = tokenizer_2_file.get(
183 tokenizer_2.clone(),
184 *sd_version,
185 *use_f16,
186 context.args.model.clone(),
187 )?;
188 let tokenizer_2 = Tokenizer::from_file(tokenizer_2).map_err(E::msg)?;
189
190 if let Some(clip2) = &sd_config.clip2 {
191 pad_id_2 = match &clip2.pad_with {
192 Some(padding) => {
193 Some(*tokenizer_2.get_vocab(true).get(padding.as_str()).unwrap())
194 }
195 None => Some(*tokenizer_2.get_vocab(true).get("<|endoftext|>").unwrap()),
196 };
197 }
198
199 tokenizer_2_option = Some(tokenizer_2);
200
201 info!("Tokenizer 2 loaded!");
202 }
203
204 info!("Loading the Clip text model.");
206
207 let text_model: Box<dyn Forwarder>;
208
209 if let Some((node_name, node)) = context.topology.get_node_for_layer(ModelFile::Clip.name())
210 {
211 info!("node {node_name} will serve Clip");
212 text_model = Box::new(
213 crate::cake::Client::new(
214 context.device.clone(),
215 &node.host,
216 ModelFile::Clip.name(),
217 context.args.cluster_key.as_deref(),
218 )
219 .await?,
220 );
221 } else {
222 info!("Clip will be served locally");
223 text_model = Clip::load_model(
224 ModelFile::Clip,
225 clip.clone(),
226 *sd_version,
227 *use_f16,
228 &context.device,
229 context.dtype,
230 context.args.model.clone(),
231 &sd_config.clip,
232 )?;
233 }
234
235 info!("Clip text model loaded!");
236
237 let mut text_model_2: Option<Box<dyn Forwarder>> = None;
240 if let StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo = sd_version {
241 info!("Loading the Clip 2 text model.");
242
243 if let Some((node_name, node)) =
244 context.topology.get_node_for_layer(ModelFile::Clip2.name())
245 {
246 info!("node {node_name} will serve clip2");
247 text_model_2 = Some(Box::new(
248 crate::cake::Client::new(
249 context.device.clone(),
250 &node.host,
251 ModelFile::Clip2.name(),
252 context.args.cluster_key.as_deref(),
253 )
254 .await?,
255 ));
256 } else {
257 info!("Clip 2 will be served locally");
258 text_model_2 = Some(Clip::load_model(
259 ModelFile::Clip2,
260 clip2.clone(),
261 *sd_version,
262 *use_f16,
263 &context.device,
264 context.dtype,
265 context.args.model.clone(),
266 sd_config.clip2.as_ref().unwrap(),
267 )?);
268 }
269
270 info!("Clip 2 text model loaded!");
271 }
272
273 info!("Loading the VAE...");
275
276 let vae_model: Box<dyn Forwarder>;
277
278 if let Some((node_name, node)) = context.topology.get_node_for_layer(ModelFile::Vae.name())
279 {
280 info!("node {node_name} will serve VAE");
281 vae_model = Box::new(
282 crate::cake::Client::new(context.device.clone(), &node.host, ModelFile::Vae.name(), context.args.cluster_key.as_deref())
283 .await?,
284 );
285 } else {
286 info!("VAE will be served locally");
287 vae_model = VAE::load_model(
288 vae.clone(),
289 *sd_version,
290 *use_f16,
291 &context.device,
292 context.dtype,
293 context.args.model.clone(),
294 &sd_config,
295 )?;
296 }
297
298 info!("VAE loaded!");
299
300 info!("Loading the UNet.");
302
303 let unet_model: Box<dyn Forwarder>;
304 if let Some((node_name, node)) = context.topology.get_node_for_layer(ModelFile::Unet.name())
305 {
306 info!("node {node_name} will serve UNet");
307 unet_model = Box::new(
308 crate::cake::Client::new(
309 context.device.clone(),
310 &node.host,
311 ModelFile::Unet.name(),
312 context.args.cluster_key.as_deref(),
313 )
314 .await?,
315 );
316 } else {
317 info!("UNet will be served locally");
318 unet_model = UNet::load_model(
319 unet.clone(),
320 *use_flash_attention,
321 *sd_version,
322 *use_f16,
323 &context.device,
324 context.dtype,
325 context.args.model.clone(),
326 &sd_config,
327 )?;
328 }
329
330 info!("UNet loaded!");
331
332 Ok(Some(Box::new(Self {
333 tokenizer,
334 sd_version: *sd_version,
335 sd_config,
336 pad_id,
337 text_model,
338 tokenizer_2: tokenizer_2_option,
339 pad_id_2,
340 text_model_2,
341 vae: vae_model,
342 unet: unet_model,
343 context: context.clone(),
344 })))
345 }
346}
347
348#[async_trait]
349impl ImageGenerator for SD {
350 async fn generate_image<F>(
351 &mut self,
352 args: &ImageGenerationArgs,
353 mut callback: F,
354 ) -> Result<(), anyhow::Error>
355 where
356 F: FnMut(Vec<ImageBuffer<Rgb<u8>, Vec<u8>>>) + Send + 'static,
357 {
358 use tracing_chrome::ChromeLayerBuilder;
359 use tracing_subscriber::prelude::*;
360
361 let ImageGenerationArgs {
362 image_prompt,
363 uncond_prompt,
364 n_steps,
365 num_samples,
366 bsize,
367 tracing,
368 guidance_scale,
369 img2img,
370 img2img_strength,
371 image_seed,
372 intermediary_images,
373 ..
374 } = args;
375
376 let sd_version = self.sd_version;
377
378 if !(0. ..=1.).contains(img2img_strength) {
379 anyhow::bail!("img2img-strength should be between 0 and 1, got {img2img_strength}")
380 }
381
382 let _guard = if *tracing {
383 let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
384 tracing_subscriber::registry().with(chrome_layer).init();
385 Some(guard)
386 } else {
387 None
388 };
389
390 let guidance_scale = match guidance_scale {
391 Some(guidance_scale) => guidance_scale,
392 None => &match sd_version {
393 StableDiffusionVersion::V1_5
394 | StableDiffusionVersion::V2_1
395 | StableDiffusionVersion::Xl => 7.5,
396 StableDiffusionVersion::Turbo => 0.,
397 },
398 };
399 let n_steps = match n_steps {
400 Some(n_steps) => n_steps,
401 None => &match sd_version {
402 StableDiffusionVersion::V1_5
403 | StableDiffusionVersion::V2_1
404 | StableDiffusionVersion::Xl => 30,
405 StableDiffusionVersion::Turbo => 1,
406 },
407 };
408
409 if let Some(seed) = image_seed {
410 self.context.device.set_seed(*seed)?;
411 }
412 let use_guide_scale = guidance_scale > &1.0;
413
414 let mut text_embeddings: Vec<Tensor> = Vec::new();
415
416 let text_embeddings_1 = self
417 .text_embeddings(image_prompt, uncond_prompt, use_guide_scale, true)
418 .await?;
419
420 text_embeddings.push(text_embeddings_1);
421
422 if let StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo = sd_version {
423 let text_embeddings_2 = self
424 .text_embeddings(image_prompt, uncond_prompt, use_guide_scale, false)
425 .await?;
426
427 text_embeddings.push(text_embeddings_2);
428 }
429
430 let text_embeddings = Tensor::cat(&text_embeddings, D::Minus1)?;
431 let text_embeddings = text_embeddings.repeat((*bsize, 1, 1))?;
432 debug!("{text_embeddings:?}");
433
434 let init_latent_dist_sample = match &img2img {
435 None => None,
436 Some(image) => {
437 let image = image_preprocess(image)?.to_device(&self.context.device)?;
438 Some(VAE::encode(&mut self.vae, image, &mut self.context).await?)
439 }
440 };
441
442 let t_start = if img2img.is_some() {
443 *n_steps - (*n_steps as f64 * img2img_strength) as usize
444 } else {
445 0
446 };
447
448 let vae_scale = match sd_version {
449 StableDiffusionVersion::V1_5
450 | StableDiffusionVersion::V2_1
451 | StableDiffusionVersion::Xl => 0.18215,
452 StableDiffusionVersion::Turbo => 0.13025,
453 };
454
455 let mut safe_scheduler = SafeScheduler {
456 scheduler: self.sd_config.build_scheduler(*n_steps)?,
457 };
458
459 for idx in 0..(*num_samples) {
460 let timesteps = safe_scheduler.scheduler.timesteps().to_vec();
461 let latents = match &init_latent_dist_sample {
462 Some(init_latent_dist) => {
463 let latents =
464 (init_latent_dist * vae_scale)?.to_device(&self.context.device)?;
465 if t_start < timesteps.len() {
466 let noise = latents.randn_like(0f64, 1f64)?;
467 safe_scheduler
468 .scheduler
469 .add_noise(&latents, noise, timesteps[t_start])?
470 } else {
471 latents
472 }
473 }
474
475 None => {
476 let latents = Tensor::randn(
477 0f32,
478 1f32,
479 (
480 *bsize,
481 4,
482 self.sd_config.height / 8,
483 self.sd_config.width / 8,
484 ),
485 &self.context.device,
486 )?;
487 (latents * safe_scheduler.scheduler.init_noise_sigma())?
489 }
490 };
491
492 let mut latents = latents.to_dtype(self.context.dtype)?;
493
494 debug!("Starting sampling...");
495
496 for (timestep_index, ×tep) in timesteps.iter().enumerate() {
497 if timestep_index < t_start {
498 continue;
499 }
500 let start_time = std::time::Instant::now();
501 let latent_model_input = if use_guide_scale {
502 Tensor::cat(&[&latents, &latents], 0)?
503 } else {
504 latents.clone()
505 };
506
507 let latent_model_input = safe_scheduler
508 .scheduler
509 .scale_model_input(latent_model_input, timestep)?;
510
511 debug!("UNet forwarding...");
512
513 let noise_pred = UNet::forward_unpacked(
514 &mut self.unet,
515 latent_model_input,
516 text_embeddings.clone(),
517 timestep,
518 &mut self.context,
519 )
520 .await?;
521
522 debug!("UNet forwarding completed!");
523
524 let noise_pred = if use_guide_scale {
525 debug!("Applying guidance scale...");
526
527 let noise_pred = noise_pred.chunk(2, 0)?;
528 let (noise_pred_uncond, noise_pred_text) = (&noise_pred[0], &noise_pred[1]);
529
530 (noise_pred_uncond
531 + ((noise_pred_text - noise_pred_uncond)? * *guidance_scale)?)?
532 } else {
533 noise_pred
534 };
535
536 debug!("Scheduler stepping...");
537
538 latents = safe_scheduler
539 .scheduler
540 .step(&noise_pred, timestep, &latents)?;
541
542 let dt = start_time.elapsed().as_secs_f32();
543 info!("step {}/{n_steps} done, {:.2}s", timestep_index + 1, dt);
544
545 if *intermediary_images != 0 && timestep_index % *intermediary_images == 0 {
546 let intermediary_batched_images =
547 self.split_images(&latents, vae_scale, *bsize).await?;
548 callback(intermediary_batched_images);
549 }
550 }
551
552 debug!(
553 "Generating the final image for sample {}/{}.",
554 idx + 1,
555 num_samples
556 );
557
558 let batched_images = self.split_images(&latents, vae_scale, *bsize).await?;
559
560 callback(batched_images);
561 }
562
563 Ok(())
564 }
565}
566
567impl SD {
568 async fn split_images(
569 &mut self,
570 latents: &Tensor,
571 vae_scale: f64,
572 bsize: usize,
573 ) -> Result<Vec<ImageBuffer<image::Rgb<u8>, Vec<u8>>>> {
574 let mut images_vec = Vec::new();
575
576 let scaled = (latents / vae_scale)?;
577 let images = VAE::decode(&mut self.vae, scaled, &mut self.context).await?;
578 let images = ((images / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
579 let images = (images.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)?;
580 for batch in 0..bsize {
581 let image_tensor = images.i(batch)?;
582 let (channel, height, width) = image_tensor.dims3()?;
583 if channel != 3 {
584 anyhow::bail!("save_image expects an input of shape (3, height, width)")
585 }
586 let image_tensor = image_tensor.permute((1, 2, 0))?.flatten_all()?;
587 let pixels = image_tensor.to_vec1::<u8>()?;
588
589 let image: ImageBuffer<image::Rgb<u8>, Vec<u8>> =
590 match ImageBuffer::from_raw(width as u32, height as u32, pixels) {
591 Some(image) => image,
592 None => anyhow::bail!("Error splitting images"),
593 };
594 images_vec.push(image)
595 }
596 Ok(images_vec)
597 }
598
599 async fn text_embeddings(
600 &mut self,
601 prompt: &str,
602 uncond_prompt: &str,
603 use_guide_scale: bool,
604 first: bool,
605 ) -> Result<Tensor> {
606 let tokenizer;
607 let text_model;
608 let pad_id;
609 let max_token_embeddings;
610
611 if first {
612 tokenizer = &self.tokenizer;
613 text_model = &mut self.text_model;
614 pad_id = self.pad_id;
615 max_token_embeddings = self.sd_config.clip.max_position_embeddings;
616 } else {
617 tokenizer = self.tokenizer_2.as_ref().unwrap();
618 text_model = self.text_model_2.as_mut().unwrap();
619 pad_id = self.pad_id_2.unwrap();
620 max_token_embeddings = self
621 .sd_config
622 .clip2
623 .as_ref()
624 .unwrap()
625 .max_position_embeddings;
626 }
627
628 info!("Running with prompt \"{prompt}\".");
629
630 let mut tokens = tokenizer
631 .encode(prompt, true)
632 .map_err(E::msg)?
633 .get_ids()
634 .to_vec();
635
636 if tokens.len() > max_token_embeddings {
637 anyhow::bail!(
638 "the prompt is too long, {} > max-tokens ({})",
639 tokens.len(),
640 max_token_embeddings
641 )
642 }
643
644 while tokens.len() < max_token_embeddings {
645 tokens.push(pad_id)
646 }
647
648 let tokens = Tensor::new(tokens.as_slice(), &self.context.device)?.unsqueeze(0)?;
649
650 let text_embeddings = text_model
651 .forward_mut(&tokens, 0, 0, &mut self.context)
652 .await?;
653
654 let text_embeddings = if use_guide_scale {
655 let mut uncond_tokens = tokenizer
656 .encode(uncond_prompt, true)
657 .map_err(E::msg)?
658 .get_ids()
659 .to_vec();
660 if uncond_tokens.len() > max_token_embeddings {
661 anyhow::bail!(
662 "the negative prompt is too long, {} > max-tokens ({})",
663 uncond_tokens.len(),
664 max_token_embeddings
665 )
666 }
667 while uncond_tokens.len() < max_token_embeddings {
668 uncond_tokens.push(pad_id)
669 }
670
671 let uncond_tokens =
672 Tensor::new(uncond_tokens.as_slice(), &self.context.device)?.unsqueeze(0)?;
673
674 info!("Clip forwarding...");
675 let uncond_embeddings = text_model
676 .forward_mut(&uncond_tokens, 0, 0, &mut self.context)
677 .await?;
678
679 Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(self.context.dtype)?
680 } else {
681 text_embeddings.to_dtype(self.context.dtype)?
682 };
683
684 Ok(text_embeddings)
685 }
686}
687
688fn image_preprocess<T: AsRef<std::path::Path>>(path: T) -> Result<Tensor> {
689 let img = image::ImageReader::open(path)?.decode()?;
690 let (height, width) = (img.height() as usize, img.width() as usize);
691 let height = height - height % 32;
692 let width = width - width % 32;
693 let img = img.resize_to_fill(
694 width as u32,
695 height as u32,
696 image::imageops::FilterType::CatmullRom,
697 );
698 let img = img.to_rgb8();
699 let img = img.into_raw();
700 let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)?
701 .permute((2, 0, 1))?
702 .to_dtype(DType::F32)?
703 .affine(2. / 255., -1.)?
704 .unsqueeze(0)?;
705 Ok(img)
706}