Skip to main content

cake_core/models/sd/
sd.rs

1use crate::cake::{Context, Forwarder};
2use crate::models::sd::clip::Clip;
3use crate::models::sd::safe_scheduler::SafeScheduler;
4use crate::models::sd::sd_shardable::SDShardable;
5use crate::models::sd::unet::UNet;
6use crate::models::sd::vae::VAE;
7use crate::models::{Generator, ImageGenerator};
8use crate::{ImageGenerationArgs, SDArgs, StableDiffusionVersion};
9use anyhow::{Error as E, Result};
10use async_trait::async_trait;
11use candle_core::{DType, Device, IndexOp, Tensor, D};
12use candle_transformers::models::stable_diffusion::StableDiffusionConfig;
13use hf_hub::api::sync::ApiBuilder;
14use hf_hub::Cache;
15use image::{ImageBuffer, Rgb};
16use log::{debug, info};
17use tokenizers::Tokenizer;
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum ModelFile {
21    Tokenizer,
22    Tokenizer2,
23    Clip,
24    Clip2,
25    Unet,
26    Vae,
27}
28
29impl ModelFile {
30    pub fn get(
31        &self,
32        filename: Option<String>,
33        version: StableDiffusionVersion,
34        use_f16: bool,
35        cache_dir: String,
36    ) -> Result<std::path::PathBuf> {
37        match filename {
38            Some(filename) => Ok(std::path::PathBuf::from(filename)),
39            None => {
40                let (repo, path) = match self {
41                    Self::Tokenizer => {
42                        let tokenizer_repo = match version {
43                            StableDiffusionVersion::V1_5 | StableDiffusionVersion::V2_1 => {
44                                "openai/clip-vit-base-patch32"
45                            }
46                            StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo => {
47                                // This seems similar to the patch32 version except some very small
48                                // difference in the split regex.
49                                "openai/clip-vit-large-patch14"
50                            }
51                        };
52                        (tokenizer_repo, "tokenizer.json")
53                    }
54                    Self::Tokenizer2 => {
55                        ("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", "tokenizer.json")
56                    }
57                    Self::Clip => (version.repo(), version.clip_file(use_f16)),
58                    Self::Clip2 => (version.repo(), version.clip2_file(use_f16)),
59                    Self::Unet => (version.repo(), version.unet_file(use_f16)),
60                    Self::Vae => {
61                        // Override for SDXL when using f16 weights.
62                        // See https://github.com/huggingface/candle/issues/1060
63                        if matches!(
64                            version,
65                            StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo,
66                        ) && use_f16
67                        {
68                            (
69                                "madebyollin/sdxl-vae-fp16-fix",
70                                "diffusion_pytorch_model.safetensors",
71                            )
72                        } else {
73                            (version.repo(), version.vae_file(use_f16))
74                        }
75                    }
76                };
77                let mut cache_path = std::path::PathBuf::from(cache_dir.as_str());
78                cache_path.push("hub");
79
80                debug!("Model cache dir: {:?}", cache_path);
81
82                let cache = Cache::new(cache_path);
83                let api = ApiBuilder::from_cache(cache).build()?;
84
85                let filename = api.model(repo.to_string()).get(path)?;
86                Ok(filename)
87            }
88        }
89    }
90
91    pub(crate) fn name(&self) -> &'static str {
92        match *self {
93            ModelFile::Tokenizer => "tokenizer",
94            ModelFile::Tokenizer2 => "tokenizer_2",
95            ModelFile::Clip => "clip",
96            ModelFile::Clip2 => "clip2",
97            ModelFile::Unet => "unet",
98            ModelFile::Vae => "vae",
99        }
100    }
101}
102
103pub struct SD {
104    tokenizer: Tokenizer,
105    pad_id: u32,
106    tokenizer_2: Option<Tokenizer>,
107    pad_id_2: Option<u32>,
108    text_model: Box<dyn Forwarder>,
109    text_model_2: Option<Box<dyn Forwarder>>,
110    vae: Box<dyn Forwarder>,
111    unet: Box<dyn Forwarder>,
112    sd_version: StableDiffusionVersion,
113    sd_config: StableDiffusionConfig,
114    context: Context,
115}
116
117#[async_trait]
118impl Generator for SD {
119    type Shardable = SDShardable;
120    const MODEL_NAME: &'static str = "stable-diffusion";
121
122    async fn load(context: &mut Context) -> Result<Option<Box<Self>>> {
123        let SDArgs {
124            tokenizer,
125            tokenizer_2,
126            sd_version,
127            use_f16,
128            width,
129            height,
130            sliced_attention_size,
131            clip,
132            clip2,
133            vae,
134            unet,
135            use_flash_attention,
136            ..
137        } = &context.args.sd_args;
138
139        let sd_config = match *sd_version {
140            StableDiffusionVersion::V1_5 => {
141                StableDiffusionConfig::v1_5(*sliced_attention_size, *height, *width)
142            }
143            StableDiffusionVersion::V2_1 => {
144                StableDiffusionConfig::v2_1(*sliced_attention_size, *height, *width)
145            }
146            StableDiffusionVersion::Xl => {
147                StableDiffusionConfig::sdxl(*sliced_attention_size, *height, *width)
148            }
149            StableDiffusionVersion::Turbo => {
150                StableDiffusionConfig::sdxl_turbo(*sliced_attention_size, *height, *width)
151            }
152        };
153
154        // Tokenizer
155        info!("Loading the Tokenizer...");
156
157        let tokenizer_file = ModelFile::Tokenizer;
158        let tokenizer = tokenizer_file.get(
159            tokenizer.clone(),
160            *sd_version,
161            *use_f16,
162            context.args.model.clone(),
163        )?;
164        let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
165
166        let pad_id = match &sd_config.clip.pad_with {
167            Some(padding) => *tokenizer.get_vocab(true).get(padding.as_str()).unwrap(),
168            None => *tokenizer.get_vocab(true).get("<|endoftext|>").unwrap(),
169        };
170
171        info!("Tokenizer loaded!");
172
173        // Tokenizer 2
174
175        let mut tokenizer_2_option: Option<Tokenizer> = None;
176        let mut pad_id_2: Option<u32> = None;
177
178        if let StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo = sd_version {
179            info!("Loading the Tokenizer 2...");
180
181            let tokenizer_2_file = ModelFile::Tokenizer2;
182            let tokenizer_2 = tokenizer_2_file.get(
183                tokenizer_2.clone(),
184                *sd_version,
185                *use_f16,
186                context.args.model.clone(),
187            )?;
188            let tokenizer_2 = Tokenizer::from_file(tokenizer_2).map_err(E::msg)?;
189
190            if let Some(clip2) = &sd_config.clip2 {
191                pad_id_2 = match &clip2.pad_with {
192                    Some(padding) => {
193                        Some(*tokenizer_2.get_vocab(true).get(padding.as_str()).unwrap())
194                    }
195                    None => Some(*tokenizer_2.get_vocab(true).get("<|endoftext|>").unwrap()),
196                };
197            }
198
199            tokenizer_2_option = Some(tokenizer_2);
200
201            info!("Tokenizer 2 loaded!");
202        }
203
204        // Clip
205        info!("Loading the Clip text model.");
206
207        let text_model: Box<dyn Forwarder>;
208
209        if let Some((node_name, node)) = context.topology.get_node_for_layer(ModelFile::Clip.name())
210        {
211            info!("node {node_name} will serve Clip");
212            text_model = Box::new(
213                crate::cake::Client::new(
214                    context.device.clone(),
215                    &node.host,
216                    ModelFile::Clip.name(),
217                    context.args.cluster_key.as_deref(),
218                )
219                .await?,
220            );
221        } else {
222            info!("Clip will be served locally");
223            text_model = Clip::load_model(
224                ModelFile::Clip,
225                clip.clone(),
226                *sd_version,
227                *use_f16,
228                &context.device,
229                context.dtype,
230                context.args.model.clone(),
231                &sd_config.clip,
232            )?;
233        }
234
235        info!("Clip text model loaded!");
236
237        // Clip 2
238
239        let mut text_model_2: Option<Box<dyn Forwarder>> = None;
240        if let StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo = sd_version {
241            info!("Loading the Clip 2 text model.");
242
243            if let Some((node_name, node)) =
244                context.topology.get_node_for_layer(ModelFile::Clip2.name())
245            {
246                info!("node {node_name} will serve clip2");
247                text_model_2 = Some(Box::new(
248                    crate::cake::Client::new(
249                        context.device.clone(),
250                        &node.host,
251                        ModelFile::Clip2.name(),
252                        context.args.cluster_key.as_deref(),
253                    )
254                    .await?,
255                ));
256            } else {
257                info!("Clip 2 will be served locally");
258                text_model_2 = Some(Clip::load_model(
259                    ModelFile::Clip2,
260                    clip2.clone(),
261                    *sd_version,
262                    *use_f16,
263                    &context.device,
264                    context.dtype,
265                    context.args.model.clone(),
266                    sd_config.clip2.as_ref().unwrap(),
267                )?);
268            }
269
270            info!("Clip 2 text model loaded!");
271        }
272
273        // VAE
274        info!("Loading the VAE...");
275
276        let vae_model: Box<dyn Forwarder>;
277
278        if let Some((node_name, node)) = context.topology.get_node_for_layer(ModelFile::Vae.name())
279        {
280            info!("node {node_name} will serve VAE");
281            vae_model = Box::new(
282                crate::cake::Client::new(context.device.clone(), &node.host, ModelFile::Vae.name(), context.args.cluster_key.as_deref())
283                    .await?,
284            );
285        } else {
286            info!("VAE will be served locally");
287            vae_model = VAE::load_model(
288                vae.clone(),
289                *sd_version,
290                *use_f16,
291                &context.device,
292                context.dtype,
293                context.args.model.clone(),
294                &sd_config,
295            )?;
296        }
297
298        info!("VAE loaded!");
299
300        // Unet
301        info!("Loading the UNet.");
302
303        let unet_model: Box<dyn Forwarder>;
304        if let Some((node_name, node)) = context.topology.get_node_for_layer(ModelFile::Unet.name())
305        {
306            info!("node {node_name} will serve UNet");
307            unet_model = Box::new(
308                crate::cake::Client::new(
309                    context.device.clone(),
310                    &node.host,
311                    ModelFile::Unet.name(),
312                    context.args.cluster_key.as_deref(),
313                )
314                .await?,
315            );
316        } else {
317            info!("UNet will be served locally");
318            unet_model = UNet::load_model(
319                unet.clone(),
320                *use_flash_attention,
321                *sd_version,
322                *use_f16,
323                &context.device,
324                context.dtype,
325                context.args.model.clone(),
326                &sd_config,
327            )?;
328        }
329
330        info!("UNet loaded!");
331
332        Ok(Some(Box::new(Self {
333            tokenizer,
334            sd_version: *sd_version,
335            sd_config,
336            pad_id,
337            text_model,
338            tokenizer_2: tokenizer_2_option,
339            pad_id_2,
340            text_model_2,
341            vae: vae_model,
342            unet: unet_model,
343            context: context.clone(),
344        })))
345    }
346}
347
348#[async_trait]
349impl ImageGenerator for SD {
350    async fn generate_image<F>(
351        &mut self,
352        args: &ImageGenerationArgs,
353        mut callback: F,
354    ) -> Result<(), anyhow::Error>
355    where
356        F: FnMut(Vec<ImageBuffer<Rgb<u8>, Vec<u8>>>) + Send + 'static,
357    {
358        use tracing_chrome::ChromeLayerBuilder;
359        use tracing_subscriber::prelude::*;
360
361        let ImageGenerationArgs {
362            image_prompt,
363            uncond_prompt,
364            n_steps,
365            num_samples,
366            bsize,
367            tracing,
368            guidance_scale,
369            img2img,
370            img2img_strength,
371            image_seed,
372            intermediary_images,
373            ..
374        } = args;
375
376        let sd_version = self.sd_version;
377
378        if !(0. ..=1.).contains(img2img_strength) {
379            anyhow::bail!("img2img-strength should be between 0 and 1, got {img2img_strength}")
380        }
381
382        let _guard = if *tracing {
383            let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
384            tracing_subscriber::registry().with(chrome_layer).init();
385            Some(guard)
386        } else {
387            None
388        };
389
390        let guidance_scale = match guidance_scale {
391            Some(guidance_scale) => guidance_scale,
392            None => &match sd_version {
393                StableDiffusionVersion::V1_5
394                | StableDiffusionVersion::V2_1
395                | StableDiffusionVersion::Xl => 7.5,
396                StableDiffusionVersion::Turbo => 0.,
397            },
398        };
399        let n_steps = match n_steps {
400            Some(n_steps) => n_steps,
401            None => &match sd_version {
402                StableDiffusionVersion::V1_5
403                | StableDiffusionVersion::V2_1
404                | StableDiffusionVersion::Xl => 30,
405                StableDiffusionVersion::Turbo => 1,
406            },
407        };
408
409        if let Some(seed) = image_seed {
410            self.context.device.set_seed(*seed)?;
411        }
412        let use_guide_scale = guidance_scale > &1.0;
413
414        let mut text_embeddings: Vec<Tensor> = Vec::new();
415
416        let text_embeddings_1 = self
417            .text_embeddings(image_prompt, uncond_prompt, use_guide_scale, true)
418            .await?;
419
420        text_embeddings.push(text_embeddings_1);
421
422        if let StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo = sd_version {
423            let text_embeddings_2 = self
424                .text_embeddings(image_prompt, uncond_prompt, use_guide_scale, false)
425                .await?;
426
427            text_embeddings.push(text_embeddings_2);
428        }
429
430        let text_embeddings = Tensor::cat(&text_embeddings, D::Minus1)?;
431        let text_embeddings = text_embeddings.repeat((*bsize, 1, 1))?;
432        debug!("{text_embeddings:?}");
433
434        let init_latent_dist_sample = match &img2img {
435            None => None,
436            Some(image) => {
437                let image = image_preprocess(image)?.to_device(&self.context.device)?;
438                Some(VAE::encode(&mut self.vae, image, &mut self.context).await?)
439            }
440        };
441
442        let t_start = if img2img.is_some() {
443            *n_steps - (*n_steps as f64 * img2img_strength) as usize
444        } else {
445            0
446        };
447
448        let vae_scale = match sd_version {
449            StableDiffusionVersion::V1_5
450            | StableDiffusionVersion::V2_1
451            | StableDiffusionVersion::Xl => 0.18215,
452            StableDiffusionVersion::Turbo => 0.13025,
453        };
454
455        let mut safe_scheduler = SafeScheduler {
456            scheduler: self.sd_config.build_scheduler(*n_steps)?,
457        };
458
459        for idx in 0..(*num_samples) {
460            let timesteps = safe_scheduler.scheduler.timesteps().to_vec();
461            let latents = match &init_latent_dist_sample {
462                Some(init_latent_dist) => {
463                    let latents =
464                        (init_latent_dist * vae_scale)?.to_device(&self.context.device)?;
465                    if t_start < timesteps.len() {
466                        let noise = latents.randn_like(0f64, 1f64)?;
467                        safe_scheduler
468                            .scheduler
469                            .add_noise(&latents, noise, timesteps[t_start])?
470                    } else {
471                        latents
472                    }
473                }
474
475                None => {
476                    let latents = Tensor::randn(
477                        0f32,
478                        1f32,
479                        (
480                            *bsize,
481                            4,
482                            self.sd_config.height / 8,
483                            self.sd_config.width / 8,
484                        ),
485                        &self.context.device,
486                    )?;
487                    // scale the initial noise by the standard deviation required by the scheduler
488                    (latents * safe_scheduler.scheduler.init_noise_sigma())?
489                }
490            };
491
492            let mut latents = latents.to_dtype(self.context.dtype)?;
493
494            debug!("Starting sampling...");
495
496            for (timestep_index, &timestep) in timesteps.iter().enumerate() {
497                if timestep_index < t_start {
498                    continue;
499                }
500                let start_time = std::time::Instant::now();
501                let latent_model_input = if use_guide_scale {
502                    Tensor::cat(&[&latents, &latents], 0)?
503                } else {
504                    latents.clone()
505                };
506
507                let latent_model_input = safe_scheduler
508                    .scheduler
509                    .scale_model_input(latent_model_input, timestep)?;
510
511                debug!("UNet forwarding...");
512
513                let noise_pred = UNet::forward_unpacked(
514                    &mut self.unet,
515                    latent_model_input,
516                    text_embeddings.clone(),
517                    timestep,
518                    &mut self.context,
519                )
520                .await?;
521
522                debug!("UNet forwarding completed!");
523
524                let noise_pred = if use_guide_scale {
525                    debug!("Applying guidance scale...");
526
527                    let noise_pred = noise_pred.chunk(2, 0)?;
528                    let (noise_pred_uncond, noise_pred_text) = (&noise_pred[0], &noise_pred[1]);
529
530                    (noise_pred_uncond
531                        + ((noise_pred_text - noise_pred_uncond)? * *guidance_scale)?)?
532                } else {
533                    noise_pred
534                };
535
536                debug!("Scheduler stepping...");
537
538                latents = safe_scheduler
539                    .scheduler
540                    .step(&noise_pred, timestep, &latents)?;
541
542                let dt = start_time.elapsed().as_secs_f32();
543                info!("step {}/{n_steps} done, {:.2}s", timestep_index + 1, dt);
544
545                if *intermediary_images != 0 && timestep_index % *intermediary_images == 0 {
546                    let intermediary_batched_images =
547                        self.split_images(&latents, vae_scale, *bsize).await?;
548                    callback(intermediary_batched_images);
549                }
550            }
551
552            debug!(
553                "Generating the final image for sample {}/{}.",
554                idx + 1,
555                num_samples
556            );
557
558            let batched_images = self.split_images(&latents, vae_scale, *bsize).await?;
559
560            callback(batched_images);
561        }
562
563        Ok(())
564    }
565}
566
567impl SD {
568    async fn split_images(
569        &mut self,
570        latents: &Tensor,
571        vae_scale: f64,
572        bsize: usize,
573    ) -> Result<Vec<ImageBuffer<image::Rgb<u8>, Vec<u8>>>> {
574        let mut images_vec = Vec::new();
575
576        let scaled = (latents / vae_scale)?;
577        let images = VAE::decode(&mut self.vae, scaled, &mut self.context).await?;
578        let images = ((images / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
579        let images = (images.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)?;
580        for batch in 0..bsize {
581            let image_tensor = images.i(batch)?;
582            let (channel, height, width) = image_tensor.dims3()?;
583            if channel != 3 {
584                anyhow::bail!("save_image expects an input of shape (3, height, width)")
585            }
586            let image_tensor = image_tensor.permute((1, 2, 0))?.flatten_all()?;
587            let pixels = image_tensor.to_vec1::<u8>()?;
588
589            let image: ImageBuffer<image::Rgb<u8>, Vec<u8>> =
590                match ImageBuffer::from_raw(width as u32, height as u32, pixels) {
591                    Some(image) => image,
592                    None => anyhow::bail!("Error splitting images"),
593                };
594            images_vec.push(image)
595        }
596        Ok(images_vec)
597    }
598
599    async fn text_embeddings(
600        &mut self,
601        prompt: &str,
602        uncond_prompt: &str,
603        use_guide_scale: bool,
604        first: bool,
605    ) -> Result<Tensor> {
606        let tokenizer;
607        let text_model;
608        let pad_id;
609        let max_token_embeddings;
610
611        if first {
612            tokenizer = &self.tokenizer;
613            text_model = &mut self.text_model;
614            pad_id = self.pad_id;
615            max_token_embeddings = self.sd_config.clip.max_position_embeddings;
616        } else {
617            tokenizer = self.tokenizer_2.as_ref().unwrap();
618            text_model = self.text_model_2.as_mut().unwrap();
619            pad_id = self.pad_id_2.unwrap();
620            max_token_embeddings = self
621                .sd_config
622                .clip2
623                .as_ref()
624                .unwrap()
625                .max_position_embeddings;
626        }
627
628        info!("Running with prompt \"{prompt}\".");
629
630        let mut tokens = tokenizer
631            .encode(prompt, true)
632            .map_err(E::msg)?
633            .get_ids()
634            .to_vec();
635
636        if tokens.len() > max_token_embeddings {
637            anyhow::bail!(
638                "the prompt is too long, {} > max-tokens ({})",
639                tokens.len(),
640                max_token_embeddings
641            )
642        }
643
644        while tokens.len() < max_token_embeddings {
645            tokens.push(pad_id)
646        }
647
648        let tokens = Tensor::new(tokens.as_slice(), &self.context.device)?.unsqueeze(0)?;
649
650        let text_embeddings = text_model
651            .forward_mut(&tokens, 0, 0, &mut self.context)
652            .await?;
653
654        let text_embeddings = if use_guide_scale {
655            let mut uncond_tokens = tokenizer
656                .encode(uncond_prompt, true)
657                .map_err(E::msg)?
658                .get_ids()
659                .to_vec();
660            if uncond_tokens.len() > max_token_embeddings {
661                anyhow::bail!(
662                    "the negative prompt is too long, {} > max-tokens ({})",
663                    uncond_tokens.len(),
664                    max_token_embeddings
665                )
666            }
667            while uncond_tokens.len() < max_token_embeddings {
668                uncond_tokens.push(pad_id)
669            }
670
671            let uncond_tokens =
672                Tensor::new(uncond_tokens.as_slice(), &self.context.device)?.unsqueeze(0)?;
673
674            info!("Clip forwarding...");
675            let uncond_embeddings = text_model
676                .forward_mut(&uncond_tokens, 0, 0, &mut self.context)
677                .await?;
678
679            Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(self.context.dtype)?
680        } else {
681            text_embeddings.to_dtype(self.context.dtype)?
682        };
683
684        Ok(text_embeddings)
685    }
686}
687
688fn image_preprocess<T: AsRef<std::path::Path>>(path: T) -> Result<Tensor> {
689    let img = image::ImageReader::open(path)?.decode()?;
690    let (height, width) = (img.height() as usize, img.width() as usize);
691    let height = height - height % 32;
692    let width = width - width % 32;
693    let img = img.resize_to_fill(
694        width as u32,
695        height as u32,
696        image::imageops::FilterType::CatmullRom,
697    );
698    let img = img.to_rgb8();
699    let img = img.into_raw();
700    let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)?
701        .permute((2, 0, 1))?
702        .to_dtype(DType::F32)?
703        .affine(2. / 255., -1.)?
704        .unsqueeze(0)?;
705    Ok(img)
706}