Skip to main content

diffusion_rs/
api.rs

1use std::collections::HashMap;
2use std::ffi::CString;
3use std::ffi::c_char;
4use std::ffi::c_void;
5use std::path::Path;
6use std::path::PathBuf;
7use std::ptr::null;
8use std::ptr::null_mut;
9use std::slice;
10use std::str::FromStr;
11use std::sync::mpsc::Sender;
12
13use chrono::Local;
14use derive_builder::Builder;
15use diffusion_rs_sys::free_upscaler_ctx;
16use diffusion_rs_sys::generate_image;
17use diffusion_rs_sys::new_upscaler_ctx;
18use diffusion_rs_sys::sd_cache_mode_t;
19use diffusion_rs_sys::sd_cache_params_t;
20use diffusion_rs_sys::sd_ctx_params_t;
21use diffusion_rs_sys::sd_embedding_t;
22use diffusion_rs_sys::sd_get_default_sample_method;
23use diffusion_rs_sys::sd_get_default_scheduler;
24use diffusion_rs_sys::sd_guidance_params_t;
25use diffusion_rs_sys::sd_hires_params_t;
26use diffusion_rs_sys::sd_image_t;
27use diffusion_rs_sys::sd_img_gen_params_t;
28use diffusion_rs_sys::sd_img_gen_params_to_str;
29use diffusion_rs_sys::sd_lora_t;
30use diffusion_rs_sys::sd_pm_params_t;
31use diffusion_rs_sys::sd_sample_params_t;
32use diffusion_rs_sys::sd_set_preview_callback;
33use diffusion_rs_sys::sd_set_progress_callback;
34use diffusion_rs_sys::sd_slg_params_t;
35use diffusion_rs_sys::sd_tiling_params_t;
36use diffusion_rs_sys::upscaler_ctx_t;
37use image::ImageBuffer;
38use image::ImageError;
39use image::RgbImage;
40use libc::free;
41use little_exif::exif_tag::ExifTag;
42use little_exif::metadata::Metadata;
43use thiserror::Error;
44use walkdir::DirEntry;
45use walkdir::WalkDir;
46
47use diffusion_rs_sys::free_sd_ctx;
48use diffusion_rs_sys::new_sd_ctx;
49use diffusion_rs_sys::sd_ctx_t;
50
51/// Specify the range function
52pub use diffusion_rs_sys::rng_type_t as RngFunction;
53
54/// Sampling methods
55pub use diffusion_rs_sys::sample_method_t as SampleMethod;
56
57/// Denoiser sigma schedule
58pub use diffusion_rs_sys::scheduler_t as Scheduler;
59
60/// Prediction override
61pub use diffusion_rs_sys::prediction_t as Prediction;
62
63/// Weight type
64pub use diffusion_rs_sys::sd_type_t as WeightType;
65
66/// Preview mode
67pub use diffusion_rs_sys::preview_t as PreviewType;
68
69/// Lora mode
70pub use diffusion_rs_sys::lora_apply_mode_t as LoraModeType;
71
72/// Hires mode
73pub use diffusion_rs_sys::sd_hires_upscaler_t as Upscaler;
74
75/// VAE latent format
76pub use diffusion_rs_sys::sd_vae_format_t as VaeFormat;
77
78static VALID_EXT: [&str; 3] = ["gguf", "safetensors", "pt"];
79
80#[allow(unused)]
81#[derive(Debug)]
82/// Progress message returned fron [gen_img_with_progress]
83pub struct Progress {
84    step: i32,
85    steps: i32,
86    time: f32,
87}
88
89#[non_exhaustive]
90#[derive(Error, Debug)]
91/// Error that can occurs while forwarding models
92pub enum DiffusionError {
93    #[error("The underling stablediffusion.cpp function returned NULL")]
94    Forward,
95    #[error(transparent)]
96    StoreImages(#[from] ImageError),
97    #[error(transparent)]
98    Io(#[from] std::io::Error),
99    #[error("The underling upscaler model returned a NULL image")]
100    Upscaler,
101}
102
103#[non_exhaustive]
104#[derive(Clone, Debug, strum::Display, strum::EnumIter)]
105#[strum(serialize_all = "lowercase")]
106/// Backend devices
107pub enum BackendDevice {
108    CPU,
109    CUDA0,
110    VULKAN0,
111    METAL,
112    GPU,
113    AUTO,
114    DISK,
115    DEFAULT,
116}
117
118#[non_exhaustive]
119#[derive(Clone, Debug, strum::Display, strum::EnumIter, PartialEq, Eq, Hash)]
120#[strum(serialize_all = "lowercase")]
121/// Module that can be bound to a specific [BackendDevice]
122pub enum Module {
123    Diffusion,
124    Model,
125    Unet,
126    Dit,
127    Te,
128    Clip,
129    Text,
130    Textencoder,
131    Textencoders,
132    Conditioner,
133    Cond,
134    Llm,
135    T5,
136    T5xxl,
137    ClipVision,
138    Vision,
139    Vae,
140    Firststage,
141    Autoencoder,
142    Tae,
143    Controlnet,
144    Control,
145    Photomaker,
146    PhotomakerId,
147    PmId,
148    Photo,
149    Upscaler,
150    Esrgan,
151    Hires,
152}
153
154#[repr(i32)]
155#[non_exhaustive]
156#[derive(Debug, Default, Copy, Clone, Hash, PartialEq, Eq)]
157/// Ignore the lower X layers of CLIP network
158pub enum ClipSkip {
159    /// Will be [ClipSkip::None] for SD1.x, [ClipSkip::OneLayer] for SD2.x
160    #[default]
161    Unspecified = 0,
162    None = 1,
163    OneLayer = 2,
164}
165
166type EmbeddingsStorage = (PathBuf, Vec<(CLibString, CLibPath)>, Vec<sd_embedding_t>);
167type LoraStorage = Vec<(CLibPath, LoraSpec)>;
168
169/// Specify the instructions for a Lora model
170#[derive(Default, Debug, Clone)]
171pub struct LoraSpec {
172    pub file_name: String,
173    pub is_high_noise: bool,
174    pub multiplier: f32,
175}
176
177/// Parameters for Spectrum Caching
178#[derive(Builder, Debug, Clone)]
179pub struct SpectrumCacheParams {
180    /// Chebyshev vs Taylor blend weight (0=Taylor, 1=Chebyshev)
181    #[builder(default = "0.40")]
182    w: f32,
183
184    /// Chebyshev polynomial degree
185    #[builder(default = "3")]
186    m: i32,
187
188    /// Ridge regression regularization
189    #[builder(default = "1.0")]
190    lam: f32,
191
192    /// Initial window size (compute every N steps)
193    #[builder(default = "2")]
194    window: i32,
195
196    /// Window growth per computed step after warmup
197    #[builder(default = "0.50")]
198    flex: f32,
199
200    /// Steps to always compute before caching starts
201    #[builder(default = "4")]
202    warmup: i32,
203
204    /// Stop caching at this fraction of total steps
205    #[builder(default = "0.9")]
206    stop: f32,
207}
208
209/// Parameters for UCache
210#[derive(Builder, Debug, Clone)]
211pub struct UCacheParams {
212    /// Error threshold for reuse decision
213    #[builder(default = "1.0")]
214    threshold: f32,
215
216    /// Start caching at this percent of steps
217    #[builder(default = "0.15")]
218    start: f32,
219
220    /// Stop caching at this percent of steps
221    #[builder(default = "0.95")]
222    end: f32,
223
224    /// Error decay rate (0-1)
225    #[builder(default = "1.0")]
226    decay: f32,
227
228    /// Scale threshold by output norm
229    #[builder(default = "true")]
230    relative: bool,
231
232    /// Reset error after computing
233    /// true: Resets accumulated error after each computed step. More aggressive caching, works well with most samplers.
234    /// false: Keeps error accumulated. More conservative, recommended for euler_a sampler
235    #[builder(default = "true")]
236    reset: bool,
237}
238
239/// Parameters for Easy Cache
240#[derive(Builder, Debug, Clone)]
241pub struct EasyCacheParams {
242    /// Error threshold for reuse decision
243    #[builder(default = "0.2")]
244    threshold: f32,
245
246    /// Start caching at this percent of steps
247    #[builder(default = "0.15")]
248    start: f32,
249
250    /// Stop caching at this percent of steps
251    #[builder(default = "0.95")]
252    end: f32,
253}
254
255/// Parameters for Db Cache
256#[derive(Builder, Debug, Clone)]
257pub struct DbCacheParams {
258    /// Front blocks to always compute
259    #[builder(default = "8")]
260    fn_blocks: i32,
261
262    /// Back blocks to always compute
263    #[builder(default = "0")]
264    bn_blocks: i32,
265
266    /// L1 residual difference threshold
267    #[builder(default = "0.08")]
268    threshold: f32,
269
270    /// Steps before caching starts
271    #[builder(default = "8")]
272    warmup: i32,
273
274    /// Steps Computation Mask controls which steps can be cached
275    /// E.g.: "1,1,1,1,0,0,1,0,0,0,1,0,0,0,1,0,0,0,1,1"
276    /// where 1 means compute, 0 means cache
277    #[builder(default = "CLibString::default()")]
278    scm_mask: CLibString,
279
280    /// Scm Policy
281    #[builder(default = "ScmPolicy::default()")]
282    scm_policy_dynamic: ScmPolicy,
283}
284
285/// Steps Computation Mask Policy controls when to cache steps
286#[derive(Debug, Default, Clone)]
287pub enum ScmPolicy {
288    /// Always cache on cacheable steps
289    Static,
290    #[default]
291    /// Check threshold before caching
292    Dynamic,
293}
294
295/// Hires parameters
296#[derive(Builder, Debug, Clone)]
297pub struct HiresParams {
298    /// highres fix target width, 0 to use scale (default: 0)
299    #[builder(default = "0")]
300    width: i32,
301    /// highres fix target height, 0 to use scale (default: 0)
302    #[builder(default = "0")]
303    height: i32,
304    /// highres fix target width, 0 to use scale (default: 0)
305    #[builder(default = "0")]
306    steps: i32,
307    /// highres fix upscaler tile size, reserved for model-backed upscalers (default: 128)
308    #[builder(default = "128")]
309    upscale_tile_size: i32,
310    /// highres fix scale when sizes is not set (default: 2.0)
311    #[builder(default = "2.0")]
312    scale: f32,
313    /// highres fix second pass denoising strength (default: 0.7)
314    #[builder(default = "0.7")]
315    denoising_strength: f32,
316    /// Custom sigma values for the highres fix second pass
317    #[builder(default = "None")]
318    hires_sigmas: Option<Vec<f32>>,
319}
320
321/// Config struct for a specific diffusion model
322#[derive(Builder, Debug)]
323#[builder(
324    setter(into, strip_option),
325    build_fn(error = "ConfigBuilderError", validate = "Self::validate")
326)]
327pub struct ModelConfig {
328    /// Number of threads to use during computation (default: 0).
329    /// If n_ threads <= 0, then threads will be set to the number of CPU physical cores.
330    #[builder(default = "num_cpus::get_physical() as i32", setter(custom))]
331    n_threads: i32,
332
333    /// Whether to memory-map model (default: false)
334    #[builder(default = "false")]
335    enable_mmap: bool,
336
337    /// Maximum VRAM budget in GiB for graph-cut segmented execution. 0 disables graph splitting; -1 auto-detects free VRAM minus 1 GiB
338    #[builder(default = "-1.0")]
339    max_vram: f32,
340
341    /// Path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now
342    #[builder(default = "Default::default()")]
343    upscale_model: Option<CLibPath>,
344
345    /// Run the ESRGAN upscaler this many times (default 1)
346    #[builder(default = "1")]
347    upscale_repeats: i32,
348
349    /// Tile size for ESRGAN upscaler (default 128)
350    #[builder(default = "128")]
351    upscale_tile_size: i32,
352
353    /// Path to full model
354    #[builder(default = "Default::default()")]
355    model: CLibPath,
356
357    /// Path to the standalone diffusion model
358    #[builder(default = "Default::default()")]
359    diffusion_model: CLibPath,
360
361    /// Path to the standalone unconditional diffusion model, currently used by Ideogram4 CFG
362    #[builder(default = "Default::default()")]
363    unconditional_diffusion_model: CLibPath,
364
365    /// Path to the qwen2vl text encoder
366    #[builder(default = "Default::default()")]
367    llm: CLibPath,
368
369    /// Path to the qwen2vl vit
370    #[builder(default = "Default::default()")]
371    llm_vision: CLibPath,
372
373    /// Path to the clip-l text encoder
374    #[builder(default = "Default::default()")]
375    clip_l: CLibPath,
376
377    /// Path to the clip-g text encoder
378    #[builder(default = "Default::default()")]
379    clip_g: CLibPath,
380
381    /// Path to the clip-vision encoder
382    #[builder(default = "Default::default()")]
383    clip_vision: CLibPath,
384
385    /// Path to the t5xxl text encoder
386    #[builder(default = "Default::default()")]
387    t5xxl: CLibPath,
388
389    /// Path to vae
390    #[builder(default = "Default::default()")]
391    vae: CLibPath,
392
393    /// should match the VAE latent layout used by the PiD checkpoint. This is important when using standalone VAE files because the PiD diffusion checkpoint alone does not identify the VAE format.
394    #[builder(default = "VaeFormat::SD_VAE_FORMAT_AUTO")]
395    vae_format: VaeFormat,
396
397    /// Path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)
398    #[builder(default = "Default::default()")]
399    taesd: CLibPath,
400
401    /// Path to control net model
402    #[builder(default = "Default::default()")]
403    control_net: CLibPath,
404
405    /// Path to embeddings
406    #[builder(default = "Default::default()", setter(custom))]
407    embeddings: EmbeddingsStorage,
408
409    /// Path to PHOTOMAKER model
410    #[builder(default = "Default::default()")]
411    photo_maker: CLibPath,
412
413    /// Path to PHOTOMAKER v2 id embed
414    #[builder(default = "Default::default()")]
415    pm_id_embed_path: CLibPath,
416
417    /// Weight type. If not specified, the default is the type of the weight file
418    #[builder(default = "WeightType::SD_TYPE_COUNT")]
419    weight_type: WeightType,
420
421    /// Lora model directory
422    #[builder(default = "Default::default()", setter(custom))]
423    lora_models: LoraStorage,
424
425    /// Path to the standalone high noise diffusion model
426    #[builder(default = "Default::default()")]
427    high_noise_diffusion_model: CLibPath,
428
429    /// Process vae in tiles to reduce memory usage (default: false)
430    #[builder(default = "false")]
431    vae_tiling: bool,
432
433    /// Tile size for vae tiling (default: 32x32)
434    #[builder(default = "(32,32)")]
435    vae_tile_size: (i32, i32),
436
437    /// Relative tile size for vae tiling, in fraction of image size if < 1, in number of tiles per dim if >=1 (overrides vae_tile_size)
438    #[builder(default = "(0.,0.)")]
439    vae_relative_tile_size: (f32, f32),
440
441    /// Tile overlap for vae tiling, in fraction of tile size (default: 0.5)
442    #[builder(default = "0.5")]
443    vae_tile_overlap: f32,
444
445    /// RNG (default: CUDA)
446    #[builder(default = "RngFunction::CUDA_RNG")]
447    rng: RngFunction,
448
449    /// Sampler RNG. If [RngFunction::RNG_TYPE_COUNT] is used will default to rng value. (default: [RngFunction::RNG_TYPE_COUNT])",
450    #[builder(default = "RngFunction::RNG_TYPE_COUNT")]
451    sampler_rng_type: RngFunction,
452
453    /// Denoiser sigma schedule (default: [Scheduler::SCHEDULER_COUNT]).
454    /// Will default to [Scheduler::EXPONENTIAL_SCHEDULER] if a denoiser is already instantiated.
455    /// Otherwise, [Scheduler::DISCRETE_SCHEDULER] is used.
456    #[builder(default = "Scheduler::SCHEDULER_COUNT")]
457    scheduler: Scheduler,
458
459    /// Custom sigma values for the sampler
460    #[builder(default = "Default::default()")]
461    sigmas: Vec<f32>,
462
463    /// Prediction type override (default: PREDICTION_COUNT)
464    #[builder(default = "Prediction::PREDICTION_COUNT")]
465    prediction: Prediction,
466
467    /// Use flash attention to reduce memory usage (model only).
468    /// For most backends, it slows things down, but for cuda it generally speeds it up too. At the moment, it is only supported for some models and some backends (like cpu, cuda/rocm, metal).
469    #[builder(default = "false")]
470    diffusion_flash_attention: bool,
471
472    /// Use flash attention to reduce memory usage.
473    /// For most backends, it slows things down, but for cuda it generally speeds it up too. At the moment, it is only supported for some models and some backends (like cpu, cuda/rocm, metal).
474    #[builder(default = "false")]
475    flash_attention: bool,
476
477    /// Disable dit mask for chroma
478    #[builder(default = "false")]
479    chroma_disable_dit_mask: bool,
480
481    /// Enable t5 mask for chroma
482    #[builder(default = "false")]
483    chroma_enable_t5_mask: bool,
484
485    /// t5 mask pad size of chroma
486    #[builder(default = "1")]
487    chroma_t5_mask_pad: i32,
488
489    /// Use qwen image zero cond true optimization
490    #[builder(default = "false")]
491    use_qwen_image_zero_cond_true: bool,
492
493    /// Use Conv2d direct in the diffusion model
494    /// This might crash if it is not supported by the backend.
495    #[builder(default = "false")]
496    diffusion_conv_direct: bool,
497
498    /// Use Conv2d direct in the vae model (should improve the performance)
499    /// This might crash if it is not supported by the backend.
500    #[builder(default = "false")]
501    vae_conv_direct: bool,
502
503    /// Force use of conv scale on sdxl vae
504    #[builder(default = "false")]
505    force_sdxl_vae_conv_scale: bool,
506
507    /// Shift value for Flow models like SD3.x or WAN (default: auto)
508    #[builder(default = "f32::INFINITY")]
509    flow_shift: f32,
510
511    /// Shift timestep for NitroFusion models, default: 0, recommended N for NitroSD-Realism around 250 and 500 for NitroSD-Vibrant
512    #[builder(default = "0")]
513    timestep_shift: i32,
514
515    /// Prevents usage of taesd for decoding the final image
516    #[builder(default = "false")]
517    taesd_preview_only: bool,
518
519    /// In auto mode, if the model weights contain any quantized parameters, the at_runtime mode will be used; otherwise, immediately will be used.The immediate mode may have precision and compatibility issues with quantized parameters, but it usually offers faster inference speed and, in some cases, lower memory usage. The at_runtime mode, on the other hand, is exactly the opposite
520    #[builder(default = "LoraModeType::LORA_APPLY_AUTO")]
521    lora_apply_mode: LoraModeType,
522
523    /// Enable circular padding for convolutions
524    #[builder(default = "false")]
525    circular: bool,
526
527    /// Enable circular RoPE wrapping on x-axis (width) only
528    #[builder(default = "false")]
529    circular_x: bool,
530
531    /// Enable circular RoPE wrapping on y-axis (height) only
532    #[builder(default = "false")]
533    circular_y: bool,
534
535    /// Hires fix parameters and upscaler model.
536    #[builder(default = "Self::hires_init()", setter(custom))]
537    hires_params: (Upscaler, HiresParams, Option<CLibPath>),
538
539    /// Extra parameters for sampling, currently used for SDXL sample params, in json string format
540    #[builder(default = "CLibString::default()")]
541    extra_sample_params: CLibString,
542
543    /// Select the runtime backend used to execute model graphs
544    #[builder(default = "(None, CLibString::default())", setter(custom))]
545    backend: (Option<HashMap<Module, BackendDevice>>, CLibString),
546
547    /// Select the backend used to allocate model parameters
548    #[builder(default = "(None, CLibString::default())", setter(custom))]
549    params_backend: (Option<HashMap<Module, BackendDevice>>, CLibString),
550
551    /// Path to LTXAV embeddings connectors
552    #[builder(default = "CLibPath::default()")]
553    embeddings_connectors: CLibPath,
554
555    /// Path to standalone LTX audio vae model
556    #[builder(default = "CLibPath::default()")]
557    audio_vae: CLibPath,
558
559    /// Enable temporal tiling for LTX video VAE decode
560    #[builder(default = "true")]
561    vae_temporal_tiling: bool,
562
563    /// Extra VAE tiling args, key=value list. LTX video VAE supports
564    #[builder(default = "(None, CLibString::default())", setter(custom))]
565    extra_tiling_args: (Option<HashMap<String, String>>, CLibString),
566
567    /// Enable residency+prefetch streaming on top of [ModelConfig::max_vram] (no effect without [ModelConfig::max_vram]; defaults to false)
568    #[builder(default = "false")]
569    stream_layers: bool,
570
571    #[builder(default = "None", private)]
572    upscaler_ctx: Option<*mut upscaler_ctx_t>,
573
574    #[builder(default = "None", private)]
575    diffusion_ctx: Option<(*mut sd_ctx_t, sd_ctx_params_t)>,
576}
577
578impl ModelConfigBuilder {
579    fn validate(&self) -> Result<(), ConfigBuilderError> {
580        self.validate_model()
581    }
582
583    fn validate_model(&self) -> Result<(), ConfigBuilderError> {
584        self.model
585            .as_ref()
586            .or(self.diffusion_model.as_ref())
587            .map(|_| ())
588            .ok_or(ConfigBuilderError::UninitializedField(
589                "Model OR DiffusionModel must be valorized",
590            ))
591    }
592
593    fn filter_valid_extensions(path: &Path) -> impl Iterator<Item = DirEntry> {
594        WalkDir::new(path)
595            .into_iter()
596            .filter_map(|entry| entry.ok())
597            .filter(|entry| {
598                entry
599                    .path()
600                    .extension()
601                    .and_then(|ext| ext.to_str())
602                    .map(|ext_str| VALID_EXT.contains(&ext_str))
603                    .unwrap_or(false)
604            })
605    }
606    fn build_single_lora_storage(
607        spec: &LoraSpec,
608        valid_loras: &HashMap<String, PathBuf>,
609    ) -> (CLibPath, LoraSpec) {
610        let path = valid_loras.get(&spec.file_name).unwrap().as_path();
611        let c_path = CLibPath::from(path);
612        (c_path, spec.clone())
613    }
614
615    pub fn embeddings(&mut self, embeddings_dir: &Path) -> &mut Self {
616        let data: Vec<(CLibString, CLibPath)> = Self::filter_valid_extensions(embeddings_dir)
617            .map(|entry| {
618                let file_stem = entry
619                    .path()
620                    .file_stem()
621                    .and_then(|stem| stem.to_str())
622                    .unwrap_or_default()
623                    .to_owned();
624                (CLibString::from(file_stem), CLibPath::from(entry.path()))
625            })
626            .collect();
627        let data_pointer = data
628            .iter()
629            .map(|(name, path)| sd_embedding_t {
630                name: name.as_ptr(),
631                path: path.as_ptr(),
632            })
633            .collect();
634        self.embeddings = Some((embeddings_dir.to_path_buf(), data, data_pointer));
635        self
636    }
637
638    pub fn lora_models(&mut self, lora_model_dir: &Path, specs: Vec<LoraSpec>) -> &mut Self {
639        let valid_loras: HashMap<String, PathBuf> = Self::filter_valid_extensions(lora_model_dir)
640            .map(|entry| {
641                let path = entry.path();
642                (
643                    path.file_stem()
644                        .and_then(|stem| stem.to_str())
645                        .unwrap_or_default()
646                        .to_owned(),
647                    path.to_path_buf(),
648                )
649            })
650            .collect();
651        let valid_lora_names: Vec<&String> = valid_loras.keys().collect();
652        let standard = specs
653            .iter()
654            .filter(|s| valid_lora_names.contains(&&s.file_name) && !s.is_high_noise)
655            .map(|s| Self::build_single_lora_storage(s, &valid_loras));
656        let high_noise = specs
657            .iter()
658            .filter(|s| valid_lora_names.contains(&&s.file_name) && s.is_high_noise)
659            .map(|s| Self::build_single_lora_storage(s, &valid_loras));
660
661        self.lora_models_internal(standard.chain(high_noise).collect())
662    }
663
664    fn lora_models_internal(&mut self, lora_storage: LoraStorage) -> &mut Self {
665        self.lora_models = Some(lora_storage);
666        self
667    }
668
669    pub fn n_threads(&mut self, value: i32) -> &mut Self {
670        self.n_threads = if value > 0 {
671            Some(value)
672        } else {
673            Some(num_cpus::get_physical() as i32)
674        };
675        self
676    }
677
678    pub fn hires_params(
679        &mut self,
680        upscaler: Upscaler,
681        params: HiresParams,
682        custom_model: Option<&Path>,
683    ) -> &mut Self {
684        if upscaler == Upscaler::SD_HIRES_UPSCALER_COUNT
685            || (upscaler == Upscaler::SD_HIRES_UPSCALER_MODEL && custom_model.is_none())
686        {
687            panic!("Invalid combination for {upscaler:?} and {custom_model:?}")
688        }
689        self.hires_params = Some((upscaler, params, custom_model.map(Into::into)));
690
691        self
692    }
693
694    fn hires_init() -> (Upscaler, HiresParams, Option<CLibPath>) {
695        (
696            Upscaler::SD_HIRES_UPSCALER_NONE,
697            HiresParamsBuilder::default().build().unwrap(),
698            None,
699        )
700    }
701
702    pub fn backend(&mut self, backend_map: HashMap<Module, BackendDevice>) -> &mut Self {
703        let backend_str = backend_map
704            .iter()
705            .map(|(key, value)| format!("{}={}", key, value))
706            .collect::<Vec<String>>()
707            .join(",");
708        self.backend = Some((Some(backend_map), CLibString::from(backend_str)));
709        self
710    }
711
712    pub fn params_backend(&mut self, backend_map: HashMap<Module, BackendDevice>) -> &mut Self {
713        let params_backend_str = backend_map
714            .iter()
715            .map(|(key, value)| format!("{}={}", key, value))
716            .collect::<Vec<String>>()
717            .join(",");
718        self.params_backend = Some((Some(backend_map), CLibString::from(params_backend_str)));
719        self
720    }
721
722    pub fn extra_tiling_args(
723        &mut self,
724        extra_tiling_args_map: HashMap<String, String>,
725    ) -> &mut Self {
726        let extra_tiling_args_str = extra_tiling_args_map
727            .iter()
728            .map(|(key, value)| format!("{}={}", key, value))
729            .collect::<Vec<String>>()
730            .join(",");
731        self.extra_tiling_args = Some((
732            Some(extra_tiling_args_map),
733            CLibString::from(extra_tiling_args_str),
734        ));
735        self
736    }
737}
738
739impl ModelConfig {
740    unsafe fn upscaler_ctx(&mut self) -> Option<*mut upscaler_ctx_t> {
741        unsafe {
742            if self.upscale_model.is_none() || self.upscale_repeats == 0 {
743                None
744            } else {
745                if self.upscaler_ctx.is_none() {
746                    let upscaler = new_upscaler_ctx(
747                        self.upscale_model.as_ref().unwrap().as_ptr(),
748                        self.diffusion_conv_direct,
749                        self.n_threads,
750                        self.upscale_tile_size,
751                        self.backend.1.as_ptr(),
752                        self.params_backend.1.as_ptr(),
753                    );
754                    self.upscaler_ctx = Some(upscaler);
755                }
756                self.upscaler_ctx
757            }
758        }
759    }
760
761    unsafe fn diffusion_ctx(&mut self) -> *mut sd_ctx_t {
762        unsafe {
763            // This is required to support img2img after text2img generation
764            // otherwise the context is cached and won't have a decode graph
765            // leading to an assertion error in sdcpp
766            if let Some((sd_ctx, _)) = self.diffusion_ctx.as_ref() {
767                sd_set_progress_callback(None, null_mut());
768                free_sd_ctx(*sd_ctx);
769                self.diffusion_ctx = None;
770            }
771            if self.diffusion_ctx.is_none() {
772                let sd_ctx_params = sd_ctx_params_t {
773                    model_path: self.model.as_ptr(),
774                    llm_path: self.llm.as_ptr(),
775                    llm_vision_path: self.llm_vision.as_ptr(),
776                    clip_l_path: self.clip_l.as_ptr(),
777                    clip_g_path: self.clip_g.as_ptr(),
778                    clip_vision_path: self.clip_vision.as_ptr(),
779                    high_noise_diffusion_model_path: self.high_noise_diffusion_model.as_ptr(),
780                    t5xxl_path: self.t5xxl.as_ptr(),
781                    diffusion_model_path: self.diffusion_model.as_ptr(),
782                    uncond_diffusion_model_path: self.unconditional_diffusion_model.as_ptr(),
783                    vae_path: self.vae.as_ptr(),
784                    taesd_path: self.taesd.as_ptr(),
785                    control_net_path: self.control_net.as_ptr(),
786                    embeddings: self.embeddings.2.as_ptr(),
787                    embedding_count: self.embeddings.1.len() as u32,
788                    photo_maker_path: self.photo_maker.as_ptr(),
789                    n_threads: self.n_threads,
790                    wtype: self.weight_type,
791                    rng_type: self.rng,
792                    diffusion_flash_attn: self.diffusion_flash_attention,
793                    flash_attn: self.flash_attention,
794                    diffusion_conv_direct: self.diffusion_conv_direct,
795                    chroma_use_dit_mask: !self.chroma_disable_dit_mask,
796                    chroma_use_t5_mask: self.chroma_enable_t5_mask,
797                    chroma_t5_mask_pad: self.chroma_t5_mask_pad,
798                    vae_conv_direct: self.vae_conv_direct,
799                    prediction: self.prediction,
800                    force_sdxl_vae_conv_scale: self.force_sdxl_vae_conv_scale,
801                    tae_preview_only: self.taesd_preview_only,
802                    lora_apply_mode: self.lora_apply_mode,
803                    tensor_type_rules: null_mut(),
804                    sampler_rng_type: self.sampler_rng_type,
805                    circular_x: self.circular || self.circular_x,
806                    circular_y: self.circular || self.circular_y,
807                    qwen_image_zero_cond_t: self.use_qwen_image_zero_cond_true,
808                    enable_mmap: self.enable_mmap,
809                    max_vram: self.max_vram,
810                    backend: self.backend.1.as_ptr(),
811                    params_backend: self.params_backend.1.as_ptr(),
812                    embeddings_connectors_path: self.embeddings_connectors.as_ptr(),
813                    audio_vae_path: self.audio_vae.as_ptr(),
814                    vae_format: self.vae_format,
815                    stream_layers: self.stream_layers,
816                };
817                let ctx = new_sd_ctx(&sd_ctx_params);
818                self.diffusion_ctx = Some((ctx, sd_ctx_params))
819            }
820            self.diffusion_ctx.unwrap().0
821        }
822    }
823}
824
825impl Drop for ModelConfig {
826    fn drop(&mut self) {
827        //Cleanup CTX section
828        unsafe {
829            if let Some((sd_ctx, _)) = self.diffusion_ctx {
830                free_sd_ctx(sd_ctx);
831            }
832
833            if let Some(upscaler_ctx) = self.upscaler_ctx {
834                free_upscaler_ctx(upscaler_ctx);
835            }
836        }
837    }
838}
839
840impl From<&ModelConfig> for ModelConfigBuilder {
841    fn from(value: &ModelConfig) -> Self {
842        let mut builder = ModelConfigBuilder::default();
843        let hires_path = value
844            .hires_params
845            .2
846            .clone()
847            .map(|f| PathBuf::from_str(&f.0.into_string().unwrap()).unwrap());
848        builder
849            .n_threads(value.n_threads)
850            .max_vram(value.max_vram)
851            .upscale_repeats(value.upscale_repeats)
852            .model(value.model.clone())
853            .diffusion_model(value.diffusion_model.clone())
854            .unconditional_diffusion_model(value.unconditional_diffusion_model.clone())
855            .llm(value.llm.clone())
856            .llm_vision(value.llm_vision.clone())
857            .clip_l(value.clip_l.clone())
858            .clip_g(value.clip_g.clone())
859            .clip_vision(value.clip_vision.clone())
860            .t5xxl(value.t5xxl.clone())
861            .vae(value.vae.clone())
862            .taesd(value.taesd.clone())
863            .control_net(value.control_net.clone())
864            .embeddings(&value.embeddings.0)
865            .photo_maker(value.photo_maker.clone())
866            .pm_id_embed_path(value.pm_id_embed_path.clone())
867            .weight_type(value.weight_type)
868            .high_noise_diffusion_model(value.high_noise_diffusion_model.clone())
869            .vae_tiling(value.vae_tiling)
870            .vae_tile_size(value.vae_tile_size)
871            .vae_relative_tile_size(value.vae_relative_tile_size)
872            .vae_tile_overlap(value.vae_tile_overlap)
873            .rng(value.rng)
874            .sampler_rng_type(value.rng)
875            .scheduler(value.scheduler)
876            .sigmas(value.sigmas.clone())
877            .prediction(value.prediction)
878            .control_net(value.control_net.clone())
879            .flash_attention(value.flash_attention)
880            .chroma_disable_dit_mask(value.chroma_disable_dit_mask)
881            .chroma_enable_t5_mask(value.chroma_enable_t5_mask)
882            .chroma_t5_mask_pad(value.chroma_t5_mask_pad)
883            .diffusion_conv_direct(value.diffusion_conv_direct)
884            .vae_conv_direct(value.vae_conv_direct)
885            .force_sdxl_vae_conv_scale(value.force_sdxl_vae_conv_scale)
886            .flow_shift(value.flow_shift)
887            .timestep_shift(value.timestep_shift)
888            .taesd_preview_only(value.taesd_preview_only)
889            .lora_apply_mode(value.lora_apply_mode)
890            .circular(value.circular)
891            .circular_x(value.circular_x)
892            .circular_y(value.circular_y)
893            .use_qwen_image_zero_cond_true(value.use_qwen_image_zero_cond_true)
894            .hires_params(
895                value.hires_params.0,
896                value.hires_params.1.clone(),
897                hires_path.as_deref(),
898            )
899            .extra_sample_params(value.extra_sample_params.clone())
900            .backend(value.backend.0.clone().unwrap_or_default())
901            .params_backend(value.params_backend.0.clone().unwrap_or_default())
902            .extra_tiling_args(value.extra_tiling_args.0.clone().unwrap_or_default());
903
904        builder.lora_models_internal(value.lora_models.clone());
905
906        if let Some(model) = &value.upscale_model {
907            builder.upscale_model(model.clone());
908        }
909        builder
910    }
911}
912
913#[derive(Builder, Debug)]
914#[builder(setter(into, strip_option), build_fn(validate = "Self::validate"))]
915/// Config struct common to all diffusion methods
916pub struct Config {
917    /// Path to PHOTOMAKER input id images dir
918    #[builder(default = "Default::default()")]
919    pm_id_images_dir: CLibPath,
920
921    /// Path to the input image, required by img2img
922    #[builder(default = "Default::default()")]
923    init_img: PathBuf,
924
925    /// Path to the image used as a mask for img2img
926    #[builder(default = "Default::default()")]
927    mask_img: PathBuf,
928
929    /// Path to image condition, control net
930    #[builder(default = "Default::default()")]
931    control_image: CLibPath,
932
933    /// Paths to reference images for in-context conditioning (e.g. for Flux2)
934    #[builder(default = "Default::default()")]
935    ref_images: Vec<PathBuf>,
936
937    /// Path to write result image to (default: ./output.png)
938    #[builder(default = "PathBuf::from(\"./output.png\")")]
939    output: PathBuf,
940
941    /// Path to write result image to (default: ./output.png)
942    #[builder(default = "PathBuf::from(\"./preview_output.png\")")]
943    preview_output: PathBuf,
944
945    /// Preview method
946    #[builder(default = "PreviewType::PREVIEW_NONE")]
947    preview_mode: PreviewType,
948
949    /// Enables previewing noisy inputs of the models rather than the denoised outputs
950    #[builder(default = "false")]
951    preview_noisy: bool,
952
953    /// Interval in denoising steps between consecutive updates of the image preview file (default is 1, meaning updating at every step)
954    #[builder(default = "1")]
955    preview_interval: i32,
956
957    /// The prompt to render
958    prompt: String,
959
960    /// The negative prompt (default: "")
961    #[builder(default = "\"\".into()")]
962    negative_prompt: CLibString,
963
964    /// Unconditional guidance scale (default: 7.0)
965    #[builder(default = "7.0")]
966    cfg_scale: f32,
967
968    /// Distilled guidance scale for models with guidance input (default: 3.5)
969    #[builder(default = "3.5")]
970    guidance: f32,
971
972    /// Strength for noising/unnoising (default: 0.75)
973    #[builder(default = "0.75")]
974    strength: f32,
975
976    /// Strength for keeping input identity (default: 20%)
977    #[builder(default = "20.0")]
978    pm_style_strength: f32,
979
980    /// Strength to apply Control Net (default: 0.9)
981    /// 1.0 corresponds to full destruction of information in init
982    #[builder(default = "0.9")]
983    control_strength: f32,
984
985    /// Image height, in pixel space (default: 512)
986    #[builder(default = "512")]
987    height: i32,
988
989    /// Image width, in pixel space (default: 512)
990    #[builder(default = "512")]
991    width: i32,
992
993    /// Sampling-method (default: [SampleMethod::SAMPLE_METHOD_COUNT]).
994    /// [SampleMethod::EULER_SAMPLE_METHOD] will be used for flux, sd3, wan, qwen_image.
995    /// Otherwise [SampleMethod::EULER_A_SAMPLE_METHOD] is used.
996    #[builder(default = "SampleMethod::SAMPLE_METHOD_COUNT")]
997    sampling_method: SampleMethod,
998
999    /// eta in DDIM, only for DDIM/TCD/res_multistep/res_2s (default: 0)
1000    #[builder(default = "0.")]
1001    eta: f32,
1002
1003    /// Number of sample steps (default: 20)
1004    #[builder(default = "20")]
1005    steps: i32,
1006
1007    /// RNG seed (default: -1, use random seed for < 0)
1008    #[builder(default = "-1")]
1009    seed: i64,
1010
1011    /// Number of images to generate (default: 1)
1012    #[builder(default = "1")]
1013    batch_count: i32,
1014
1015    /// Ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)
1016    /// <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x
1017    #[builder(default = "ClipSkip::Unspecified")]
1018    clip_skip: ClipSkip,
1019
1020    /// Apply canny preprocessor (edge detection) (default: false)
1021    #[builder(default = "false")]
1022    canny: bool,
1023
1024    /// skip layer guidance (SLG) scale, only for DiT models: (default: 0)
1025    /// 0 means disabled, a value of 2.5 is nice for sd3.5 medium
1026    #[builder(default = "0.")]
1027    slg_scale: f32,
1028
1029    /// Layers to skip for SLG steps: (default: \[7,8,9\])
1030    #[builder(default = "vec![7, 8, 9]")]
1031    skip_layer: Vec<i32>,
1032
1033    /// SLG enabling point: (default: 0.01)
1034    #[builder(default = "0.01")]
1035    skip_layer_start: f32,
1036
1037    /// SLG disabling point: (default: 0.2)
1038    #[builder(default = "0.2")]
1039    skip_layer_end: f32,
1040
1041    /// Disable auto resize of ref images
1042    #[builder(default = "false")]
1043    disable_auto_resize_ref_image: bool,
1044
1045    #[builder(default = "Self::cache_init()", private)]
1046    cache: (sd_cache_params_t, Option<CLibString>),
1047}
1048
1049impl ConfigBuilder {
1050    fn validate(&self) -> Result<(), ConfigBuilderError> {
1051        self.validate_output_dir()
1052    }
1053
1054    fn validate_output_dir(&self) -> Result<(), ConfigBuilderError> {
1055        let is_dir = self.output.as_ref().is_some_and(|val| val.is_dir());
1056        let multiple_items = self.batch_count.as_ref().is_some_and(|val| *val > 1);
1057        if is_dir == multiple_items {
1058            Ok(())
1059        } else {
1060            Err(ConfigBuilderError::ValidationError(
1061                "When batch_count > 1, output should point to folder and vice versa".to_owned(),
1062            ))
1063        }
1064    }
1065
1066    fn cache_init() -> (sd_cache_params_t, Option<CLibString>) {
1067        (
1068            sd_cache_params_t {
1069                mode: sd_cache_mode_t::SD_CACHE_DISABLED,
1070                reuse_threshold: 1.0,
1071                start_percent: 0.15,
1072                end_percent: 0.95,
1073                error_decay_rate: 1.0,
1074                use_relative_threshold: true,
1075                reset_error_on_compute: true,
1076                Fn_compute_blocks: 8,
1077                Bn_compute_blocks: 0,
1078                residual_diff_threshold: 0.08,
1079                max_warmup_steps: 8,
1080                max_cached_steps: -1,
1081                max_continuous_cached_steps: -1,
1082                taylorseer_n_derivatives: 1,
1083                taylorseer_skip_interval: 1,
1084                scm_mask: null(),
1085                scm_policy_dynamic: true,
1086                spectrum_w: 0.4,
1087                spectrum_m: 3,
1088                spectrum_lam: 1.0,
1089                spectrum_window_size: 2,
1090                spectrum_flex_window: 0.5,
1091                spectrum_warmup_steps: 4,
1092                spectrum_stop_percent: 0.9,
1093            },
1094            None,
1095        )
1096    }
1097
1098    pub fn no_caching(&mut self) -> &mut Self {
1099        let mut cache = Self::cache_init();
1100        cache.0.mode = sd_cache_mode_t::SD_CACHE_DISABLED;
1101        self.cache = Some(cache);
1102        self
1103    }
1104
1105    pub fn spectrum_caching(&mut self, params: SpectrumCacheParams) -> &mut Self {
1106        let (mut cache, mask) = Self::cache_init();
1107        cache.mode = sd_cache_mode_t::SD_CACHE_SPECTRUM;
1108        cache.spectrum_w = params.w;
1109        cache.spectrum_m = params.m;
1110        cache.spectrum_lam = params.lam;
1111        cache.spectrum_window_size = params.window;
1112        cache.spectrum_flex_window = params.flex;
1113        cache.spectrum_warmup_steps = params.warmup;
1114        cache.spectrum_stop_percent = params.stop;
1115        self.cache = Some((cache, mask));
1116        self
1117    }
1118
1119    pub fn ucache_caching(&mut self, params: UCacheParams) -> &mut Self {
1120        let (mut cache, mask) = Self::cache_init();
1121        cache.mode = sd_cache_mode_t::SD_CACHE_UCACHE;
1122        cache.reuse_threshold = params.threshold;
1123        cache.start_percent = params.start;
1124        cache.end_percent = params.end;
1125        cache.error_decay_rate = params.decay;
1126        cache.use_relative_threshold = params.relative;
1127        cache.reset_error_on_compute = params.reset;
1128        self.cache = Some((cache, mask));
1129        self
1130    }
1131
1132    pub fn easy_cache_caching(&mut self, params: EasyCacheParams) -> &mut Self {
1133        let (mut cache, mask) = Self::cache_init();
1134        cache.mode = sd_cache_mode_t::SD_CACHE_EASYCACHE;
1135        cache.reuse_threshold = params.threshold;
1136        cache.start_percent = params.start;
1137        cache.end_percent = params.end;
1138        self.cache = Some((cache, mask));
1139        self
1140    }
1141
1142    pub fn db_cache_caching(&mut self, params: DbCacheParams) -> &mut Self {
1143        let (mut cache, _) = Self::cache_init();
1144        cache.mode = sd_cache_mode_t::SD_CACHE_DBCACHE;
1145        cache.Fn_compute_blocks = params.fn_blocks;
1146        cache.Bn_compute_blocks = params.bn_blocks;
1147        cache.residual_diff_threshold = params.threshold;
1148        cache.max_warmup_steps = params.warmup;
1149        cache.scm_policy_dynamic = match params.scm_policy_dynamic {
1150            ScmPolicy::Static => false,
1151            ScmPolicy::Dynamic => true,
1152        };
1153        self.cache = Some((cache, Some(params.scm_mask)));
1154        self
1155    }
1156
1157    pub fn taylor_seer_caching(&mut self) -> &mut Self {
1158        let (mut cache, mask) = Self::cache_init();
1159        cache.mode = sd_cache_mode_t::SD_CACHE_TAYLORSEER;
1160        self.cache = Some((cache, mask));
1161        self
1162    }
1163
1164    pub fn cache_dit_caching(&mut self, params: DbCacheParams) -> &mut Self {
1165        self.db_cache_caching(params).cache.as_mut().unwrap().0.mode =
1166            sd_cache_mode_t::SD_CACHE_CACHE_DIT;
1167        self
1168    }
1169}
1170
1171impl From<&Config> for ConfigBuilder {
1172    fn from(value: &Config) -> Self {
1173        let mut builder = ConfigBuilder::default();
1174        builder
1175            .pm_id_images_dir(value.pm_id_images_dir.clone())
1176            .init_img(value.init_img.clone())
1177            .mask_img(value.mask_img.clone())
1178            .control_image(value.control_image.clone())
1179            .ref_images(value.ref_images.clone())
1180            .output(value.output.clone())
1181            .prompt(value.prompt.clone())
1182            .negative_prompt(value.negative_prompt.clone())
1183            .cfg_scale(value.cfg_scale)
1184            .strength(value.strength)
1185            .pm_style_strength(value.pm_style_strength)
1186            .control_strength(value.control_strength)
1187            .height(value.height)
1188            .width(value.width)
1189            .sampling_method(value.sampling_method)
1190            .steps(value.steps)
1191            .seed(value.seed)
1192            .batch_count(value.batch_count)
1193            .clip_skip(value.clip_skip)
1194            .slg_scale(value.slg_scale)
1195            .skip_layer(value.skip_layer.clone())
1196            .skip_layer_start(value.skip_layer_start)
1197            .skip_layer_end(value.skip_layer_end)
1198            .canny(value.canny)
1199            .disable_auto_resize_ref_image(value.disable_auto_resize_ref_image)
1200            .preview_output(value.preview_output.clone())
1201            .preview_mode(value.preview_mode)
1202            .preview_noisy(value.preview_noisy)
1203            .preview_interval(value.preview_interval)
1204            .cache(value.cache.clone());
1205        builder
1206    }
1207}
1208
1209#[derive(Debug, Clone, Default)]
1210struct CLibString(CString);
1211
1212impl CLibString {
1213    fn as_ptr(&self) -> *const c_char {
1214        self.0.as_ptr()
1215    }
1216}
1217
1218impl From<&str> for CLibString {
1219    fn from(value: &str) -> Self {
1220        Self(CString::new(value).unwrap())
1221    }
1222}
1223
1224impl From<String> for CLibString {
1225    fn from(value: String) -> Self {
1226        Self(CString::new(value).unwrap())
1227    }
1228}
1229
1230#[derive(Debug, Clone, Default)]
1231struct CLibPath(CString);
1232
1233impl CLibPath {
1234    fn as_ptr(&self) -> *const c_char {
1235        self.0.as_ptr()
1236    }
1237}
1238
1239impl From<PathBuf> for CLibPath {
1240    fn from(value: PathBuf) -> Self {
1241        Self(CString::new(value.to_str().unwrap_or_default()).unwrap())
1242    }
1243}
1244
1245impl From<&Path> for CLibPath {
1246    fn from(value: &Path) -> Self {
1247        Self(CString::new(value.to_str().unwrap_or_default()).unwrap())
1248    }
1249}
1250
1251impl From<&CLibPath> for PathBuf {
1252    fn from(value: &CLibPath) -> Self {
1253        PathBuf::from(value.0.to_str().unwrap())
1254    }
1255}
1256
1257fn output_files(path: &Path, batch_size: i32) -> Vec<PathBuf> {
1258    let date = Local::now().format("%Y.%m.%d-%H.%M.%S");
1259    if batch_size == 1 {
1260        vec![path.into()]
1261    } else {
1262        (1..=batch_size)
1263            .map(|id| path.join(format!("output_{date}_{id}.png")))
1264            .collect()
1265    }
1266}
1267
1268unsafe fn upscale(
1269    upscale_repeats: i32,
1270    upscaler_ctx: Option<*mut upscaler_ctx_t>,
1271    data: sd_image_t,
1272) -> Result<sd_image_t, DiffusionError> {
1273    unsafe {
1274        match upscaler_ctx {
1275            Some(upscaler_ctx) => {
1276                let upscale_factor = 4; // unused for RealESRGAN_x4plus_anime_6B.pth
1277                let mut current_image = data;
1278                for _ in 0..upscale_repeats {
1279                    let upscaled_image =
1280                        diffusion_rs_sys::upscale(upscaler_ctx, current_image, upscale_factor);
1281
1282                    if upscaled_image.data.is_null() {
1283                        return Err(DiffusionError::Upscaler);
1284                    }
1285
1286                    free(current_image.data as *mut c_void);
1287                    current_image = upscaled_image;
1288                }
1289                Ok(current_image)
1290            }
1291            None => Ok(data),
1292        }
1293    }
1294}
1295
1296/// Generate an image and receive update via queue
1297pub fn gen_img_with_progress(
1298    config: &Config,
1299    model_config: &mut ModelConfig,
1300    sender: Sender<Progress>,
1301) -> Result<(), DiffusionError> {
1302    gen_img_maybe_progress(config, model_config, Some(sender))
1303}
1304
1305/// Generate an image
1306pub fn gen_img(config: &Config, model_config: &mut ModelConfig) -> Result<(), DiffusionError> {
1307    gen_img_maybe_progress(config, model_config, None)
1308}
1309
1310fn gen_img_maybe_progress(
1311    config: &Config,
1312    model_config: &mut ModelConfig,
1313    mut sender: Option<Sender<Progress>>,
1314) -> Result<(), DiffusionError> {
1315    let prompt: CLibString = CLibString::from(config.prompt.as_str());
1316    let files = output_files(&config.output, config.batch_count);
1317    unsafe {
1318        let has_init_image = config.init_img.exists();
1319        let has_mask_image = config.mask_img.exists();
1320
1321        let sd_ctx = model_config.diffusion_ctx();
1322        let upscaler_ctx = model_config.upscaler_ctx();
1323
1324        let mut init_image = sd_image_t {
1325            width: 0,
1326            height: 0,
1327            channel: 3,
1328            data: std::ptr::null_mut(),
1329        };
1330        let mut mask_image = sd_image_t {
1331            width: config.width as u32,
1332            height: config.height as u32,
1333            channel: 1,
1334            data: null_mut(),
1335        };
1336        let mut layers = config.skip_layer.clone();
1337        let guidance = sd_guidance_params_t {
1338            txt_cfg: config.cfg_scale,
1339            img_cfg: config.cfg_scale,
1340            distilled_guidance: config.guidance,
1341            slg: sd_slg_params_t {
1342                layers: layers.as_mut_ptr(),
1343                layer_count: config.skip_layer.len(),
1344                layer_start: config.skip_layer_start,
1345                layer_end: config.skip_layer_end,
1346                scale: config.slg_scale,
1347            },
1348        };
1349        let scheduler = if model_config.scheduler == Scheduler::SCHEDULER_COUNT {
1350            sd_get_default_scheduler(sd_ctx, config.sampling_method)
1351        } else {
1352            model_config.scheduler
1353        };
1354        let sample_method = if config.sampling_method == SampleMethod::SAMPLE_METHOD_COUNT {
1355            sd_get_default_sample_method(sd_ctx)
1356        } else {
1357            config.sampling_method
1358        };
1359        let sample_params = sd_sample_params_t {
1360            guidance,
1361            sample_method,
1362            sample_steps: config.steps,
1363            eta: config.eta,
1364            scheduler,
1365            shifted_timestep: model_config.timestep_shift,
1366            custom_sigmas: model_config.sigmas.as_mut_ptr(),
1367            custom_sigmas_count: model_config.sigmas.len() as i32,
1368            flow_shift: model_config.flow_shift,
1369            extra_sample_args: model_config.extra_sample_params.as_ptr(),
1370        };
1371        let control_image = sd_image_t {
1372            width: 0,
1373            height: 0,
1374            channel: 3,
1375            data: null_mut(),
1376        };
1377        let vae_tiling_params = sd_tiling_params_t {
1378            enabled: model_config.vae_tiling,
1379            tile_size_x: model_config.vae_tile_size.0,
1380            tile_size_y: model_config.vae_tile_size.1,
1381            target_overlap: model_config.vae_tile_overlap,
1382            rel_size_x: model_config.vae_relative_tile_size.0,
1383            rel_size_y: model_config.vae_relative_tile_size.1,
1384            temporal_tiling: model_config.vae_temporal_tiling,
1385            extra_tiling_args: model_config.extra_tiling_args.1.as_ptr(),
1386        };
1387        let pm_params = sd_pm_params_t {
1388            id_images: null_mut(),
1389            id_images_count: 0,
1390            id_embed_path: model_config.pm_id_embed_path.as_ptr(),
1391            style_strength: config.pm_style_strength,
1392        };
1393
1394        // Declare the buffer in the function scope so it outlives the match block
1395        let mut image_buffer: Vec<u8> = Vec::new();
1396        let mut mask_buffer: Vec<u8> = Vec::new();
1397
1398        if has_init_image {
1399            let img = image::open(&config.init_img)?;
1400            image_buffer = img.to_rgb8().into_raw();
1401
1402            init_image = sd_image_t {
1403                width: img.width(),
1404                height: img.height(),
1405                channel: 3,
1406                data: image_buffer.as_mut_ptr(),
1407            }
1408        }
1409
1410        if has_mask_image {
1411            let img = image::open(&config.mask_img)?;
1412            // Masks have to have single channel luminosity information only
1413            mask_buffer = img.to_luma8().into_raw();
1414
1415            mask_image = sd_image_t {
1416                width: img.width(),
1417                height: img.height(),
1418                channel: 1,
1419                data: mask_buffer.as_mut_ptr(),
1420            }
1421        }
1422
1423        // stable-diffusion.cpp assumes that a mask is also included with the img2img flow
1424        // if a mask is not provided, we use a flat white mask meaning all of the image is in scope
1425        // otherwise generate_image throws a sigsegv when it tries to assign the mask
1426        if !image_buffer.is_empty() && mask_buffer.is_empty() {
1427            let img: ImageBuffer<image::Luma<u8>, Vec<u8>> =
1428                ImageBuffer::from_pixel(init_image.width, init_image.height, image::Luma([255]));
1429            mask_buffer = img.into_raw();
1430            mask_image = sd_image_t {
1431                width: init_image.width,
1432                height: init_image.height,
1433                channel: 1,
1434                data: mask_buffer.as_mut_ptr(),
1435            }
1436        }
1437
1438        let mut ref_image_list = Vec::new();
1439        let mut ref_pixel_storage = Vec::new();
1440        for ref_path in &config.ref_images {
1441            if ref_path.exists() {
1442                let img = image::open(ref_path)?;
1443                let image_data = img.to_rgb8().into_raw();
1444
1445                ref_pixel_storage.push(image_data);
1446                let storage_ref = ref_pixel_storage.last_mut().unwrap();
1447                ref_image_list.push(sd_image_t {
1448                    width: img.width(),
1449                    height: img.height(),
1450                    channel: 3,
1451                    data: storage_ref.as_mut_ptr(),
1452                });
1453            }
1454        }
1455
1456        let num_ref_images = ref_image_list.len();
1457        let ref_image_ptr = if num_ref_images > 0 {
1458            ref_image_list.as_mut_ptr()
1459        } else {
1460            null_mut()
1461        };
1462
1463        unsafe extern "C" fn save_preview_local(
1464            _step: ::std::os::raw::c_int,
1465            _frame_count: ::std::os::raw::c_int,
1466            frames: *mut sd_image_t,
1467            _is_noisy: bool,
1468            data: *mut ::std::os::raw::c_void,
1469        ) {
1470            unsafe {
1471                let path = &*data.cast::<PathBuf>();
1472                let _ = save_img(*frames, path, None);
1473            }
1474        }
1475
1476        if config.preview_mode != PreviewType::PREVIEW_NONE {
1477            let data = &config.preview_output as *const PathBuf;
1478
1479            sd_set_preview_callback(
1480                Some(save_preview_local),
1481                config.preview_mode,
1482                config.preview_interval,
1483                !config.preview_noisy,
1484                config.preview_noisy,
1485                data as *mut c_void,
1486            );
1487        }
1488
1489        if sender.is_some() {
1490            unsafe extern "C" fn progress_callback(
1491                step: ::std::os::raw::c_int,
1492                steps: ::std::os::raw::c_int,
1493                time: f32,
1494                data: *mut ::std::os::raw::c_void,
1495            ) {
1496                unsafe {
1497                    let sender = &*data.cast::<Option<Sender<Progress>>>();
1498
1499                    if let Some(sender) = sender {
1500                        let _ = sender.send(Progress { step, steps, time });
1501                    }
1502                }
1503            }
1504            let sender_ptr: *mut c_void = &mut sender as *mut _ as *mut c_void;
1505            sd_set_progress_callback(Some(progress_callback), sender_ptr);
1506        }
1507
1508        let loras: Vec<sd_lora_t> = model_config
1509            .lora_models
1510            .iter()
1511            .map(|(c_path, spec)| sd_lora_t {
1512                is_high_noise: spec.is_high_noise,
1513                multiplier: spec.multiplier,
1514                path: c_path.as_ptr(),
1515            })
1516            .collect();
1517
1518        let mut cache = config.cache.0;
1519        if let Some(scm_mask) = &config.cache.1 {
1520            cache.scm_mask = scm_mask.as_ptr();
1521        }
1522
1523        let mut hires_path = null();
1524        let mut hires_sigmas = null_mut();
1525        let mut hires_sigmas_count = 0;
1526        if let Some(path) = &model_config.hires_params.2 {
1527            hires_path = path.as_ptr();
1528        }
1529        if let Some(sigmas) = &mut model_config.hires_params.1.hires_sigmas {
1530            hires_sigmas = sigmas.as_mut_ptr();
1531            hires_sigmas_count = sigmas.len() as i32;
1532        }
1533
1534        let hires = sd_hires_params_t {
1535            enabled: model_config.hires_params.0 != Upscaler::SD_HIRES_UPSCALER_NONE,
1536            upscaler: model_config.hires_params.0,
1537            model_path: hires_path,
1538            scale: model_config.hires_params.1.scale,
1539            target_width: model_config.hires_params.1.width,
1540            target_height: model_config.hires_params.1.height,
1541            steps: model_config.hires_params.1.steps,
1542            denoising_strength: model_config.hires_params.1.denoising_strength,
1543            upscale_tile_size: model_config.hires_params.1.upscale_tile_size,
1544            custom_sigmas: hires_sigmas,
1545            custom_sigmas_count: hires_sigmas_count,
1546        };
1547
1548        let sd_img_gen_params = sd_img_gen_params_t {
1549            prompt: prompt.as_ptr(),
1550            negative_prompt: config.negative_prompt.as_ptr(),
1551            clip_skip: config.clip_skip as i32,
1552            init_image,
1553            ref_images: ref_image_ptr,
1554            ref_images_count: num_ref_images as i32,
1555            increase_ref_index: false,
1556            mask_image,
1557            width: config.width,
1558            height: config.height,
1559            sample_params,
1560            strength: config.strength,
1561            seed: config.seed,
1562            batch_count: config.batch_count,
1563            control_image,
1564            control_strength: config.control_strength,
1565            pm_params,
1566            vae_tiling_params,
1567            auto_resize_ref_image: config.disable_auto_resize_ref_image,
1568            cache,
1569            loras: loras.as_ptr(),
1570            lora_count: loras.len() as u32,
1571            hires,
1572        };
1573
1574        let params_str = CString::from_raw(sd_img_gen_params_to_str(&sd_img_gen_params))
1575            .into_string()
1576            .unwrap();
1577
1578        let slice = generate_image(sd_ctx, &sd_img_gen_params);
1579        let ret = {
1580            if slice.is_null() {
1581                return Err(DiffusionError::Forward);
1582            }
1583            for (img, path) in slice::from_raw_parts(slice, config.batch_count as usize)
1584                .iter()
1585                .zip(files)
1586            {
1587                // img.data will be null on OOM or other generation errors,
1588                // in which case we skip saving and just return an error
1589                if img.data.is_null() {
1590                    return Err(DiffusionError::Forward);
1591                }
1592                match upscale(model_config.upscale_repeats, upscaler_ctx, *img) {
1593                    Ok(img) => save_img(img, &path, Some(&params_str))?,
1594                    Err(err) => {
1595                        return Err(err);
1596                    }
1597                }
1598            }
1599            Ok(())
1600        };
1601        free(slice as *mut c_void);
1602        ret
1603    }
1604}
1605
1606fn save_img(img: sd_image_t, path: &Path, params: Option<&str>) -> Result<(), DiffusionError> {
1607    // Thx @wandbrandon
1608    let len = (img.width * img.height * img.channel) as usize;
1609    let buffer = unsafe { slice::from_raw_parts(img.data, len).to_vec() };
1610    let save_state = ImageBuffer::from_raw(img.width, img.height, buffer).map(|img| {
1611        RgbImage::from(img)
1612            .save(path)
1613            .map_err(DiffusionError::StoreImages)
1614    });
1615    if let Some(Err(err)) = save_state {
1616        return Err(err);
1617    }
1618    if let Some(params) = params {
1619        let mut metadata = Metadata::new();
1620        metadata.set_tag(ExifTag::ImageDescription(params.to_string()));
1621        metadata.write_to_file(path)?;
1622    }
1623    Ok(())
1624}
1625
1626#[cfg(test)]
1627mod tests {
1628    use image::{DynamicImage, ImageBuffer, Rgba};
1629    use std::path::PathBuf;
1630
1631    use crate::{
1632        api::{ConfigBuilderError, ModelConfigBuilder},
1633        util::download_file_hf_hub,
1634    };
1635
1636    use super::{ConfigBuilder, gen_img};
1637
1638    #[test]
1639    fn test_required_args_txt2img() {
1640        assert!(ConfigBuilder::default().build().is_err());
1641        assert!(ModelConfigBuilder::default().build().is_err());
1642        ModelConfigBuilder::default()
1643            .model(PathBuf::from("./test.ckpt"))
1644            .build()
1645            .unwrap();
1646
1647        ConfigBuilder::default()
1648            .prompt("a lovely cat driving a sport car")
1649            .build()
1650            .unwrap();
1651
1652        assert!(matches!(
1653            ConfigBuilder::default()
1654                .prompt("a lovely cat driving a sport car")
1655                .batch_count(10)
1656                .build(),
1657            Err(ConfigBuilderError::ValidationError(_))
1658        ));
1659
1660        ConfigBuilder::default()
1661            .prompt("a lovely cat driving a sport car")
1662            .build()
1663            .unwrap();
1664
1665        ConfigBuilder::default()
1666            .prompt("a lovely duck drinking water from a bottle")
1667            .batch_count(2)
1668            .output(PathBuf::from("./"))
1669            .build()
1670            .unwrap();
1671    }
1672
1673    #[ignore]
1674    #[test]
1675    fn test_img2img_gen() {
1676        let model_path =
1677            download_file_hf_hub("CompVis/stable-diffusion-v-1-4-original", "sd-v1-4.ckpt")
1678                .unwrap();
1679        let gen_img_output = "./output_img.png";
1680        let config = ConfigBuilder::default()
1681            .prompt("A high quality 3d texture")
1682            .output(PathBuf::from(gen_img_output))
1683            .batch_count(1)
1684            .build()
1685            .unwrap();
1686
1687        let mut model_config = ModelConfigBuilder::default()
1688            .model(model_path)
1689            .build()
1690            .unwrap();
1691
1692        gen_img(&config, &mut model_config).unwrap();
1693
1694        // 2. Create conditioning image: Gradient square
1695        let mut cond = ImageBuffer::new(512, 512);
1696        for (x, y, pixel) in cond.enumerate_pixels_mut() {
1697            let r = (x as f32 / 512.0 * 255.0) as u8;
1698            let g = (y as f32 / 512.0 * 255.0) as u8;
1699            let b = 127;
1700            *pixel = Rgba([r, g, b, 255]);
1701        }
1702        let cond_path = "test_cond_image.png";
1703        DynamicImage::ImageRgba8(cond)
1704            .save(cond_path)
1705            .expect("Failed to save reference image");
1706
1707        // 3. Call refinement using the generated image as input
1708        let refine_prompt = "PBR texture map, matching the lighting and micro-detail density of the reference image.";
1709        let img2img_config = ConfigBuilder::default()
1710            .prompt(refine_prompt)
1711            .output(PathBuf::from("./output_img_ref.png"))
1712            .ref_images(vec![PathBuf::from(cond_path)])
1713            .init_img(PathBuf::from(gen_img_output))
1714            .batch_count(1)
1715            .build()
1716            .unwrap();
1717        gen_img(&img2img_config, &mut model_config).unwrap();
1718
1719        // 4. Ensure decoder only mode works after img2img generation
1720        gen_img(&config, &mut model_config).unwrap();
1721    }
1722
1723    #[ignore]
1724    #[test]
1725    fn test_img_gen() {
1726        let model_path =
1727            download_file_hf_hub("CompVis/stable-diffusion-v-1-4-original", "sd-v1-4.ckpt")
1728                .unwrap();
1729
1730        let upscaler_path = download_file_hf_hub(
1731            "ximso/RealESRGAN_x4plus_anime_6B",
1732            "RealESRGAN_x4plus_anime_6B.pth",
1733        )
1734        .unwrap();
1735        let config = ConfigBuilder::default()
1736            .prompt("a lovely duck drinking water from a bottle")
1737            .output(PathBuf::from("./output_1.png"))
1738            .batch_count(1)
1739            .build()
1740            .unwrap();
1741        let mut model_config = ModelConfigBuilder::default()
1742            .model(model_path)
1743            .upscale_model(upscaler_path)
1744            .upscale_repeats(1)
1745            .build()
1746            .unwrap();
1747
1748        gen_img(&config, &mut model_config).unwrap();
1749        let config2 = ConfigBuilder::from(&config)
1750            .prompt("a lovely duck drinking water from a straw")
1751            .output(PathBuf::from("./output_2.png"))
1752            .build()
1753            .unwrap();
1754        gen_img(&config2, &mut model_config).unwrap();
1755
1756        let config3 = ConfigBuilder::from(&config)
1757            .prompt("a lovely dog drinking water from a starbucks cup")
1758            .batch_count(2)
1759            .output(PathBuf::from("./"))
1760            .build()
1761            .unwrap();
1762
1763        gen_img(&config3, &mut model_config).unwrap();
1764    }
1765}