diffusion_rs/
api.rs

1use std::cmp::max;
2use std::collections::HashMap;
3use std::ffi::CString;
4use std::ffi::c_char;
5use std::ffi::c_void;
6use std::fmt::Display;
7use std::path::Path;
8use std::path::PathBuf;
9use std::ptr::null;
10use std::ptr::null_mut;
11use std::slice;
12
13use derive_builder::Builder;
14use diffusion_rs_sys::free_upscaler_ctx;
15use diffusion_rs_sys::new_upscaler_ctx;
16use diffusion_rs_sys::sd_cache_mode_t;
17use diffusion_rs_sys::sd_cache_params_t;
18use diffusion_rs_sys::sd_ctx_params_t;
19use diffusion_rs_sys::sd_embedding_t;
20use diffusion_rs_sys::sd_get_default_sample_method;
21use diffusion_rs_sys::sd_get_default_scheduler;
22use diffusion_rs_sys::sd_guidance_params_t;
23use diffusion_rs_sys::sd_image_t;
24use diffusion_rs_sys::sd_img_gen_params_t;
25use diffusion_rs_sys::sd_lora_t;
26use diffusion_rs_sys::sd_pm_params_t;
27use diffusion_rs_sys::sd_sample_params_t;
28use diffusion_rs_sys::sd_set_preview_callback;
29use diffusion_rs_sys::sd_slg_params_t;
30use diffusion_rs_sys::sd_tiling_params_t;
31use diffusion_rs_sys::upscaler_ctx_t;
32use image::ImageBuffer;
33use image::ImageError;
34use image::RgbImage;
35use libc::free;
36use thiserror::Error;
37use walkdir::DirEntry;
38use walkdir::WalkDir;
39
40use diffusion_rs_sys::free_sd_ctx;
41use diffusion_rs_sys::new_sd_ctx;
42use diffusion_rs_sys::sd_ctx_t;
43
44/// Specify the range function
45pub use diffusion_rs_sys::rng_type_t as RngFunction;
46
47/// Sampling methods
48pub use diffusion_rs_sys::sample_method_t as SampleMethod;
49
50/// Denoiser sigma schedule
51pub use diffusion_rs_sys::scheduler_t as Scheduler;
52
53/// Prediction override
54pub use diffusion_rs_sys::prediction_t as Prediction;
55
56/// Weight type
57pub use diffusion_rs_sys::sd_type_t as WeightType;
58
59/// Preview mode
60pub use diffusion_rs_sys::preview_t as PreviewType;
61
62/// Lora mode
63pub use diffusion_rs_sys::lora_apply_mode_t as LoraModeType;
64
65static VALID_EXT: [&str; 3] = ["pt", "safetensors", "gguf"];
66
67#[non_exhaustive]
68#[derive(Error, Debug)]
69/// Error that can occurs while forwarding models
70pub enum DiffusionError {
71    #[error("The underling stablediffusion.cpp function returned NULL")]
72    Forward,
73    #[error(transparent)]
74    StoreImages(#[from] ImageError),
75    #[error("The underling upscaler model returned a NULL image")]
76    Upscaler,
77}
78
79#[repr(i32)]
80#[non_exhaustive]
81#[derive(Debug, Default, Copy, Clone, Hash, PartialEq, Eq)]
82/// Ignore the lower X layers of CLIP network
83pub enum ClipSkip {
84    /// Will be [ClipSkip::None] for SD1.x, [ClipSkip::OneLayer] for SD2.x
85    #[default]
86    Unspecified = 0,
87    None = 1,
88    OneLayer = 2,
89}
90
91type EmbeddingsStorage = (PathBuf, Vec<(CLibString, CLibPath)>, Vec<sd_embedding_t>);
92
93#[derive(Default, Debug, Clone)]
94struct LoraStorage {
95    lora_model_dir: CLibPath,
96    data: Vec<(CLibPath, String, f32)>,
97    loras_t: Vec<sd_lora_t>,
98}
99
100/// Specify the instructions for a Lora model
101#[derive(Default, Debug, Clone)]
102pub struct LoraSpec {
103    pub file_name: String,
104    pub is_high_noise: bool,
105    pub multiplier: f32,
106}
107
108/// Parameters for UCache
109#[derive(Builder, Debug, Clone)]
110pub struct UCacheParams {
111    /// Error threshold for reuse decision
112    #[builder(default = "1.0")]
113    threshold: f32,
114
115    /// Start caching at this percent of steps
116    #[builder(default = "0.15")]
117    start: f32,
118
119    /// Stop caching at this percent of steps
120    #[builder(default = "0.95")]
121    end: f32,
122
123    /// Error decay rate (0-1)
124    #[builder(default = "1.0")]
125    decay: f32,
126
127    /// Scale threshold by output norm
128    #[builder(default = "true")]
129    relative: bool,
130
131    /// Reset error after computing
132    /// true: Resets accumulated error after each computed step. More aggressive caching, works well with most samplers.
133    /// false: Keeps error accumulated. More conservative, recommended for euler_a sampler
134    #[builder(default = "true")]
135    reset: bool,
136}
137
138/// Parameters for Easy Cache
139#[derive(Builder, Debug, Clone)]
140pub struct EasyCacheParams {
141    /// Error threshold for reuse decision
142    #[builder(default = "0.2")]
143    threshold: f32,
144
145    /// Start caching at this percent of steps
146    #[builder(default = "0.15")]
147    start: f32,
148
149    /// Stop caching at this percent of steps
150    #[builder(default = "0.95")]
151    end: f32,
152}
153
154/// Parameters for Db Cache
155#[derive(Builder, Debug, Clone)]
156pub struct DbCacheParams {
157    /// Front blocks to always compute
158    #[builder(default = "8")]
159    fn_blocks: i32,
160
161    /// Back blocks to always compute
162    #[builder(default = "0")]
163    bn_blocks: i32,
164
165    /// L1 residual difference threshold
166    #[builder(default = "0.08")]
167    threshold: f32,
168
169    /// Steps before caching starts
170    #[builder(default = "8")]
171    warmup: i32,
172
173    /// Steps Computation Mask controls which steps can be cached
174    scm_mask: ScmPreset,
175
176    /// Scm Policy
177    #[builder(default = "ScmPolicy::default()")]
178    scm_policy_dynamic: ScmPolicy,
179}
180
181/// Steps Computation Mask Policy controls when to cache steps
182#[derive(Debug, Default, Clone)]
183pub enum ScmPolicy {
184    /// Always cache on cacheable steps
185    Static,
186    #[default]
187    /// Check threshold before caching
188    Dynamic,
189}
190
191/// Steps Computation Mask Preset controls which steps can be cached
192#[derive(Debug, Default, Clone)]
193pub enum ScmPreset {
194    Slow,
195    #[default]
196    Medium,
197    Fast,
198    Ultra,
199    /// E.g.: "1,1,1,1,0,0,1,0,0,0,1,0,0,0,1,0,0,0,1,1"
200    /// where 1 means compute, 0 means cache
201    Custom(String),
202}
203
204impl ScmPreset {
205    fn to_vec_string(&self, steps: i32) -> String {
206        match self {
207            ScmPreset::Slow => ScmPresetBins {
208                compute_bins: vec![8, 3, 3, 2, 1, 1],
209                cache_bins: vec![1, 2, 2, 2, 3],
210                steps,
211            }
212            .to_string(),
213            ScmPreset::Medium => ScmPresetBins {
214                compute_bins: vec![6, 2, 2, 2, 2, 1],
215                cache_bins: vec![1, 3, 3, 3, 3],
216                steps,
217            }
218            .to_string(),
219            ScmPreset::Fast => ScmPresetBins {
220                compute_bins: vec![6, 1, 1, 1, 1, 1],
221                cache_bins: vec![1, 3, 4, 5, 4],
222                steps,
223            }
224            .to_string(),
225            ScmPreset::Ultra => ScmPresetBins {
226                compute_bins: vec![4, 1, 1, 1, 1],
227                cache_bins: vec![2, 5, 6, 7],
228                steps,
229            }
230            .to_string(),
231            ScmPreset::Custom(s) => s.clone(),
232        }
233    }
234}
235
236#[derive(Debug, Clone)]
237struct ScmPresetBins {
238    compute_bins: Vec<i32>,
239    cache_bins: Vec<i32>,
240    steps: i32,
241}
242
243impl ScmPresetBins {
244    fn maybe_scale(&self) -> ScmPresetBins {
245        if self.steps == 28 || self.steps <= 0 {
246            return self.clone();
247        }
248        self.scale()
249    }
250
251    fn scale(&self) -> ScmPresetBins {
252        let scale = self.steps as f32 / 28.0;
253        let scaled_compute_bins = self
254            .compute_bins
255            .iter()
256            .map(|b| max(1, (*b as f32 * scale * 0.5) as i32))
257            .collect();
258        let scaled_cached_bins = self
259            .cache_bins
260            .iter()
261            .map(|b| max(1, (*b as f32 * scale * 0.5) as i32))
262            .collect();
263        ScmPresetBins {
264            compute_bins: scaled_compute_bins,
265            cache_bins: scaled_cached_bins,
266            steps: self.steps,
267        }
268    }
269
270    fn generate_vec_mask(&self) -> Vec<i32> {
271        let mut mask = Vec::new();
272        let mut c_idx = 0;
273        let mut cache_idx = 0;
274
275        while mask.len() < self.steps as usize {
276            if c_idx < self.compute_bins.len() {
277                let compute_count = self.compute_bins[c_idx];
278                for _ in 0..compute_count {
279                    if mask.len() < self.steps as usize {
280                        mask.push(1);
281                    }
282                }
283                c_idx += 1;
284            }
285            if cache_idx < self.cache_bins.len() {
286                let cache_count = self.cache_bins[c_idx];
287                for _ in 0..cache_count {
288                    if mask.len() < self.steps as usize {
289                        mask.push(0);
290                    }
291                }
292                cache_idx += 1;
293            }
294            if c_idx >= self.compute_bins.len() && cache_idx >= self.cache_bins.len() {
295                break;
296            }
297        }
298        if let Some(last) = mask.last_mut() {
299            *last = 1;
300        }
301        mask
302    }
303}
304
305impl Display for ScmPresetBins {
306    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
307        let mask: String = self
308            .maybe_scale()
309            .generate_vec_mask()
310            .iter()
311            .map(|x| x.to_string())
312            .collect::<Vec<_>>()
313            .join(",");
314        write!(f, "{mask}")
315    }
316}
317
318/// Config struct for a specific diffusion model
319#[derive(Builder, Debug, Clone)]
320#[builder(
321    setter(into, strip_option),
322    build_fn(error = "ConfigBuilderError", validate = "Self::validate")
323)]
324pub struct ModelConfig {
325    /// Number of threads to use during computation (default: 0).
326    /// If n_ threads <= 0, then threads will be set to the number of CPU physical cores.
327    #[builder(default = "num_cpus::get_physical() as i32", setter(custom))]
328    n_threads: i32,
329
330    /// Place the weights in RAM to save VRAM, and automatically load them into VRAM when needed
331    #[builder(default = "false")]
332    offload_params_to_cpu: bool,
333
334    /// Path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now
335    #[builder(default = "Default::default()")]
336    upscale_model: Option<CLibPath>,
337
338    /// Run the ESRGAN upscaler this many times (default 1)
339    #[builder(default = "1")]
340    upscale_repeats: i32,
341
342    /// Tile size for ESRGAN upscaler (default 128)
343    #[builder(default = "128")]
344    upscale_tile_size: i32,
345
346    /// Path to full model
347    #[builder(default = "Default::default()")]
348    model: CLibPath,
349
350    /// Path to the standalone diffusion model
351    #[builder(default = "Default::default()")]
352    diffusion_model: CLibPath,
353
354    /// Path to the qwen2vl text encoder
355    #[builder(default = "Default::default()")]
356    llm: CLibPath,
357
358    /// Path to the qwen2vl vit
359    #[builder(default = "Default::default()")]
360    llm_vision: CLibPath,
361
362    /// Path to the clip-l text encoder
363    #[builder(default = "Default::default()")]
364    clip_l: CLibPath,
365
366    /// Path to the clip-g text encoder
367    #[builder(default = "Default::default()")]
368    clip_g: CLibPath,
369
370    /// Path to the clip-vision encoder
371    #[builder(default = "Default::default()")]
372    clip_vision: CLibPath,
373
374    /// Path to the t5xxl text encoder
375    #[builder(default = "Default::default()")]
376    t5xxl: CLibPath,
377
378    /// Path to vae
379    #[builder(default = "Default::default()")]
380    vae: CLibPath,
381
382    /// Path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)
383    #[builder(default = "Default::default()")]
384    taesd: CLibPath,
385
386    /// Path to control net model
387    #[builder(default = "Default::default()")]
388    control_net: CLibPath,
389
390    /// Path to embeddings
391    #[builder(default = "Default::default()", setter(custom))]
392    embeddings: EmbeddingsStorage,
393
394    /// Path to PHOTOMAKER model
395    #[builder(default = "Default::default()")]
396    photo_maker: CLibPath,
397
398    /// Path to PHOTOMAKER v2 id embed
399    #[builder(default = "Default::default()")]
400    pm_id_embed_path: CLibPath,
401
402    /// Weight type. If not specified, the default is the type of the weight file
403    #[builder(default = "WeightType::SD_TYPE_COUNT")]
404    weight_type: WeightType,
405
406    /// Lora model directory
407    #[builder(default = "Default::default()", setter(custom))]
408    lora_models: LoraStorage,
409
410    /// Path to the standalone high noise diffusion model
411    #[builder(default = "Default::default()")]
412    high_noise_diffusion_model: CLibPath,
413
414    /// Process vae in tiles to reduce memory usage (default: false)
415    #[builder(default = "false")]
416    vae_tiling: bool,
417
418    /// Tile size for vae tiling (default: 32x32)
419    #[builder(default = "(32,32)")]
420    vae_tile_size: (i32, i32),
421
422    /// Relative tile size for vae tiling, in fraction of image size if < 1, in number of tiles per dim if >=1 (overrides vae_tile_size)
423    #[builder(default = "(0.,0.)")]
424    vae_relative_tile_size: (f32, f32),
425
426    /// Tile overlap for vae tiling, in fraction of tile size (default: 0.5)
427    #[builder(default = "0.5")]
428    vae_tile_overlap: f32,
429
430    /// RNG (default: CUDA)
431    #[builder(default = "RngFunction::CUDA_RNG")]
432    rng: RngFunction,
433
434    /// Sampler RNG. If [RngFunction::RNG_TYPE_COUNT] is used will default to rng value. (default: [RngFunction::RNG_TYPE_COUNT])",
435    #[builder(default = "RngFunction::RNG_TYPE_COUNT")]
436    sampler_rng_type: RngFunction,
437
438    /// Denoiser sigma schedule (default: [Scheduler::SCHEDULER_COUNT]).
439    /// Will default to [Scheduler::EXPONENTIAL_SCHEDULER] if a denoiser is already instantiated.
440    /// Otherwise, [Scheduler::DISCRETE_SCHEDULER] is used.
441    #[builder(default = "Scheduler::SCHEDULER_COUNT")]
442    scheduler: Scheduler,
443
444    /// Custom sigma values for the sampler
445    #[builder(default = "Default::default()")]
446    sigmas: Vec<f32>,
447
448    /// Prediction type override (default: PREDICTION_COUNT)
449    #[builder(default = "Prediction::PREDICTION_COUNT")]
450    prediction: Prediction,
451
452    /// Keep vae in cpu (for low vram) (default: false)
453    #[builder(default = "false")]
454    vae_on_cpu: bool,
455
456    /// keep clip in cpu (for low vram) (default: false)
457    #[builder(default = "false")]
458    clip_on_cpu: bool,
459
460    /// Keep controlnet in cpu (for low vram) (default: false)
461    #[builder(default = "false")]
462    control_net_cpu: bool,
463
464    /// Use flash attention to reduce memory usage (for low vram).
465    // /For most backends, it slows things down, but for cuda it generally speeds it up too. At the moment, it is only supported for some models and some backends (like cpu, cuda/rocm, metal).
466    #[builder(default = "false")]
467    flash_attention: bool,
468
469    /// Disable dit mask for chroma
470    #[builder(default = "false")]
471    chroma_disable_dit_mask: bool,
472
473    /// Enable t5 mask for chroma
474    #[builder(default = "false")]
475    chroma_enable_t5_mask: bool,
476
477    /// t5 mask pad size of chroma
478    #[builder(default = "1")]
479    chroma_t5_mask_pad: i32,
480
481    /// Use qwen image zero cond true optimization
482    #[builder(default = "false")]
483    use_qwen_image_zero_cond_true: bool,
484
485    /// Use Conv2d direct in the diffusion model
486    /// This might crash if it is not supported by the backend.
487    #[builder(default = "false")]
488    diffusion_conv_direct: bool,
489
490    /// Use Conv2d direct in the vae model (should improve the performance)
491    /// This might crash if it is not supported by the backend.
492    #[builder(default = "false")]
493    vae_conv_direct: bool,
494
495    /// Force use of conv scale on sdxl vae
496    #[builder(default = "false")]
497    force_sdxl_vae_conv_scale: bool,
498
499    /// Shift value for Flow models like SD3.x or WAN (default: auto)
500    #[builder(default = "f32::INFINITY")]
501    flow_shift: f32,
502
503    /// Shift timestep for NitroFusion models, default: 0, recommended N for NitroSD-Realism around 250 and 500 for NitroSD-Vibrant
504    #[builder(default = "0")]
505    timestep_shift: i32,
506
507    /// Prevents usage of taesd for decoding the final image
508    #[builder(default = "false")]
509    taesd_preview_only: bool,
510
511    /// In auto mode, if the model weights contain any quantized parameters, the at_runtime mode will be used; otherwise, immediately will be used.The immediate mode may have precision and compatibility issues with quantized parameters, but it usually offers faster inference speed and, in some cases, lower memory usage. The at_runtime mode, on the other hand, is exactly the opposite
512    #[builder(default = "LoraModeType::LORA_APPLY_AUTO")]
513    lora_apply_mode: LoraModeType,
514
515    /// Enable circular padding for convolutions
516    #[builder(default = "false")]
517    circular: bool,
518
519    /// Enable circular RoPE wrapping on x-axis (width) only
520    #[builder(default = "false")]
521    circular_x: bool,
522
523    /// Enable circular RoPE wrapping on y-axis (height) only
524    #[builder(default = "false")]
525    circular_y: bool,
526
527    #[builder(default = "None", private)]
528    upscaler_ctx: Option<*mut upscaler_ctx_t>,
529
530    #[builder(default = "None", private)]
531    diffusion_ctx: Option<(*mut sd_ctx_t, sd_ctx_params_t)>,
532}
533
534impl ModelConfigBuilder {
535    fn validate(&self) -> Result<(), ConfigBuilderError> {
536        self.validate_model()
537    }
538
539    fn validate_model(&self) -> Result<(), ConfigBuilderError> {
540        self.model
541            .as_ref()
542            .or(self.diffusion_model.as_ref())
543            .map(|_| ())
544            .ok_or(ConfigBuilderError::UninitializedField(
545                "Model OR DiffusionModel must be valorized",
546            ))
547    }
548
549    fn filter_valid_extensions(&self, path: &Path) -> impl Iterator<Item = DirEntry> {
550        WalkDir::new(path)
551            .into_iter()
552            .filter_map(|entry| entry.ok())
553            .filter(|entry| {
554                entry
555                    .path()
556                    .extension()
557                    .and_then(|ext| ext.to_str())
558                    .map(|ext_str| VALID_EXT.contains(&ext_str))
559                    .unwrap_or(false)
560            })
561    }
562
563    fn build_single_lora_storage(
564        spec: &LoraSpec,
565        is_high_noise: bool,
566        valid_loras: &HashMap<String, PathBuf>,
567    ) -> ((CLibPath, String, f32), sd_lora_t) {
568        let path = valid_loras.get(&spec.file_name).unwrap().as_path();
569        let c_path = CLibPath::from(path);
570        let lora = sd_lora_t {
571            is_high_noise,
572            multiplier: spec.multiplier,
573            path: c_path.as_ptr(),
574        };
575        let data = (c_path, spec.file_name.clone(), spec.multiplier);
576        (data, lora)
577    }
578
579    pub fn embeddings(&mut self, embeddings_dir: &Path) -> &mut Self {
580        let data: Vec<(CLibString, CLibPath)> = self
581            .filter_valid_extensions(embeddings_dir)
582            .map(|entry| {
583                let file_stem = entry
584                    .path()
585                    .file_stem()
586                    .and_then(|stem| stem.to_str())
587                    .unwrap_or_default()
588                    .to_owned();
589                (CLibString::from(file_stem), CLibPath::from(entry.path()))
590            })
591            .collect();
592        let data_pointer = data
593            .iter()
594            .map(|(name, path)| sd_embedding_t {
595                name: name.as_ptr(),
596                path: path.as_ptr(),
597            })
598            .collect();
599        self.embeddings = Some((embeddings_dir.to_path_buf(), data, data_pointer));
600        self
601    }
602
603    pub fn lora_models(&mut self, lora_model_dir: &Path, specs: Vec<LoraSpec>) -> &mut Self {
604        let valid_loras: HashMap<String, PathBuf> = self
605            .filter_valid_extensions(lora_model_dir)
606            .map(|entry| {
607                let path = entry.path();
608                (
609                    path.file_stem()
610                        .and_then(|stem| stem.to_str())
611                        .unwrap_or_default()
612                        .to_owned(),
613                    path.to_path_buf(),
614                )
615            })
616            .collect();
617        let valid_lora_names: Vec<&String> = valid_loras.keys().collect();
618        let standard = specs
619            .iter()
620            .filter(|s| valid_lora_names.contains(&&s.file_name) && !s.is_high_noise)
621            .map(|s| Self::build_single_lora_storage(s, false, &valid_loras));
622        let high_noise = specs
623            .iter()
624            .filter(|s| valid_lora_names.contains(&&s.file_name) && s.is_high_noise)
625            .map(|s| Self::build_single_lora_storage(s, true, &valid_loras));
626
627        let mut data = Vec::new();
628        let mut loras_t = Vec::new();
629        for lora in standard.chain(high_noise) {
630            data.push(lora.0);
631            loras_t.push(lora.1);
632        }
633
634        self.lora_models = Some(LoraStorage {
635            lora_model_dir: lora_model_dir.into(),
636            data,
637            loras_t,
638        });
639        self
640    }
641
642    pub fn n_threads(&mut self, value: i32) -> &mut Self {
643        self.n_threads = if value > 0 {
644            Some(value)
645        } else {
646            Some(num_cpus::get_physical() as i32)
647        };
648        self
649    }
650}
651
652impl ModelConfig {
653    unsafe fn upscaler_ctx(&mut self) -> Option<*mut upscaler_ctx_t> {
654        unsafe {
655            if self.upscale_model.is_none() || self.upscale_repeats == 0 {
656                None
657            } else {
658                if self.upscaler_ctx.is_none() {
659                    let upscaler = new_upscaler_ctx(
660                        self.upscale_model.as_ref().unwrap().as_ptr(),
661                        self.offload_params_to_cpu,
662                        self.diffusion_conv_direct,
663                        self.n_threads,
664                        self.upscale_tile_size,
665                    );
666                    self.upscaler_ctx = Some(upscaler);
667                }
668                self.upscaler_ctx
669            }
670        }
671    }
672
673    unsafe fn diffusion_ctx(&mut self, vae_decode_only: bool) -> *mut sd_ctx_t {
674        unsafe {
675            if self.diffusion_ctx.is_none() {
676                let sd_ctx_params = sd_ctx_params_t {
677                    model_path: self.model.as_ptr(),
678                    llm_path: self.llm.as_ptr(),
679                    llm_vision_path: self.llm_vision.as_ptr(),
680                    clip_l_path: self.clip_l.as_ptr(),
681                    clip_g_path: self.clip_g.as_ptr(),
682                    clip_vision_path: self.clip_vision.as_ptr(),
683                    high_noise_diffusion_model_path: self.high_noise_diffusion_model.as_ptr(),
684                    t5xxl_path: self.t5xxl.as_ptr(),
685                    diffusion_model_path: self.diffusion_model.as_ptr(),
686                    vae_path: self.vae.as_ptr(),
687                    taesd_path: self.taesd.as_ptr(),
688                    control_net_path: self.control_net.as_ptr(),
689                    embeddings: self.embeddings.2.as_ptr(),
690                    embedding_count: self.embeddings.1.len() as u32,
691                    photo_maker_path: self.photo_maker.as_ptr(),
692                    vae_decode_only,
693                    free_params_immediately: false,
694                    n_threads: self.n_threads,
695                    wtype: self.weight_type,
696                    rng_type: self.rng,
697                    keep_clip_on_cpu: self.clip_on_cpu,
698                    keep_control_net_on_cpu: self.control_net_cpu,
699                    keep_vae_on_cpu: self.vae_on_cpu,
700                    diffusion_flash_attn: self.flash_attention,
701                    diffusion_conv_direct: self.diffusion_conv_direct,
702                    chroma_use_dit_mask: !self.chroma_disable_dit_mask,
703                    chroma_use_t5_mask: self.chroma_enable_t5_mask,
704                    chroma_t5_mask_pad: self.chroma_t5_mask_pad,
705                    vae_conv_direct: self.vae_conv_direct,
706                    offload_params_to_cpu: self.offload_params_to_cpu,
707                    flow_shift: self.flow_shift,
708                    prediction: self.prediction,
709                    force_sdxl_vae_conv_scale: self.force_sdxl_vae_conv_scale,
710                    tae_preview_only: self.taesd_preview_only,
711                    lora_apply_mode: self.lora_apply_mode,
712                    tensor_type_rules: null_mut(),
713                    sampler_rng_type: self.sampler_rng_type,
714                    circular_x: self.circular || self.circular_x,
715                    circular_y: self.circular || self.circular_y,
716                    qwen_image_zero_cond_t: self.use_qwen_image_zero_cond_true,
717                };
718                let ctx = new_sd_ctx(&sd_ctx_params);
719                self.diffusion_ctx = Some((ctx, sd_ctx_params))
720            }
721            self.diffusion_ctx.unwrap().0
722        }
723    }
724}
725
726impl Drop for ModelConfig {
727    fn drop(&mut self) {
728        //Cleanup CTX section
729        unsafe {
730            if let Some((sd_ctx, _)) = self.diffusion_ctx {
731                free_sd_ctx(sd_ctx);
732            }
733
734            if let Some(upscaler_ctx) = self.upscaler_ctx {
735                free_upscaler_ctx(upscaler_ctx);
736            }
737        }
738    }
739}
740
741impl From<ModelConfig> for ModelConfigBuilder {
742    fn from(value: ModelConfig) -> Self {
743        let mut builder = ModelConfigBuilder::default();
744        builder
745            .n_threads(value.n_threads)
746            .offload_params_to_cpu(value.offload_params_to_cpu)
747            .upscale_repeats(value.upscale_repeats)
748            .model(value.model.clone())
749            .diffusion_model(value.diffusion_model.clone())
750            .llm(value.llm.clone())
751            .llm_vision(value.llm_vision.clone())
752            .clip_l(value.clip_l.clone())
753            .clip_g(value.clip_g.clone())
754            .clip_vision(value.clip_vision.clone())
755            .t5xxl(value.t5xxl.clone())
756            .vae(value.vae.clone())
757            .taesd(value.taesd.clone())
758            .control_net(value.control_net.clone())
759            .embeddings(&value.embeddings.0)
760            .photo_maker(value.photo_maker.clone())
761            .pm_id_embed_path(value.pm_id_embed_path.clone())
762            .weight_type(value.weight_type)
763            .high_noise_diffusion_model(value.high_noise_diffusion_model.clone())
764            .vae_tiling(value.vae_tiling)
765            .vae_tile_size(value.vae_tile_size)
766            .vae_relative_tile_size(value.vae_relative_tile_size)
767            .vae_tile_overlap(value.vae_tile_overlap)
768            .rng(value.rng)
769            .sampler_rng_type(value.rng)
770            .scheduler(value.scheduler)
771            .sigmas(value.sigmas.clone())
772            .prediction(value.prediction)
773            .vae_on_cpu(value.vae_on_cpu)
774            .clip_on_cpu(value.clip_on_cpu)
775            .control_net(value.control_net.clone())
776            .control_net_cpu(value.control_net_cpu)
777            .flash_attention(value.flash_attention)
778            .chroma_disable_dit_mask(value.chroma_disable_dit_mask)
779            .chroma_enable_t5_mask(value.chroma_enable_t5_mask)
780            .chroma_t5_mask_pad(value.chroma_t5_mask_pad)
781            .diffusion_conv_direct(value.diffusion_conv_direct)
782            .vae_conv_direct(value.vae_conv_direct)
783            .force_sdxl_vae_conv_scale(value.force_sdxl_vae_conv_scale)
784            .flow_shift(value.flow_shift)
785            .timestep_shift(value.timestep_shift)
786            .taesd_preview_only(value.taesd_preview_only)
787            .lora_apply_mode(value.lora_apply_mode)
788            .circular(value.circular)
789            .circular_x(value.circular_x)
790            .circular_y(value.circular_y)
791            .use_qwen_image_zero_cond_true(value.use_qwen_image_zero_cond_true);
792
793        let lora_model_dir = Into::<PathBuf>::into(&value.lora_models.lora_model_dir);
794        let lora_specs = value
795            .lora_models
796            .data
797            .iter()
798            .map(|(_, name, multiplier)| LoraSpec {
799                file_name: name.clone(),
800                is_high_noise: false,
801                multiplier: *multiplier,
802            })
803            .collect();
804
805        builder.lora_models(&lora_model_dir, lora_specs);
806
807        if let Some(model) = &value.upscale_model {
808            builder.upscale_model(model.clone());
809        }
810        builder
811    }
812}
813
814#[derive(Builder, Debug, Clone)]
815#[builder(setter(into, strip_option), build_fn(validate = "Self::validate"))]
816/// Config struct common to all diffusion methods
817pub struct Config {
818    /// Path to PHOTOMAKER input id images dir
819    #[builder(default = "Default::default()")]
820    pm_id_images_dir: CLibPath,
821
822    /// Path to the input image, required by img2img
823    #[builder(default = "Default::default()")]
824    init_img: CLibPath,
825
826    /// Path to image condition, control net
827    #[builder(default = "Default::default()")]
828    control_image: CLibPath,
829
830    /// Path to write result image to (default: ./output.png)
831    #[builder(default = "PathBuf::from(\"./output.png\")")]
832    output: PathBuf,
833
834    /// Path to write result image to (default: ./output.png)
835    #[builder(default = "PathBuf::from(\"./preview_output.png\")")]
836    preview_output: PathBuf,
837
838    /// Preview method
839    #[builder(default = "PreviewType::PREVIEW_NONE")]
840    preview_mode: PreviewType,
841
842    /// Enables previewing noisy inputs of the models rather than the denoised outputs
843    #[builder(default = "false")]
844    preview_noisy: bool,
845
846    /// Interval in denoising steps between consecutive updates of the image preview file (default is 1, meaning updating at every step)
847    #[builder(default = "1")]
848    preview_interval: i32,
849
850    /// The prompt to render
851    prompt: String,
852
853    /// The negative prompt (default: "")
854    #[builder(default = "\"\".into()")]
855    negative_prompt: CLibString,
856
857    /// Unconditional guidance scale (default: 7.0)
858    #[builder(default = "7.0")]
859    cfg_scale: f32,
860
861    /// Distilled guidance scale for models with guidance input (default: 3.5)
862    #[builder(default = "3.5")]
863    guidance: f32,
864
865    /// Strength for noising/unnoising (default: 0.75)
866    #[builder(default = "0.75")]
867    strength: f32,
868
869    /// Strength for keeping input identity (default: 20%)
870    #[builder(default = "20.0")]
871    pm_style_strength: f32,
872
873    /// Strength to apply Control Net (default: 0.9)
874    /// 1.0 corresponds to full destruction of information in init
875    #[builder(default = "0.9")]
876    control_strength: f32,
877
878    /// Image height, in pixel space (default: 512)
879    #[builder(default = "512")]
880    height: i32,
881
882    /// Image width, in pixel space (default: 512)
883    #[builder(default = "512")]
884    width: i32,
885
886    /// Sampling-method (default: [SampleMethod::SAMPLE_METHOD_COUNT]).
887    /// [SampleMethod::EULER_SAMPLE_METHOD] will be used for flux, sd3, wan, qwen_image.
888    /// Otherwise [SampleMethod::EULER_A_SAMPLE_METHOD] is used.
889    #[builder(default = "SampleMethod::SAMPLE_METHOD_COUNT")]
890    sampling_method: SampleMethod,
891
892    /// eta in DDIM, only for DDIM and TCD: (default: 0)
893    #[builder(default = "0.")]
894    eta: f32,
895
896    /// Number of sample steps (default: 20)
897    #[builder(default = "20")]
898    steps: i32,
899
900    /// RNG seed (default: 42, use random seed for < 0)
901    #[builder(default = "42")]
902    seed: i64,
903
904    /// Number of images to generate (default: 1)
905    #[builder(default = "1")]
906    batch_count: i32,
907
908    /// Ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)
909    /// <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x
910    #[builder(default = "ClipSkip::Unspecified")]
911    clip_skip: ClipSkip,
912
913    /// Apply canny preprocessor (edge detection) (default: false)
914    #[builder(default = "false")]
915    canny: bool,
916
917    /// skip layer guidance (SLG) scale, only for DiT models: (default: 0)
918    /// 0 means disabled, a value of 2.5 is nice for sd3.5 medium
919    #[builder(default = "0.")]
920    slg_scale: f32,
921
922    /// Layers to skip for SLG steps: (default: \[7,8,9\])
923    #[builder(default = "vec![7, 8, 9]")]
924    skip_layer: Vec<i32>,
925
926    /// SLG enabling point: (default: 0.01)
927    #[builder(default = "0.01")]
928    skip_layer_start: f32,
929
930    /// SLG disabling point: (default: 0.2)
931    #[builder(default = "0.2")]
932    skip_layer_end: f32,
933
934    /// Disable auto resize of ref images
935    #[builder(default = "false")]
936    disable_auto_resize_ref_image: bool,
937
938    #[builder(default = "Self::cache_init()", private)]
939    cache: sd_cache_params_t,
940
941    #[builder(default = "CLibString::default()", private)]
942    scm_mask: CLibString,
943}
944
945impl ConfigBuilder {
946    fn validate(&self) -> Result<(), ConfigBuilderError> {
947        self.validate_output_dir()
948    }
949
950    fn validate_output_dir(&self) -> Result<(), ConfigBuilderError> {
951        let is_dir = self.output.as_ref().is_some_and(|val| val.is_dir());
952        let multiple_items = self.batch_count.as_ref().is_some_and(|val| *val > 1);
953        if is_dir == multiple_items {
954            Ok(())
955        } else {
956            Err(ConfigBuilderError::ValidationError(
957                "When batch_count > 1, output should point to folder and vice versa".to_owned(),
958            ))
959        }
960    }
961
962    fn cache_init() -> sd_cache_params_t {
963        sd_cache_params_t {
964            mode: sd_cache_mode_t::SD_CACHE_DISABLED,
965            reuse_threshold: 1.0,
966            start_percent: 0.15,
967            end_percent: 0.95,
968            error_decay_rate: 1.0,
969            use_relative_threshold: true,
970            reset_error_on_compute: true,
971            Fn_compute_blocks: 8,
972            Bn_compute_blocks: 0,
973            residual_diff_threshold: 0.08,
974            max_warmup_steps: 8,
975            max_cached_steps: -1,
976            max_continuous_cached_steps: -1,
977            taylorseer_n_derivatives: 1,
978            taylorseer_skip_interval: 1,
979            scm_mask: null(),
980            scm_policy_dynamic: true,
981        }
982    }
983
984    pub fn no_caching(&mut self) -> &mut Self {
985        let mut cache = Self::cache_init();
986        cache.mode = sd_cache_mode_t::SD_CACHE_DISABLED;
987        self.cache = Some(cache);
988        self
989    }
990
991    pub fn ucache_caching(&mut self, params: UCacheParams) -> &mut Self {
992        let mut cache = Self::cache_init();
993        cache.mode = sd_cache_mode_t::SD_CACHE_UCACHE;
994        cache.reuse_threshold = params.threshold;
995        cache.start_percent = params.start;
996        cache.end_percent = params.end;
997        cache.error_decay_rate = params.decay;
998        cache.use_relative_threshold = params.relative;
999        cache.reset_error_on_compute = params.reset;
1000        self.cache = Some(cache);
1001        self
1002    }
1003
1004    pub fn easy_cache_caching(&mut self, params: EasyCacheParams) -> &mut Self {
1005        let mut cache = Self::cache_init();
1006        cache.mode = sd_cache_mode_t::SD_CACHE_EASYCACHE;
1007        cache.reuse_threshold = params.threshold;
1008        cache.start_percent = params.start;
1009        cache.end_percent = params.end;
1010        self.cache = Some(cache);
1011        self
1012    }
1013
1014    pub fn db_cache_caching(&mut self, params: DbCacheParams) -> &mut Self {
1015        let mut cache = Self::cache_init();
1016        cache.mode = sd_cache_mode_t::SD_CACHE_DBCACHE;
1017        cache.Fn_compute_blocks = params.fn_blocks;
1018        cache.Bn_compute_blocks = params.bn_blocks;
1019        cache.residual_diff_threshold = params.threshold;
1020        cache.max_warmup_steps = params.warmup;
1021        cache.scm_policy_dynamic = match params.scm_policy_dynamic {
1022            ScmPolicy::Static => false,
1023            ScmPolicy::Dynamic => true,
1024        };
1025        self.scm_mask = Some(CLibString::from(
1026            params
1027                .scm_mask
1028                .to_vec_string(self.steps.unwrap_or_default()),
1029        ));
1030        cache.scm_mask = self.scm_mask.as_ref().unwrap().as_ptr();
1031
1032        self.cache = Some(cache);
1033        self
1034    }
1035
1036    pub fn taylor_seer_caching(&mut self) -> &mut Self {
1037        let mut cache = Self::cache_init();
1038        cache.mode = sd_cache_mode_t::SD_CACHE_TAYLORSEER;
1039        self.cache = Some(cache);
1040        self
1041    }
1042
1043    pub fn cache_dit_caching(&mut self, params: DbCacheParams) -> &mut Self {
1044        self.db_cache_caching(params).cache.unwrap().mode = sd_cache_mode_t::SD_CACHE_CACHE_DIT;
1045        self
1046    }
1047}
1048
1049impl From<Config> for ConfigBuilder {
1050    fn from(value: Config) -> Self {
1051        let mut builder = ConfigBuilder::default();
1052        let mut cache = value.cache;
1053        let scm_mask = value.scm_mask.clone();
1054        cache.scm_mask = scm_mask.as_ptr();
1055        builder
1056            .pm_id_images_dir(value.pm_id_images_dir)
1057            .init_img(value.init_img)
1058            .control_image(value.control_image)
1059            .output(value.output)
1060            .prompt(value.prompt)
1061            .negative_prompt(value.negative_prompt)
1062            .cfg_scale(value.cfg_scale)
1063            .strength(value.strength)
1064            .pm_style_strength(value.pm_style_strength)
1065            .control_strength(value.control_strength)
1066            .height(value.height)
1067            .width(value.width)
1068            .sampling_method(value.sampling_method)
1069            .steps(value.steps)
1070            .seed(value.seed)
1071            .batch_count(value.batch_count)
1072            .clip_skip(value.clip_skip)
1073            .slg_scale(value.slg_scale)
1074            .skip_layer(value.skip_layer)
1075            .skip_layer_start(value.skip_layer_start)
1076            .skip_layer_end(value.skip_layer_end)
1077            .canny(value.canny)
1078            .disable_auto_resize_ref_image(value.disable_auto_resize_ref_image)
1079            .preview_output(value.preview_output)
1080            .preview_mode(value.preview_mode)
1081            .preview_noisy(value.preview_noisy)
1082            .preview_interval(value.preview_interval)
1083            .cache(cache)
1084            .scm_mask(scm_mask);
1085        builder
1086    }
1087}
1088
1089#[derive(Debug, Clone, Default)]
1090struct CLibString(CString);
1091
1092impl CLibString {
1093    fn as_ptr(&self) -> *const c_char {
1094        self.0.as_ptr()
1095    }
1096}
1097
1098impl From<&str> for CLibString {
1099    fn from(value: &str) -> Self {
1100        Self(CString::new(value).unwrap())
1101    }
1102}
1103
1104impl From<String> for CLibString {
1105    fn from(value: String) -> Self {
1106        Self(CString::new(value).unwrap())
1107    }
1108}
1109
1110#[derive(Debug, Clone, Default)]
1111struct CLibPath(CString);
1112
1113impl CLibPath {
1114    fn as_ptr(&self) -> *const c_char {
1115        self.0.as_ptr()
1116    }
1117}
1118
1119impl From<PathBuf> for CLibPath {
1120    fn from(value: PathBuf) -> Self {
1121        Self(CString::new(value.to_str().unwrap_or_default()).unwrap())
1122    }
1123}
1124
1125impl From<&Path> for CLibPath {
1126    fn from(value: &Path) -> Self {
1127        Self(CString::new(value.to_str().unwrap_or_default()).unwrap())
1128    }
1129}
1130
1131impl From<&CLibPath> for PathBuf {
1132    fn from(value: &CLibPath) -> Self {
1133        PathBuf::from(value.0.to_str().unwrap())
1134    }
1135}
1136
1137fn output_files(path: &Path, prompt: &str, batch_size: i32) -> Vec<PathBuf> {
1138    if batch_size == 1 {
1139        vec![path.into()]
1140    } else {
1141        (1..=batch_size)
1142            .map(|id| path.join(format!("{prompt}_{id}.png")))
1143            .collect()
1144    }
1145}
1146
1147unsafe fn upscale(
1148    upscale_repeats: i32,
1149    upscaler_ctx: Option<*mut upscaler_ctx_t>,
1150    data: sd_image_t,
1151) -> Result<sd_image_t, DiffusionError> {
1152    unsafe {
1153        match upscaler_ctx {
1154            Some(upscaler_ctx) => {
1155                let upscale_factor = 4; // unused for RealESRGAN_x4plus_anime_6B.pth
1156                let mut current_image = data;
1157                for _ in 0..upscale_repeats {
1158                    let upscaled_image =
1159                        diffusion_rs_sys::upscale(upscaler_ctx, current_image, upscale_factor);
1160
1161                    if upscaled_image.data.is_null() {
1162                        return Err(DiffusionError::Upscaler);
1163                    }
1164
1165                    free(current_image.data as *mut c_void);
1166                    current_image = upscaled_image;
1167                }
1168                Ok(current_image)
1169            }
1170            None => Ok(data),
1171        }
1172    }
1173}
1174
1175/// Generate an image with a prompt
1176pub fn gen_img(config: &Config, model_config: &mut ModelConfig) -> Result<(), DiffusionError> {
1177    let prompt: CLibString = CLibString::from(config.prompt.as_str());
1178    let files = output_files(&config.output, &config.prompt, config.batch_count);
1179    unsafe {
1180        let sd_ctx = model_config.diffusion_ctx(true);
1181        let upscaler_ctx = model_config.upscaler_ctx();
1182        let init_image = sd_image_t {
1183            width: 0,
1184            height: 0,
1185            channel: 3,
1186            data: null_mut(),
1187        };
1188        let mask_image = sd_image_t {
1189            width: config.width as u32,
1190            height: config.height as u32,
1191            channel: 1,
1192            data: null_mut(),
1193        };
1194        let mut layers = config.skip_layer.clone();
1195        let guidance = sd_guidance_params_t {
1196            txt_cfg: config.cfg_scale,
1197            img_cfg: config.cfg_scale,
1198            distilled_guidance: config.guidance,
1199            slg: sd_slg_params_t {
1200                layers: layers.as_mut_ptr(),
1201                layer_count: config.skip_layer.len(),
1202                layer_start: config.skip_layer_start,
1203                layer_end: config.skip_layer_end,
1204                scale: config.slg_scale,
1205            },
1206        };
1207        let scheduler = if model_config.scheduler == Scheduler::SCHEDULER_COUNT {
1208            sd_get_default_scheduler(sd_ctx, config.sampling_method)
1209        } else {
1210            model_config.scheduler
1211        };
1212        let sample_method = if config.sampling_method == SampleMethod::SAMPLE_METHOD_COUNT {
1213            sd_get_default_sample_method(sd_ctx)
1214        } else {
1215            config.sampling_method
1216        };
1217        let sample_params = sd_sample_params_t {
1218            guidance,
1219            sample_method,
1220            sample_steps: config.steps,
1221            eta: config.eta,
1222            scheduler,
1223            shifted_timestep: model_config.timestep_shift,
1224            custom_sigmas: model_config.sigmas.as_mut_ptr(),
1225            custom_sigmas_count: model_config.sigmas.len() as i32,
1226        };
1227        let control_image = sd_image_t {
1228            width: 0,
1229            height: 0,
1230            channel: 3,
1231            data: null_mut(),
1232        };
1233        let vae_tiling_params = sd_tiling_params_t {
1234            enabled: model_config.vae_tiling,
1235            tile_size_x: model_config.vae_tile_size.0,
1236            tile_size_y: model_config.vae_tile_size.1,
1237            target_overlap: model_config.vae_tile_overlap,
1238            rel_size_x: model_config.vae_relative_tile_size.0,
1239            rel_size_y: model_config.vae_relative_tile_size.1,
1240        };
1241        let pm_params = sd_pm_params_t {
1242            id_images: null_mut(),
1243            id_images_count: 0,
1244            id_embed_path: model_config.pm_id_embed_path.as_ptr(),
1245            style_strength: config.pm_style_strength,
1246        };
1247
1248        unsafe extern "C" fn save_preview_local(
1249            _step: ::std::os::raw::c_int,
1250            _frame_count: ::std::os::raw::c_int,
1251            frames: *mut sd_image_t,
1252            _is_noisy: bool,
1253            data: *mut ::std::os::raw::c_void,
1254        ) {
1255            unsafe {
1256                let path = &*data.cast::<PathBuf>();
1257                let _ = save_img(*frames, path);
1258            }
1259        }
1260
1261        if config.preview_mode != PreviewType::PREVIEW_NONE {
1262            let data = &config.preview_output as *const PathBuf;
1263
1264            sd_set_preview_callback(
1265                Some(save_preview_local),
1266                config.preview_mode,
1267                config.preview_interval,
1268                !config.preview_noisy,
1269                config.preview_noisy,
1270                data as *mut c_void,
1271            );
1272        }
1273
1274        let sd_img_gen_params = sd_img_gen_params_t {
1275            prompt: prompt.as_ptr(),
1276            negative_prompt: config.negative_prompt.as_ptr(),
1277            clip_skip: config.clip_skip as i32,
1278            init_image,
1279            ref_images: null_mut(),
1280            ref_images_count: 0,
1281            increase_ref_index: false,
1282            mask_image,
1283            width: config.width,
1284            height: config.height,
1285            sample_params,
1286            strength: config.strength,
1287            seed: config.seed,
1288            batch_count: config.batch_count,
1289            control_image,
1290            control_strength: config.control_strength,
1291            pm_params,
1292            vae_tiling_params,
1293            auto_resize_ref_image: config.disable_auto_resize_ref_image,
1294            cache: config.cache,
1295            loras: model_config.lora_models.loras_t.as_ptr(),
1296            lora_count: model_config.lora_models.loras_t.len() as u32,
1297        };
1298        let slice = diffusion_rs_sys::generate_image(sd_ctx, &sd_img_gen_params);
1299        let ret = {
1300            if slice.is_null() {
1301                return Err(DiffusionError::Forward);
1302            }
1303            for (img, path) in slice::from_raw_parts(slice, config.batch_count as usize)
1304                .iter()
1305                .zip(files)
1306            {
1307                match upscale(model_config.upscale_repeats, upscaler_ctx, *img) {
1308                    Ok(img) => save_img(img, &path)?,
1309                    Err(err) => {
1310                        return Err(err);
1311                    }
1312                }
1313            }
1314            Ok(())
1315        };
1316        free(slice as *mut c_void);
1317        ret
1318    }
1319}
1320
1321fn save_img(img: sd_image_t, path: &Path) -> Result<(), DiffusionError> {
1322    // Thx @wandbrandon
1323    let len = (img.width * img.height * img.channel) as usize;
1324    let buffer = unsafe { slice::from_raw_parts(img.data, len).to_vec() };
1325    let save_state = ImageBuffer::from_raw(img.width, img.height, buffer)
1326        .map(|img| RgbImage::from(img).save(path));
1327    if let Some(Err(err)) = save_state {
1328        return Err(DiffusionError::StoreImages(err));
1329    }
1330    Ok(())
1331}
1332
1333#[cfg(test)]
1334mod tests {
1335    use std::path::PathBuf;
1336
1337    use crate::{
1338        api::{ConfigBuilderError, ModelConfigBuilder},
1339        util::download_file_hf_hub,
1340    };
1341
1342    use super::{ConfigBuilder, gen_img};
1343
1344    #[test]
1345    fn test_required_args_txt2img() {
1346        assert!(ConfigBuilder::default().build().is_err());
1347        assert!(ModelConfigBuilder::default().build().is_err());
1348        ModelConfigBuilder::default()
1349            .model(PathBuf::from("./test.ckpt"))
1350            .build()
1351            .unwrap();
1352
1353        ConfigBuilder::default()
1354            .prompt("a lovely cat driving a sport car")
1355            .build()
1356            .unwrap();
1357
1358        assert!(matches!(
1359            ConfigBuilder::default()
1360                .prompt("a lovely cat driving a sport car")
1361                .batch_count(10)
1362                .build(),
1363            Err(ConfigBuilderError::ValidationError(_))
1364        ));
1365
1366        ConfigBuilder::default()
1367            .prompt("a lovely cat driving a sport car")
1368            .build()
1369            .unwrap();
1370
1371        ConfigBuilder::default()
1372            .prompt("a lovely duck drinking water from a bottle")
1373            .batch_count(2)
1374            .output(PathBuf::from("./"))
1375            .build()
1376            .unwrap();
1377    }
1378
1379    #[ignore]
1380    #[test]
1381    fn test_img_gen() {
1382        let model_path =
1383            download_file_hf_hub("CompVis/stable-diffusion-v-1-4-original", "sd-v1-4.ckpt")
1384                .unwrap();
1385
1386        let upscaler_path = download_file_hf_hub(
1387            "ximso/RealESRGAN_x4plus_anime_6B",
1388            "RealESRGAN_x4plus_anime_6B.pth",
1389        )
1390        .unwrap();
1391        let config = ConfigBuilder::default()
1392            .prompt("a lovely duck drinking water from a bottle")
1393            .output(PathBuf::from("./output_1.png"))
1394            .batch_count(1)
1395            .build()
1396            .unwrap();
1397        let mut model_config = ModelConfigBuilder::default()
1398            .model(model_path)
1399            .upscale_model(upscaler_path)
1400            .upscale_repeats(1)
1401            .build()
1402            .unwrap();
1403
1404        gen_img(&config, &mut model_config).unwrap();
1405        let config2 = ConfigBuilder::from(config.clone())
1406            .prompt("a lovely duck drinking water from a straw")
1407            .output(PathBuf::from("./output_2.png"))
1408            .build()
1409            .unwrap();
1410        gen_img(&config2, &mut model_config).unwrap();
1411
1412        let config3 = ConfigBuilder::from(config)
1413            .prompt("a lovely dog drinking water from a starbucks cup")
1414            .batch_count(2)
1415            .output(PathBuf::from("./"))
1416            .build()
1417            .unwrap();
1418
1419        gen_img(&config3, &mut model_config).unwrap();
1420    }
1421}