Skip to main content

diffusion_rs/
api.rs

1use std::cmp::max;
2use std::collections::HashMap;
3use std::ffi::CString;
4use std::ffi::c_char;
5use std::ffi::c_void;
6use std::fmt::Display;
7use std::path::Path;
8use std::path::PathBuf;
9use std::ptr::null;
10use std::ptr::null_mut;
11use std::slice;
12use std::sync::mpsc::Sender;
13
14use chrono::Local;
15use derive_builder::Builder;
16use diffusion_rs_sys::free_upscaler_ctx;
17use diffusion_rs_sys::generate_image;
18use diffusion_rs_sys::new_upscaler_ctx;
19use diffusion_rs_sys::sd_cache_mode_t;
20use diffusion_rs_sys::sd_cache_params_t;
21use diffusion_rs_sys::sd_ctx_params_t;
22use diffusion_rs_sys::sd_embedding_t;
23use diffusion_rs_sys::sd_get_default_sample_method;
24use diffusion_rs_sys::sd_get_default_scheduler;
25use diffusion_rs_sys::sd_guidance_params_t;
26use diffusion_rs_sys::sd_image_t;
27use diffusion_rs_sys::sd_img_gen_params_t;
28use diffusion_rs_sys::sd_img_gen_params_to_str;
29use diffusion_rs_sys::sd_lora_t;
30use diffusion_rs_sys::sd_pm_params_t;
31use diffusion_rs_sys::sd_sample_params_t;
32use diffusion_rs_sys::sd_set_preview_callback;
33use diffusion_rs_sys::sd_set_progress_callback;
34use diffusion_rs_sys::sd_slg_params_t;
35use diffusion_rs_sys::sd_tiling_params_t;
36use diffusion_rs_sys::upscaler_ctx_t;
37use image::ImageBuffer;
38use image::ImageError;
39use image::RgbImage;
40use libc::free;
41use little_exif::exif_tag::ExifTag;
42use little_exif::metadata::Metadata;
43use thiserror::Error;
44use walkdir::DirEntry;
45use walkdir::WalkDir;
46
47use diffusion_rs_sys::free_sd_ctx;
48use diffusion_rs_sys::new_sd_ctx;
49use diffusion_rs_sys::sd_ctx_t;
50
51/// Specify the range function
52pub use diffusion_rs_sys::rng_type_t as RngFunction;
53
54/// Sampling methods
55pub use diffusion_rs_sys::sample_method_t as SampleMethod;
56
57/// Denoiser sigma schedule
58pub use diffusion_rs_sys::scheduler_t as Scheduler;
59
60/// Prediction override
61pub use diffusion_rs_sys::prediction_t as Prediction;
62
63/// Weight type
64pub use diffusion_rs_sys::sd_type_t as WeightType;
65
66/// Preview mode
67pub use diffusion_rs_sys::preview_t as PreviewType;
68
69/// Lora mode
70pub use diffusion_rs_sys::lora_apply_mode_t as LoraModeType;
71
72static VALID_EXT: [&str; 3] = ["gguf", "safetensors", "pt"];
73
74#[allow(unused)]
75#[derive(Debug)]
76/// Progress message returned fron [gen_img_with_progress]
77pub struct Progress {
78    step: i32,
79    steps: i32,
80    time: f32,
81}
82
83#[non_exhaustive]
84#[derive(Error, Debug)]
85/// Error that can occurs while forwarding models
86pub enum DiffusionError {
87    #[error("The underling stablediffusion.cpp function returned NULL")]
88    Forward,
89    #[error(transparent)]
90    StoreImages(#[from] ImageError),
91    #[error(transparent)]
92    Io(#[from] std::io::Error),
93    #[error("The underling upscaler model returned a NULL image")]
94    Upscaler,
95}
96
97#[repr(i32)]
98#[non_exhaustive]
99#[derive(Debug, Default, Copy, Clone, Hash, PartialEq, Eq)]
100/// Ignore the lower X layers of CLIP network
101pub enum ClipSkip {
102    /// Will be [ClipSkip::None] for SD1.x, [ClipSkip::OneLayer] for SD2.x
103    #[default]
104    Unspecified = 0,
105    None = 1,
106    OneLayer = 2,
107}
108
109type EmbeddingsStorage = (PathBuf, Vec<(CLibString, CLibPath)>, Vec<sd_embedding_t>);
110type LoraStorage = Vec<(CLibPath, LoraSpec)>;
111
112/// Specify the instructions for a Lora model
113#[derive(Default, Debug, Clone)]
114pub struct LoraSpec {
115    pub file_name: String,
116    pub is_high_noise: bool,
117    pub multiplier: f32,
118}
119
120/// Parameters for UCache
121#[derive(Builder, Debug, Clone)]
122pub struct UCacheParams {
123    /// Error threshold for reuse decision
124    #[builder(default = "1.0")]
125    threshold: f32,
126
127    /// Start caching at this percent of steps
128    #[builder(default = "0.15")]
129    start: f32,
130
131    /// Stop caching at this percent of steps
132    #[builder(default = "0.95")]
133    end: f32,
134
135    /// Error decay rate (0-1)
136    #[builder(default = "1.0")]
137    decay: f32,
138
139    /// Scale threshold by output norm
140    #[builder(default = "true")]
141    relative: bool,
142
143    /// Reset error after computing
144    /// true: Resets accumulated error after each computed step. More aggressive caching, works well with most samplers.
145    /// false: Keeps error accumulated. More conservative, recommended for euler_a sampler
146    #[builder(default = "true")]
147    reset: bool,
148}
149
150/// Parameters for Easy Cache
151#[derive(Builder, Debug, Clone)]
152pub struct EasyCacheParams {
153    /// Error threshold for reuse decision
154    #[builder(default = "0.2")]
155    threshold: f32,
156
157    /// Start caching at this percent of steps
158    #[builder(default = "0.15")]
159    start: f32,
160
161    /// Stop caching at this percent of steps
162    #[builder(default = "0.95")]
163    end: f32,
164}
165
166/// Parameters for Db Cache
167#[derive(Builder, Debug, Clone)]
168pub struct DbCacheParams {
169    /// Front blocks to always compute
170    #[builder(default = "8")]
171    fn_blocks: i32,
172
173    /// Back blocks to always compute
174    #[builder(default = "0")]
175    bn_blocks: i32,
176
177    /// L1 residual difference threshold
178    #[builder(default = "0.08")]
179    threshold: f32,
180
181    /// Steps before caching starts
182    #[builder(default = "8")]
183    warmup: i32,
184
185    /// Steps Computation Mask controls which steps can be cached
186    scm_mask: ScmPreset,
187
188    /// Scm Policy
189    #[builder(default = "ScmPolicy::default()")]
190    scm_policy_dynamic: ScmPolicy,
191}
192
193/// Steps Computation Mask Policy controls when to cache steps
194#[derive(Debug, Default, Clone)]
195pub enum ScmPolicy {
196    /// Always cache on cacheable steps
197    Static,
198    #[default]
199    /// Check threshold before caching
200    Dynamic,
201}
202
203/// Steps Computation Mask Preset controls which steps can be cached
204#[derive(Debug, Default, Clone)]
205pub enum ScmPreset {
206    Slow,
207    #[default]
208    Medium,
209    Fast,
210    Ultra,
211    /// E.g.: "1,1,1,1,0,0,1,0,0,0,1,0,0,0,1,0,0,0,1,1"
212    /// where 1 means compute, 0 means cache
213    Custom(String),
214}
215
216impl ScmPreset {
217    fn to_vec_string(&self, steps: i32) -> String {
218        match self {
219            ScmPreset::Slow => ScmPresetBins {
220                compute_bins: vec![8, 3, 3, 2, 1, 1],
221                cache_bins: vec![1, 2, 2, 2, 3],
222                steps,
223            }
224            .to_string(),
225            ScmPreset::Medium => ScmPresetBins {
226                compute_bins: vec![6, 2, 2, 2, 2, 1],
227                cache_bins: vec![1, 3, 3, 3, 3],
228                steps,
229            }
230            .to_string(),
231            ScmPreset::Fast => ScmPresetBins {
232                compute_bins: vec![6, 1, 1, 1, 1, 1],
233                cache_bins: vec![1, 3, 4, 5, 4],
234                steps,
235            }
236            .to_string(),
237            ScmPreset::Ultra => ScmPresetBins {
238                compute_bins: vec![4, 1, 1, 1, 1],
239                cache_bins: vec![2, 5, 6, 7],
240                steps,
241            }
242            .to_string(),
243            ScmPreset::Custom(s) => s.clone(),
244        }
245    }
246}
247
248#[derive(Debug, Clone)]
249struct ScmPresetBins {
250    compute_bins: Vec<i32>,
251    cache_bins: Vec<i32>,
252    steps: i32,
253}
254
255impl ScmPresetBins {
256    fn maybe_scale(&self) -> ScmPresetBins {
257        if self.steps == 28 || self.steps <= 0 {
258            return self.clone();
259        }
260        self.scale()
261    }
262
263    fn scale(&self) -> ScmPresetBins {
264        let scale = self.steps as f32 / 28.0;
265        let scaled_compute_bins = self
266            .compute_bins
267            .iter()
268            .map(|b| max(1, (*b as f32 * scale * 0.5) as i32))
269            .collect();
270        let scaled_cached_bins = self
271            .cache_bins
272            .iter()
273            .map(|b| max(1, (*b as f32 * scale * 0.5) as i32))
274            .collect();
275        ScmPresetBins {
276            compute_bins: scaled_compute_bins,
277            cache_bins: scaled_cached_bins,
278            steps: self.steps,
279        }
280    }
281
282    fn generate_vec_mask(&self) -> Vec<i32> {
283        let mut mask = Vec::new();
284        let mut c_idx = 0;
285        let mut cache_idx = 0;
286
287        while mask.len() < self.steps as usize {
288            if c_idx < self.compute_bins.len() {
289                let compute_count = self.compute_bins[c_idx];
290                for _ in 0..compute_count {
291                    if mask.len() < self.steps as usize {
292                        mask.push(1);
293                    }
294                }
295                c_idx += 1;
296            }
297            if cache_idx < self.cache_bins.len() {
298                let cache_count = self.cache_bins[c_idx];
299                for _ in 0..cache_count {
300                    if mask.len() < self.steps as usize {
301                        mask.push(0);
302                    }
303                }
304                cache_idx += 1;
305            }
306            if c_idx >= self.compute_bins.len() && cache_idx >= self.cache_bins.len() {
307                break;
308            }
309        }
310        if let Some(last) = mask.last_mut() {
311            *last = 1;
312        }
313        mask
314    }
315}
316
317impl Display for ScmPresetBins {
318    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
319        let mask: String = self
320            .maybe_scale()
321            .generate_vec_mask()
322            .iter()
323            .map(|x| x.to_string())
324            .collect::<Vec<_>>()
325            .join(",");
326        write!(f, "{mask}")
327    }
328}
329
330/// Config struct for a specific diffusion model
331#[derive(Builder, Debug, Clone)]
332#[builder(
333    setter(into, strip_option),
334    build_fn(error = "ConfigBuilderError", validate = "Self::validate")
335)]
336pub struct ModelConfig {
337    /// Number of threads to use during computation (default: 0).
338    /// If n_ threads <= 0, then threads will be set to the number of CPU physical cores.
339    #[builder(default = "num_cpus::get_physical() as i32", setter(custom))]
340    n_threads: i32,
341
342    /// Whether to memory-map model (default: false)
343    #[builder(default = "false")]
344    enable_mmap: bool,
345
346    /// Place the weights in RAM to save VRAM, and automatically load them into VRAM when needed
347    #[builder(default = "false")]
348    offload_params_to_cpu: bool,
349
350    /// Path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now
351    #[builder(default = "Default::default()")]
352    upscale_model: Option<CLibPath>,
353
354    /// Run the ESRGAN upscaler this many times (default 1)
355    #[builder(default = "1")]
356    upscale_repeats: i32,
357
358    /// Tile size for ESRGAN upscaler (default 128)
359    #[builder(default = "128")]
360    upscale_tile_size: i32,
361
362    /// Path to full model
363    #[builder(default = "Default::default()")]
364    model: CLibPath,
365
366    /// Path to the standalone diffusion model
367    #[builder(default = "Default::default()")]
368    diffusion_model: CLibPath,
369
370    /// Path to the qwen2vl text encoder
371    #[builder(default = "Default::default()")]
372    llm: CLibPath,
373
374    /// Path to the qwen2vl vit
375    #[builder(default = "Default::default()")]
376    llm_vision: CLibPath,
377
378    /// Path to the clip-l text encoder
379    #[builder(default = "Default::default()")]
380    clip_l: CLibPath,
381
382    /// Path to the clip-g text encoder
383    #[builder(default = "Default::default()")]
384    clip_g: CLibPath,
385
386    /// Path to the clip-vision encoder
387    #[builder(default = "Default::default()")]
388    clip_vision: CLibPath,
389
390    /// Path to the t5xxl text encoder
391    #[builder(default = "Default::default()")]
392    t5xxl: CLibPath,
393
394    /// Path to vae
395    #[builder(default = "Default::default()")]
396    vae: CLibPath,
397
398    /// Path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)
399    #[builder(default = "Default::default()")]
400    taesd: CLibPath,
401
402    /// Path to control net model
403    #[builder(default = "Default::default()")]
404    control_net: CLibPath,
405
406    /// Path to embeddings
407    #[builder(default = "Default::default()", setter(custom))]
408    embeddings: EmbeddingsStorage,
409
410    /// Path to PHOTOMAKER model
411    #[builder(default = "Default::default()")]
412    photo_maker: CLibPath,
413
414    /// Path to PHOTOMAKER v2 id embed
415    #[builder(default = "Default::default()")]
416    pm_id_embed_path: CLibPath,
417
418    /// Weight type. If not specified, the default is the type of the weight file
419    #[builder(default = "WeightType::SD_TYPE_COUNT")]
420    weight_type: WeightType,
421
422    /// Lora model directory
423    #[builder(default = "Default::default()", setter(custom))]
424    lora_models: LoraStorage,
425
426    /// Path to the standalone high noise diffusion model
427    #[builder(default = "Default::default()")]
428    high_noise_diffusion_model: CLibPath,
429
430    /// Process vae in tiles to reduce memory usage (default: false)
431    #[builder(default = "false")]
432    vae_tiling: bool,
433
434    /// Tile size for vae tiling (default: 32x32)
435    #[builder(default = "(32,32)")]
436    vae_tile_size: (i32, i32),
437
438    /// Relative tile size for vae tiling, in fraction of image size if < 1, in number of tiles per dim if >=1 (overrides vae_tile_size)
439    #[builder(default = "(0.,0.)")]
440    vae_relative_tile_size: (f32, f32),
441
442    /// Tile overlap for vae tiling, in fraction of tile size (default: 0.5)
443    #[builder(default = "0.5")]
444    vae_tile_overlap: f32,
445
446    /// RNG (default: CUDA)
447    #[builder(default = "RngFunction::CUDA_RNG")]
448    rng: RngFunction,
449
450    /// Sampler RNG. If [RngFunction::RNG_TYPE_COUNT] is used will default to rng value. (default: [RngFunction::RNG_TYPE_COUNT])",
451    #[builder(default = "RngFunction::RNG_TYPE_COUNT")]
452    sampler_rng_type: RngFunction,
453
454    /// Denoiser sigma schedule (default: [Scheduler::SCHEDULER_COUNT]).
455    /// Will default to [Scheduler::EXPONENTIAL_SCHEDULER] if a denoiser is already instantiated.
456    /// Otherwise, [Scheduler::DISCRETE_SCHEDULER] is used.
457    #[builder(default = "Scheduler::SCHEDULER_COUNT")]
458    scheduler: Scheduler,
459
460    /// Custom sigma values for the sampler
461    #[builder(default = "Default::default()")]
462    sigmas: Vec<f32>,
463
464    /// Prediction type override (default: PREDICTION_COUNT)
465    #[builder(default = "Prediction::PREDICTION_COUNT")]
466    prediction: Prediction,
467
468    /// Keep vae in cpu (for low vram) (default: false)
469    #[builder(default = "false")]
470    vae_on_cpu: bool,
471
472    /// keep clip in cpu (for low vram) (default: false)
473    #[builder(default = "false")]
474    clip_on_cpu: bool,
475
476    /// Keep controlnet in cpu (for low vram) (default: false)
477    #[builder(default = "false")]
478    control_net_cpu: bool,
479
480    /// Use flash attention to reduce memory usage (model only).
481    /// For most backends, it slows things down, but for cuda it generally speeds it up too. At the moment, it is only supported for some models and some backends (like cpu, cuda/rocm, metal).
482    #[builder(default = "false")]
483    diffusion_flash_attention: bool,
484
485    /// Use flash attention to reduce memory usage.
486    /// For most backends, it slows things down, but for cuda it generally speeds it up too. At the moment, it is only supported for some models and some backends (like cpu, cuda/rocm, metal).
487    #[builder(default = "false")]
488    flash_attention: bool,
489
490    /// Disable dit mask for chroma
491    #[builder(default = "false")]
492    chroma_disable_dit_mask: bool,
493
494    /// Enable t5 mask for chroma
495    #[builder(default = "false")]
496    chroma_enable_t5_mask: bool,
497
498    /// t5 mask pad size of chroma
499    #[builder(default = "1")]
500    chroma_t5_mask_pad: i32,
501
502    /// Use qwen image zero cond true optimization
503    #[builder(default = "false")]
504    use_qwen_image_zero_cond_true: bool,
505
506    /// Use Conv2d direct in the diffusion model
507    /// This might crash if it is not supported by the backend.
508    #[builder(default = "false")]
509    diffusion_conv_direct: bool,
510
511    /// Use Conv2d direct in the vae model (should improve the performance)
512    /// This might crash if it is not supported by the backend.
513    #[builder(default = "false")]
514    vae_conv_direct: bool,
515
516    /// Force use of conv scale on sdxl vae
517    #[builder(default = "false")]
518    force_sdxl_vae_conv_scale: bool,
519
520    /// Shift value for Flow models like SD3.x or WAN (default: auto)
521    #[builder(default = "f32::INFINITY")]
522    flow_shift: f32,
523
524    /// Shift timestep for NitroFusion models, default: 0, recommended N for NitroSD-Realism around 250 and 500 for NitroSD-Vibrant
525    #[builder(default = "0")]
526    timestep_shift: i32,
527
528    /// Prevents usage of taesd for decoding the final image
529    #[builder(default = "false")]
530    taesd_preview_only: bool,
531
532    /// In auto mode, if the model weights contain any quantized parameters, the at_runtime mode will be used; otherwise, immediately will be used.The immediate mode may have precision and compatibility issues with quantized parameters, but it usually offers faster inference speed and, in some cases, lower memory usage. The at_runtime mode, on the other hand, is exactly the opposite
533    #[builder(default = "LoraModeType::LORA_APPLY_AUTO")]
534    lora_apply_mode: LoraModeType,
535
536    /// Enable circular padding for convolutions
537    #[builder(default = "false")]
538    circular: bool,
539
540    /// Enable circular RoPE wrapping on x-axis (width) only
541    #[builder(default = "false")]
542    circular_x: bool,
543
544    /// Enable circular RoPE wrapping on y-axis (height) only
545    #[builder(default = "false")]
546    circular_y: bool,
547
548    #[builder(default = "None", private)]
549    upscaler_ctx: Option<*mut upscaler_ctx_t>,
550
551    #[builder(default = "None", private)]
552    diffusion_ctx: Option<(*mut sd_ctx_t, sd_ctx_params_t)>,
553}
554
555impl ModelConfigBuilder {
556    fn validate(&self) -> Result<(), ConfigBuilderError> {
557        self.validate_model()
558    }
559
560    fn validate_model(&self) -> Result<(), ConfigBuilderError> {
561        self.model
562            .as_ref()
563            .or(self.diffusion_model.as_ref())
564            .map(|_| ())
565            .ok_or(ConfigBuilderError::UninitializedField(
566                "Model OR DiffusionModel must be valorized",
567            ))
568    }
569
570    fn filter_valid_extensions(path: &Path) -> impl Iterator<Item = DirEntry> {
571        WalkDir::new(path)
572            .into_iter()
573            .filter_map(|entry| entry.ok())
574            .filter(|entry| {
575                entry
576                    .path()
577                    .extension()
578                    .and_then(|ext| ext.to_str())
579                    .map(|ext_str| VALID_EXT.contains(&ext_str))
580                    .unwrap_or(false)
581            })
582    }
583    fn build_single_lora_storage(
584        spec: &LoraSpec,
585        valid_loras: &HashMap<String, PathBuf>,
586    ) -> (CLibPath, LoraSpec) {
587        let path = valid_loras.get(&spec.file_name).unwrap().as_path();
588        let c_path = CLibPath::from(path);
589        (c_path, spec.clone())
590    }
591
592    pub fn embeddings(&mut self, embeddings_dir: &Path) -> &mut Self {
593        let data: Vec<(CLibString, CLibPath)> = Self::filter_valid_extensions(embeddings_dir)
594            .map(|entry| {
595                let file_stem = entry
596                    .path()
597                    .file_stem()
598                    .and_then(|stem| stem.to_str())
599                    .unwrap_or_default()
600                    .to_owned();
601                (CLibString::from(file_stem), CLibPath::from(entry.path()))
602            })
603            .collect();
604        let data_pointer = data
605            .iter()
606            .map(|(name, path)| sd_embedding_t {
607                name: name.as_ptr(),
608                path: path.as_ptr(),
609            })
610            .collect();
611        self.embeddings = Some((embeddings_dir.to_path_buf(), data, data_pointer));
612        self
613    }
614
615    pub fn lora_models(&mut self, lora_model_dir: &Path, specs: Vec<LoraSpec>) -> &mut Self {
616        let valid_loras: HashMap<String, PathBuf> = Self::filter_valid_extensions(lora_model_dir)
617            .map(|entry| {
618                let path = entry.path();
619                (
620                    path.file_stem()
621                        .and_then(|stem| stem.to_str())
622                        .unwrap_or_default()
623                        .to_owned(),
624                    path.to_path_buf(),
625                )
626            })
627            .collect();
628        let valid_lora_names: Vec<&String> = valid_loras.keys().collect();
629        let standard = specs
630            .iter()
631            .filter(|s| valid_lora_names.contains(&&s.file_name) && !s.is_high_noise)
632            .map(|s| Self::build_single_lora_storage(s, &valid_loras));
633        let high_noise = specs
634            .iter()
635            .filter(|s| valid_lora_names.contains(&&s.file_name) && s.is_high_noise)
636            .map(|s| Self::build_single_lora_storage(s, &valid_loras));
637
638        self.lora_models_internal(standard.chain(high_noise).collect())
639    }
640
641    fn lora_models_internal(&mut self, lora_storage: LoraStorage) -> &mut Self {
642        self.lora_models = Some(lora_storage);
643        self
644    }
645
646    pub fn n_threads(&mut self, value: i32) -> &mut Self {
647        self.n_threads = if value > 0 {
648            Some(value)
649        } else {
650            Some(num_cpus::get_physical() as i32)
651        };
652        self
653    }
654}
655
656impl ModelConfig {
657    unsafe fn upscaler_ctx(&mut self) -> Option<*mut upscaler_ctx_t> {
658        unsafe {
659            if self.upscale_model.is_none() || self.upscale_repeats == 0 {
660                None
661            } else {
662                if self.upscaler_ctx.is_none() {
663                    let upscaler = new_upscaler_ctx(
664                        self.upscale_model.as_ref().unwrap().as_ptr(),
665                        self.offload_params_to_cpu,
666                        self.diffusion_conv_direct,
667                        self.n_threads,
668                        self.upscale_tile_size,
669                    );
670                    self.upscaler_ctx = Some(upscaler);
671                }
672                self.upscaler_ctx
673            }
674        }
675    }
676
677    unsafe fn diffusion_ctx(&mut self, vae_decode_only: bool) -> *mut sd_ctx_t {
678        unsafe {
679            // This is required to support img2img after text2img generation
680            // otherwise the context is cached and won't have a decode graph
681            // leading to an assertion error in sdcpp
682            if let Some((sd_ctx, sd_ctx_params)) = self.diffusion_ctx.as_ref()
683                && sd_ctx_params.vae_decode_only != vae_decode_only
684            {
685                sd_set_progress_callback(None, null_mut());
686                free_sd_ctx(*sd_ctx);
687                self.diffusion_ctx = None;
688            }
689            if self.diffusion_ctx.is_none() {
690                let sd_ctx_params = sd_ctx_params_t {
691                    model_path: self.model.as_ptr(),
692                    llm_path: self.llm.as_ptr(),
693                    llm_vision_path: self.llm_vision.as_ptr(),
694                    clip_l_path: self.clip_l.as_ptr(),
695                    clip_g_path: self.clip_g.as_ptr(),
696                    clip_vision_path: self.clip_vision.as_ptr(),
697                    high_noise_diffusion_model_path: self.high_noise_diffusion_model.as_ptr(),
698                    t5xxl_path: self.t5xxl.as_ptr(),
699                    diffusion_model_path: self.diffusion_model.as_ptr(),
700                    vae_path: self.vae.as_ptr(),
701                    taesd_path: self.taesd.as_ptr(),
702                    control_net_path: self.control_net.as_ptr(),
703                    embeddings: self.embeddings.2.as_ptr(),
704                    embedding_count: self.embeddings.1.len() as u32,
705                    photo_maker_path: self.photo_maker.as_ptr(),
706                    vae_decode_only,
707                    free_params_immediately: false,
708                    n_threads: self.n_threads,
709                    wtype: self.weight_type,
710                    rng_type: self.rng,
711                    keep_clip_on_cpu: self.clip_on_cpu,
712                    keep_control_net_on_cpu: self.control_net_cpu,
713                    keep_vae_on_cpu: self.vae_on_cpu,
714                    diffusion_flash_attn: self.diffusion_flash_attention,
715                    flash_attn: self.flash_attention,
716                    diffusion_conv_direct: self.diffusion_conv_direct,
717                    chroma_use_dit_mask: !self.chroma_disable_dit_mask,
718                    chroma_use_t5_mask: self.chroma_enable_t5_mask,
719                    chroma_t5_mask_pad: self.chroma_t5_mask_pad,
720                    vae_conv_direct: self.vae_conv_direct,
721                    offload_params_to_cpu: self.offload_params_to_cpu,
722                    prediction: self.prediction,
723                    force_sdxl_vae_conv_scale: self.force_sdxl_vae_conv_scale,
724                    tae_preview_only: self.taesd_preview_only,
725                    lora_apply_mode: self.lora_apply_mode,
726                    tensor_type_rules: null_mut(),
727                    sampler_rng_type: self.sampler_rng_type,
728                    circular_x: self.circular || self.circular_x,
729                    circular_y: self.circular || self.circular_y,
730                    qwen_image_zero_cond_t: self.use_qwen_image_zero_cond_true,
731                    enable_mmap: self.enable_mmap,
732                };
733                let ctx = new_sd_ctx(&sd_ctx_params);
734                self.diffusion_ctx = Some((ctx, sd_ctx_params))
735            }
736            self.diffusion_ctx.unwrap().0
737        }
738    }
739}
740
741impl Drop for ModelConfig {
742    fn drop(&mut self) {
743        //Cleanup CTX section
744        unsafe {
745            if let Some((sd_ctx, _)) = self.diffusion_ctx {
746                free_sd_ctx(sd_ctx);
747            }
748
749            if let Some(upscaler_ctx) = self.upscaler_ctx {
750                free_upscaler_ctx(upscaler_ctx);
751            }
752        }
753    }
754}
755
756impl From<ModelConfig> for ModelConfigBuilder {
757    fn from(value: ModelConfig) -> Self {
758        let mut builder = ModelConfigBuilder::default();
759        builder
760            .n_threads(value.n_threads)
761            .offload_params_to_cpu(value.offload_params_to_cpu)
762            .upscale_repeats(value.upscale_repeats)
763            .model(value.model.clone())
764            .diffusion_model(value.diffusion_model.clone())
765            .llm(value.llm.clone())
766            .llm_vision(value.llm_vision.clone())
767            .clip_l(value.clip_l.clone())
768            .clip_g(value.clip_g.clone())
769            .clip_vision(value.clip_vision.clone())
770            .t5xxl(value.t5xxl.clone())
771            .vae(value.vae.clone())
772            .taesd(value.taesd.clone())
773            .control_net(value.control_net.clone())
774            .embeddings(&value.embeddings.0)
775            .photo_maker(value.photo_maker.clone())
776            .pm_id_embed_path(value.pm_id_embed_path.clone())
777            .weight_type(value.weight_type)
778            .high_noise_diffusion_model(value.high_noise_diffusion_model.clone())
779            .vae_tiling(value.vae_tiling)
780            .vae_tile_size(value.vae_tile_size)
781            .vae_relative_tile_size(value.vae_relative_tile_size)
782            .vae_tile_overlap(value.vae_tile_overlap)
783            .rng(value.rng)
784            .sampler_rng_type(value.rng)
785            .scheduler(value.scheduler)
786            .sigmas(value.sigmas.clone())
787            .prediction(value.prediction)
788            .vae_on_cpu(value.vae_on_cpu)
789            .clip_on_cpu(value.clip_on_cpu)
790            .control_net(value.control_net.clone())
791            .control_net_cpu(value.control_net_cpu)
792            .flash_attention(value.flash_attention)
793            .chroma_disable_dit_mask(value.chroma_disable_dit_mask)
794            .chroma_enable_t5_mask(value.chroma_enable_t5_mask)
795            .chroma_t5_mask_pad(value.chroma_t5_mask_pad)
796            .diffusion_conv_direct(value.diffusion_conv_direct)
797            .vae_conv_direct(value.vae_conv_direct)
798            .force_sdxl_vae_conv_scale(value.force_sdxl_vae_conv_scale)
799            .flow_shift(value.flow_shift)
800            .timestep_shift(value.timestep_shift)
801            .taesd_preview_only(value.taesd_preview_only)
802            .lora_apply_mode(value.lora_apply_mode)
803            .circular(value.circular)
804            .circular_x(value.circular_x)
805            .circular_y(value.circular_y)
806            .use_qwen_image_zero_cond_true(value.use_qwen_image_zero_cond_true);
807
808        builder.lora_models_internal(value.lora_models.clone());
809
810        if let Some(model) = &value.upscale_model {
811            builder.upscale_model(model.clone());
812        }
813        builder
814    }
815}
816
817#[derive(Builder, Debug, Clone)]
818#[builder(setter(into, strip_option), build_fn(validate = "Self::validate"))]
819/// Config struct common to all diffusion methods
820pub struct Config {
821    /// Path to PHOTOMAKER input id images dir
822    #[builder(default = "Default::default()")]
823    pm_id_images_dir: CLibPath,
824
825    /// Path to the input image, required by img2img
826    #[builder(default = "Default::default()")]
827    init_img: PathBuf,
828
829    /// Path to the image used as a mask for img2img
830    #[builder(default = "Default::default()")]
831    mask_img: PathBuf,
832
833    /// Path to image condition, control net
834    #[builder(default = "Default::default()")]
835    control_image: CLibPath,
836
837    /// Paths to reference images for in-context conditioning (e.g. for Flux2)
838    #[builder(default = "Default::default()")]
839    ref_images: Vec<PathBuf>,
840
841    /// Path to write result image to (default: ./output.png)
842    #[builder(default = "PathBuf::from(\"./output.png\")")]
843    output: PathBuf,
844
845    /// Path to write result image to (default: ./output.png)
846    #[builder(default = "PathBuf::from(\"./preview_output.png\")")]
847    preview_output: PathBuf,
848
849    /// Preview method
850    #[builder(default = "PreviewType::PREVIEW_NONE")]
851    preview_mode: PreviewType,
852
853    /// Enables previewing noisy inputs of the models rather than the denoised outputs
854    #[builder(default = "false")]
855    preview_noisy: bool,
856
857    /// Interval in denoising steps between consecutive updates of the image preview file (default is 1, meaning updating at every step)
858    #[builder(default = "1")]
859    preview_interval: i32,
860
861    /// The prompt to render
862    prompt: String,
863
864    /// The negative prompt (default: "")
865    #[builder(default = "\"\".into()")]
866    negative_prompt: CLibString,
867
868    /// Unconditional guidance scale (default: 7.0)
869    #[builder(default = "7.0")]
870    cfg_scale: f32,
871
872    /// Distilled guidance scale for models with guidance input (default: 3.5)
873    #[builder(default = "3.5")]
874    guidance: f32,
875
876    /// Strength for noising/unnoising (default: 0.75)
877    #[builder(default = "0.75")]
878    strength: f32,
879
880    /// Strength for keeping input identity (default: 20%)
881    #[builder(default = "20.0")]
882    pm_style_strength: f32,
883
884    /// Strength to apply Control Net (default: 0.9)
885    /// 1.0 corresponds to full destruction of information in init
886    #[builder(default = "0.9")]
887    control_strength: f32,
888
889    /// Image height, in pixel space (default: 512)
890    #[builder(default = "512")]
891    height: i32,
892
893    /// Image width, in pixel space (default: 512)
894    #[builder(default = "512")]
895    width: i32,
896
897    /// Sampling-method (default: [SampleMethod::SAMPLE_METHOD_COUNT]).
898    /// [SampleMethod::EULER_SAMPLE_METHOD] will be used for flux, sd3, wan, qwen_image.
899    /// Otherwise [SampleMethod::EULER_A_SAMPLE_METHOD] is used.
900    #[builder(default = "SampleMethod::SAMPLE_METHOD_COUNT")]
901    sampling_method: SampleMethod,
902
903    /// eta in DDIM, only for DDIM/TCD/res_multistep/res_2s (default: 0)
904    #[builder(default = "0.")]
905    eta: f32,
906
907    /// Number of sample steps (default: 20)
908    #[builder(default = "20")]
909    steps: i32,
910
911    /// RNG seed (default: 42, use random seed for < 0)
912    #[builder(default = "42")]
913    seed: i64,
914
915    /// Number of images to generate (default: 1)
916    #[builder(default = "1")]
917    batch_count: i32,
918
919    /// Ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)
920    /// <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x
921    #[builder(default = "ClipSkip::Unspecified")]
922    clip_skip: ClipSkip,
923
924    /// Apply canny preprocessor (edge detection) (default: false)
925    #[builder(default = "false")]
926    canny: bool,
927
928    /// skip layer guidance (SLG) scale, only for DiT models: (default: 0)
929    /// 0 means disabled, a value of 2.5 is nice for sd3.5 medium
930    #[builder(default = "0.")]
931    slg_scale: f32,
932
933    /// Layers to skip for SLG steps: (default: \[7,8,9\])
934    #[builder(default = "vec![7, 8, 9]")]
935    skip_layer: Vec<i32>,
936
937    /// SLG enabling point: (default: 0.01)
938    #[builder(default = "0.01")]
939    skip_layer_start: f32,
940
941    /// SLG disabling point: (default: 0.2)
942    #[builder(default = "0.2")]
943    skip_layer_end: f32,
944
945    /// Disable auto resize of ref images
946    #[builder(default = "false")]
947    disable_auto_resize_ref_image: bool,
948
949    #[builder(default = "Self::cache_init()", private)]
950    cache: sd_cache_params_t,
951
952    #[builder(default = "CLibString::default()", private)]
953    scm_mask: CLibString,
954}
955
956impl ConfigBuilder {
957    fn validate(&self) -> Result<(), ConfigBuilderError> {
958        self.validate_output_dir()
959    }
960
961    fn validate_output_dir(&self) -> Result<(), ConfigBuilderError> {
962        let is_dir = self.output.as_ref().is_some_and(|val| val.is_dir());
963        let multiple_items = self.batch_count.as_ref().is_some_and(|val| *val > 1);
964        if is_dir == multiple_items {
965            Ok(())
966        } else {
967            Err(ConfigBuilderError::ValidationError(
968                "When batch_count > 1, output should point to folder and vice versa".to_owned(),
969            ))
970        }
971    }
972
973    fn cache_init() -> sd_cache_params_t {
974        sd_cache_params_t {
975            mode: sd_cache_mode_t::SD_CACHE_DISABLED,
976            reuse_threshold: 1.0,
977            start_percent: 0.15,
978            end_percent: 0.95,
979            error_decay_rate: 1.0,
980            use_relative_threshold: true,
981            reset_error_on_compute: true,
982            Fn_compute_blocks: 8,
983            Bn_compute_blocks: 0,
984            residual_diff_threshold: 0.08,
985            max_warmup_steps: 8,
986            max_cached_steps: -1,
987            max_continuous_cached_steps: -1,
988            taylorseer_n_derivatives: 1,
989            taylorseer_skip_interval: 1,
990            scm_mask: null(),
991            scm_policy_dynamic: true,
992        }
993    }
994
995    pub fn no_caching(&mut self) -> &mut Self {
996        let mut cache = Self::cache_init();
997        cache.mode = sd_cache_mode_t::SD_CACHE_DISABLED;
998        self.cache = Some(cache);
999        self
1000    }
1001
1002    pub fn ucache_caching(&mut self, params: UCacheParams) -> &mut Self {
1003        let mut cache = Self::cache_init();
1004        cache.mode = sd_cache_mode_t::SD_CACHE_UCACHE;
1005        cache.reuse_threshold = params.threshold;
1006        cache.start_percent = params.start;
1007        cache.end_percent = params.end;
1008        cache.error_decay_rate = params.decay;
1009        cache.use_relative_threshold = params.relative;
1010        cache.reset_error_on_compute = params.reset;
1011        self.cache = Some(cache);
1012        self
1013    }
1014
1015    pub fn easy_cache_caching(&mut self, params: EasyCacheParams) -> &mut Self {
1016        let mut cache = Self::cache_init();
1017        cache.mode = sd_cache_mode_t::SD_CACHE_EASYCACHE;
1018        cache.reuse_threshold = params.threshold;
1019        cache.start_percent = params.start;
1020        cache.end_percent = params.end;
1021        self.cache = Some(cache);
1022        self
1023    }
1024
1025    pub fn db_cache_caching(&mut self, params: DbCacheParams) -> &mut Self {
1026        let mut cache = Self::cache_init();
1027        cache.mode = sd_cache_mode_t::SD_CACHE_DBCACHE;
1028        cache.Fn_compute_blocks = params.fn_blocks;
1029        cache.Bn_compute_blocks = params.bn_blocks;
1030        cache.residual_diff_threshold = params.threshold;
1031        cache.max_warmup_steps = params.warmup;
1032        cache.scm_policy_dynamic = match params.scm_policy_dynamic {
1033            ScmPolicy::Static => false,
1034            ScmPolicy::Dynamic => true,
1035        };
1036        self.scm_mask = Some(CLibString::from(
1037            params
1038                .scm_mask
1039                .to_vec_string(self.steps.unwrap_or_default()),
1040        ));
1041        cache.scm_mask = self.scm_mask.as_ref().unwrap().as_ptr();
1042
1043        self.cache = Some(cache);
1044        self
1045    }
1046
1047    pub fn taylor_seer_caching(&mut self) -> &mut Self {
1048        let mut cache = Self::cache_init();
1049        cache.mode = sd_cache_mode_t::SD_CACHE_TAYLORSEER;
1050        self.cache = Some(cache);
1051        self
1052    }
1053
1054    pub fn cache_dit_caching(&mut self, params: DbCacheParams) -> &mut Self {
1055        self.db_cache_caching(params).cache.unwrap().mode = sd_cache_mode_t::SD_CACHE_CACHE_DIT;
1056        self
1057    }
1058}
1059
1060impl From<Config> for ConfigBuilder {
1061    fn from(value: Config) -> Self {
1062        let mut builder = ConfigBuilder::default();
1063        let mut cache = value.cache;
1064        let scm_mask = value.scm_mask.clone();
1065        cache.scm_mask = scm_mask.as_ptr();
1066        builder
1067            .pm_id_images_dir(value.pm_id_images_dir)
1068            .init_img(value.init_img)
1069            .mask_img(value.mask_img)
1070            .control_image(value.control_image)
1071            .ref_images(value.ref_images)
1072            .output(value.output)
1073            .prompt(value.prompt)
1074            .negative_prompt(value.negative_prompt)
1075            .cfg_scale(value.cfg_scale)
1076            .strength(value.strength)
1077            .pm_style_strength(value.pm_style_strength)
1078            .control_strength(value.control_strength)
1079            .height(value.height)
1080            .width(value.width)
1081            .sampling_method(value.sampling_method)
1082            .steps(value.steps)
1083            .seed(value.seed)
1084            .batch_count(value.batch_count)
1085            .clip_skip(value.clip_skip)
1086            .slg_scale(value.slg_scale)
1087            .skip_layer(value.skip_layer)
1088            .skip_layer_start(value.skip_layer_start)
1089            .skip_layer_end(value.skip_layer_end)
1090            .canny(value.canny)
1091            .disable_auto_resize_ref_image(value.disable_auto_resize_ref_image)
1092            .preview_output(value.preview_output)
1093            .preview_mode(value.preview_mode)
1094            .preview_noisy(value.preview_noisy)
1095            .preview_interval(value.preview_interval)
1096            .cache(cache)
1097            .scm_mask(scm_mask);
1098        builder
1099    }
1100}
1101
1102#[derive(Debug, Clone, Default)]
1103struct CLibString(CString);
1104
1105impl CLibString {
1106    fn as_ptr(&self) -> *const c_char {
1107        self.0.as_ptr()
1108    }
1109}
1110
1111impl From<&str> for CLibString {
1112    fn from(value: &str) -> Self {
1113        Self(CString::new(value).unwrap())
1114    }
1115}
1116
1117impl From<String> for CLibString {
1118    fn from(value: String) -> Self {
1119        Self(CString::new(value).unwrap())
1120    }
1121}
1122
1123#[derive(Debug, Clone, Default)]
1124struct CLibPath(CString);
1125
1126impl CLibPath {
1127    fn as_ptr(&self) -> *const c_char {
1128        self.0.as_ptr()
1129    }
1130}
1131
1132impl From<PathBuf> for CLibPath {
1133    fn from(value: PathBuf) -> Self {
1134        Self(CString::new(value.to_str().unwrap_or_default()).unwrap())
1135    }
1136}
1137
1138impl From<&Path> for CLibPath {
1139    fn from(value: &Path) -> Self {
1140        Self(CString::new(value.to_str().unwrap_or_default()).unwrap())
1141    }
1142}
1143
1144impl From<&CLibPath> for PathBuf {
1145    fn from(value: &CLibPath) -> Self {
1146        PathBuf::from(value.0.to_str().unwrap())
1147    }
1148}
1149
1150fn output_files(path: &Path, batch_size: i32) -> Vec<PathBuf> {
1151    let date = Local::now().format("%Y.%m.%d-%H.%M.%S");
1152    if batch_size == 1 {
1153        vec![path.into()]
1154    } else {
1155        (1..=batch_size)
1156            .map(|id| path.join(format!("output_{date}_{id}.png")))
1157            .collect()
1158    }
1159}
1160
1161unsafe fn upscale(
1162    upscale_repeats: i32,
1163    upscaler_ctx: Option<*mut upscaler_ctx_t>,
1164    data: sd_image_t,
1165) -> Result<sd_image_t, DiffusionError> {
1166    unsafe {
1167        match upscaler_ctx {
1168            Some(upscaler_ctx) => {
1169                let upscale_factor = 4; // unused for RealESRGAN_x4plus_anime_6B.pth
1170                let mut current_image = data;
1171                for _ in 0..upscale_repeats {
1172                    let upscaled_image =
1173                        diffusion_rs_sys::upscale(upscaler_ctx, current_image, upscale_factor);
1174
1175                    if upscaled_image.data.is_null() {
1176                        return Err(DiffusionError::Upscaler);
1177                    }
1178
1179                    free(current_image.data as *mut c_void);
1180                    current_image = upscaled_image;
1181                }
1182                Ok(current_image)
1183            }
1184            None => Ok(data),
1185        }
1186    }
1187}
1188
1189/// Generate an image and receive update via queue
1190pub fn gen_img_with_progress(
1191    config: &Config,
1192    model_config: &mut ModelConfig,
1193    sender: Sender<Progress>,
1194) -> Result<(), DiffusionError> {
1195    gen_img_maybe_progress(config, model_config, Some(sender))
1196}
1197
1198/// Generate an image
1199pub fn gen_img(config: &Config, model_config: &mut ModelConfig) -> Result<(), DiffusionError> {
1200    gen_img_maybe_progress(config, model_config, None)
1201}
1202
1203fn gen_img_maybe_progress(
1204    config: &Config,
1205    model_config: &mut ModelConfig,
1206    mut sender: Option<Sender<Progress>>,
1207) -> Result<(), DiffusionError> {
1208    let prompt: CLibString = CLibString::from(config.prompt.as_str());
1209    let files = output_files(&config.output, config.batch_count);
1210    unsafe {
1211        let has_init_image = config.init_img.exists();
1212        let has_mask_image = config.mask_img.exists();
1213
1214        let is_decode_only = !has_init_image;
1215        let sd_ctx = model_config.diffusion_ctx(is_decode_only);
1216        let upscaler_ctx = model_config.upscaler_ctx();
1217
1218        let mut init_image = sd_image_t {
1219            width: 0,
1220            height: 0,
1221            channel: 3,
1222            data: std::ptr::null_mut(),
1223        };
1224        let mut mask_image = sd_image_t {
1225            width: config.width as u32,
1226            height: config.height as u32,
1227            channel: 1,
1228            data: null_mut(),
1229        };
1230        let mut layers = config.skip_layer.clone();
1231        let guidance = sd_guidance_params_t {
1232            txt_cfg: config.cfg_scale,
1233            img_cfg: config.cfg_scale,
1234            distilled_guidance: config.guidance,
1235            slg: sd_slg_params_t {
1236                layers: layers.as_mut_ptr(),
1237                layer_count: config.skip_layer.len(),
1238                layer_start: config.skip_layer_start,
1239                layer_end: config.skip_layer_end,
1240                scale: config.slg_scale,
1241            },
1242        };
1243        let scheduler = if model_config.scheduler == Scheduler::SCHEDULER_COUNT {
1244            sd_get_default_scheduler(sd_ctx, config.sampling_method)
1245        } else {
1246            model_config.scheduler
1247        };
1248        let sample_method = if config.sampling_method == SampleMethod::SAMPLE_METHOD_COUNT {
1249            sd_get_default_sample_method(sd_ctx)
1250        } else {
1251            config.sampling_method
1252        };
1253        let sample_params = sd_sample_params_t {
1254            guidance,
1255            sample_method,
1256            sample_steps: config.steps,
1257            eta: config.eta,
1258            scheduler,
1259            shifted_timestep: model_config.timestep_shift,
1260            custom_sigmas: model_config.sigmas.as_mut_ptr(),
1261            custom_sigmas_count: model_config.sigmas.len() as i32,
1262            flow_shift: model_config.flow_shift,
1263        };
1264        let control_image = sd_image_t {
1265            width: 0,
1266            height: 0,
1267            channel: 3,
1268            data: null_mut(),
1269        };
1270        let vae_tiling_params = sd_tiling_params_t {
1271            enabled: model_config.vae_tiling,
1272            tile_size_x: model_config.vae_tile_size.0,
1273            tile_size_y: model_config.vae_tile_size.1,
1274            target_overlap: model_config.vae_tile_overlap,
1275            rel_size_x: model_config.vae_relative_tile_size.0,
1276            rel_size_y: model_config.vae_relative_tile_size.1,
1277        };
1278        let pm_params = sd_pm_params_t {
1279            id_images: null_mut(),
1280            id_images_count: 0,
1281            id_embed_path: model_config.pm_id_embed_path.as_ptr(),
1282            style_strength: config.pm_style_strength,
1283        };
1284
1285        // Declare the buffer in the function scope so it outlives the match block
1286        let mut image_buffer: Vec<u8> = Vec::new();
1287        let mut mask_buffer: Vec<u8> = Vec::new();
1288
1289        if has_init_image {
1290            let img = image::open(&config.init_img)?;
1291            image_buffer = img.to_rgb8().into_raw();
1292
1293            init_image = sd_image_t {
1294                width: img.width(),
1295                height: img.height(),
1296                channel: 3,
1297                data: image_buffer.as_mut_ptr(),
1298            }
1299        }
1300
1301        if has_mask_image {
1302            let img = image::open(&config.mask_img)?;
1303            // Masks have to have single channel luminosity information only
1304            mask_buffer = img.to_luma8().into_raw();
1305
1306            mask_image = sd_image_t {
1307                width: img.width(),
1308                height: img.height(),
1309                channel: 1,
1310                data: mask_buffer.as_mut_ptr(),
1311            }
1312        }
1313
1314        // stable-diffusion.cpp assumes that a mask is also included with the img2img flow
1315        // if a mask is not provided, we use a flat white mask meaning all of the image is in scope
1316        // otherwise generate_image throws a sigsegv when it tries to assign the mask
1317        if !image_buffer.is_empty() && mask_buffer.is_empty() {
1318            let img: ImageBuffer<image::Luma<u8>, Vec<u8>> =
1319                ImageBuffer::from_pixel(init_image.width, init_image.height, image::Luma([255]));
1320            mask_buffer = img.into_raw();
1321            mask_image = sd_image_t {
1322                width: init_image.width,
1323                height: init_image.height,
1324                channel: 1,
1325                data: mask_buffer.as_mut_ptr(),
1326            }
1327        }
1328
1329        let mut ref_image_list = Vec::new();
1330        let mut ref_pixel_storage = Vec::new();
1331        for ref_path in &config.ref_images {
1332            if ref_path.exists() {
1333                let img = image::open(ref_path)?;
1334                let image_data = img.to_rgb8().into_raw();
1335
1336                ref_pixel_storage.push(image_data);
1337                let storage_ref = ref_pixel_storage.last_mut().unwrap();
1338                ref_image_list.push(sd_image_t {
1339                    width: img.width(),
1340                    height: img.height(),
1341                    channel: 3,
1342                    data: storage_ref.as_mut_ptr(),
1343                });
1344            }
1345        }
1346
1347        let num_ref_images = ref_image_list.len();
1348        let ref_image_ptr = if num_ref_images > 0 {
1349            ref_image_list.as_mut_ptr()
1350        } else {
1351            null_mut()
1352        };
1353
1354        unsafe extern "C" fn save_preview_local(
1355            _step: ::std::os::raw::c_int,
1356            _frame_count: ::std::os::raw::c_int,
1357            frames: *mut sd_image_t,
1358            _is_noisy: bool,
1359            data: *mut ::std::os::raw::c_void,
1360        ) {
1361            unsafe {
1362                let path = &*data.cast::<PathBuf>();
1363                let _ = save_img(*frames, path, None);
1364            }
1365        }
1366
1367        if config.preview_mode != PreviewType::PREVIEW_NONE {
1368            let data = &config.preview_output as *const PathBuf;
1369
1370            sd_set_preview_callback(
1371                Some(save_preview_local),
1372                config.preview_mode,
1373                config.preview_interval,
1374                !config.preview_noisy,
1375                config.preview_noisy,
1376                data as *mut c_void,
1377            );
1378        }
1379
1380        if sender.is_some() {
1381            unsafe extern "C" fn progress_callback(
1382                step: ::std::os::raw::c_int,
1383                steps: ::std::os::raw::c_int,
1384                time: f32,
1385                data: *mut ::std::os::raw::c_void,
1386            ) {
1387                unsafe {
1388                    let sender = &*data.cast::<Option<Sender<Progress>>>();
1389
1390                    if let Some(sender) = sender {
1391                        let _ = sender.send(Progress { step, steps, time });
1392                    }
1393                }
1394            }
1395            let sender_ptr: *mut c_void = &mut sender as *mut _ as *mut c_void;
1396            sd_set_progress_callback(Some(progress_callback), sender_ptr);
1397        }
1398
1399        let loras: Vec<sd_lora_t> = model_config
1400            .lora_models
1401            .iter()
1402            .map(|(c_path, spec)| sd_lora_t {
1403                is_high_noise: spec.is_high_noise,
1404                multiplier: spec.multiplier,
1405                path: c_path.as_ptr(),
1406            })
1407            .collect();
1408
1409        let sd_img_gen_params = sd_img_gen_params_t {
1410            prompt: prompt.as_ptr(),
1411            negative_prompt: config.negative_prompt.as_ptr(),
1412            clip_skip: config.clip_skip as i32,
1413            init_image,
1414            ref_images: ref_image_ptr,
1415            ref_images_count: num_ref_images as i32,
1416            increase_ref_index: false,
1417            mask_image,
1418            width: config.width,
1419            height: config.height,
1420            sample_params,
1421            strength: config.strength,
1422            seed: config.seed,
1423            batch_count: config.batch_count,
1424            control_image,
1425            control_strength: config.control_strength,
1426            pm_params,
1427            vae_tiling_params,
1428            auto_resize_ref_image: config.disable_auto_resize_ref_image,
1429            cache: config.cache,
1430            loras: loras.as_ptr(),
1431            lora_count: loras.len() as u32,
1432        };
1433
1434        let params_str = CString::from_raw(sd_img_gen_params_to_str(&sd_img_gen_params))
1435            .into_string()
1436            .unwrap();
1437
1438        let slice = generate_image(sd_ctx, &sd_img_gen_params);
1439        let ret = {
1440            if slice.is_null() {
1441                return Err(DiffusionError::Forward);
1442            }
1443            for (img, path) in slice::from_raw_parts(slice, config.batch_count as usize)
1444                .iter()
1445                .zip(files)
1446            {
1447                match upscale(model_config.upscale_repeats, upscaler_ctx, *img) {
1448                    Ok(img) => save_img(img, &path, Some(&params_str))?,
1449                    Err(err) => {
1450                        return Err(err);
1451                    }
1452                }
1453            }
1454            Ok(())
1455        };
1456        free(slice as *mut c_void);
1457        ret
1458    }
1459}
1460
1461fn save_img(img: sd_image_t, path: &Path, params: Option<&str>) -> Result<(), DiffusionError> {
1462    // Thx @wandbrandon
1463    let len = (img.width * img.height * img.channel) as usize;
1464    let buffer = unsafe { slice::from_raw_parts(img.data, len).to_vec() };
1465    let save_state = ImageBuffer::from_raw(img.width, img.height, buffer).map(|img| {
1466        RgbImage::from(img)
1467            .save(path)
1468            .map_err(DiffusionError::StoreImages)
1469    });
1470    if let Some(Err(err)) = save_state {
1471        return Err(err);
1472    }
1473    if let Some(params) = params {
1474        let mut metadata = Metadata::new();
1475        metadata.set_tag(ExifTag::ImageDescription(params.to_string()));
1476        metadata.write_to_file(path)?;
1477    }
1478    Ok(())
1479}
1480
1481#[cfg(test)]
1482mod tests {
1483    use image::{DynamicImage, ImageBuffer, Rgba};
1484    use std::path::PathBuf;
1485
1486    use crate::{
1487        api::{ConfigBuilderError, ModelConfigBuilder},
1488        util::download_file_hf_hub,
1489    };
1490
1491    use super::{ConfigBuilder, gen_img};
1492
1493    #[test]
1494    fn test_required_args_txt2img() {
1495        assert!(ConfigBuilder::default().build().is_err());
1496        assert!(ModelConfigBuilder::default().build().is_err());
1497        ModelConfigBuilder::default()
1498            .model(PathBuf::from("./test.ckpt"))
1499            .build()
1500            .unwrap();
1501
1502        ConfigBuilder::default()
1503            .prompt("a lovely cat driving a sport car")
1504            .build()
1505            .unwrap();
1506
1507        assert!(matches!(
1508            ConfigBuilder::default()
1509                .prompt("a lovely cat driving a sport car")
1510                .batch_count(10)
1511                .build(),
1512            Err(ConfigBuilderError::ValidationError(_))
1513        ));
1514
1515        ConfigBuilder::default()
1516            .prompt("a lovely cat driving a sport car")
1517            .build()
1518            .unwrap();
1519
1520        ConfigBuilder::default()
1521            .prompt("a lovely duck drinking water from a bottle")
1522            .batch_count(2)
1523            .output(PathBuf::from("./"))
1524            .build()
1525            .unwrap();
1526    }
1527
1528    #[ignore]
1529    #[test]
1530    fn test_img2img_gen() {
1531        let model_path =
1532            download_file_hf_hub("CompVis/stable-diffusion-v-1-4-original", "sd-v1-4.ckpt")
1533                .unwrap();
1534        let gen_img_output = "./output_img.png";
1535        let config = ConfigBuilder::default()
1536            .prompt("A high quality 3d texture")
1537            .output(PathBuf::from(gen_img_output))
1538            .batch_count(1)
1539            .build()
1540            .unwrap();
1541
1542        let mut model_config = ModelConfigBuilder::default()
1543            .model(model_path)
1544            .build()
1545            .unwrap();
1546
1547        gen_img(&config, &mut model_config).unwrap();
1548
1549        // 2. Create conditioning image: Gradient square
1550        let mut cond = ImageBuffer::new(512, 512);
1551        for (x, y, pixel) in cond.enumerate_pixels_mut() {
1552            let r = (x as f32 / 512.0 * 255.0) as u8;
1553            let g = (y as f32 / 512.0 * 255.0) as u8;
1554            let b = 127;
1555            *pixel = Rgba([r, g, b, 255]);
1556        }
1557        let cond_path = "test_cond_image.png";
1558        DynamicImage::ImageRgba8(cond)
1559            .save(cond_path)
1560            .expect("Failed to save reference image");
1561
1562        // 3. Call refinement using the generated image as input
1563        let refine_prompt = "PBR texture map, matching the lighting and micro-detail density of the reference image.";
1564        let img2img_config = ConfigBuilder::default()
1565            .prompt(refine_prompt)
1566            .output(PathBuf::from("./output_img_ref.png"))
1567            .ref_images(vec![PathBuf::from(cond_path)])
1568            .init_img(PathBuf::from(gen_img_output))
1569            .batch_count(1)
1570            .build()
1571            .unwrap();
1572        gen_img(&img2img_config, &mut model_config).unwrap();
1573
1574        // 4. Ensure decoder only mode works after img2img generation
1575        gen_img(&config, &mut model_config).unwrap();
1576    }
1577
1578    #[ignore]
1579    #[test]
1580    fn test_img_gen() {
1581        let model_path =
1582            download_file_hf_hub("CompVis/stable-diffusion-v-1-4-original", "sd-v1-4.ckpt")
1583                .unwrap();
1584
1585        let upscaler_path = download_file_hf_hub(
1586            "ximso/RealESRGAN_x4plus_anime_6B",
1587            "RealESRGAN_x4plus_anime_6B.pth",
1588        )
1589        .unwrap();
1590        let config = ConfigBuilder::default()
1591            .prompt("a lovely duck drinking water from a bottle")
1592            .output(PathBuf::from("./output_1.png"))
1593            .batch_count(1)
1594            .build()
1595            .unwrap();
1596        let mut model_config = ModelConfigBuilder::default()
1597            .model(model_path)
1598            .upscale_model(upscaler_path)
1599            .upscale_repeats(1)
1600            .build()
1601            .unwrap();
1602
1603        gen_img(&config, &mut model_config).unwrap();
1604        let config2 = ConfigBuilder::from(config.clone())
1605            .prompt("a lovely duck drinking water from a straw")
1606            .output(PathBuf::from("./output_2.png"))
1607            .build()
1608            .unwrap();
1609        gen_img(&config2, &mut model_config).unwrap();
1610
1611        let config3 = ConfigBuilder::from(config)
1612            .prompt("a lovely dog drinking water from a starbucks cup")
1613            .batch_count(2)
1614            .output(PathBuf::from("./"))
1615            .build()
1616            .unwrap();
1617
1618        gen_img(&config3, &mut model_config).unwrap();
1619    }
1620}