diffusion_rs/
api.rs

1use std::cmp::max;
2use std::collections::HashMap;
3use std::ffi::CString;
4use std::ffi::c_char;
5use std::ffi::c_void;
6use std::fmt::Display;
7use std::path::Path;
8use std::path::PathBuf;
9use std::ptr::null;
10use std::ptr::null_mut;
11use std::slice;
12
13use chrono::Local;
14use derive_builder::Builder;
15use diffusion_rs_sys::free_upscaler_ctx;
16use diffusion_rs_sys::new_upscaler_ctx;
17use diffusion_rs_sys::sd_cache_mode_t;
18use diffusion_rs_sys::sd_cache_params_t;
19use diffusion_rs_sys::sd_ctx_params_t;
20use diffusion_rs_sys::sd_embedding_t;
21use diffusion_rs_sys::sd_get_default_sample_method;
22use diffusion_rs_sys::sd_get_default_scheduler;
23use diffusion_rs_sys::sd_guidance_params_t;
24use diffusion_rs_sys::sd_image_t;
25use diffusion_rs_sys::sd_img_gen_params_t;
26use diffusion_rs_sys::sd_lora_t;
27use diffusion_rs_sys::sd_pm_params_t;
28use diffusion_rs_sys::sd_sample_params_t;
29use diffusion_rs_sys::sd_set_preview_callback;
30use diffusion_rs_sys::sd_slg_params_t;
31use diffusion_rs_sys::sd_tiling_params_t;
32use diffusion_rs_sys::upscaler_ctx_t;
33use image::ImageBuffer;
34use image::ImageError;
35use image::RgbImage;
36use libc::free;
37use thiserror::Error;
38use walkdir::DirEntry;
39use walkdir::WalkDir;
40
41use diffusion_rs_sys::free_sd_ctx;
42use diffusion_rs_sys::new_sd_ctx;
43use diffusion_rs_sys::sd_ctx_t;
44
45/// Specify the range function
46pub use diffusion_rs_sys::rng_type_t as RngFunction;
47
48/// Sampling methods
49pub use diffusion_rs_sys::sample_method_t as SampleMethod;
50
51/// Denoiser sigma schedule
52pub use diffusion_rs_sys::scheduler_t as Scheduler;
53
54/// Prediction override
55pub use diffusion_rs_sys::prediction_t as Prediction;
56
57/// Weight type
58pub use diffusion_rs_sys::sd_type_t as WeightType;
59
60/// Preview mode
61pub use diffusion_rs_sys::preview_t as PreviewType;
62
63/// Lora mode
64pub use diffusion_rs_sys::lora_apply_mode_t as LoraModeType;
65
66static VALID_EXT: [&str; 3] = ["gguf", "safetensors", "pt"];
67
68#[non_exhaustive]
69#[derive(Error, Debug)]
70/// Error that can occurs while forwarding models
71pub enum DiffusionError {
72    #[error("The underling stablediffusion.cpp function returned NULL")]
73    Forward,
74    #[error(transparent)]
75    StoreImages(#[from] ImageError),
76    #[error("The underling upscaler model returned a NULL image")]
77    Upscaler,
78}
79
80#[repr(i32)]
81#[non_exhaustive]
82#[derive(Debug, Default, Copy, Clone, Hash, PartialEq, Eq)]
83/// Ignore the lower X layers of CLIP network
84pub enum ClipSkip {
85    /// Will be [ClipSkip::None] for SD1.x, [ClipSkip::OneLayer] for SD2.x
86    #[default]
87    Unspecified = 0,
88    None = 1,
89    OneLayer = 2,
90}
91
92type EmbeddingsStorage = (PathBuf, Vec<(CLibString, CLibPath)>, Vec<sd_embedding_t>);
93
94#[derive(Default, Debug, Clone)]
95struct LoraStorage {
96    lora_model_dir: CLibPath,
97    data: Vec<(CLibPath, String, f32)>,
98    loras_t: Vec<sd_lora_t>,
99}
100
101/// Specify the instructions for a Lora model
102#[derive(Default, Debug, Clone)]
103pub struct LoraSpec {
104    pub file_name: String,
105    pub is_high_noise: bool,
106    pub multiplier: f32,
107}
108
109/// Parameters for UCache
110#[derive(Builder, Debug, Clone)]
111pub struct UCacheParams {
112    /// Error threshold for reuse decision
113    #[builder(default = "1.0")]
114    threshold: f32,
115
116    /// Start caching at this percent of steps
117    #[builder(default = "0.15")]
118    start: f32,
119
120    /// Stop caching at this percent of steps
121    #[builder(default = "0.95")]
122    end: f32,
123
124    /// Error decay rate (0-1)
125    #[builder(default = "1.0")]
126    decay: f32,
127
128    /// Scale threshold by output norm
129    #[builder(default = "true")]
130    relative: bool,
131
132    /// Reset error after computing
133    /// true: Resets accumulated error after each computed step. More aggressive caching, works well with most samplers.
134    /// false: Keeps error accumulated. More conservative, recommended for euler_a sampler
135    #[builder(default = "true")]
136    reset: bool,
137}
138
139/// Parameters for Easy Cache
140#[derive(Builder, Debug, Clone)]
141pub struct EasyCacheParams {
142    /// Error threshold for reuse decision
143    #[builder(default = "0.2")]
144    threshold: f32,
145
146    /// Start caching at this percent of steps
147    #[builder(default = "0.15")]
148    start: f32,
149
150    /// Stop caching at this percent of steps
151    #[builder(default = "0.95")]
152    end: f32,
153}
154
155/// Parameters for Db Cache
156#[derive(Builder, Debug, Clone)]
157pub struct DbCacheParams {
158    /// Front blocks to always compute
159    #[builder(default = "8")]
160    fn_blocks: i32,
161
162    /// Back blocks to always compute
163    #[builder(default = "0")]
164    bn_blocks: i32,
165
166    /// L1 residual difference threshold
167    #[builder(default = "0.08")]
168    threshold: f32,
169
170    /// Steps before caching starts
171    #[builder(default = "8")]
172    warmup: i32,
173
174    /// Steps Computation Mask controls which steps can be cached
175    scm_mask: ScmPreset,
176
177    /// Scm Policy
178    #[builder(default = "ScmPolicy::default()")]
179    scm_policy_dynamic: ScmPolicy,
180}
181
182/// Steps Computation Mask Policy controls when to cache steps
183#[derive(Debug, Default, Clone)]
184pub enum ScmPolicy {
185    /// Always cache on cacheable steps
186    Static,
187    #[default]
188    /// Check threshold before caching
189    Dynamic,
190}
191
192/// Steps Computation Mask Preset controls which steps can be cached
193#[derive(Debug, Default, Clone)]
194pub enum ScmPreset {
195    Slow,
196    #[default]
197    Medium,
198    Fast,
199    Ultra,
200    /// E.g.: "1,1,1,1,0,0,1,0,0,0,1,0,0,0,1,0,0,0,1,1"
201    /// where 1 means compute, 0 means cache
202    Custom(String),
203}
204
205impl ScmPreset {
206    fn to_vec_string(&self, steps: i32) -> String {
207        match self {
208            ScmPreset::Slow => ScmPresetBins {
209                compute_bins: vec![8, 3, 3, 2, 1, 1],
210                cache_bins: vec![1, 2, 2, 2, 3],
211                steps,
212            }
213            .to_string(),
214            ScmPreset::Medium => ScmPresetBins {
215                compute_bins: vec![6, 2, 2, 2, 2, 1],
216                cache_bins: vec![1, 3, 3, 3, 3],
217                steps,
218            }
219            .to_string(),
220            ScmPreset::Fast => ScmPresetBins {
221                compute_bins: vec![6, 1, 1, 1, 1, 1],
222                cache_bins: vec![1, 3, 4, 5, 4],
223                steps,
224            }
225            .to_string(),
226            ScmPreset::Ultra => ScmPresetBins {
227                compute_bins: vec![4, 1, 1, 1, 1],
228                cache_bins: vec![2, 5, 6, 7],
229                steps,
230            }
231            .to_string(),
232            ScmPreset::Custom(s) => s.clone(),
233        }
234    }
235}
236
237#[derive(Debug, Clone)]
238struct ScmPresetBins {
239    compute_bins: Vec<i32>,
240    cache_bins: Vec<i32>,
241    steps: i32,
242}
243
244impl ScmPresetBins {
245    fn maybe_scale(&self) -> ScmPresetBins {
246        if self.steps == 28 || self.steps <= 0 {
247            return self.clone();
248        }
249        self.scale()
250    }
251
252    fn scale(&self) -> ScmPresetBins {
253        let scale = self.steps as f32 / 28.0;
254        let scaled_compute_bins = self
255            .compute_bins
256            .iter()
257            .map(|b| max(1, (*b as f32 * scale * 0.5) as i32))
258            .collect();
259        let scaled_cached_bins = self
260            .cache_bins
261            .iter()
262            .map(|b| max(1, (*b as f32 * scale * 0.5) as i32))
263            .collect();
264        ScmPresetBins {
265            compute_bins: scaled_compute_bins,
266            cache_bins: scaled_cached_bins,
267            steps: self.steps,
268        }
269    }
270
271    fn generate_vec_mask(&self) -> Vec<i32> {
272        let mut mask = Vec::new();
273        let mut c_idx = 0;
274        let mut cache_idx = 0;
275
276        while mask.len() < self.steps as usize {
277            if c_idx < self.compute_bins.len() {
278                let compute_count = self.compute_bins[c_idx];
279                for _ in 0..compute_count {
280                    if mask.len() < self.steps as usize {
281                        mask.push(1);
282                    }
283                }
284                c_idx += 1;
285            }
286            if cache_idx < self.cache_bins.len() {
287                let cache_count = self.cache_bins[c_idx];
288                for _ in 0..cache_count {
289                    if mask.len() < self.steps as usize {
290                        mask.push(0);
291                    }
292                }
293                cache_idx += 1;
294            }
295            if c_idx >= self.compute_bins.len() && cache_idx >= self.cache_bins.len() {
296                break;
297            }
298        }
299        if let Some(last) = mask.last_mut() {
300            *last = 1;
301        }
302        mask
303    }
304}
305
306impl Display for ScmPresetBins {
307    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
308        let mask: String = self
309            .maybe_scale()
310            .generate_vec_mask()
311            .iter()
312            .map(|x| x.to_string())
313            .collect::<Vec<_>>()
314            .join(",");
315        write!(f, "{mask}")
316    }
317}
318
319/// Config struct for a specific diffusion model
320#[derive(Builder, Debug, Clone)]
321#[builder(
322    setter(into, strip_option),
323    build_fn(error = "ConfigBuilderError", validate = "Self::validate")
324)]
325pub struct ModelConfig {
326    /// Number of threads to use during computation (default: 0).
327    /// If n_ threads <= 0, then threads will be set to the number of CPU physical cores.
328    #[builder(default = "num_cpus::get_physical() as i32", setter(custom))]
329    n_threads: i32,
330
331    /// Whether to memory-map model (default: false)
332    #[builder(default = "false")]
333    enable_mmap: bool,
334
335    /// Place the weights in RAM to save VRAM, and automatically load them into VRAM when needed
336    #[builder(default = "false")]
337    offload_params_to_cpu: bool,
338
339    /// Path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now
340    #[builder(default = "Default::default()")]
341    upscale_model: Option<CLibPath>,
342
343    /// Run the ESRGAN upscaler this many times (default 1)
344    #[builder(default = "1")]
345    upscale_repeats: i32,
346
347    /// Tile size for ESRGAN upscaler (default 128)
348    #[builder(default = "128")]
349    upscale_tile_size: i32,
350
351    /// Path to full model
352    #[builder(default = "Default::default()")]
353    model: CLibPath,
354
355    /// Path to the standalone diffusion model
356    #[builder(default = "Default::default()")]
357    diffusion_model: CLibPath,
358
359    /// Path to the qwen2vl text encoder
360    #[builder(default = "Default::default()")]
361    llm: CLibPath,
362
363    /// Path to the qwen2vl vit
364    #[builder(default = "Default::default()")]
365    llm_vision: CLibPath,
366
367    /// Path to the clip-l text encoder
368    #[builder(default = "Default::default()")]
369    clip_l: CLibPath,
370
371    /// Path to the clip-g text encoder
372    #[builder(default = "Default::default()")]
373    clip_g: CLibPath,
374
375    /// Path to the clip-vision encoder
376    #[builder(default = "Default::default()")]
377    clip_vision: CLibPath,
378
379    /// Path to the t5xxl text encoder
380    #[builder(default = "Default::default()")]
381    t5xxl: CLibPath,
382
383    /// Path to vae
384    #[builder(default = "Default::default()")]
385    vae: CLibPath,
386
387    /// Path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)
388    #[builder(default = "Default::default()")]
389    taesd: CLibPath,
390
391    /// Path to control net model
392    #[builder(default = "Default::default()")]
393    control_net: CLibPath,
394
395    /// Path to embeddings
396    #[builder(default = "Default::default()", setter(custom))]
397    embeddings: EmbeddingsStorage,
398
399    /// Path to PHOTOMAKER model
400    #[builder(default = "Default::default()")]
401    photo_maker: CLibPath,
402
403    /// Path to PHOTOMAKER v2 id embed
404    #[builder(default = "Default::default()")]
405    pm_id_embed_path: CLibPath,
406
407    /// Weight type. If not specified, the default is the type of the weight file
408    #[builder(default = "WeightType::SD_TYPE_COUNT")]
409    weight_type: WeightType,
410
411    /// Lora model directory
412    #[builder(default = "Default::default()", setter(custom))]
413    lora_models: LoraStorage,
414
415    /// Path to the standalone high noise diffusion model
416    #[builder(default = "Default::default()")]
417    high_noise_diffusion_model: CLibPath,
418
419    /// Process vae in tiles to reduce memory usage (default: false)
420    #[builder(default = "false")]
421    vae_tiling: bool,
422
423    /// Tile size for vae tiling (default: 32x32)
424    #[builder(default = "(32,32)")]
425    vae_tile_size: (i32, i32),
426
427    /// Relative tile size for vae tiling, in fraction of image size if < 1, in number of tiles per dim if >=1 (overrides vae_tile_size)
428    #[builder(default = "(0.,0.)")]
429    vae_relative_tile_size: (f32, f32),
430
431    /// Tile overlap for vae tiling, in fraction of tile size (default: 0.5)
432    #[builder(default = "0.5")]
433    vae_tile_overlap: f32,
434
435    /// RNG (default: CUDA)
436    #[builder(default = "RngFunction::CUDA_RNG")]
437    rng: RngFunction,
438
439    /// Sampler RNG. If [RngFunction::RNG_TYPE_COUNT] is used will default to rng value. (default: [RngFunction::RNG_TYPE_COUNT])",
440    #[builder(default = "RngFunction::RNG_TYPE_COUNT")]
441    sampler_rng_type: RngFunction,
442
443    /// Denoiser sigma schedule (default: [Scheduler::SCHEDULER_COUNT]).
444    /// Will default to [Scheduler::EXPONENTIAL_SCHEDULER] if a denoiser is already instantiated.
445    /// Otherwise, [Scheduler::DISCRETE_SCHEDULER] is used.
446    #[builder(default = "Scheduler::SCHEDULER_COUNT")]
447    scheduler: Scheduler,
448
449    /// Custom sigma values for the sampler
450    #[builder(default = "Default::default()")]
451    sigmas: Vec<f32>,
452
453    /// Prediction type override (default: PREDICTION_COUNT)
454    #[builder(default = "Prediction::PREDICTION_COUNT")]
455    prediction: Prediction,
456
457    /// Keep vae in cpu (for low vram) (default: false)
458    #[builder(default = "false")]
459    vae_on_cpu: bool,
460
461    /// keep clip in cpu (for low vram) (default: false)
462    #[builder(default = "false")]
463    clip_on_cpu: bool,
464
465    /// Keep controlnet in cpu (for low vram) (default: false)
466    #[builder(default = "false")]
467    control_net_cpu: bool,
468
469    /// Use flash attention to reduce memory usage (for low vram).
470    // /For most backends, it slows things down, but for cuda it generally speeds it up too. At the moment, it is only supported for some models and some backends (like cpu, cuda/rocm, metal).
471    #[builder(default = "false")]
472    flash_attention: bool,
473
474    /// Disable dit mask for chroma
475    #[builder(default = "false")]
476    chroma_disable_dit_mask: bool,
477
478    /// Enable t5 mask for chroma
479    #[builder(default = "false")]
480    chroma_enable_t5_mask: bool,
481
482    /// t5 mask pad size of chroma
483    #[builder(default = "1")]
484    chroma_t5_mask_pad: i32,
485
486    /// Use qwen image zero cond true optimization
487    #[builder(default = "false")]
488    use_qwen_image_zero_cond_true: bool,
489
490    /// Use Conv2d direct in the diffusion model
491    /// This might crash if it is not supported by the backend.
492    #[builder(default = "false")]
493    diffusion_conv_direct: bool,
494
495    /// Use Conv2d direct in the vae model (should improve the performance)
496    /// This might crash if it is not supported by the backend.
497    #[builder(default = "false")]
498    vae_conv_direct: bool,
499
500    /// Force use of conv scale on sdxl vae
501    #[builder(default = "false")]
502    force_sdxl_vae_conv_scale: bool,
503
504    /// Shift value for Flow models like SD3.x or WAN (default: auto)
505    #[builder(default = "f32::INFINITY")]
506    flow_shift: f32,
507
508    /// Shift timestep for NitroFusion models, default: 0, recommended N for NitroSD-Realism around 250 and 500 for NitroSD-Vibrant
509    #[builder(default = "0")]
510    timestep_shift: i32,
511
512    /// Prevents usage of taesd for decoding the final image
513    #[builder(default = "false")]
514    taesd_preview_only: bool,
515
516    /// In auto mode, if the model weights contain any quantized parameters, the at_runtime mode will be used; otherwise, immediately will be used.The immediate mode may have precision and compatibility issues with quantized parameters, but it usually offers faster inference speed and, in some cases, lower memory usage. The at_runtime mode, on the other hand, is exactly the opposite
517    #[builder(default = "LoraModeType::LORA_APPLY_AUTO")]
518    lora_apply_mode: LoraModeType,
519
520    /// Enable circular padding for convolutions
521    #[builder(default = "false")]
522    circular: bool,
523
524    /// Enable circular RoPE wrapping on x-axis (width) only
525    #[builder(default = "false")]
526    circular_x: bool,
527
528    /// Enable circular RoPE wrapping on y-axis (height) only
529    #[builder(default = "false")]
530    circular_y: bool,
531
532    #[builder(default = "None", private)]
533    upscaler_ctx: Option<*mut upscaler_ctx_t>,
534
535    #[builder(default = "None", private)]
536    diffusion_ctx: Option<(*mut sd_ctx_t, sd_ctx_params_t)>,
537}
538
539impl ModelConfigBuilder {
540    fn validate(&self) -> Result<(), ConfigBuilderError> {
541        self.validate_model()
542    }
543
544    fn validate_model(&self) -> Result<(), ConfigBuilderError> {
545        self.model
546            .as_ref()
547            .or(self.diffusion_model.as_ref())
548            .map(|_| ())
549            .ok_or(ConfigBuilderError::UninitializedField(
550                "Model OR DiffusionModel must be valorized",
551            ))
552    }
553
554    fn filter_valid_extensions(&self, path: &Path) -> impl Iterator<Item = DirEntry> {
555        WalkDir::new(path)
556            .into_iter()
557            .filter_map(|entry| entry.ok())
558            .filter(|entry| {
559                entry
560                    .path()
561                    .extension()
562                    .and_then(|ext| ext.to_str())
563                    .map(|ext_str| VALID_EXT.contains(&ext_str))
564                    .unwrap_or(false)
565            })
566    }
567
568    fn build_single_lora_storage(
569        spec: &LoraSpec,
570        is_high_noise: bool,
571        valid_loras: &HashMap<String, PathBuf>,
572    ) -> ((CLibPath, String, f32), sd_lora_t) {
573        let path = valid_loras.get(&spec.file_name).unwrap().as_path();
574        let c_path = CLibPath::from(path);
575        let lora = sd_lora_t {
576            is_high_noise,
577            multiplier: spec.multiplier,
578            path: c_path.as_ptr(),
579        };
580        let data = (c_path, spec.file_name.clone(), spec.multiplier);
581        (data, lora)
582    }
583
584    pub fn embeddings(&mut self, embeddings_dir: &Path) -> &mut Self {
585        let data: Vec<(CLibString, CLibPath)> = self
586            .filter_valid_extensions(embeddings_dir)
587            .map(|entry| {
588                let file_stem = entry
589                    .path()
590                    .file_stem()
591                    .and_then(|stem| stem.to_str())
592                    .unwrap_or_default()
593                    .to_owned();
594                (CLibString::from(file_stem), CLibPath::from(entry.path()))
595            })
596            .collect();
597        let data_pointer = data
598            .iter()
599            .map(|(name, path)| sd_embedding_t {
600                name: name.as_ptr(),
601                path: path.as_ptr(),
602            })
603            .collect();
604        self.embeddings = Some((embeddings_dir.to_path_buf(), data, data_pointer));
605        self
606    }
607
608    pub fn lora_models(&mut self, lora_model_dir: &Path, specs: Vec<LoraSpec>) -> &mut Self {
609        let valid_loras: HashMap<String, PathBuf> = self
610            .filter_valid_extensions(lora_model_dir)
611            .map(|entry| {
612                let path = entry.path();
613                (
614                    path.file_stem()
615                        .and_then(|stem| stem.to_str())
616                        .unwrap_or_default()
617                        .to_owned(),
618                    path.to_path_buf(),
619                )
620            })
621            .collect();
622        let valid_lora_names: Vec<&String> = valid_loras.keys().collect();
623        let standard = specs
624            .iter()
625            .filter(|s| valid_lora_names.contains(&&s.file_name) && !s.is_high_noise)
626            .map(|s| Self::build_single_lora_storage(s, false, &valid_loras));
627        let high_noise = specs
628            .iter()
629            .filter(|s| valid_lora_names.contains(&&s.file_name) && s.is_high_noise)
630            .map(|s| Self::build_single_lora_storage(s, true, &valid_loras));
631
632        let mut data = Vec::new();
633        let mut loras_t = Vec::new();
634        for lora in standard.chain(high_noise) {
635            data.push(lora.0);
636            loras_t.push(lora.1);
637        }
638
639        self.lora_models = Some(LoraStorage {
640            lora_model_dir: lora_model_dir.into(),
641            data,
642            loras_t,
643        });
644        self
645    }
646
647    pub fn n_threads(&mut self, value: i32) -> &mut Self {
648        self.n_threads = if value > 0 {
649            Some(value)
650        } else {
651            Some(num_cpus::get_physical() as i32)
652        };
653        self
654    }
655}
656
657impl ModelConfig {
658    unsafe fn upscaler_ctx(&mut self) -> Option<*mut upscaler_ctx_t> {
659        unsafe {
660            if self.upscale_model.is_none() || self.upscale_repeats == 0 {
661                None
662            } else {
663                if self.upscaler_ctx.is_none() {
664                    let upscaler = new_upscaler_ctx(
665                        self.upscale_model.as_ref().unwrap().as_ptr(),
666                        self.offload_params_to_cpu,
667                        self.diffusion_conv_direct,
668                        self.n_threads,
669                        self.upscale_tile_size,
670                    );
671                    self.upscaler_ctx = Some(upscaler);
672                }
673                self.upscaler_ctx
674            }
675        }
676    }
677
678    unsafe fn diffusion_ctx(&mut self, vae_decode_only: bool) -> *mut sd_ctx_t {
679        unsafe {
680            if self.diffusion_ctx.is_none() {
681                let sd_ctx_params = sd_ctx_params_t {
682                    model_path: self.model.as_ptr(),
683                    llm_path: self.llm.as_ptr(),
684                    llm_vision_path: self.llm_vision.as_ptr(),
685                    clip_l_path: self.clip_l.as_ptr(),
686                    clip_g_path: self.clip_g.as_ptr(),
687                    clip_vision_path: self.clip_vision.as_ptr(),
688                    high_noise_diffusion_model_path: self.high_noise_diffusion_model.as_ptr(),
689                    t5xxl_path: self.t5xxl.as_ptr(),
690                    diffusion_model_path: self.diffusion_model.as_ptr(),
691                    vae_path: self.vae.as_ptr(),
692                    taesd_path: self.taesd.as_ptr(),
693                    control_net_path: self.control_net.as_ptr(),
694                    embeddings: self.embeddings.2.as_ptr(),
695                    embedding_count: self.embeddings.1.len() as u32,
696                    photo_maker_path: self.photo_maker.as_ptr(),
697                    vae_decode_only,
698                    free_params_immediately: false,
699                    n_threads: self.n_threads,
700                    wtype: self.weight_type,
701                    rng_type: self.rng,
702                    keep_clip_on_cpu: self.clip_on_cpu,
703                    keep_control_net_on_cpu: self.control_net_cpu,
704                    keep_vae_on_cpu: self.vae_on_cpu,
705                    diffusion_flash_attn: self.flash_attention,
706                    diffusion_conv_direct: self.diffusion_conv_direct,
707                    chroma_use_dit_mask: !self.chroma_disable_dit_mask,
708                    chroma_use_t5_mask: self.chroma_enable_t5_mask,
709                    chroma_t5_mask_pad: self.chroma_t5_mask_pad,
710                    vae_conv_direct: self.vae_conv_direct,
711                    offload_params_to_cpu: self.offload_params_to_cpu,
712                    flow_shift: self.flow_shift,
713                    prediction: self.prediction,
714                    force_sdxl_vae_conv_scale: self.force_sdxl_vae_conv_scale,
715                    tae_preview_only: self.taesd_preview_only,
716                    lora_apply_mode: self.lora_apply_mode,
717                    tensor_type_rules: null_mut(),
718                    sampler_rng_type: self.sampler_rng_type,
719                    circular_x: self.circular || self.circular_x,
720                    circular_y: self.circular || self.circular_y,
721                    qwen_image_zero_cond_t: self.use_qwen_image_zero_cond_true,
722                    enable_mmap: self.enable_mmap,
723                };
724                let ctx = new_sd_ctx(&sd_ctx_params);
725                self.diffusion_ctx = Some((ctx, sd_ctx_params))
726            }
727            self.diffusion_ctx.unwrap().0
728        }
729    }
730}
731
732impl Drop for ModelConfig {
733    fn drop(&mut self) {
734        //Cleanup CTX section
735        unsafe {
736            if let Some((sd_ctx, _)) = self.diffusion_ctx {
737                free_sd_ctx(sd_ctx);
738            }
739
740            if let Some(upscaler_ctx) = self.upscaler_ctx {
741                free_upscaler_ctx(upscaler_ctx);
742            }
743        }
744    }
745}
746
747impl From<ModelConfig> for ModelConfigBuilder {
748    fn from(value: ModelConfig) -> Self {
749        let mut builder = ModelConfigBuilder::default();
750        builder
751            .n_threads(value.n_threads)
752            .offload_params_to_cpu(value.offload_params_to_cpu)
753            .upscale_repeats(value.upscale_repeats)
754            .model(value.model.clone())
755            .diffusion_model(value.diffusion_model.clone())
756            .llm(value.llm.clone())
757            .llm_vision(value.llm_vision.clone())
758            .clip_l(value.clip_l.clone())
759            .clip_g(value.clip_g.clone())
760            .clip_vision(value.clip_vision.clone())
761            .t5xxl(value.t5xxl.clone())
762            .vae(value.vae.clone())
763            .taesd(value.taesd.clone())
764            .control_net(value.control_net.clone())
765            .embeddings(&value.embeddings.0)
766            .photo_maker(value.photo_maker.clone())
767            .pm_id_embed_path(value.pm_id_embed_path.clone())
768            .weight_type(value.weight_type)
769            .high_noise_diffusion_model(value.high_noise_diffusion_model.clone())
770            .vae_tiling(value.vae_tiling)
771            .vae_tile_size(value.vae_tile_size)
772            .vae_relative_tile_size(value.vae_relative_tile_size)
773            .vae_tile_overlap(value.vae_tile_overlap)
774            .rng(value.rng)
775            .sampler_rng_type(value.rng)
776            .scheduler(value.scheduler)
777            .sigmas(value.sigmas.clone())
778            .prediction(value.prediction)
779            .vae_on_cpu(value.vae_on_cpu)
780            .clip_on_cpu(value.clip_on_cpu)
781            .control_net(value.control_net.clone())
782            .control_net_cpu(value.control_net_cpu)
783            .flash_attention(value.flash_attention)
784            .chroma_disable_dit_mask(value.chroma_disable_dit_mask)
785            .chroma_enable_t5_mask(value.chroma_enable_t5_mask)
786            .chroma_t5_mask_pad(value.chroma_t5_mask_pad)
787            .diffusion_conv_direct(value.diffusion_conv_direct)
788            .vae_conv_direct(value.vae_conv_direct)
789            .force_sdxl_vae_conv_scale(value.force_sdxl_vae_conv_scale)
790            .flow_shift(value.flow_shift)
791            .timestep_shift(value.timestep_shift)
792            .taesd_preview_only(value.taesd_preview_only)
793            .lora_apply_mode(value.lora_apply_mode)
794            .circular(value.circular)
795            .circular_x(value.circular_x)
796            .circular_y(value.circular_y)
797            .use_qwen_image_zero_cond_true(value.use_qwen_image_zero_cond_true);
798
799        let lora_model_dir = Into::<PathBuf>::into(&value.lora_models.lora_model_dir);
800        let lora_specs = value
801            .lora_models
802            .data
803            .iter()
804            .map(|(_, name, multiplier)| LoraSpec {
805                file_name: name.clone(),
806                is_high_noise: false,
807                multiplier: *multiplier,
808            })
809            .collect();
810
811        builder.lora_models(&lora_model_dir, lora_specs);
812
813        if let Some(model) = &value.upscale_model {
814            builder.upscale_model(model.clone());
815        }
816        builder
817    }
818}
819
820#[derive(Builder, Debug, Clone)]
821#[builder(setter(into, strip_option), build_fn(validate = "Self::validate"))]
822/// Config struct common to all diffusion methods
823pub struct Config {
824    /// Path to PHOTOMAKER input id images dir
825    #[builder(default = "Default::default()")]
826    pm_id_images_dir: CLibPath,
827
828    /// Path to the input image, required by img2img
829    #[builder(default = "Default::default()")]
830    init_img: CLibPath,
831
832    /// Path to image condition, control net
833    #[builder(default = "Default::default()")]
834    control_image: CLibPath,
835
836    /// Path to write result image to (default: ./output.png)
837    #[builder(default = "PathBuf::from(\"./output.png\")")]
838    output: PathBuf,
839
840    /// Path to write result image to (default: ./output.png)
841    #[builder(default = "PathBuf::from(\"./preview_output.png\")")]
842    preview_output: PathBuf,
843
844    /// Preview method
845    #[builder(default = "PreviewType::PREVIEW_NONE")]
846    preview_mode: PreviewType,
847
848    /// Enables previewing noisy inputs of the models rather than the denoised outputs
849    #[builder(default = "false")]
850    preview_noisy: bool,
851
852    /// Interval in denoising steps between consecutive updates of the image preview file (default is 1, meaning updating at every step)
853    #[builder(default = "1")]
854    preview_interval: i32,
855
856    /// The prompt to render
857    prompt: String,
858
859    /// The negative prompt (default: "")
860    #[builder(default = "\"\".into()")]
861    negative_prompt: CLibString,
862
863    /// Unconditional guidance scale (default: 7.0)
864    #[builder(default = "7.0")]
865    cfg_scale: f32,
866
867    /// Distilled guidance scale for models with guidance input (default: 3.5)
868    #[builder(default = "3.5")]
869    guidance: f32,
870
871    /// Strength for noising/unnoising (default: 0.75)
872    #[builder(default = "0.75")]
873    strength: f32,
874
875    /// Strength for keeping input identity (default: 20%)
876    #[builder(default = "20.0")]
877    pm_style_strength: f32,
878
879    /// Strength to apply Control Net (default: 0.9)
880    /// 1.0 corresponds to full destruction of information in init
881    #[builder(default = "0.9")]
882    control_strength: f32,
883
884    /// Image height, in pixel space (default: 512)
885    #[builder(default = "512")]
886    height: i32,
887
888    /// Image width, in pixel space (default: 512)
889    #[builder(default = "512")]
890    width: i32,
891
892    /// Sampling-method (default: [SampleMethod::SAMPLE_METHOD_COUNT]).
893    /// [SampleMethod::EULER_SAMPLE_METHOD] will be used for flux, sd3, wan, qwen_image.
894    /// Otherwise [SampleMethod::EULER_A_SAMPLE_METHOD] is used.
895    #[builder(default = "SampleMethod::SAMPLE_METHOD_COUNT")]
896    sampling_method: SampleMethod,
897
898    /// eta in DDIM, only for DDIM and TCD: (default: 0)
899    #[builder(default = "0.")]
900    eta: f32,
901
902    /// Number of sample steps (default: 20)
903    #[builder(default = "20")]
904    steps: i32,
905
906    /// RNG seed (default: 42, use random seed for < 0)
907    #[builder(default = "42")]
908    seed: i64,
909
910    /// Number of images to generate (default: 1)
911    #[builder(default = "1")]
912    batch_count: i32,
913
914    /// Ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)
915    /// <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x
916    #[builder(default = "ClipSkip::Unspecified")]
917    clip_skip: ClipSkip,
918
919    /// Apply canny preprocessor (edge detection) (default: false)
920    #[builder(default = "false")]
921    canny: bool,
922
923    /// skip layer guidance (SLG) scale, only for DiT models: (default: 0)
924    /// 0 means disabled, a value of 2.5 is nice for sd3.5 medium
925    #[builder(default = "0.")]
926    slg_scale: f32,
927
928    /// Layers to skip for SLG steps: (default: \[7,8,9\])
929    #[builder(default = "vec![7, 8, 9]")]
930    skip_layer: Vec<i32>,
931
932    /// SLG enabling point: (default: 0.01)
933    #[builder(default = "0.01")]
934    skip_layer_start: f32,
935
936    /// SLG disabling point: (default: 0.2)
937    #[builder(default = "0.2")]
938    skip_layer_end: f32,
939
940    /// Disable auto resize of ref images
941    #[builder(default = "false")]
942    disable_auto_resize_ref_image: bool,
943
944    #[builder(default = "Self::cache_init()", private)]
945    cache: sd_cache_params_t,
946
947    #[builder(default = "CLibString::default()", private)]
948    scm_mask: CLibString,
949}
950
951impl ConfigBuilder {
952    fn validate(&self) -> Result<(), ConfigBuilderError> {
953        self.validate_output_dir()
954    }
955
956    fn validate_output_dir(&self) -> Result<(), ConfigBuilderError> {
957        let is_dir = self.output.as_ref().is_some_and(|val| val.is_dir());
958        let multiple_items = self.batch_count.as_ref().is_some_and(|val| *val > 1);
959        if is_dir == multiple_items {
960            Ok(())
961        } else {
962            Err(ConfigBuilderError::ValidationError(
963                "When batch_count > 1, output should point to folder and vice versa".to_owned(),
964            ))
965        }
966    }
967
968    fn cache_init() -> sd_cache_params_t {
969        sd_cache_params_t {
970            mode: sd_cache_mode_t::SD_CACHE_DISABLED,
971            reuse_threshold: 1.0,
972            start_percent: 0.15,
973            end_percent: 0.95,
974            error_decay_rate: 1.0,
975            use_relative_threshold: true,
976            reset_error_on_compute: true,
977            Fn_compute_blocks: 8,
978            Bn_compute_blocks: 0,
979            residual_diff_threshold: 0.08,
980            max_warmup_steps: 8,
981            max_cached_steps: -1,
982            max_continuous_cached_steps: -1,
983            taylorseer_n_derivatives: 1,
984            taylorseer_skip_interval: 1,
985            scm_mask: null(),
986            scm_policy_dynamic: true,
987        }
988    }
989
990    pub fn no_caching(&mut self) -> &mut Self {
991        let mut cache = Self::cache_init();
992        cache.mode = sd_cache_mode_t::SD_CACHE_DISABLED;
993        self.cache = Some(cache);
994        self
995    }
996
997    pub fn ucache_caching(&mut self, params: UCacheParams) -> &mut Self {
998        let mut cache = Self::cache_init();
999        cache.mode = sd_cache_mode_t::SD_CACHE_UCACHE;
1000        cache.reuse_threshold = params.threshold;
1001        cache.start_percent = params.start;
1002        cache.end_percent = params.end;
1003        cache.error_decay_rate = params.decay;
1004        cache.use_relative_threshold = params.relative;
1005        cache.reset_error_on_compute = params.reset;
1006        self.cache = Some(cache);
1007        self
1008    }
1009
1010    pub fn easy_cache_caching(&mut self, params: EasyCacheParams) -> &mut Self {
1011        let mut cache = Self::cache_init();
1012        cache.mode = sd_cache_mode_t::SD_CACHE_EASYCACHE;
1013        cache.reuse_threshold = params.threshold;
1014        cache.start_percent = params.start;
1015        cache.end_percent = params.end;
1016        self.cache = Some(cache);
1017        self
1018    }
1019
1020    pub fn db_cache_caching(&mut self, params: DbCacheParams) -> &mut Self {
1021        let mut cache = Self::cache_init();
1022        cache.mode = sd_cache_mode_t::SD_CACHE_DBCACHE;
1023        cache.Fn_compute_blocks = params.fn_blocks;
1024        cache.Bn_compute_blocks = params.bn_blocks;
1025        cache.residual_diff_threshold = params.threshold;
1026        cache.max_warmup_steps = params.warmup;
1027        cache.scm_policy_dynamic = match params.scm_policy_dynamic {
1028            ScmPolicy::Static => false,
1029            ScmPolicy::Dynamic => true,
1030        };
1031        self.scm_mask = Some(CLibString::from(
1032            params
1033                .scm_mask
1034                .to_vec_string(self.steps.unwrap_or_default()),
1035        ));
1036        cache.scm_mask = self.scm_mask.as_ref().unwrap().as_ptr();
1037
1038        self.cache = Some(cache);
1039        self
1040    }
1041
1042    pub fn taylor_seer_caching(&mut self) -> &mut Self {
1043        let mut cache = Self::cache_init();
1044        cache.mode = sd_cache_mode_t::SD_CACHE_TAYLORSEER;
1045        self.cache = Some(cache);
1046        self
1047    }
1048
1049    pub fn cache_dit_caching(&mut self, params: DbCacheParams) -> &mut Self {
1050        self.db_cache_caching(params).cache.unwrap().mode = sd_cache_mode_t::SD_CACHE_CACHE_DIT;
1051        self
1052    }
1053}
1054
1055impl From<Config> for ConfigBuilder {
1056    fn from(value: Config) -> Self {
1057        let mut builder = ConfigBuilder::default();
1058        let mut cache = value.cache;
1059        let scm_mask = value.scm_mask.clone();
1060        cache.scm_mask = scm_mask.as_ptr();
1061        builder
1062            .pm_id_images_dir(value.pm_id_images_dir)
1063            .init_img(value.init_img)
1064            .control_image(value.control_image)
1065            .output(value.output)
1066            .prompt(value.prompt)
1067            .negative_prompt(value.negative_prompt)
1068            .cfg_scale(value.cfg_scale)
1069            .strength(value.strength)
1070            .pm_style_strength(value.pm_style_strength)
1071            .control_strength(value.control_strength)
1072            .height(value.height)
1073            .width(value.width)
1074            .sampling_method(value.sampling_method)
1075            .steps(value.steps)
1076            .seed(value.seed)
1077            .batch_count(value.batch_count)
1078            .clip_skip(value.clip_skip)
1079            .slg_scale(value.slg_scale)
1080            .skip_layer(value.skip_layer)
1081            .skip_layer_start(value.skip_layer_start)
1082            .skip_layer_end(value.skip_layer_end)
1083            .canny(value.canny)
1084            .disable_auto_resize_ref_image(value.disable_auto_resize_ref_image)
1085            .preview_output(value.preview_output)
1086            .preview_mode(value.preview_mode)
1087            .preview_noisy(value.preview_noisy)
1088            .preview_interval(value.preview_interval)
1089            .cache(cache)
1090            .scm_mask(scm_mask);
1091        builder
1092    }
1093}
1094
1095#[derive(Debug, Clone, Default)]
1096struct CLibString(CString);
1097
1098impl CLibString {
1099    fn as_ptr(&self) -> *const c_char {
1100        self.0.as_ptr()
1101    }
1102}
1103
1104impl From<&str> for CLibString {
1105    fn from(value: &str) -> Self {
1106        Self(CString::new(value).unwrap())
1107    }
1108}
1109
1110impl From<String> for CLibString {
1111    fn from(value: String) -> Self {
1112        Self(CString::new(value).unwrap())
1113    }
1114}
1115
1116#[derive(Debug, Clone, Default)]
1117struct CLibPath(CString);
1118
1119impl CLibPath {
1120    fn as_ptr(&self) -> *const c_char {
1121        self.0.as_ptr()
1122    }
1123}
1124
1125impl From<PathBuf> for CLibPath {
1126    fn from(value: PathBuf) -> Self {
1127        Self(CString::new(value.to_str().unwrap_or_default()).unwrap())
1128    }
1129}
1130
1131impl From<&Path> for CLibPath {
1132    fn from(value: &Path) -> Self {
1133        Self(CString::new(value.to_str().unwrap_or_default()).unwrap())
1134    }
1135}
1136
1137impl From<&CLibPath> for PathBuf {
1138    fn from(value: &CLibPath) -> Self {
1139        PathBuf::from(value.0.to_str().unwrap())
1140    }
1141}
1142
1143fn output_files(path: &Path, batch_size: i32) -> Vec<PathBuf> {
1144    let date = Local::now().format("%Y.%m.%d-%H.%M.%S");
1145    if batch_size == 1 {
1146        vec![path.into()]
1147    } else {
1148        (1..=batch_size)
1149            .map(|id| path.join(format!("output_{date}_{id}.png")))
1150            .collect()
1151    }
1152}
1153
1154unsafe fn upscale(
1155    upscale_repeats: i32,
1156    upscaler_ctx: Option<*mut upscaler_ctx_t>,
1157    data: sd_image_t,
1158) -> Result<sd_image_t, DiffusionError> {
1159    unsafe {
1160        match upscaler_ctx {
1161            Some(upscaler_ctx) => {
1162                let upscale_factor = 4; // unused for RealESRGAN_x4plus_anime_6B.pth
1163                let mut current_image = data;
1164                for _ in 0..upscale_repeats {
1165                    let upscaled_image =
1166                        diffusion_rs_sys::upscale(upscaler_ctx, current_image, upscale_factor);
1167
1168                    if upscaled_image.data.is_null() {
1169                        return Err(DiffusionError::Upscaler);
1170                    }
1171
1172                    free(current_image.data as *mut c_void);
1173                    current_image = upscaled_image;
1174                }
1175                Ok(current_image)
1176            }
1177            None => Ok(data),
1178        }
1179    }
1180}
1181
1182/// Generate an image with a prompt
1183pub fn gen_img(config: &Config, model_config: &mut ModelConfig) -> Result<(), DiffusionError> {
1184    let prompt: CLibString = CLibString::from(config.prompt.as_str());
1185    let files = output_files(&config.output, config.batch_count);
1186    unsafe {
1187        let sd_ctx = model_config.diffusion_ctx(true);
1188        let upscaler_ctx = model_config.upscaler_ctx();
1189        let init_image = sd_image_t {
1190            width: 0,
1191            height: 0,
1192            channel: 3,
1193            data: null_mut(),
1194        };
1195        let mask_image = sd_image_t {
1196            width: config.width as u32,
1197            height: config.height as u32,
1198            channel: 1,
1199            data: null_mut(),
1200        };
1201        let mut layers = config.skip_layer.clone();
1202        let guidance = sd_guidance_params_t {
1203            txt_cfg: config.cfg_scale,
1204            img_cfg: config.cfg_scale,
1205            distilled_guidance: config.guidance,
1206            slg: sd_slg_params_t {
1207                layers: layers.as_mut_ptr(),
1208                layer_count: config.skip_layer.len(),
1209                layer_start: config.skip_layer_start,
1210                layer_end: config.skip_layer_end,
1211                scale: config.slg_scale,
1212            },
1213        };
1214        let scheduler = if model_config.scheduler == Scheduler::SCHEDULER_COUNT {
1215            sd_get_default_scheduler(sd_ctx, config.sampling_method)
1216        } else {
1217            model_config.scheduler
1218        };
1219        let sample_method = if config.sampling_method == SampleMethod::SAMPLE_METHOD_COUNT {
1220            sd_get_default_sample_method(sd_ctx)
1221        } else {
1222            config.sampling_method
1223        };
1224        let sample_params = sd_sample_params_t {
1225            guidance,
1226            sample_method,
1227            sample_steps: config.steps,
1228            eta: config.eta,
1229            scheduler,
1230            shifted_timestep: model_config.timestep_shift,
1231            custom_sigmas: model_config.sigmas.as_mut_ptr(),
1232            custom_sigmas_count: model_config.sigmas.len() as i32,
1233        };
1234        let control_image = sd_image_t {
1235            width: 0,
1236            height: 0,
1237            channel: 3,
1238            data: null_mut(),
1239        };
1240        let vae_tiling_params = sd_tiling_params_t {
1241            enabled: model_config.vae_tiling,
1242            tile_size_x: model_config.vae_tile_size.0,
1243            tile_size_y: model_config.vae_tile_size.1,
1244            target_overlap: model_config.vae_tile_overlap,
1245            rel_size_x: model_config.vae_relative_tile_size.0,
1246            rel_size_y: model_config.vae_relative_tile_size.1,
1247        };
1248        let pm_params = sd_pm_params_t {
1249            id_images: null_mut(),
1250            id_images_count: 0,
1251            id_embed_path: model_config.pm_id_embed_path.as_ptr(),
1252            style_strength: config.pm_style_strength,
1253        };
1254
1255        unsafe extern "C" fn save_preview_local(
1256            _step: ::std::os::raw::c_int,
1257            _frame_count: ::std::os::raw::c_int,
1258            frames: *mut sd_image_t,
1259            _is_noisy: bool,
1260            data: *mut ::std::os::raw::c_void,
1261        ) {
1262            unsafe {
1263                let path = &*data.cast::<PathBuf>();
1264                let _ = save_img(*frames, path);
1265            }
1266        }
1267
1268        if config.preview_mode != PreviewType::PREVIEW_NONE {
1269            let data = &config.preview_output as *const PathBuf;
1270
1271            sd_set_preview_callback(
1272                Some(save_preview_local),
1273                config.preview_mode,
1274                config.preview_interval,
1275                !config.preview_noisy,
1276                config.preview_noisy,
1277                data as *mut c_void,
1278            );
1279        }
1280
1281        let sd_img_gen_params = sd_img_gen_params_t {
1282            prompt: prompt.as_ptr(),
1283            negative_prompt: config.negative_prompt.as_ptr(),
1284            clip_skip: config.clip_skip as i32,
1285            init_image,
1286            ref_images: null_mut(),
1287            ref_images_count: 0,
1288            increase_ref_index: false,
1289            mask_image,
1290            width: config.width,
1291            height: config.height,
1292            sample_params,
1293            strength: config.strength,
1294            seed: config.seed,
1295            batch_count: config.batch_count,
1296            control_image,
1297            control_strength: config.control_strength,
1298            pm_params,
1299            vae_tiling_params,
1300            auto_resize_ref_image: config.disable_auto_resize_ref_image,
1301            cache: config.cache,
1302            loras: model_config.lora_models.loras_t.as_ptr(),
1303            lora_count: model_config.lora_models.loras_t.len() as u32,
1304        };
1305        let slice = diffusion_rs_sys::generate_image(sd_ctx, &sd_img_gen_params);
1306        let ret = {
1307            if slice.is_null() {
1308                return Err(DiffusionError::Forward);
1309            }
1310            for (img, path) in slice::from_raw_parts(slice, config.batch_count as usize)
1311                .iter()
1312                .zip(files)
1313            {
1314                match upscale(model_config.upscale_repeats, upscaler_ctx, *img) {
1315                    Ok(img) => save_img(img, &path)?,
1316                    Err(err) => {
1317                        return Err(err);
1318                    }
1319                }
1320            }
1321            Ok(())
1322        };
1323        free(slice as *mut c_void);
1324        ret
1325    }
1326}
1327
1328fn save_img(img: sd_image_t, path: &Path) -> Result<(), DiffusionError> {
1329    // Thx @wandbrandon
1330    let len = (img.width * img.height * img.channel) as usize;
1331    let buffer = unsafe { slice::from_raw_parts(img.data, len).to_vec() };
1332    let save_state = ImageBuffer::from_raw(img.width, img.height, buffer)
1333        .map(|img| RgbImage::from(img).save(path));
1334    if let Some(Err(err)) = save_state {
1335        return Err(DiffusionError::StoreImages(err));
1336    }
1337    Ok(())
1338}
1339
1340#[cfg(test)]
1341mod tests {
1342    use std::path::PathBuf;
1343
1344    use crate::{
1345        api::{ConfigBuilderError, ModelConfigBuilder},
1346        util::download_file_hf_hub,
1347    };
1348
1349    use super::{ConfigBuilder, gen_img};
1350
1351    #[test]
1352    fn test_required_args_txt2img() {
1353        assert!(ConfigBuilder::default().build().is_err());
1354        assert!(ModelConfigBuilder::default().build().is_err());
1355        ModelConfigBuilder::default()
1356            .model(PathBuf::from("./test.ckpt"))
1357            .build()
1358            .unwrap();
1359
1360        ConfigBuilder::default()
1361            .prompt("a lovely cat driving a sport car")
1362            .build()
1363            .unwrap();
1364
1365        assert!(matches!(
1366            ConfigBuilder::default()
1367                .prompt("a lovely cat driving a sport car")
1368                .batch_count(10)
1369                .build(),
1370            Err(ConfigBuilderError::ValidationError(_))
1371        ));
1372
1373        ConfigBuilder::default()
1374            .prompt("a lovely cat driving a sport car")
1375            .build()
1376            .unwrap();
1377
1378        ConfigBuilder::default()
1379            .prompt("a lovely duck drinking water from a bottle")
1380            .batch_count(2)
1381            .output(PathBuf::from("./"))
1382            .build()
1383            .unwrap();
1384    }
1385
1386    #[ignore]
1387    #[test]
1388    fn test_img_gen() {
1389        let model_path =
1390            download_file_hf_hub("CompVis/stable-diffusion-v-1-4-original", "sd-v1-4.ckpt")
1391                .unwrap();
1392
1393        let upscaler_path = download_file_hf_hub(
1394            "ximso/RealESRGAN_x4plus_anime_6B",
1395            "RealESRGAN_x4plus_anime_6B.pth",
1396        )
1397        .unwrap();
1398        let config = ConfigBuilder::default()
1399            .prompt("a lovely duck drinking water from a bottle")
1400            .output(PathBuf::from("./output_1.png"))
1401            .batch_count(1)
1402            .build()
1403            .unwrap();
1404        let mut model_config = ModelConfigBuilder::default()
1405            .model(model_path)
1406            .upscale_model(upscaler_path)
1407            .upscale_repeats(1)
1408            .build()
1409            .unwrap();
1410
1411        gen_img(&config, &mut model_config).unwrap();
1412        let config2 = ConfigBuilder::from(config.clone())
1413            .prompt("a lovely duck drinking water from a straw")
1414            .output(PathBuf::from("./output_2.png"))
1415            .build()
1416            .unwrap();
1417        gen_img(&config2, &mut model_config).unwrap();
1418
1419        let config3 = ConfigBuilder::from(config)
1420            .prompt("a lovely dog drinking water from a starbucks cup")
1421            .batch_count(2)
1422            .output(PathBuf::from("./"))
1423            .build()
1424            .unwrap();
1425
1426        gen_img(&config3, &mut model_config).unwrap();
1427    }
1428}