diffusion_rs/
api.rs

1use std::cmp::max;
2use std::collections::HashMap;
3use std::ffi::CString;
4use std::ffi::c_char;
5use std::ffi::c_void;
6use std::fmt::Display;
7use std::path::Path;
8use std::path::PathBuf;
9use std::ptr::null;
10use std::ptr::null_mut;
11use std::slice;
12use std::sync::mpsc::Sender;
13
14use chrono::Local;
15use derive_builder::Builder;
16use diffusion_rs_sys::free_upscaler_ctx;
17use diffusion_rs_sys::generate_image;
18use diffusion_rs_sys::new_upscaler_ctx;
19use diffusion_rs_sys::sd_cache_mode_t;
20use diffusion_rs_sys::sd_cache_params_t;
21use diffusion_rs_sys::sd_ctx_params_t;
22use diffusion_rs_sys::sd_embedding_t;
23use diffusion_rs_sys::sd_get_default_sample_method;
24use diffusion_rs_sys::sd_get_default_scheduler;
25use diffusion_rs_sys::sd_guidance_params_t;
26use diffusion_rs_sys::sd_image_t;
27use diffusion_rs_sys::sd_img_gen_params_t;
28use diffusion_rs_sys::sd_img_gen_params_to_str;
29use diffusion_rs_sys::sd_lora_t;
30use diffusion_rs_sys::sd_pm_params_t;
31use diffusion_rs_sys::sd_sample_params_t;
32use diffusion_rs_sys::sd_set_preview_callback;
33use diffusion_rs_sys::sd_set_progress_callback;
34use diffusion_rs_sys::sd_slg_params_t;
35use diffusion_rs_sys::sd_tiling_params_t;
36use diffusion_rs_sys::upscaler_ctx_t;
37use image::ImageBuffer;
38use image::ImageError;
39use image::RgbImage;
40use libc::free;
41use little_exif::exif_tag::ExifTag;
42use little_exif::metadata::Metadata;
43use thiserror::Error;
44use walkdir::DirEntry;
45use walkdir::WalkDir;
46
47use diffusion_rs_sys::free_sd_ctx;
48use diffusion_rs_sys::new_sd_ctx;
49use diffusion_rs_sys::sd_ctx_t;
50
51/// Specify the range function
52pub use diffusion_rs_sys::rng_type_t as RngFunction;
53
54/// Sampling methods
55pub use diffusion_rs_sys::sample_method_t as SampleMethod;
56
57/// Denoiser sigma schedule
58pub use diffusion_rs_sys::scheduler_t as Scheduler;
59
60/// Prediction override
61pub use diffusion_rs_sys::prediction_t as Prediction;
62
63/// Weight type
64pub use diffusion_rs_sys::sd_type_t as WeightType;
65
66/// Preview mode
67pub use diffusion_rs_sys::preview_t as PreviewType;
68
69/// Lora mode
70pub use diffusion_rs_sys::lora_apply_mode_t as LoraModeType;
71
72static VALID_EXT: [&str; 3] = ["gguf", "safetensors", "pt"];
73
74#[allow(unused)]
75#[derive(Debug)]
76/// Progress message returned fron [gen_img_with_progress]
77pub struct Progress {
78    step: i32,
79    steps: i32,
80    time: f32,
81}
82
83#[non_exhaustive]
84#[derive(Error, Debug)]
85/// Error that can occurs while forwarding models
86pub enum DiffusionError {
87    #[error("The underling stablediffusion.cpp function returned NULL")]
88    Forward,
89    #[error(transparent)]
90    StoreImages(#[from] ImageError),
91    #[error(transparent)]
92    Io(#[from] std::io::Error),
93    #[error("The underling upscaler model returned a NULL image")]
94    Upscaler,
95}
96
97#[repr(i32)]
98#[non_exhaustive]
99#[derive(Debug, Default, Copy, Clone, Hash, PartialEq, Eq)]
100/// Ignore the lower X layers of CLIP network
101pub enum ClipSkip {
102    /// Will be [ClipSkip::None] for SD1.x, [ClipSkip::OneLayer] for SD2.x
103    #[default]
104    Unspecified = 0,
105    None = 1,
106    OneLayer = 2,
107}
108
109type EmbeddingsStorage = (PathBuf, Vec<(CLibString, CLibPath)>, Vec<sd_embedding_t>);
110
111#[derive(Default, Debug, Clone)]
112struct LoraStorage {
113    lora_model_dir: CLibPath,
114    data: Vec<(CLibPath, String, f32)>,
115    loras_t: Vec<sd_lora_t>,
116}
117
118/// Specify the instructions for a Lora model
119#[derive(Default, Debug, Clone)]
120pub struct LoraSpec {
121    pub file_name: String,
122    pub is_high_noise: bool,
123    pub multiplier: f32,
124}
125
126/// Parameters for UCache
127#[derive(Builder, Debug, Clone)]
128pub struct UCacheParams {
129    /// Error threshold for reuse decision
130    #[builder(default = "1.0")]
131    threshold: f32,
132
133    /// Start caching at this percent of steps
134    #[builder(default = "0.15")]
135    start: f32,
136
137    /// Stop caching at this percent of steps
138    #[builder(default = "0.95")]
139    end: f32,
140
141    /// Error decay rate (0-1)
142    #[builder(default = "1.0")]
143    decay: f32,
144
145    /// Scale threshold by output norm
146    #[builder(default = "true")]
147    relative: bool,
148
149    /// Reset error after computing
150    /// true: Resets accumulated error after each computed step. More aggressive caching, works well with most samplers.
151    /// false: Keeps error accumulated. More conservative, recommended for euler_a sampler
152    #[builder(default = "true")]
153    reset: bool,
154}
155
156/// Parameters for Easy Cache
157#[derive(Builder, Debug, Clone)]
158pub struct EasyCacheParams {
159    /// Error threshold for reuse decision
160    #[builder(default = "0.2")]
161    threshold: f32,
162
163    /// Start caching at this percent of steps
164    #[builder(default = "0.15")]
165    start: f32,
166
167    /// Stop caching at this percent of steps
168    #[builder(default = "0.95")]
169    end: f32,
170}
171
172/// Parameters for Db Cache
173#[derive(Builder, Debug, Clone)]
174pub struct DbCacheParams {
175    /// Front blocks to always compute
176    #[builder(default = "8")]
177    fn_blocks: i32,
178
179    /// Back blocks to always compute
180    #[builder(default = "0")]
181    bn_blocks: i32,
182
183    /// L1 residual difference threshold
184    #[builder(default = "0.08")]
185    threshold: f32,
186
187    /// Steps before caching starts
188    #[builder(default = "8")]
189    warmup: i32,
190
191    /// Steps Computation Mask controls which steps can be cached
192    scm_mask: ScmPreset,
193
194    /// Scm Policy
195    #[builder(default = "ScmPolicy::default()")]
196    scm_policy_dynamic: ScmPolicy,
197}
198
199/// Steps Computation Mask Policy controls when to cache steps
200#[derive(Debug, Default, Clone)]
201pub enum ScmPolicy {
202    /// Always cache on cacheable steps
203    Static,
204    #[default]
205    /// Check threshold before caching
206    Dynamic,
207}
208
209/// Steps Computation Mask Preset controls which steps can be cached
210#[derive(Debug, Default, Clone)]
211pub enum ScmPreset {
212    Slow,
213    #[default]
214    Medium,
215    Fast,
216    Ultra,
217    /// E.g.: "1,1,1,1,0,0,1,0,0,0,1,0,0,0,1,0,0,0,1,1"
218    /// where 1 means compute, 0 means cache
219    Custom(String),
220}
221
222impl ScmPreset {
223    fn to_vec_string(&self, steps: i32) -> String {
224        match self {
225            ScmPreset::Slow => ScmPresetBins {
226                compute_bins: vec![8, 3, 3, 2, 1, 1],
227                cache_bins: vec![1, 2, 2, 2, 3],
228                steps,
229            }
230            .to_string(),
231            ScmPreset::Medium => ScmPresetBins {
232                compute_bins: vec![6, 2, 2, 2, 2, 1],
233                cache_bins: vec![1, 3, 3, 3, 3],
234                steps,
235            }
236            .to_string(),
237            ScmPreset::Fast => ScmPresetBins {
238                compute_bins: vec![6, 1, 1, 1, 1, 1],
239                cache_bins: vec![1, 3, 4, 5, 4],
240                steps,
241            }
242            .to_string(),
243            ScmPreset::Ultra => ScmPresetBins {
244                compute_bins: vec![4, 1, 1, 1, 1],
245                cache_bins: vec![2, 5, 6, 7],
246                steps,
247            }
248            .to_string(),
249            ScmPreset::Custom(s) => s.clone(),
250        }
251    }
252}
253
254#[derive(Debug, Clone)]
255struct ScmPresetBins {
256    compute_bins: Vec<i32>,
257    cache_bins: Vec<i32>,
258    steps: i32,
259}
260
261impl ScmPresetBins {
262    fn maybe_scale(&self) -> ScmPresetBins {
263        if self.steps == 28 || self.steps <= 0 {
264            return self.clone();
265        }
266        self.scale()
267    }
268
269    fn scale(&self) -> ScmPresetBins {
270        let scale = self.steps as f32 / 28.0;
271        let scaled_compute_bins = self
272            .compute_bins
273            .iter()
274            .map(|b| max(1, (*b as f32 * scale * 0.5) as i32))
275            .collect();
276        let scaled_cached_bins = self
277            .cache_bins
278            .iter()
279            .map(|b| max(1, (*b as f32 * scale * 0.5) as i32))
280            .collect();
281        ScmPresetBins {
282            compute_bins: scaled_compute_bins,
283            cache_bins: scaled_cached_bins,
284            steps: self.steps,
285        }
286    }
287
288    fn generate_vec_mask(&self) -> Vec<i32> {
289        let mut mask = Vec::new();
290        let mut c_idx = 0;
291        let mut cache_idx = 0;
292
293        while mask.len() < self.steps as usize {
294            if c_idx < self.compute_bins.len() {
295                let compute_count = self.compute_bins[c_idx];
296                for _ in 0..compute_count {
297                    if mask.len() < self.steps as usize {
298                        mask.push(1);
299                    }
300                }
301                c_idx += 1;
302            }
303            if cache_idx < self.cache_bins.len() {
304                let cache_count = self.cache_bins[c_idx];
305                for _ in 0..cache_count {
306                    if mask.len() < self.steps as usize {
307                        mask.push(0);
308                    }
309                }
310                cache_idx += 1;
311            }
312            if c_idx >= self.compute_bins.len() && cache_idx >= self.cache_bins.len() {
313                break;
314            }
315        }
316        if let Some(last) = mask.last_mut() {
317            *last = 1;
318        }
319        mask
320    }
321}
322
323impl Display for ScmPresetBins {
324    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
325        let mask: String = self
326            .maybe_scale()
327            .generate_vec_mask()
328            .iter()
329            .map(|x| x.to_string())
330            .collect::<Vec<_>>()
331            .join(",");
332        write!(f, "{mask}")
333    }
334}
335
336/// Config struct for a specific diffusion model
337#[derive(Builder, Debug, Clone)]
338#[builder(
339    setter(into, strip_option),
340    build_fn(error = "ConfigBuilderError", validate = "Self::validate")
341)]
342pub struct ModelConfig {
343    /// Number of threads to use during computation (default: 0).
344    /// If n_ threads <= 0, then threads will be set to the number of CPU physical cores.
345    #[builder(default = "num_cpus::get_physical() as i32", setter(custom))]
346    n_threads: i32,
347
348    /// Whether to memory-map model (default: false)
349    #[builder(default = "false")]
350    enable_mmap: bool,
351
352    /// Place the weights in RAM to save VRAM, and automatically load them into VRAM when needed
353    #[builder(default = "false")]
354    offload_params_to_cpu: bool,
355
356    /// Path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now
357    #[builder(default = "Default::default()")]
358    upscale_model: Option<CLibPath>,
359
360    /// Run the ESRGAN upscaler this many times (default 1)
361    #[builder(default = "1")]
362    upscale_repeats: i32,
363
364    /// Tile size for ESRGAN upscaler (default 128)
365    #[builder(default = "128")]
366    upscale_tile_size: i32,
367
368    /// Path to full model
369    #[builder(default = "Default::default()")]
370    model: CLibPath,
371
372    /// Path to the standalone diffusion model
373    #[builder(default = "Default::default()")]
374    diffusion_model: CLibPath,
375
376    /// Path to the qwen2vl text encoder
377    #[builder(default = "Default::default()")]
378    llm: CLibPath,
379
380    /// Path to the qwen2vl vit
381    #[builder(default = "Default::default()")]
382    llm_vision: CLibPath,
383
384    /// Path to the clip-l text encoder
385    #[builder(default = "Default::default()")]
386    clip_l: CLibPath,
387
388    /// Path to the clip-g text encoder
389    #[builder(default = "Default::default()")]
390    clip_g: CLibPath,
391
392    /// Path to the clip-vision encoder
393    #[builder(default = "Default::default()")]
394    clip_vision: CLibPath,
395
396    /// Path to the t5xxl text encoder
397    #[builder(default = "Default::default()")]
398    t5xxl: CLibPath,
399
400    /// Path to vae
401    #[builder(default = "Default::default()")]
402    vae: CLibPath,
403
404    /// Path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)
405    #[builder(default = "Default::default()")]
406    taesd: CLibPath,
407
408    /// Path to control net model
409    #[builder(default = "Default::default()")]
410    control_net: CLibPath,
411
412    /// Path to embeddings
413    #[builder(default = "Default::default()", setter(custom))]
414    embeddings: EmbeddingsStorage,
415
416    /// Path to PHOTOMAKER model
417    #[builder(default = "Default::default()")]
418    photo_maker: CLibPath,
419
420    /// Path to PHOTOMAKER v2 id embed
421    #[builder(default = "Default::default()")]
422    pm_id_embed_path: CLibPath,
423
424    /// Weight type. If not specified, the default is the type of the weight file
425    #[builder(default = "WeightType::SD_TYPE_COUNT")]
426    weight_type: WeightType,
427
428    /// Lora model directory
429    #[builder(default = "Default::default()", setter(custom))]
430    lora_models: LoraStorage,
431
432    /// Path to the standalone high noise diffusion model
433    #[builder(default = "Default::default()")]
434    high_noise_diffusion_model: CLibPath,
435
436    /// Process vae in tiles to reduce memory usage (default: false)
437    #[builder(default = "false")]
438    vae_tiling: bool,
439
440    /// Tile size for vae tiling (default: 32x32)
441    #[builder(default = "(32,32)")]
442    vae_tile_size: (i32, i32),
443
444    /// Relative tile size for vae tiling, in fraction of image size if < 1, in number of tiles per dim if >=1 (overrides vae_tile_size)
445    #[builder(default = "(0.,0.)")]
446    vae_relative_tile_size: (f32, f32),
447
448    /// Tile overlap for vae tiling, in fraction of tile size (default: 0.5)
449    #[builder(default = "0.5")]
450    vae_tile_overlap: f32,
451
452    /// RNG (default: CUDA)
453    #[builder(default = "RngFunction::CUDA_RNG")]
454    rng: RngFunction,
455
456    /// Sampler RNG. If [RngFunction::RNG_TYPE_COUNT] is used will default to rng value. (default: [RngFunction::RNG_TYPE_COUNT])",
457    #[builder(default = "RngFunction::RNG_TYPE_COUNT")]
458    sampler_rng_type: RngFunction,
459
460    /// Denoiser sigma schedule (default: [Scheduler::SCHEDULER_COUNT]).
461    /// Will default to [Scheduler::EXPONENTIAL_SCHEDULER] if a denoiser is already instantiated.
462    /// Otherwise, [Scheduler::DISCRETE_SCHEDULER] is used.
463    #[builder(default = "Scheduler::SCHEDULER_COUNT")]
464    scheduler: Scheduler,
465
466    /// Custom sigma values for the sampler
467    #[builder(default = "Default::default()")]
468    sigmas: Vec<f32>,
469
470    /// Prediction type override (default: PREDICTION_COUNT)
471    #[builder(default = "Prediction::PREDICTION_COUNT")]
472    prediction: Prediction,
473
474    /// Keep vae in cpu (for low vram) (default: false)
475    #[builder(default = "false")]
476    vae_on_cpu: bool,
477
478    /// keep clip in cpu (for low vram) (default: false)
479    #[builder(default = "false")]
480    clip_on_cpu: bool,
481
482    /// Keep controlnet in cpu (for low vram) (default: false)
483    #[builder(default = "false")]
484    control_net_cpu: bool,
485
486    /// Use flash attention to reduce memory usage (for low vram).
487    // /For most backends, it slows things down, but for cuda it generally speeds it up too. At the moment, it is only supported for some models and some backends (like cpu, cuda/rocm, metal).
488    #[builder(default = "false")]
489    flash_attention: bool,
490
491    /// Disable dit mask for chroma
492    #[builder(default = "false")]
493    chroma_disable_dit_mask: bool,
494
495    /// Enable t5 mask for chroma
496    #[builder(default = "false")]
497    chroma_enable_t5_mask: bool,
498
499    /// t5 mask pad size of chroma
500    #[builder(default = "1")]
501    chroma_t5_mask_pad: i32,
502
503    /// Use qwen image zero cond true optimization
504    #[builder(default = "false")]
505    use_qwen_image_zero_cond_true: bool,
506
507    /// Use Conv2d direct in the diffusion model
508    /// This might crash if it is not supported by the backend.
509    #[builder(default = "false")]
510    diffusion_conv_direct: bool,
511
512    /// Use Conv2d direct in the vae model (should improve the performance)
513    /// This might crash if it is not supported by the backend.
514    #[builder(default = "false")]
515    vae_conv_direct: bool,
516
517    /// Force use of conv scale on sdxl vae
518    #[builder(default = "false")]
519    force_sdxl_vae_conv_scale: bool,
520
521    /// Shift value for Flow models like SD3.x or WAN (default: auto)
522    #[builder(default = "f32::INFINITY")]
523    flow_shift: f32,
524
525    /// Shift timestep for NitroFusion models, default: 0, recommended N for NitroSD-Realism around 250 and 500 for NitroSD-Vibrant
526    #[builder(default = "0")]
527    timestep_shift: i32,
528
529    /// Prevents usage of taesd for decoding the final image
530    #[builder(default = "false")]
531    taesd_preview_only: bool,
532
533    /// In auto mode, if the model weights contain any quantized parameters, the at_runtime mode will be used; otherwise, immediately will be used.The immediate mode may have precision and compatibility issues with quantized parameters, but it usually offers faster inference speed and, in some cases, lower memory usage. The at_runtime mode, on the other hand, is exactly the opposite
534    #[builder(default = "LoraModeType::LORA_APPLY_AUTO")]
535    lora_apply_mode: LoraModeType,
536
537    /// Enable circular padding for convolutions
538    #[builder(default = "false")]
539    circular: bool,
540
541    /// Enable circular RoPE wrapping on x-axis (width) only
542    #[builder(default = "false")]
543    circular_x: bool,
544
545    /// Enable circular RoPE wrapping on y-axis (height) only
546    #[builder(default = "false")]
547    circular_y: bool,
548
549    #[builder(default = "None", private)]
550    upscaler_ctx: Option<*mut upscaler_ctx_t>,
551
552    #[builder(default = "None", private)]
553    diffusion_ctx: Option<(*mut sd_ctx_t, sd_ctx_params_t)>,
554}
555
556impl ModelConfigBuilder {
557    fn validate(&self) -> Result<(), ConfigBuilderError> {
558        self.validate_model()
559    }
560
561    fn validate_model(&self) -> Result<(), ConfigBuilderError> {
562        self.model
563            .as_ref()
564            .or(self.diffusion_model.as_ref())
565            .map(|_| ())
566            .ok_or(ConfigBuilderError::UninitializedField(
567                "Model OR DiffusionModel must be valorized",
568            ))
569    }
570
571    fn filter_valid_extensions(&self, path: &Path) -> impl Iterator<Item = DirEntry> {
572        WalkDir::new(path)
573            .into_iter()
574            .filter_map(|entry| entry.ok())
575            .filter(|entry| {
576                entry
577                    .path()
578                    .extension()
579                    .and_then(|ext| ext.to_str())
580                    .map(|ext_str| VALID_EXT.contains(&ext_str))
581                    .unwrap_or(false)
582            })
583    }
584
585    fn build_single_lora_storage(
586        spec: &LoraSpec,
587        is_high_noise: bool,
588        valid_loras: &HashMap<String, PathBuf>,
589    ) -> ((CLibPath, String, f32), sd_lora_t) {
590        let path = valid_loras.get(&spec.file_name).unwrap().as_path();
591        let c_path = CLibPath::from(path);
592        let lora = sd_lora_t {
593            is_high_noise,
594            multiplier: spec.multiplier,
595            path: c_path.as_ptr(),
596        };
597        let data = (c_path, spec.file_name.clone(), spec.multiplier);
598        (data, lora)
599    }
600
601    pub fn embeddings(&mut self, embeddings_dir: &Path) -> &mut Self {
602        let data: Vec<(CLibString, CLibPath)> = self
603            .filter_valid_extensions(embeddings_dir)
604            .map(|entry| {
605                let file_stem = entry
606                    .path()
607                    .file_stem()
608                    .and_then(|stem| stem.to_str())
609                    .unwrap_or_default()
610                    .to_owned();
611                (CLibString::from(file_stem), CLibPath::from(entry.path()))
612            })
613            .collect();
614        let data_pointer = data
615            .iter()
616            .map(|(name, path)| sd_embedding_t {
617                name: name.as_ptr(),
618                path: path.as_ptr(),
619            })
620            .collect();
621        self.embeddings = Some((embeddings_dir.to_path_buf(), data, data_pointer));
622        self
623    }
624
625    pub fn lora_models(&mut self, lora_model_dir: &Path, specs: Vec<LoraSpec>) -> &mut Self {
626        let valid_loras: HashMap<String, PathBuf> = self
627            .filter_valid_extensions(lora_model_dir)
628            .map(|entry| {
629                let path = entry.path();
630                (
631                    path.file_stem()
632                        .and_then(|stem| stem.to_str())
633                        .unwrap_or_default()
634                        .to_owned(),
635                    path.to_path_buf(),
636                )
637            })
638            .collect();
639        let valid_lora_names: Vec<&String> = valid_loras.keys().collect();
640        let standard = specs
641            .iter()
642            .filter(|s| valid_lora_names.contains(&&s.file_name) && !s.is_high_noise)
643            .map(|s| Self::build_single_lora_storage(s, false, &valid_loras));
644        let high_noise = specs
645            .iter()
646            .filter(|s| valid_lora_names.contains(&&s.file_name) && s.is_high_noise)
647            .map(|s| Self::build_single_lora_storage(s, true, &valid_loras));
648
649        let mut data = Vec::new();
650        let mut loras_t = Vec::new();
651        for lora in standard.chain(high_noise) {
652            data.push(lora.0);
653            loras_t.push(lora.1);
654        }
655
656        self.lora_models = Some(LoraStorage {
657            lora_model_dir: lora_model_dir.into(),
658            data,
659            loras_t,
660        });
661        self
662    }
663
664    pub fn n_threads(&mut self, value: i32) -> &mut Self {
665        self.n_threads = if value > 0 {
666            Some(value)
667        } else {
668            Some(num_cpus::get_physical() as i32)
669        };
670        self
671    }
672}
673
674impl ModelConfig {
675    unsafe fn upscaler_ctx(&mut self) -> Option<*mut upscaler_ctx_t> {
676        unsafe {
677            if self.upscale_model.is_none() || self.upscale_repeats == 0 {
678                None
679            } else {
680                if self.upscaler_ctx.is_none() {
681                    let upscaler = new_upscaler_ctx(
682                        self.upscale_model.as_ref().unwrap().as_ptr(),
683                        self.offload_params_to_cpu,
684                        self.diffusion_conv_direct,
685                        self.n_threads,
686                        self.upscale_tile_size,
687                    );
688                    self.upscaler_ctx = Some(upscaler);
689                }
690                self.upscaler_ctx
691            }
692        }
693    }
694
695    unsafe fn diffusion_ctx(&mut self, vae_decode_only: bool) -> *mut sd_ctx_t {
696        unsafe {
697            if self.diffusion_ctx.is_none() {
698                let sd_ctx_params = sd_ctx_params_t {
699                    model_path: self.model.as_ptr(),
700                    llm_path: self.llm.as_ptr(),
701                    llm_vision_path: self.llm_vision.as_ptr(),
702                    clip_l_path: self.clip_l.as_ptr(),
703                    clip_g_path: self.clip_g.as_ptr(),
704                    clip_vision_path: self.clip_vision.as_ptr(),
705                    high_noise_diffusion_model_path: self.high_noise_diffusion_model.as_ptr(),
706                    t5xxl_path: self.t5xxl.as_ptr(),
707                    diffusion_model_path: self.diffusion_model.as_ptr(),
708                    vae_path: self.vae.as_ptr(),
709                    taesd_path: self.taesd.as_ptr(),
710                    control_net_path: self.control_net.as_ptr(),
711                    embeddings: self.embeddings.2.as_ptr(),
712                    embedding_count: self.embeddings.1.len() as u32,
713                    photo_maker_path: self.photo_maker.as_ptr(),
714                    vae_decode_only,
715                    free_params_immediately: false,
716                    n_threads: self.n_threads,
717                    wtype: self.weight_type,
718                    rng_type: self.rng,
719                    keep_clip_on_cpu: self.clip_on_cpu,
720                    keep_control_net_on_cpu: self.control_net_cpu,
721                    keep_vae_on_cpu: self.vae_on_cpu,
722                    diffusion_flash_attn: self.flash_attention,
723                    diffusion_conv_direct: self.diffusion_conv_direct,
724                    chroma_use_dit_mask: !self.chroma_disable_dit_mask,
725                    chroma_use_t5_mask: self.chroma_enable_t5_mask,
726                    chroma_t5_mask_pad: self.chroma_t5_mask_pad,
727                    vae_conv_direct: self.vae_conv_direct,
728                    offload_params_to_cpu: self.offload_params_to_cpu,
729                    flow_shift: self.flow_shift,
730                    prediction: self.prediction,
731                    force_sdxl_vae_conv_scale: self.force_sdxl_vae_conv_scale,
732                    tae_preview_only: self.taesd_preview_only,
733                    lora_apply_mode: self.lora_apply_mode,
734                    tensor_type_rules: null_mut(),
735                    sampler_rng_type: self.sampler_rng_type,
736                    circular_x: self.circular || self.circular_x,
737                    circular_y: self.circular || self.circular_y,
738                    qwen_image_zero_cond_t: self.use_qwen_image_zero_cond_true,
739                    enable_mmap: self.enable_mmap,
740                };
741                let ctx = new_sd_ctx(&sd_ctx_params);
742                self.diffusion_ctx = Some((ctx, sd_ctx_params))
743            }
744            self.diffusion_ctx.unwrap().0
745        }
746    }
747}
748
749impl Drop for ModelConfig {
750    fn drop(&mut self) {
751        //Cleanup CTX section
752        unsafe {
753            if let Some((sd_ctx, _)) = self.diffusion_ctx {
754                free_sd_ctx(sd_ctx);
755            }
756
757            if let Some(upscaler_ctx) = self.upscaler_ctx {
758                free_upscaler_ctx(upscaler_ctx);
759            }
760        }
761    }
762}
763
764impl From<ModelConfig> for ModelConfigBuilder {
765    fn from(value: ModelConfig) -> Self {
766        let mut builder = ModelConfigBuilder::default();
767        builder
768            .n_threads(value.n_threads)
769            .offload_params_to_cpu(value.offload_params_to_cpu)
770            .upscale_repeats(value.upscale_repeats)
771            .model(value.model.clone())
772            .diffusion_model(value.diffusion_model.clone())
773            .llm(value.llm.clone())
774            .llm_vision(value.llm_vision.clone())
775            .clip_l(value.clip_l.clone())
776            .clip_g(value.clip_g.clone())
777            .clip_vision(value.clip_vision.clone())
778            .t5xxl(value.t5xxl.clone())
779            .vae(value.vae.clone())
780            .taesd(value.taesd.clone())
781            .control_net(value.control_net.clone())
782            .embeddings(&value.embeddings.0)
783            .photo_maker(value.photo_maker.clone())
784            .pm_id_embed_path(value.pm_id_embed_path.clone())
785            .weight_type(value.weight_type)
786            .high_noise_diffusion_model(value.high_noise_diffusion_model.clone())
787            .vae_tiling(value.vae_tiling)
788            .vae_tile_size(value.vae_tile_size)
789            .vae_relative_tile_size(value.vae_relative_tile_size)
790            .vae_tile_overlap(value.vae_tile_overlap)
791            .rng(value.rng)
792            .sampler_rng_type(value.rng)
793            .scheduler(value.scheduler)
794            .sigmas(value.sigmas.clone())
795            .prediction(value.prediction)
796            .vae_on_cpu(value.vae_on_cpu)
797            .clip_on_cpu(value.clip_on_cpu)
798            .control_net(value.control_net.clone())
799            .control_net_cpu(value.control_net_cpu)
800            .flash_attention(value.flash_attention)
801            .chroma_disable_dit_mask(value.chroma_disable_dit_mask)
802            .chroma_enable_t5_mask(value.chroma_enable_t5_mask)
803            .chroma_t5_mask_pad(value.chroma_t5_mask_pad)
804            .diffusion_conv_direct(value.diffusion_conv_direct)
805            .vae_conv_direct(value.vae_conv_direct)
806            .force_sdxl_vae_conv_scale(value.force_sdxl_vae_conv_scale)
807            .flow_shift(value.flow_shift)
808            .timestep_shift(value.timestep_shift)
809            .taesd_preview_only(value.taesd_preview_only)
810            .lora_apply_mode(value.lora_apply_mode)
811            .circular(value.circular)
812            .circular_x(value.circular_x)
813            .circular_y(value.circular_y)
814            .use_qwen_image_zero_cond_true(value.use_qwen_image_zero_cond_true);
815
816        let lora_model_dir = Into::<PathBuf>::into(&value.lora_models.lora_model_dir);
817        let lora_specs = value
818            .lora_models
819            .data
820            .iter()
821            .map(|(_, name, multiplier)| LoraSpec {
822                file_name: name.clone(),
823                is_high_noise: false,
824                multiplier: *multiplier,
825            })
826            .collect();
827
828        builder.lora_models(&lora_model_dir, lora_specs);
829
830        if let Some(model) = &value.upscale_model {
831            builder.upscale_model(model.clone());
832        }
833        builder
834    }
835}
836
837#[derive(Builder, Debug, Clone)]
838#[builder(setter(into, strip_option), build_fn(validate = "Self::validate"))]
839/// Config struct common to all diffusion methods
840pub struct Config {
841    /// Path to PHOTOMAKER input id images dir
842    #[builder(default = "Default::default()")]
843    pm_id_images_dir: CLibPath,
844
845    /// Path to the input image, required by img2img
846    #[builder(default = "Default::default()")]
847    init_img: CLibPath,
848
849    /// Path to image condition, control net
850    #[builder(default = "Default::default()")]
851    control_image: CLibPath,
852
853    /// Path to write result image to (default: ./output.png)
854    #[builder(default = "PathBuf::from(\"./output.png\")")]
855    output: PathBuf,
856
857    /// Path to write result image to (default: ./output.png)
858    #[builder(default = "PathBuf::from(\"./preview_output.png\")")]
859    preview_output: PathBuf,
860
861    /// Preview method
862    #[builder(default = "PreviewType::PREVIEW_NONE")]
863    preview_mode: PreviewType,
864
865    /// Enables previewing noisy inputs of the models rather than the denoised outputs
866    #[builder(default = "false")]
867    preview_noisy: bool,
868
869    /// Interval in denoising steps between consecutive updates of the image preview file (default is 1, meaning updating at every step)
870    #[builder(default = "1")]
871    preview_interval: i32,
872
873    /// The prompt to render
874    prompt: String,
875
876    /// The negative prompt (default: "")
877    #[builder(default = "\"\".into()")]
878    negative_prompt: CLibString,
879
880    /// Unconditional guidance scale (default: 7.0)
881    #[builder(default = "7.0")]
882    cfg_scale: f32,
883
884    /// Distilled guidance scale for models with guidance input (default: 3.5)
885    #[builder(default = "3.5")]
886    guidance: f32,
887
888    /// Strength for noising/unnoising (default: 0.75)
889    #[builder(default = "0.75")]
890    strength: f32,
891
892    /// Strength for keeping input identity (default: 20%)
893    #[builder(default = "20.0")]
894    pm_style_strength: f32,
895
896    /// Strength to apply Control Net (default: 0.9)
897    /// 1.0 corresponds to full destruction of information in init
898    #[builder(default = "0.9")]
899    control_strength: f32,
900
901    /// Image height, in pixel space (default: 512)
902    #[builder(default = "512")]
903    height: i32,
904
905    /// Image width, in pixel space (default: 512)
906    #[builder(default = "512")]
907    width: i32,
908
909    /// Sampling-method (default: [SampleMethod::SAMPLE_METHOD_COUNT]).
910    /// [SampleMethod::EULER_SAMPLE_METHOD] will be used for flux, sd3, wan, qwen_image.
911    /// Otherwise [SampleMethod::EULER_A_SAMPLE_METHOD] is used.
912    #[builder(default = "SampleMethod::SAMPLE_METHOD_COUNT")]
913    sampling_method: SampleMethod,
914
915    /// eta in DDIM, only for DDIM and TCD: (default: 0)
916    #[builder(default = "0.")]
917    eta: f32,
918
919    /// Number of sample steps (default: 20)
920    #[builder(default = "20")]
921    steps: i32,
922
923    /// RNG seed (default: 42, use random seed for < 0)
924    #[builder(default = "42")]
925    seed: i64,
926
927    /// Number of images to generate (default: 1)
928    #[builder(default = "1")]
929    batch_count: i32,
930
931    /// Ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)
932    /// <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x
933    #[builder(default = "ClipSkip::Unspecified")]
934    clip_skip: ClipSkip,
935
936    /// Apply canny preprocessor (edge detection) (default: false)
937    #[builder(default = "false")]
938    canny: bool,
939
940    /// skip layer guidance (SLG) scale, only for DiT models: (default: 0)
941    /// 0 means disabled, a value of 2.5 is nice for sd3.5 medium
942    #[builder(default = "0.")]
943    slg_scale: f32,
944
945    /// Layers to skip for SLG steps: (default: \[7,8,9\])
946    #[builder(default = "vec![7, 8, 9]")]
947    skip_layer: Vec<i32>,
948
949    /// SLG enabling point: (default: 0.01)
950    #[builder(default = "0.01")]
951    skip_layer_start: f32,
952
953    /// SLG disabling point: (default: 0.2)
954    #[builder(default = "0.2")]
955    skip_layer_end: f32,
956
957    /// Disable auto resize of ref images
958    #[builder(default = "false")]
959    disable_auto_resize_ref_image: bool,
960
961    #[builder(default = "Self::cache_init()", private)]
962    cache: sd_cache_params_t,
963
964    #[builder(default = "CLibString::default()", private)]
965    scm_mask: CLibString,
966}
967
968impl ConfigBuilder {
969    fn validate(&self) -> Result<(), ConfigBuilderError> {
970        self.validate_output_dir()
971    }
972
973    fn validate_output_dir(&self) -> Result<(), ConfigBuilderError> {
974        let is_dir = self.output.as_ref().is_some_and(|val| val.is_dir());
975        let multiple_items = self.batch_count.as_ref().is_some_and(|val| *val > 1);
976        if is_dir == multiple_items {
977            Ok(())
978        } else {
979            Err(ConfigBuilderError::ValidationError(
980                "When batch_count > 1, output should point to folder and vice versa".to_owned(),
981            ))
982        }
983    }
984
985    fn cache_init() -> sd_cache_params_t {
986        sd_cache_params_t {
987            mode: sd_cache_mode_t::SD_CACHE_DISABLED,
988            reuse_threshold: 1.0,
989            start_percent: 0.15,
990            end_percent: 0.95,
991            error_decay_rate: 1.0,
992            use_relative_threshold: true,
993            reset_error_on_compute: true,
994            Fn_compute_blocks: 8,
995            Bn_compute_blocks: 0,
996            residual_diff_threshold: 0.08,
997            max_warmup_steps: 8,
998            max_cached_steps: -1,
999            max_continuous_cached_steps: -1,
1000            taylorseer_n_derivatives: 1,
1001            taylorseer_skip_interval: 1,
1002            scm_mask: null(),
1003            scm_policy_dynamic: true,
1004        }
1005    }
1006
1007    pub fn no_caching(&mut self) -> &mut Self {
1008        let mut cache = Self::cache_init();
1009        cache.mode = sd_cache_mode_t::SD_CACHE_DISABLED;
1010        self.cache = Some(cache);
1011        self
1012    }
1013
1014    pub fn ucache_caching(&mut self, params: UCacheParams) -> &mut Self {
1015        let mut cache = Self::cache_init();
1016        cache.mode = sd_cache_mode_t::SD_CACHE_UCACHE;
1017        cache.reuse_threshold = params.threshold;
1018        cache.start_percent = params.start;
1019        cache.end_percent = params.end;
1020        cache.error_decay_rate = params.decay;
1021        cache.use_relative_threshold = params.relative;
1022        cache.reset_error_on_compute = params.reset;
1023        self.cache = Some(cache);
1024        self
1025    }
1026
1027    pub fn easy_cache_caching(&mut self, params: EasyCacheParams) -> &mut Self {
1028        let mut cache = Self::cache_init();
1029        cache.mode = sd_cache_mode_t::SD_CACHE_EASYCACHE;
1030        cache.reuse_threshold = params.threshold;
1031        cache.start_percent = params.start;
1032        cache.end_percent = params.end;
1033        self.cache = Some(cache);
1034        self
1035    }
1036
1037    pub fn db_cache_caching(&mut self, params: DbCacheParams) -> &mut Self {
1038        let mut cache = Self::cache_init();
1039        cache.mode = sd_cache_mode_t::SD_CACHE_DBCACHE;
1040        cache.Fn_compute_blocks = params.fn_blocks;
1041        cache.Bn_compute_blocks = params.bn_blocks;
1042        cache.residual_diff_threshold = params.threshold;
1043        cache.max_warmup_steps = params.warmup;
1044        cache.scm_policy_dynamic = match params.scm_policy_dynamic {
1045            ScmPolicy::Static => false,
1046            ScmPolicy::Dynamic => true,
1047        };
1048        self.scm_mask = Some(CLibString::from(
1049            params
1050                .scm_mask
1051                .to_vec_string(self.steps.unwrap_or_default()),
1052        ));
1053        cache.scm_mask = self.scm_mask.as_ref().unwrap().as_ptr();
1054
1055        self.cache = Some(cache);
1056        self
1057    }
1058
1059    pub fn taylor_seer_caching(&mut self) -> &mut Self {
1060        let mut cache = Self::cache_init();
1061        cache.mode = sd_cache_mode_t::SD_CACHE_TAYLORSEER;
1062        self.cache = Some(cache);
1063        self
1064    }
1065
1066    pub fn cache_dit_caching(&mut self, params: DbCacheParams) -> &mut Self {
1067        self.db_cache_caching(params).cache.unwrap().mode = sd_cache_mode_t::SD_CACHE_CACHE_DIT;
1068        self
1069    }
1070}
1071
1072impl From<Config> for ConfigBuilder {
1073    fn from(value: Config) -> Self {
1074        let mut builder = ConfigBuilder::default();
1075        let mut cache = value.cache;
1076        let scm_mask = value.scm_mask.clone();
1077        cache.scm_mask = scm_mask.as_ptr();
1078        builder
1079            .pm_id_images_dir(value.pm_id_images_dir)
1080            .init_img(value.init_img)
1081            .control_image(value.control_image)
1082            .output(value.output)
1083            .prompt(value.prompt)
1084            .negative_prompt(value.negative_prompt)
1085            .cfg_scale(value.cfg_scale)
1086            .strength(value.strength)
1087            .pm_style_strength(value.pm_style_strength)
1088            .control_strength(value.control_strength)
1089            .height(value.height)
1090            .width(value.width)
1091            .sampling_method(value.sampling_method)
1092            .steps(value.steps)
1093            .seed(value.seed)
1094            .batch_count(value.batch_count)
1095            .clip_skip(value.clip_skip)
1096            .slg_scale(value.slg_scale)
1097            .skip_layer(value.skip_layer)
1098            .skip_layer_start(value.skip_layer_start)
1099            .skip_layer_end(value.skip_layer_end)
1100            .canny(value.canny)
1101            .disable_auto_resize_ref_image(value.disable_auto_resize_ref_image)
1102            .preview_output(value.preview_output)
1103            .preview_mode(value.preview_mode)
1104            .preview_noisy(value.preview_noisy)
1105            .preview_interval(value.preview_interval)
1106            .cache(cache)
1107            .scm_mask(scm_mask);
1108        builder
1109    }
1110}
1111
1112#[derive(Debug, Clone, Default)]
1113struct CLibString(CString);
1114
1115impl CLibString {
1116    fn as_ptr(&self) -> *const c_char {
1117        self.0.as_ptr()
1118    }
1119}
1120
1121impl From<&str> for CLibString {
1122    fn from(value: &str) -> Self {
1123        Self(CString::new(value).unwrap())
1124    }
1125}
1126
1127impl From<String> for CLibString {
1128    fn from(value: String) -> Self {
1129        Self(CString::new(value).unwrap())
1130    }
1131}
1132
1133#[derive(Debug, Clone, Default)]
1134struct CLibPath(CString);
1135
1136impl CLibPath {
1137    fn as_ptr(&self) -> *const c_char {
1138        self.0.as_ptr()
1139    }
1140}
1141
1142impl From<PathBuf> for CLibPath {
1143    fn from(value: PathBuf) -> Self {
1144        Self(CString::new(value.to_str().unwrap_or_default()).unwrap())
1145    }
1146}
1147
1148impl From<&Path> for CLibPath {
1149    fn from(value: &Path) -> Self {
1150        Self(CString::new(value.to_str().unwrap_or_default()).unwrap())
1151    }
1152}
1153
1154impl From<&CLibPath> for PathBuf {
1155    fn from(value: &CLibPath) -> Self {
1156        PathBuf::from(value.0.to_str().unwrap())
1157    }
1158}
1159
1160fn output_files(path: &Path, batch_size: i32) -> Vec<PathBuf> {
1161    let date = Local::now().format("%Y.%m.%d-%H.%M.%S");
1162    if batch_size == 1 {
1163        vec![path.into()]
1164    } else {
1165        (1..=batch_size)
1166            .map(|id| path.join(format!("output_{date}_{id}.png")))
1167            .collect()
1168    }
1169}
1170
1171unsafe fn upscale(
1172    upscale_repeats: i32,
1173    upscaler_ctx: Option<*mut upscaler_ctx_t>,
1174    data: sd_image_t,
1175) -> Result<sd_image_t, DiffusionError> {
1176    unsafe {
1177        match upscaler_ctx {
1178            Some(upscaler_ctx) => {
1179                let upscale_factor = 4; // unused for RealESRGAN_x4plus_anime_6B.pth
1180                let mut current_image = data;
1181                for _ in 0..upscale_repeats {
1182                    let upscaled_image =
1183                        diffusion_rs_sys::upscale(upscaler_ctx, current_image, upscale_factor);
1184
1185                    if upscaled_image.data.is_null() {
1186                        return Err(DiffusionError::Upscaler);
1187                    }
1188
1189                    free(current_image.data as *mut c_void);
1190                    current_image = upscaled_image;
1191                }
1192                Ok(current_image)
1193            }
1194            None => Ok(data),
1195        }
1196    }
1197}
1198
1199/// Generate an image and receive update via queue
1200pub fn gen_img_with_progress(
1201    config: &Config,
1202    model_config: &mut ModelConfig,
1203    sender: Sender<Progress>,
1204) -> Result<(), DiffusionError> {
1205    gen_img_maybe_progress(config, model_config, Some(sender))
1206}
1207
1208/// Generate an image
1209pub fn gen_img(config: &Config, model_config: &mut ModelConfig) -> Result<(), DiffusionError> {
1210    gen_img_maybe_progress(config, model_config, None)
1211}
1212
1213fn gen_img_maybe_progress(
1214    config: &Config,
1215    model_config: &mut ModelConfig,
1216    mut sender: Option<Sender<Progress>>,
1217) -> Result<(), DiffusionError> {
1218    let prompt: CLibString = CLibString::from(config.prompt.as_str());
1219    let files = output_files(&config.output, config.batch_count);
1220    unsafe {
1221        let sd_ctx = model_config.diffusion_ctx(true);
1222        let upscaler_ctx = model_config.upscaler_ctx();
1223        let init_image = sd_image_t {
1224            width: 0,
1225            height: 0,
1226            channel: 3,
1227            data: null_mut(),
1228        };
1229        let mask_image = sd_image_t {
1230            width: config.width as u32,
1231            height: config.height as u32,
1232            channel: 1,
1233            data: null_mut(),
1234        };
1235        let mut layers = config.skip_layer.clone();
1236        let guidance = sd_guidance_params_t {
1237            txt_cfg: config.cfg_scale,
1238            img_cfg: config.cfg_scale,
1239            distilled_guidance: config.guidance,
1240            slg: sd_slg_params_t {
1241                layers: layers.as_mut_ptr(),
1242                layer_count: config.skip_layer.len(),
1243                layer_start: config.skip_layer_start,
1244                layer_end: config.skip_layer_end,
1245                scale: config.slg_scale,
1246            },
1247        };
1248        let scheduler = if model_config.scheduler == Scheduler::SCHEDULER_COUNT {
1249            sd_get_default_scheduler(sd_ctx, config.sampling_method)
1250        } else {
1251            model_config.scheduler
1252        };
1253        let sample_method = if config.sampling_method == SampleMethod::SAMPLE_METHOD_COUNT {
1254            sd_get_default_sample_method(sd_ctx)
1255        } else {
1256            config.sampling_method
1257        };
1258        let sample_params = sd_sample_params_t {
1259            guidance,
1260            sample_method,
1261            sample_steps: config.steps,
1262            eta: config.eta,
1263            scheduler,
1264            shifted_timestep: model_config.timestep_shift,
1265            custom_sigmas: model_config.sigmas.as_mut_ptr(),
1266            custom_sigmas_count: model_config.sigmas.len() as i32,
1267        };
1268        let control_image = sd_image_t {
1269            width: 0,
1270            height: 0,
1271            channel: 3,
1272            data: null_mut(),
1273        };
1274        let vae_tiling_params = sd_tiling_params_t {
1275            enabled: model_config.vae_tiling,
1276            tile_size_x: model_config.vae_tile_size.0,
1277            tile_size_y: model_config.vae_tile_size.1,
1278            target_overlap: model_config.vae_tile_overlap,
1279            rel_size_x: model_config.vae_relative_tile_size.0,
1280            rel_size_y: model_config.vae_relative_tile_size.1,
1281        };
1282        let pm_params = sd_pm_params_t {
1283            id_images: null_mut(),
1284            id_images_count: 0,
1285            id_embed_path: model_config.pm_id_embed_path.as_ptr(),
1286            style_strength: config.pm_style_strength,
1287        };
1288
1289        unsafe extern "C" fn save_preview_local(
1290            _step: ::std::os::raw::c_int,
1291            _frame_count: ::std::os::raw::c_int,
1292            frames: *mut sd_image_t,
1293            _is_noisy: bool,
1294            data: *mut ::std::os::raw::c_void,
1295        ) {
1296            unsafe {
1297                let path = &*data.cast::<PathBuf>();
1298                let _ = save_img(*frames, path, None);
1299            }
1300        }
1301
1302        if config.preview_mode != PreviewType::PREVIEW_NONE {
1303            let data = &config.preview_output as *const PathBuf;
1304
1305            sd_set_preview_callback(
1306                Some(save_preview_local),
1307                config.preview_mode,
1308                config.preview_interval,
1309                !config.preview_noisy,
1310                config.preview_noisy,
1311                data as *mut c_void,
1312            );
1313        }
1314
1315        if sender.is_some() {
1316            unsafe extern "C" fn progress_callback(
1317                step: ::std::os::raw::c_int,
1318                steps: ::std::os::raw::c_int,
1319                time: f32,
1320                data: *mut ::std::os::raw::c_void,
1321            ) {
1322                unsafe {
1323                    let sender = &*data.cast::<Option<Sender<Progress>>>();
1324
1325                    if let Some(sender) = sender {
1326                        let _ = sender.send(Progress { step, steps, time });
1327                    }
1328                }
1329            }
1330            let sender_ptr: *mut c_void = &mut sender as *mut _ as *mut c_void;
1331            sd_set_progress_callback(Some(progress_callback), sender_ptr);
1332        }
1333
1334        let sd_img_gen_params = sd_img_gen_params_t {
1335            prompt: prompt.as_ptr(),
1336            negative_prompt: config.negative_prompt.as_ptr(),
1337            clip_skip: config.clip_skip as i32,
1338            init_image,
1339            ref_images: null_mut(),
1340            ref_images_count: 0,
1341            increase_ref_index: false,
1342            mask_image,
1343            width: config.width,
1344            height: config.height,
1345            sample_params,
1346            strength: config.strength,
1347            seed: config.seed,
1348            batch_count: config.batch_count,
1349            control_image,
1350            control_strength: config.control_strength,
1351            pm_params,
1352            vae_tiling_params,
1353            auto_resize_ref_image: config.disable_auto_resize_ref_image,
1354            cache: config.cache,
1355            loras: model_config.lora_models.loras_t.as_ptr(),
1356            lora_count: model_config.lora_models.loras_t.len() as u32,
1357        };
1358
1359        let params_str = CString::from_raw(sd_img_gen_params_to_str(&sd_img_gen_params))
1360            .into_string()
1361            .unwrap();
1362
1363        let slice = generate_image(sd_ctx, &sd_img_gen_params);
1364        let ret = {
1365            if slice.is_null() {
1366                return Err(DiffusionError::Forward);
1367            }
1368            for (img, path) in slice::from_raw_parts(slice, config.batch_count as usize)
1369                .iter()
1370                .zip(files)
1371            {
1372                match upscale(model_config.upscale_repeats, upscaler_ctx, *img) {
1373                    Ok(img) => save_img(img, &path, Some(&params_str))?,
1374                    Err(err) => {
1375                        return Err(err);
1376                    }
1377                }
1378            }
1379            Ok(())
1380        };
1381        free(slice as *mut c_void);
1382        ret
1383    }
1384}
1385
1386fn save_img(img: sd_image_t, path: &Path, params: Option<&str>) -> Result<(), DiffusionError> {
1387    // Thx @wandbrandon
1388    let len = (img.width * img.height * img.channel) as usize;
1389    let buffer = unsafe { slice::from_raw_parts(img.data, len).to_vec() };
1390    let save_state = ImageBuffer::from_raw(img.width, img.height, buffer).map(|img| {
1391        RgbImage::from(img)
1392            .save(path)
1393            .map_err(DiffusionError::StoreImages)
1394    });
1395    if let Some(Err(err)) = save_state {
1396        return Err(err);
1397    }
1398    if let Some(params) = params {
1399        let mut metadata = Metadata::new();
1400        metadata.set_tag(ExifTag::ImageDescription(params.to_string()));
1401        metadata.write_to_file(path)?;
1402    }
1403    Ok(())
1404}
1405
1406#[cfg(test)]
1407mod tests {
1408    use std::path::PathBuf;
1409
1410    use crate::{
1411        api::{ConfigBuilderError, ModelConfigBuilder},
1412        util::download_file_hf_hub,
1413    };
1414
1415    use super::{ConfigBuilder, gen_img};
1416
1417    #[test]
1418    fn test_required_args_txt2img() {
1419        assert!(ConfigBuilder::default().build().is_err());
1420        assert!(ModelConfigBuilder::default().build().is_err());
1421        ModelConfigBuilder::default()
1422            .model(PathBuf::from("./test.ckpt"))
1423            .build()
1424            .unwrap();
1425
1426        ConfigBuilder::default()
1427            .prompt("a lovely cat driving a sport car")
1428            .build()
1429            .unwrap();
1430
1431        assert!(matches!(
1432            ConfigBuilder::default()
1433                .prompt("a lovely cat driving a sport car")
1434                .batch_count(10)
1435                .build(),
1436            Err(ConfigBuilderError::ValidationError(_))
1437        ));
1438
1439        ConfigBuilder::default()
1440            .prompt("a lovely cat driving a sport car")
1441            .build()
1442            .unwrap();
1443
1444        ConfigBuilder::default()
1445            .prompt("a lovely duck drinking water from a bottle")
1446            .batch_count(2)
1447            .output(PathBuf::from("./"))
1448            .build()
1449            .unwrap();
1450    }
1451
1452    #[ignore]
1453    #[test]
1454    fn test_img_gen() {
1455        let model_path =
1456            download_file_hf_hub("CompVis/stable-diffusion-v-1-4-original", "sd-v1-4.ckpt")
1457                .unwrap();
1458
1459        let upscaler_path = download_file_hf_hub(
1460            "ximso/RealESRGAN_x4plus_anime_6B",
1461            "RealESRGAN_x4plus_anime_6B.pth",
1462        )
1463        .unwrap();
1464        let config = ConfigBuilder::default()
1465            .prompt("a lovely duck drinking water from a bottle")
1466            .output(PathBuf::from("./output_1.png"))
1467            .batch_count(1)
1468            .build()
1469            .unwrap();
1470        let mut model_config = ModelConfigBuilder::default()
1471            .model(model_path)
1472            .upscale_model(upscaler_path)
1473            .upscale_repeats(1)
1474            .build()
1475            .unwrap();
1476
1477        gen_img(&config, &mut model_config).unwrap();
1478        let config2 = ConfigBuilder::from(config.clone())
1479            .prompt("a lovely duck drinking water from a straw")
1480            .output(PathBuf::from("./output_2.png"))
1481            .build()
1482            .unwrap();
1483        gen_img(&config2, &mut model_config).unwrap();
1484
1485        let config3 = ConfigBuilder::from(config)
1486            .prompt("a lovely dog drinking water from a starbucks cup")
1487            .batch_count(2)
1488            .output(PathBuf::from("./"))
1489            .build()
1490            .unwrap();
1491
1492        gen_img(&config3, &mut model_config).unwrap();
1493    }
1494}