diffusion_rs/
api.rs

1use std::ffi::CString;
2use std::ffi::c_char;
3use std::ffi::c_void;
4use std::path::Path;
5use std::path::PathBuf;
6use std::ptr::null_mut;
7use std::slice;
8
9use derive_builder::Builder;
10use diffusion_rs_sys::free_upscaler_ctx;
11use diffusion_rs_sys::new_upscaler_ctx;
12use diffusion_rs_sys::sd_ctx_params_t;
13use diffusion_rs_sys::sd_guidance_params_t;
14use diffusion_rs_sys::sd_image_t;
15use diffusion_rs_sys::sd_img_gen_params_t;
16use diffusion_rs_sys::sd_sample_params_t;
17use diffusion_rs_sys::sd_slg_params_t;
18use diffusion_rs_sys::upscaler_ctx_t;
19use image::ImageBuffer;
20use image::ImageError;
21use image::RgbImage;
22use libc::free;
23use thiserror::Error;
24
25use diffusion_rs_sys::free_sd_ctx;
26use diffusion_rs_sys::new_sd_ctx;
27use diffusion_rs_sys::sd_ctx_t;
28
29/// Specify the range function
30pub use diffusion_rs_sys::rng_type_t as RngFunction;
31
32/// Sampling methods
33pub use diffusion_rs_sys::sample_method_t as SampleMethod;
34
35/// Denoiser sigma schedule
36pub use diffusion_rs_sys::scheduler_t as Scheduler;
37
38/// Weight type
39pub use diffusion_rs_sys::sd_type_t as WeightType;
40
41#[non_exhaustive]
42#[derive(Error, Debug)]
43/// Error that can occurs while forwarding models
44pub enum DiffusionError {
45    #[error("The underling stablediffusion.cpp function returned NULL")]
46    Forward,
47    #[error(transparent)]
48    StoreImages(#[from] ImageError),
49    #[error("The underling upscaler model returned a NULL image")]
50    Upscaler,
51}
52
53#[repr(i32)]
54#[non_exhaustive]
55#[derive(Debug, Default, Copy, Clone, Hash, PartialEq, Eq)]
56/// Ignore the lower X layers of CLIP network
57pub enum ClipSkip {
58    /// Will be [ClipSkip::None] for SD1.x, [ClipSkip::OneLayer] for SD2.x
59    #[default]
60    Unspecified = 0,
61    None = 1,
62    OneLayer = 2,
63}
64
65#[derive(Builder, Debug, Clone)]
66#[builder(
67    setter(into, strip_option),
68    build_fn(error = "ConfigBuilderError", validate = "Self::validate")
69)]
70pub struct ModelConfig {
71    /// Number of threads to use during computation (default: 0).
72    /// If n_ threads <= 0, then threads will be set to the number of CPU physical cores.
73    #[builder(default = "num_cpus::get_physical() as i32", setter(custom))]
74    n_threads: i32,
75
76    /// Place the weights in RAM to save VRAM, and automatically load them into VRAM when needed
77    #[builder(default = "false")]
78    offload_params_to_cpu: bool,
79
80    /// Path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now
81    #[builder(default = "Default::default()")]
82    upscale_model: Option<CLibPath>,
83
84    /// Run the ESRGAN upscaler this many times (default 1)
85    #[builder(default = "0")]
86    upscale_repeats: i32,
87
88    /// Path to full model
89    #[builder(default = "Default::default()")]
90    model: CLibPath,
91
92    /// Path to the standalone diffusion model
93    #[builder(default = "Default::default()")]
94    diffusion_model: CLibPath,
95
96    /// Path to the clip-l text encoder
97    #[builder(default = "Default::default()")]
98    clip_l: CLibPath,
99
100    /// Path to the clip-g text encoder
101    #[builder(default = "Default::default()")]
102    clip_g: CLibPath,
103
104    /// Path to the clip-vision encoder
105    #[builder(default = "Default::default()")]
106    clip_vision: CLibPath,
107
108    /// Path to the t5xxl text encoder
109    #[builder(default = "Default::default()")]
110    t5xxl: CLibPath,
111
112    /// Path to vae
113    #[builder(default = "Default::default()")]
114    vae: CLibPath,
115
116    /// Path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)
117    #[builder(default = "Default::default()")]
118    taesd: CLibPath,
119
120    /// Path to control net model
121    #[builder(default = "Default::default()")]
122    control_net: CLibPath,
123
124    /// Path to embeddings
125    #[builder(default = "Default::default()")]
126    embeddings: CLibPath,
127
128    /// Path to PHOTOMAKER stacked id embeddings
129    #[builder(default = "Default::default()")]
130    stacked_id_embd: CLibPath,
131
132    /// Weight type. If not specified, the default is the type of the weight file
133    #[builder(default = "WeightType::SD_TYPE_COUNT")]
134    weight_type: WeightType,
135
136    /// Lora model directory
137    #[builder(default = "Default::default()", setter(custom))]
138    lora_model: CLibPath,
139
140    /// Path to the standalone high noise diffusion model
141    #[builder(default = "Default::default()")]
142    high_noise_diffusion_model: CLibPath,
143
144    /// Suffix that needs to be added to prompt (e.g. lora model)
145    #[builder(default = "None", private)]
146    prompt_suffix: Option<String>,
147
148    /// Process vae in tiles to reduce memory usage (default: false)
149    #[builder(default = "false")]
150    vae_tiling: bool,
151
152    /// RNG (default: CUDA)
153    #[builder(default = "RngFunction::CUDA_RNG")]
154    rng: RngFunction,
155
156    /// Denoiser sigma schedule (default: DEFAULT)
157    #[builder(default = "Scheduler::DEFAULT")]
158    scheduler: Scheduler,
159
160    /// Keep vae in cpu (for low vram) (default: false)
161    #[builder(default = "false")]
162    vae_on_cpu: bool,
163
164    /// keep clip in cpu (for low vram) (default: false)
165    #[builder(default = "false")]
166    clip_on_cpu: bool,
167
168    /// Keep controlnet in cpu (for low vram) (default: false)
169    #[builder(default = "false")]
170    control_net_cpu: bool,
171
172    /// Use flash attention in the diffusion model (for low vram).
173    /// Might lower quality, since it implies converting k and v to f16.
174    /// This might crash if it is not supported by the backend.
175    #[builder(default = "false")]
176    flash_attention: bool,
177
178    /// Disable dit mask for chroma
179    #[builder(default = "false")]
180    chroma_disable_dit_mask: bool,
181
182    /// Enable t5 mask for chroma
183    #[builder(default = "false")]
184    chroma_enable_t5_mask: bool,
185
186    /// t5 mask pad size of chroma
187    #[builder(default = "1")]
188    chroma_t5_mask_pad: i32,
189
190    /// Use Conv2d direct in the diffusion model
191    /// This might crash if it is not supported by the backend.
192    #[builder(default = "false")]
193    diffusion_conv_direct: bool,
194
195    /// Use Conv2d direct in the vae model (should improve the performance)
196    /// This might crash if it is not supported by the backend.
197    #[builder(default = "false")]
198    vae_conv_direct: bool,
199
200    /// Shift value for Flow models like SD3.x or WAN (default: auto)
201    #[builder(default = "f32::INFINITY")]
202    flow_shift: f32,
203
204    #[builder(default = "None", private)]
205    upscaler_ctx: Option<*mut upscaler_ctx_t>,
206
207    #[builder(default = "None", private)]
208    diffusion_ctx: Option<(*mut sd_ctx_t, sd_ctx_params_t)>,
209}
210
211impl ModelConfigBuilder {
212    fn validate(&self) -> Result<(), ConfigBuilderError> {
213        self.validate_model()
214    }
215
216    fn validate_model(&self) -> Result<(), ConfigBuilderError> {
217        self.model
218            .as_ref()
219            .or(self.diffusion_model.as_ref())
220            .map(|_| ())
221            .ok_or(ConfigBuilderError::UninitializedField(
222                "Model OR DiffusionModel must be valorized",
223            ))
224    }
225
226    pub fn lora_model(&mut self, lora_model: &Path) -> &mut Self {
227        let folder = lora_model.parent().unwrap();
228        let file_name = lora_model.file_stem().unwrap().to_str().unwrap().to_owned();
229        self.prompt_suffix(format!("<lora:{file_name}:1>"));
230        self.lora_model = Some(folder.into());
231        self
232    }
233
234    pub fn n_threads(&mut self, value: i32) -> &mut Self {
235        self.n_threads = if value > 0 {
236            Some(value)
237        } else {
238            Some(num_cpus::get_physical() as i32)
239        };
240        self
241    }
242}
243
244impl ModelConfig {
245    unsafe fn upscaler_ctx(&mut self) -> Option<*mut upscaler_ctx_t> {
246        unsafe {
247            if self.upscale_model.is_none() || self.upscale_repeats == 0 {
248                None
249            } else {
250                if self.upscaler_ctx.is_none() {
251                    let upscaler = new_upscaler_ctx(
252                        self.upscale_model.as_ref().unwrap().as_ptr(),
253                        self.offload_params_to_cpu,
254                        self.diffusion_conv_direct,
255                        self.n_threads,
256                    );
257                    self.upscaler_ctx = Some(upscaler);
258                }
259                self.upscaler_ctx
260            }
261        }
262    }
263
264    unsafe fn diffusion_ctx(&mut self, vae_decode_only: bool) -> *mut sd_ctx_t {
265        unsafe {
266            if self.diffusion_ctx.is_none() {
267                let sd_ctx_params = sd_ctx_params_t {
268                    model_path: self.model.as_ptr(),
269                    clip_l_path: self.clip_l.as_ptr(),
270                    clip_g_path: self.clip_g.as_ptr(),
271                    clip_vision_path: self.clip_vision.as_ptr(),
272                    high_noise_diffusion_model_path: self.high_noise_diffusion_model.as_ptr(),
273                    t5xxl_path: self.t5xxl.as_ptr(),
274                    diffusion_model_path: self.diffusion_model.as_ptr(),
275                    vae_path: self.vae.as_ptr(),
276                    taesd_path: self.taesd.as_ptr(),
277                    control_net_path: self.control_net.as_ptr(),
278                    lora_model_dir: self.lora_model.as_ptr(),
279                    embedding_dir: self.embeddings.as_ptr(),
280                    stacked_id_embed_dir: self.stacked_id_embd.as_ptr(),
281                    vae_decode_only,
282                    vae_tiling: self.vae_tiling,
283                    free_params_immediately: false,
284                    n_threads: self.n_threads,
285                    wtype: self.weight_type,
286                    rng_type: self.rng,
287                    keep_clip_on_cpu: self.clip_on_cpu,
288                    keep_control_net_on_cpu: self.control_net_cpu,
289                    keep_vae_on_cpu: self.vae_on_cpu,
290                    diffusion_flash_attn: self.flash_attention,
291                    diffusion_conv_direct: self.diffusion_conv_direct,
292                    chroma_use_dit_mask: !self.chroma_disable_dit_mask,
293                    chroma_use_t5_mask: self.chroma_enable_t5_mask,
294                    chroma_t5_mask_pad: self.chroma_t5_mask_pad,
295                    vae_conv_direct: self.vae_conv_direct,
296                    offload_params_to_cpu: self.offload_params_to_cpu,
297                    flow_shift: self.flow_shift,
298                };
299                let ctx = new_sd_ctx(&sd_ctx_params);
300                self.diffusion_ctx = Some((ctx, sd_ctx_params))
301            }
302            self.diffusion_ctx.unwrap().0
303        }
304    }
305}
306
307impl Drop for ModelConfig {
308    fn drop(&mut self) {
309        //Clean-up CTX section
310        unsafe {
311            if let Some((sd_ctx, _)) = self.diffusion_ctx {
312                free_sd_ctx(sd_ctx);
313            }
314
315            if let Some(upscaler_ctx) = self.upscaler_ctx {
316                free_upscaler_ctx(upscaler_ctx);
317            }
318        }
319    }
320}
321
322#[derive(Builder, Debug, Clone)]
323#[builder(setter(into, strip_option), build_fn(validate = "Self::validate"))]
324/// Config struct common to all diffusion methods
325pub struct Config {
326    /// Path to PHOTOMAKER input id images dir
327    #[builder(default = "Default::default()")]
328    input_id_images: CLibPath,
329
330    /// Normalize PHOTOMAKER input id images
331    #[builder(default = "false")]
332    normalize_input: bool,
333
334    /// Path to the input image, required by img2img
335    #[builder(default = "Default::default()")]
336    init_img: CLibPath,
337
338    /// Path to image condition, control net
339    #[builder(default = "Default::default()")]
340    control_image: CLibPath,
341
342    /// Path to write result image to (default: ./output.png)
343    #[builder(default = "PathBuf::from(\"./output.png\")")]
344    output: PathBuf,
345
346    /// The prompt to render
347    prompt: String,
348
349    /// The negative prompt (default: "")
350    #[builder(default = "\"\".into()")]
351    negative_prompt: CLibString,
352
353    /// Unconditional guidance scale (default: 7.0)
354    #[builder(default = "7.0")]
355    cfg_scale: f32,
356
357    /// Distilled guidance scale for models with guidance input (default: 3.5)
358    #[builder(default = "3.5")]
359    guidance: f32,
360
361    /// Strength for noising/unnoising (default: 0.75)
362    #[builder(default = "0.75")]
363    strength: f32,
364
365    /// Strength for keeping input identity (default: 20%)
366    #[builder(default = "20.0")]
367    style_ratio: f32,
368
369    /// Strength to apply Control Net (default: 0.9)
370    /// 1.0 corresponds to full destruction of information in init
371    #[builder(default = "0.9")]
372    control_strength: f32,
373
374    /// Image height, in pixel space (default: 512)
375    #[builder(default = "512")]
376    height: i32,
377
378    /// Image width, in pixel space (default: 512)
379    #[builder(default = "512")]
380    width: i32,
381
382    /// Sampling-method (default: EULER_A)
383    #[builder(default = "SampleMethod::EULER_A")]
384    sampling_method: SampleMethod,
385
386    /// eta in DDIM, only for DDIM and TCD: (default: 0)
387    #[builder(default = "0.")]
388    eta: f32,
389
390    /// Number of sample steps (default: 20)
391    #[builder(default = "20")]
392    steps: i32,
393
394    /// RNG seed (default: 42, use random seed for < 0)
395    #[builder(default = "42")]
396    seed: i64,
397
398    /// Number of images to generate (default: 1)
399    #[builder(default = "1")]
400    batch_count: i32,
401
402    /// Ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)
403    /// <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x
404    #[builder(default = "ClipSkip::Unspecified")]
405    clip_skip: ClipSkip,
406
407    /// Apply canny preprocessor (edge detection) (default: false)
408    #[builder(default = "false")]
409    canny: bool,
410
411    /// skip layer guidance (SLG) scale, only for DiT models: (default: 0)
412    /// 0 means disabled, a value of 2.5 is nice for sd3.5 medium
413    #[builder(default = "0.")]
414    slg_scale: f32,
415
416    /// Layers to skip for SLG steps: (default: \[7,8,9\])
417    #[builder(default = "vec![7, 8, 9]")]
418    skip_layer: Vec<i32>,
419
420    /// SLG enabling point: (default: 0.01)
421    #[builder(default = "0.01")]
422    skip_layer_start: f32,
423
424    /// SLG disabling point: (default: 0.2)
425    #[builder(default = "0.2")]
426    skip_layer_end: f32,
427}
428
429impl ConfigBuilder {
430    fn validate(&self) -> Result<(), ConfigBuilderError> {
431        self.validate_output_dir()
432    }
433
434    fn validate_output_dir(&self) -> Result<(), ConfigBuilderError> {
435        let is_dir = self.output.as_ref().is_some_and(|val| val.is_dir());
436        let multiple_items = self.batch_count.as_ref().is_some_and(|val| *val > 1);
437        if is_dir == multiple_items {
438            Ok(())
439        } else {
440            Err(ConfigBuilderError::ValidationError(
441                "When batch_count > 1, output should point to folder and vice versa".to_owned(),
442            ))
443        }
444    }
445}
446
447impl From<Config> for ConfigBuilder {
448    fn from(value: Config) -> Self {
449        let mut builder = ConfigBuilder::default();
450        builder
451            .input_id_images(value.input_id_images)
452            .normalize_input(value.normalize_input)
453            .init_img(value.init_img)
454            .control_image(value.control_image)
455            .output(value.output)
456            .prompt(value.prompt)
457            .negative_prompt(value.negative_prompt)
458            .cfg_scale(value.cfg_scale)
459            .strength(value.strength)
460            .style_ratio(value.style_ratio)
461            .control_strength(value.control_strength)
462            .height(value.height)
463            .width(value.width)
464            .sampling_method(value.sampling_method)
465            .steps(value.steps)
466            .seed(value.seed)
467            .batch_count(value.batch_count)
468            .clip_skip(value.clip_skip)
469            .slg_scale(value.slg_scale)
470            .skip_layer(value.skip_layer)
471            .skip_layer_start(value.skip_layer_start)
472            .skip_layer_end(value.skip_layer_end)
473            .canny(value.canny);
474
475        builder
476    }
477}
478
479#[derive(Debug, Clone, Default)]
480struct CLibString(CString);
481
482impl CLibString {
483    fn as_ptr(&self) -> *const c_char {
484        self.0.as_ptr()
485    }
486}
487
488impl From<&str> for CLibString {
489    fn from(value: &str) -> Self {
490        Self(CString::new(value).unwrap())
491    }
492}
493
494impl From<String> for CLibString {
495    fn from(value: String) -> Self {
496        Self(CString::new(value).unwrap())
497    }
498}
499
500#[derive(Debug, Clone, Default)]
501struct CLibPath(CString);
502
503impl CLibPath {
504    fn as_ptr(&self) -> *const c_char {
505        self.0.as_ptr()
506    }
507}
508
509impl From<PathBuf> for CLibPath {
510    fn from(value: PathBuf) -> Self {
511        Self(CString::new(value.to_str().unwrap_or_default()).unwrap())
512    }
513}
514
515impl From<&Path> for CLibPath {
516    fn from(value: &Path) -> Self {
517        Self(CString::new(value.to_str().unwrap_or_default()).unwrap())
518    }
519}
520
521fn output_files(path: &Path, prompt: &str, batch_size: i32) -> Vec<PathBuf> {
522    if batch_size == 1 {
523        vec![path.into()]
524    } else {
525        (1..=batch_size)
526            .map(|id| path.join(format!("{prompt}_{id}.png")))
527            .collect()
528    }
529}
530
531unsafe fn upscale(
532    upscale_repeats: i32,
533    upscaler_ctx: Option<*mut upscaler_ctx_t>,
534    data: sd_image_t,
535) -> Result<sd_image_t, DiffusionError> {
536    unsafe {
537        match upscaler_ctx {
538            Some(upscaler_ctx) => {
539                let upscale_factor = 4; // unused for RealESRGAN_x4plus_anime_6B.pth
540                let mut current_image = data;
541                for _ in 0..upscale_repeats {
542                    let upscaled_image =
543                        diffusion_rs_sys::upscale(upscaler_ctx, current_image, upscale_factor);
544
545                    if upscaled_image.data.is_null() {
546                        return Err(DiffusionError::Upscaler);
547                    }
548
549                    free(current_image.data as *mut c_void);
550                    current_image = upscaled_image;
551                }
552                Ok(current_image)
553            }
554            None => Ok(data),
555        }
556    }
557}
558
559/// Generate an image with a prompt
560pub fn gen_img(config: &mut Config, model_config: &mut ModelConfig) -> Result<(), DiffusionError> {
561    let prompt: CLibString = match &model_config.prompt_suffix {
562        Some(suffix) => format!("{} {suffix}", &config.prompt),
563        None => config.prompt.clone(),
564    }
565    .into();
566    let files = output_files(&config.output, &config.prompt, config.batch_count);
567    unsafe {
568        let sd_ctx = model_config.diffusion_ctx(true);
569        let upscaler_ctx = model_config.upscaler_ctx();
570        let init_image = sd_image_t {
571            width: 0,
572            height: 0,
573            channel: 3,
574            data: null_mut(),
575        };
576        let mask_image = sd_image_t {
577            width: config.width as u32,
578            height: config.height as u32,
579            channel: 1,
580            data: null_mut(),
581        };
582        let guidance = sd_guidance_params_t {
583            txt_cfg: config.cfg_scale,
584            img_cfg: config.cfg_scale,
585            distilled_guidance: config.guidance,
586            slg: sd_slg_params_t {
587                layers: config.skip_layer.as_mut_ptr(),
588                layer_count: config.skip_layer.len(),
589                layer_start: config.skip_layer_start,
590                layer_end: config.skip_layer_end,
591                scale: config.slg_scale,
592            },
593        };
594        let sample_params = sd_sample_params_t {
595            guidance,
596            sample_method: config.sampling_method,
597            sample_steps: config.steps,
598            eta: config.eta,
599            scheduler: model_config.scheduler,
600        };
601        let control_image = sd_image_t {
602            width: 0,
603            height: 0,
604            channel: 3,
605            data: null_mut(),
606        };
607
608        let sd_img_gen_params = sd_img_gen_params_t {
609            prompt: prompt.as_ptr(),
610            negative_prompt: config.negative_prompt.as_ptr(),
611            clip_skip: config.clip_skip as i32,
612            init_image,
613            ref_images: null_mut(),
614            ref_images_count: 0,
615            increase_ref_index: false,
616            mask_image,
617            width: config.width,
618            height: config.height,
619            sample_params,
620            strength: config.strength,
621            seed: config.seed,
622            batch_count: config.batch_count,
623            control_image,
624            control_strength: config.control_strength,
625            style_strength: config.style_ratio,
626            normalize_input: config.normalize_input,
627            input_id_images_path: config.input_id_images.as_ptr(),
628        };
629
630        let slice = diffusion_rs_sys::generate_image(sd_ctx, &sd_img_gen_params);
631        if slice.is_null() {
632            return Err(DiffusionError::Forward);
633        }
634        for (img, path) in slice::from_raw_parts(slice, config.batch_count as usize)
635            .iter()
636            .zip(files)
637        {
638            match upscale(model_config.upscale_repeats, upscaler_ctx, *img) {
639                Ok(img) => {
640                    // Thx @wandbrandon
641                    let len = (img.width * img.height * img.channel) as usize;
642                    let buffer = slice::from_raw_parts(img.data, len).to_vec();
643                    let save_state = ImageBuffer::from_raw(img.width, img.height, buffer)
644                        .map(|img| RgbImage::from(img).save(path));
645                    if let Some(Err(err)) = save_state {
646                        return Err(DiffusionError::StoreImages(err));
647                    }
648                }
649                Err(err) => {
650                    return Err(err);
651                }
652            }
653        }
654
655        //Clean-up slice section
656        free(slice as *mut c_void);
657        Ok(())
658    }
659}
660
661#[cfg(test)]
662mod tests {
663    use std::path::PathBuf;
664
665    use crate::{
666        api::{ConfigBuilderError, ModelConfigBuilder},
667        util::download_file_hf_hub,
668    };
669
670    use super::{ConfigBuilder, gen_img};
671
672    #[test]
673    fn test_required_args_txt2img() {
674        assert!(ConfigBuilder::default().build().is_err());
675        assert!(ModelConfigBuilder::default().build().is_err());
676        ModelConfigBuilder::default()
677            .model(PathBuf::from("./test.ckpt"))
678            .build()
679            .unwrap();
680
681        ConfigBuilder::default()
682            .prompt("a lovely cat driving a sport car")
683            .build()
684            .unwrap();
685
686        assert!(matches!(
687            ConfigBuilder::default()
688                .prompt("a lovely cat driving a sport car")
689                .batch_count(10)
690                .build(),
691            Err(ConfigBuilderError::ValidationError(_))
692        ));
693
694        ConfigBuilder::default()
695            .prompt("a lovely cat driving a sport car")
696            .build()
697            .unwrap();
698
699        ConfigBuilder::default()
700            .prompt("a lovely duck drinking water from a bottle")
701            .batch_count(2)
702            .output(PathBuf::from("./"))
703            .build()
704            .unwrap();
705    }
706
707    #[ignore]
708    #[test]
709    fn test_img_gen() {
710        let model_path =
711            download_file_hf_hub("CompVis/stable-diffusion-v-1-4-original", "sd-v1-4.ckpt")
712                .unwrap();
713
714        let upscaler_path = download_file_hf_hub(
715            "ximso/RealESRGAN_x4plus_anime_6B",
716            "RealESRGAN_x4plus_anime_6B.pth",
717        )
718        .unwrap();
719        let mut config = ConfigBuilder::default()
720            .prompt("a lovely duck drinking water from a bottle")
721            .output(PathBuf::from("./output_1.png"))
722            .batch_count(1)
723            .build()
724            .unwrap();
725        let mut model_config = ModelConfigBuilder::default()
726            .model(model_path)
727            .upscale_model(upscaler_path)
728            .upscale_repeats(1)
729            .build()
730            .unwrap();
731
732        gen_img(&mut config, &mut model_config).unwrap();
733        let mut config2 = ConfigBuilder::from(config.clone())
734            .prompt("a lovely duck drinking water from a straw")
735            .output(PathBuf::from("./output_2.png"))
736            .build()
737            .unwrap();
738        gen_img(&mut config2, &mut model_config).unwrap();
739
740        let mut config3 = ConfigBuilder::from(config)
741            .prompt("a lovely dog drinking water from a starbucks cup")
742            .batch_count(2)
743            .output(PathBuf::from("./"))
744            .build()
745            .unwrap();
746
747        gen_img(&mut config3, &mut model_config).unwrap();
748    }
749}