diffusion_rs/
api.rs

1use std::collections::HashMap;
2use std::ffi::CString;
3use std::ffi::c_char;
4use std::ffi::c_void;
5use std::path::Path;
6use std::path::PathBuf;
7use std::ptr::null_mut;
8use std::slice;
9
10use derive_builder::Builder;
11use diffusion_rs_sys::free_upscaler_ctx;
12use diffusion_rs_sys::new_upscaler_ctx;
13use diffusion_rs_sys::sd_ctx_params_t;
14use diffusion_rs_sys::sd_easycache_params_t;
15use diffusion_rs_sys::sd_embedding_t;
16use diffusion_rs_sys::sd_get_default_sample_method;
17use diffusion_rs_sys::sd_get_default_scheduler;
18use diffusion_rs_sys::sd_guidance_params_t;
19use diffusion_rs_sys::sd_image_t;
20use diffusion_rs_sys::sd_img_gen_params_t;
21use diffusion_rs_sys::sd_lora_t;
22use diffusion_rs_sys::sd_pm_params_t;
23use diffusion_rs_sys::sd_sample_params_t;
24use diffusion_rs_sys::sd_set_preview_callback;
25use diffusion_rs_sys::sd_slg_params_t;
26use diffusion_rs_sys::sd_tiling_params_t;
27use diffusion_rs_sys::upscaler_ctx_t;
28use image::ImageBuffer;
29use image::ImageError;
30use image::RgbImage;
31use libc::free;
32use thiserror::Error;
33use walkdir::DirEntry;
34use walkdir::WalkDir;
35
36use diffusion_rs_sys::free_sd_ctx;
37use diffusion_rs_sys::new_sd_ctx;
38use diffusion_rs_sys::sd_ctx_t;
39
40/// Specify the range function
41pub use diffusion_rs_sys::rng_type_t as RngFunction;
42
43/// Sampling methods
44pub use diffusion_rs_sys::sample_method_t as SampleMethod;
45
46/// Denoiser sigma schedule
47pub use diffusion_rs_sys::scheduler_t as Scheduler;
48
49/// Prediction override
50pub use diffusion_rs_sys::prediction_t as Prediction;
51
52/// Weight type
53pub use diffusion_rs_sys::sd_type_t as WeightType;
54
55/// Preview mode
56pub use diffusion_rs_sys::preview_t as PreviewType;
57
58/// Lora mode
59pub use diffusion_rs_sys::lora_apply_mode_t as LoraModeType;
60
61static VALID_EXT: [&str; 3] = ["pt", "safetensors", "gguf"];
62
63#[non_exhaustive]
64#[derive(Error, Debug)]
65/// Error that can occurs while forwarding models
66pub enum DiffusionError {
67    #[error("The underling stablediffusion.cpp function returned NULL")]
68    Forward,
69    #[error(transparent)]
70    StoreImages(#[from] ImageError),
71    #[error("The underling upscaler model returned a NULL image")]
72    Upscaler,
73}
74
75#[repr(i32)]
76#[non_exhaustive]
77#[derive(Debug, Default, Copy, Clone, Hash, PartialEq, Eq)]
78/// Ignore the lower X layers of CLIP network
79pub enum ClipSkip {
80    /// Will be [ClipSkip::None] for SD1.x, [ClipSkip::OneLayer] for SD2.x
81    #[default]
82    Unspecified = 0,
83    None = 1,
84    OneLayer = 2,
85}
86
87type EmbeddingsStorage = (PathBuf, Vec<(CLibString, CLibPath)>, Vec<sd_embedding_t>);
88
89#[derive(Default, Debug, Clone)]
90struct LoraStorage {
91    lora_model_dir: CLibPath,
92    data: Vec<(CLibPath, String, f32)>,
93    loras_t: Vec<sd_lora_t>,
94}
95
96/// Specify the instructions for a Lora model
97#[derive(Default, Debug, Clone)]
98pub struct LoraSpec {
99    pub file_name: String,
100    pub is_high_noise: bool,
101    pub multiplier: f32,
102}
103
104#[derive(Builder, Debug, Clone)]
105#[builder(
106    setter(into, strip_option),
107    build_fn(error = "ConfigBuilderError", validate = "Self::validate")
108)]
109pub struct ModelConfig {
110    /// Number of threads to use during computation (default: 0).
111    /// If n_ threads <= 0, then threads will be set to the number of CPU physical cores.
112    #[builder(default = "num_cpus::get_physical() as i32", setter(custom))]
113    n_threads: i32,
114
115    /// Place the weights in RAM to save VRAM, and automatically load them into VRAM when needed
116    #[builder(default = "false")]
117    offload_params_to_cpu: bool,
118
119    /// Path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now
120    #[builder(default = "Default::default()")]
121    upscale_model: Option<CLibPath>,
122
123    /// Run the ESRGAN upscaler this many times (default 1)
124    #[builder(default = "1")]
125    upscale_repeats: i32,
126
127    /// Tile size for ESRGAN upscaler (default 128)
128    #[builder(default = "128")]
129    upscale_tile_size: i32,
130
131    /// Path to full model
132    #[builder(default = "Default::default()")]
133    model: CLibPath,
134
135    /// Path to the standalone diffusion model
136    #[builder(default = "Default::default()")]
137    diffusion_model: CLibPath,
138
139    /// Path to the qwen2vl text encoder
140    #[builder(default = "Default::default()")]
141    llm: CLibPath,
142
143    /// Path to the qwen2vl vit
144    #[builder(default = "Default::default()")]
145    llm_vision: CLibPath,
146
147    /// Path to the clip-l text encoder
148    #[builder(default = "Default::default()")]
149    clip_l: CLibPath,
150
151    /// Path to the clip-g text encoder
152    #[builder(default = "Default::default()")]
153    clip_g: CLibPath,
154
155    /// Path to the clip-vision encoder
156    #[builder(default = "Default::default()")]
157    clip_vision: CLibPath,
158
159    /// Path to the t5xxl text encoder
160    #[builder(default = "Default::default()")]
161    t5xxl: CLibPath,
162
163    /// Path to vae
164    #[builder(default = "Default::default()")]
165    vae: CLibPath,
166
167    /// Path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)
168    #[builder(default = "Default::default()")]
169    taesd: CLibPath,
170
171    /// Path to control net model
172    #[builder(default = "Default::default()")]
173    control_net: CLibPath,
174
175    /// Path to embeddings
176    #[builder(default = "Default::default()", setter(custom))]
177    embeddings: EmbeddingsStorage,
178
179    /// Path to PHOTOMAKER model
180    #[builder(default = "Default::default()")]
181    photo_maker: CLibPath,
182
183    /// Path to PHOTOMAKER v2 id embed
184    #[builder(default = "Default::default()")]
185    pm_id_embed_path: CLibPath,
186
187    /// Weight type. If not specified, the default is the type of the weight file
188    #[builder(default = "WeightType::SD_TYPE_COUNT")]
189    weight_type: WeightType,
190
191    /// Lora model directory
192    #[builder(default = "Default::default()", setter(custom))]
193    lora_models: LoraStorage,
194
195    /// Path to the standalone high noise diffusion model
196    #[builder(default = "Default::default()")]
197    high_noise_diffusion_model: CLibPath,
198
199    /// Process vae in tiles to reduce memory usage (default: false)
200    #[builder(default = "false")]
201    vae_tiling: bool,
202
203    /// Tile size for vae tiling (default: 32x32)
204    #[builder(default = "(32,32)")]
205    vae_tile_size: (i32, i32),
206
207    /// Relative tile size for vae tiling, in fraction of image size if < 1, in number of tiles per dim if >=1 (overrides vae_tile_size)
208    #[builder(default = "(0.,0.)")]
209    vae_relative_tile_size: (f32, f32),
210
211    /// Tile overlap for vae tiling, in fraction of tile size (default: 0.5)
212    #[builder(default = "0.5")]
213    vae_tile_overlap: f32,
214
215    /// RNG (default: CUDA)
216    #[builder(default = "RngFunction::CUDA_RNG")]
217    rng: RngFunction,
218
219    /// Sampler RNG. If [RngFunction::RNG_TYPE_COUNT] is used will default to rng value. (default: [RngFunction::RNG_TYPE_COUNT])",
220    #[builder(default = "RngFunction::RNG_TYPE_COUNT")]
221    sampler_rng_type: RngFunction,
222
223    /// Denoiser sigma schedule (default: [Scheduler::SCHEDULER_COUNT]).
224    /// Will default to [Scheduler::EXPONENTIAL_SCHEDULER] if a denoiser is already instantiated.
225    /// Otherwise, [Scheduler::DISCRETE_SCHEDULER] is used.
226    #[builder(default = "Scheduler::SCHEDULER_COUNT")]
227    scheduler: Scheduler,
228
229    /// Prediction type override (default: PREDICTION_COUNT)
230    #[builder(default = "Prediction::PREDICTION_COUNT")]
231    prediction: Prediction,
232
233    /// Keep vae in cpu (for low vram) (default: false)
234    #[builder(default = "false")]
235    vae_on_cpu: bool,
236
237    /// keep clip in cpu (for low vram) (default: false)
238    #[builder(default = "false")]
239    clip_on_cpu: bool,
240
241    /// Keep controlnet in cpu (for low vram) (default: false)
242    #[builder(default = "false")]
243    control_net_cpu: bool,
244
245    /// Use flash attention to reduce memory usage (for low vram).
246    // /For most backends, it slows things down, but for cuda it generally speeds it up too. At the moment, it is only supported for some models and some backends (like cpu, cuda/rocm, metal).
247    #[builder(default = "false")]
248    flash_attention: bool,
249
250    /// Disable dit mask for chroma
251    #[builder(default = "false")]
252    chroma_disable_dit_mask: bool,
253
254    /// Enable t5 mask for chroma
255    #[builder(default = "false")]
256    chroma_enable_t5_mask: bool,
257
258    /// t5 mask pad size of chroma
259    #[builder(default = "1")]
260    chroma_t5_mask_pad: i32,
261
262    /// Use Conv2d direct in the diffusion model
263    /// This might crash if it is not supported by the backend.
264    #[builder(default = "false")]
265    diffusion_conv_direct: bool,
266
267    /// Use Conv2d direct in the vae model (should improve the performance)
268    /// This might crash if it is not supported by the backend.
269    #[builder(default = "false")]
270    vae_conv_direct: bool,
271
272    /// Force use of conv scale on sdxl vae
273    #[builder(default = "false")]
274    force_sdxl_vae_conv_scale: bool,
275
276    /// Shift value for Flow models like SD3.x or WAN (default: auto)
277    #[builder(default = "f32::INFINITY")]
278    flow_shift: f32,
279
280    /// Shift timestep for NitroFusion models, default: 0, recommended N for NitroSD-Realism around 250 and 500 for NitroSD-Vibrant
281    #[builder(default = "0")]
282    timestep_shift: i32,
283
284    /// Prevents usage of taesd for decoding the final image
285    #[builder(default = "false")]
286    taesd_preview_only: bool,
287
288    /// In auto mode, if the model weights contain any quantized parameters, the at_runtime mode will be used; otherwise, immediately will be used.The immediately mode may have precision and compatibility issues with quantized parameters, but it usually offers faster inference speed and, in some cases, lower memory usage. The at_runtime mode, on the other hand, is exactly the opposite
289    #[builder(default = "LoraModeType::LORA_APPLY_AUTO")]
290    lora_apply_mode: LoraModeType,
291
292    /// Enable easycache to achieve speedup (default: false)
293    #[builder(default = "false")]
294    easy_cache: bool,
295
296    /// Easycache reuse threashold (default: 0.2)
297    #[builder(default = "0.2")]
298    easy_cache_reuse_threshold: f32,
299
300    /// Easycache start percent (default: 0.15)
301    #[builder(default = "0.15")]
302    easy_cache_start_percent: f32,
303
304    /// Easycache end percent (default: 0.95)
305    #[builder(default = "0.95")]
306    easy_cache_end_percent: f32,
307
308    #[builder(default = "None", private)]
309    upscaler_ctx: Option<*mut upscaler_ctx_t>,
310
311    #[builder(default = "None", private)]
312    diffusion_ctx: Option<(*mut sd_ctx_t, sd_ctx_params_t)>,
313}
314
315impl ModelConfigBuilder {
316    fn validate(&self) -> Result<(), ConfigBuilderError> {
317        self.validate_model()
318    }
319
320    fn validate_model(&self) -> Result<(), ConfigBuilderError> {
321        self.model
322            .as_ref()
323            .or(self.diffusion_model.as_ref())
324            .map(|_| ())
325            .ok_or(ConfigBuilderError::UninitializedField(
326                "Model OR DiffusionModel must be valorized",
327            ))
328    }
329
330    fn filter_valid_extensions(&self, path: &Path) -> impl Iterator<Item = DirEntry> {
331        WalkDir::new(path)
332            .into_iter()
333            .filter_map(|entry| entry.ok())
334            .filter(|entry| {
335                entry
336                    .path()
337                    .extension()
338                    .and_then(|ext| ext.to_str())
339                    .map(|ext_str| VALID_EXT.contains(&ext_str))
340                    .unwrap_or(false)
341            })
342    }
343
344    fn build_single_lora_storage(
345        spec: &LoraSpec,
346        is_high_noise: bool,
347        valid_loras: &HashMap<String, PathBuf>,
348    ) -> ((CLibPath, String, f32), sd_lora_t) {
349        let path = valid_loras.get(&spec.file_name).unwrap().as_path();
350        let c_path = CLibPath::from(path);
351        let lora = sd_lora_t {
352            is_high_noise,
353            multiplier: spec.multiplier,
354            path: c_path.as_ptr(),
355        };
356        let data = (c_path, spec.file_name.clone(), spec.multiplier);
357        (data, lora)
358    }
359
360    pub fn embeddings(&mut self, embeddings_dir: &Path) -> &mut Self {
361        let data: Vec<(CLibString, CLibPath)> = self
362            .filter_valid_extensions(embeddings_dir)
363            .map(|entry| {
364                let file_stem = entry
365                    .path()
366                    .file_stem()
367                    .and_then(|stem| stem.to_str())
368                    .unwrap_or_default()
369                    .to_owned();
370                (CLibString::from(file_stem), CLibPath::from(entry.path()))
371            })
372            .collect();
373        let data_pointer = data
374            .iter()
375            .map(|(name, path)| sd_embedding_t {
376                name: name.as_ptr(),
377                path: path.as_ptr(),
378            })
379            .collect();
380        self.embeddings = Some((embeddings_dir.to_path_buf(), data, data_pointer));
381        self
382    }
383
384    pub fn lora_models(&mut self, lora_model_dir: &Path, specs: Vec<LoraSpec>) -> &mut Self {
385        let valid_loras: HashMap<String, PathBuf> = self
386            .filter_valid_extensions(lora_model_dir)
387            .map(|entry| {
388                let path = entry.path();
389                (
390                    path.file_stem()
391                        .and_then(|stem| stem.to_str())
392                        .unwrap_or_default()
393                        .to_owned(),
394                    path.to_path_buf(),
395                )
396            })
397            .collect();
398        let valid_lora_names: Vec<&String> = valid_loras.keys().collect();
399        let standard = specs
400            .iter()
401            .filter(|s| valid_lora_names.contains(&&s.file_name) && !s.is_high_noise)
402            .map(|s| Self::build_single_lora_storage(s, false, &valid_loras));
403        let high_noise = specs
404            .iter()
405            .filter(|s| valid_lora_names.contains(&&s.file_name) && s.is_high_noise)
406            .map(|s| Self::build_single_lora_storage(s, true, &valid_loras));
407
408        let mut data = Vec::new();
409        let mut loras_t = Vec::new();
410        for lora in standard.chain(high_noise) {
411            data.push(lora.0);
412            loras_t.push(lora.1);
413        }
414
415        self.lora_models = Some(LoraStorage {
416            lora_model_dir: lora_model_dir.into(),
417            data,
418            loras_t,
419        });
420        self
421    }
422
423    pub fn n_threads(&mut self, value: i32) -> &mut Self {
424        self.n_threads = if value > 0 {
425            Some(value)
426        } else {
427            Some(num_cpus::get_physical() as i32)
428        };
429        self
430    }
431}
432
433impl ModelConfig {
434    unsafe fn upscaler_ctx(&mut self) -> Option<*mut upscaler_ctx_t> {
435        unsafe {
436            if self.upscale_model.is_none() || self.upscale_repeats == 0 {
437                None
438            } else {
439                if self.upscaler_ctx.is_none() {
440                    let upscaler = new_upscaler_ctx(
441                        self.upscale_model.as_ref().unwrap().as_ptr(),
442                        self.offload_params_to_cpu,
443                        self.diffusion_conv_direct,
444                        self.n_threads,
445                        self.upscale_tile_size,
446                    );
447                    self.upscaler_ctx = Some(upscaler);
448                }
449                self.upscaler_ctx
450            }
451        }
452    }
453
454    unsafe fn diffusion_ctx(&mut self, vae_decode_only: bool) -> *mut sd_ctx_t {
455        unsafe {
456            if self.diffusion_ctx.is_none() {
457                let sd_ctx_params = sd_ctx_params_t {
458                    model_path: self.model.as_ptr(),
459                    llm_path: self.llm.as_ptr(),
460                    llm_vision_path: self.llm_vision.as_ptr(),
461                    clip_l_path: self.clip_l.as_ptr(),
462                    clip_g_path: self.clip_g.as_ptr(),
463                    clip_vision_path: self.clip_vision.as_ptr(),
464                    high_noise_diffusion_model_path: self.high_noise_diffusion_model.as_ptr(),
465                    t5xxl_path: self.t5xxl.as_ptr(),
466                    diffusion_model_path: self.diffusion_model.as_ptr(),
467                    vae_path: self.vae.as_ptr(),
468                    taesd_path: self.taesd.as_ptr(),
469                    control_net_path: self.control_net.as_ptr(),
470                    lora_model_dir: self.lora_models.lora_model_dir.as_ptr(),
471                    embeddings: self.embeddings.2.as_ptr(),
472                    embedding_count: self.embeddings.1.len() as u32,
473                    photo_maker_path: self.photo_maker.as_ptr(),
474                    vae_decode_only,
475                    free_params_immediately: false,
476                    n_threads: self.n_threads,
477                    wtype: self.weight_type,
478                    rng_type: self.rng,
479                    keep_clip_on_cpu: self.clip_on_cpu,
480                    keep_control_net_on_cpu: self.control_net_cpu,
481                    keep_vae_on_cpu: self.vae_on_cpu,
482                    diffusion_flash_attn: self.flash_attention,
483                    diffusion_conv_direct: self.diffusion_conv_direct,
484                    chroma_use_dit_mask: !self.chroma_disable_dit_mask,
485                    chroma_use_t5_mask: self.chroma_enable_t5_mask,
486                    chroma_t5_mask_pad: self.chroma_t5_mask_pad,
487                    vae_conv_direct: self.vae_conv_direct,
488                    offload_params_to_cpu: self.offload_params_to_cpu,
489                    flow_shift: self.flow_shift,
490                    prediction: self.prediction,
491                    force_sdxl_vae_conv_scale: self.force_sdxl_vae_conv_scale,
492                    tae_preview_only: self.taesd_preview_only,
493                    lora_apply_mode: self.lora_apply_mode,
494                    tensor_type_rules: null_mut(),
495                    sampler_rng_type: self.sampler_rng_type,
496                };
497                let ctx = new_sd_ctx(&sd_ctx_params);
498                self.diffusion_ctx = Some((ctx, sd_ctx_params))
499            }
500            self.diffusion_ctx.unwrap().0
501        }
502    }
503}
504
505impl Drop for ModelConfig {
506    fn drop(&mut self) {
507        //Clean-up CTX section
508        unsafe {
509            if let Some((sd_ctx, _)) = self.diffusion_ctx {
510                free_sd_ctx(sd_ctx);
511            }
512
513            if let Some(upscaler_ctx) = self.upscaler_ctx {
514                free_upscaler_ctx(upscaler_ctx);
515            }
516        }
517    }
518}
519
520impl From<ModelConfig> for ModelConfigBuilder {
521    fn from(value: ModelConfig) -> Self {
522        let mut builder = ModelConfigBuilder::default();
523        builder
524            .n_threads(value.n_threads)
525            .offload_params_to_cpu(value.offload_params_to_cpu)
526            .upscale_repeats(value.upscale_repeats)
527            .model(value.model.clone())
528            .diffusion_model(value.diffusion_model.clone())
529            .llm(value.llm.clone())
530            .llm_vision(value.llm_vision.clone())
531            .clip_l(value.clip_l.clone())
532            .clip_g(value.clip_g.clone())
533            .clip_vision(value.clip_vision.clone())
534            .t5xxl(value.t5xxl.clone())
535            .vae(value.vae.clone())
536            .taesd(value.taesd.clone())
537            .control_net(value.control_net.clone())
538            .embeddings(&value.embeddings.0)
539            .photo_maker(value.photo_maker.clone())
540            .pm_id_embed_path(value.pm_id_embed_path.clone())
541            .weight_type(value.weight_type)
542            .high_noise_diffusion_model(value.high_noise_diffusion_model.clone())
543            .vae_tiling(value.vae_tiling)
544            .vae_tile_size(value.vae_tile_size)
545            .vae_relative_tile_size(value.vae_relative_tile_size)
546            .vae_tile_overlap(value.vae_tile_overlap)
547            .rng(value.rng)
548            .sampler_rng_type(value.rng)
549            .scheduler(value.scheduler)
550            .prediction(value.prediction)
551            .vae_on_cpu(value.vae_on_cpu)
552            .clip_on_cpu(value.clip_on_cpu)
553            .control_net(value.control_net.clone())
554            .control_net_cpu(value.control_net_cpu)
555            .flash_attention(value.flash_attention)
556            .chroma_disable_dit_mask(value.chroma_disable_dit_mask)
557            .chroma_enable_t5_mask(value.chroma_enable_t5_mask)
558            .chroma_t5_mask_pad(value.chroma_t5_mask_pad)
559            .diffusion_conv_direct(value.diffusion_conv_direct)
560            .vae_conv_direct(value.vae_conv_direct)
561            .force_sdxl_vae_conv_scale(value.force_sdxl_vae_conv_scale)
562            .flow_shift(value.flow_shift)
563            .timestep_shift(value.timestep_shift)
564            .taesd_preview_only(value.taesd_preview_only)
565            .lora_apply_mode(value.lora_apply_mode);
566
567        let lora_model_dir = Into::<PathBuf>::into(&value.lora_models.lora_model_dir);
568        let lora_specs = value
569            .lora_models
570            .data
571            .iter()
572            .map(|(_, name, multiplier)| LoraSpec {
573                file_name: name.clone(),
574                is_high_noise: false,
575                multiplier: *multiplier,
576            })
577            .collect();
578
579        builder.lora_models(&lora_model_dir, lora_specs);
580
581        if let Some(model) = &value.upscale_model {
582            builder.upscale_model(model.clone());
583        }
584        builder
585    }
586}
587
588#[derive(Builder, Debug, Clone)]
589#[builder(setter(into, strip_option), build_fn(validate = "Self::validate"))]
590/// Config struct common to all diffusion methods
591pub struct Config {
592    /// Path to PHOTOMAKER input id images dir
593    #[builder(default = "Default::default()")]
594    pm_id_images_dir: CLibPath,
595
596    /// Path to the input image, required by img2img
597    #[builder(default = "Default::default()")]
598    init_img: CLibPath,
599
600    /// Path to image condition, control net
601    #[builder(default = "Default::default()")]
602    control_image: CLibPath,
603
604    /// Path to write result image to (default: ./output.png)
605    #[builder(default = "PathBuf::from(\"./output.png\")")]
606    output: PathBuf,
607
608    /// Path to write result image to (default: ./output.png)
609    #[builder(default = "PathBuf::from(\"./preview_output.png\")")]
610    preview_output: PathBuf,
611
612    /// Preview method
613    #[builder(default = "PreviewType::PREVIEW_NONE")]
614    preview_mode: PreviewType,
615
616    /// Enables previewing noisy inputs of the models rather than the denoised outputs
617    #[builder(default = "false")]
618    preview_noisy: bool,
619
620    /// Interval in denoising steps between consecutive updates of the image preview file (default is 1, meaning updating at every step)
621    #[builder(default = "1")]
622    preview_interval: i32,
623
624    /// The prompt to render
625    prompt: String,
626
627    /// The negative prompt (default: "")
628    #[builder(default = "\"\".into()")]
629    negative_prompt: CLibString,
630
631    /// Unconditional guidance scale (default: 7.0)
632    #[builder(default = "7.0")]
633    cfg_scale: f32,
634
635    /// Distilled guidance scale for models with guidance input (default: 3.5)
636    #[builder(default = "3.5")]
637    guidance: f32,
638
639    /// Strength for noising/unnoising (default: 0.75)
640    #[builder(default = "0.75")]
641    strength: f32,
642
643    /// Strength for keeping input identity (default: 20%)
644    #[builder(default = "20.0")]
645    pm_style_strength: f32,
646
647    /// Strength to apply Control Net (default: 0.9)
648    /// 1.0 corresponds to full destruction of information in init
649    #[builder(default = "0.9")]
650    control_strength: f32,
651
652    /// Image height, in pixel space (default: 512)
653    #[builder(default = "512")]
654    height: i32,
655
656    /// Image width, in pixel space (default: 512)
657    #[builder(default = "512")]
658    width: i32,
659
660    /// Sampling-method (default: [SampleMethod::SAMPLE_METHOD_COUNT]).
661    /// [SampleMethod::EULER_SAMPLE_METHOD] will be used for flux, sd3, wan, qwen_image.
662    /// Otherwise [SampleMethod::EULER_A_SAMPLE_METHOD] is used.
663    #[builder(default = "SampleMethod::SAMPLE_METHOD_COUNT")]
664    sampling_method: SampleMethod,
665
666    /// eta in DDIM, only for DDIM and TCD: (default: 0)
667    #[builder(default = "0.")]
668    eta: f32,
669
670    /// Number of sample steps (default: 20)
671    #[builder(default = "20")]
672    steps: i32,
673
674    /// RNG seed (default: 42, use random seed for < 0)
675    #[builder(default = "42")]
676    seed: i64,
677
678    /// Number of images to generate (default: 1)
679    #[builder(default = "1")]
680    batch_count: i32,
681
682    /// Ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)
683    /// <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x
684    #[builder(default = "ClipSkip::Unspecified")]
685    clip_skip: ClipSkip,
686
687    /// Apply canny preprocessor (edge detection) (default: false)
688    #[builder(default = "false")]
689    canny: bool,
690
691    /// skip layer guidance (SLG) scale, only for DiT models: (default: 0)
692    /// 0 means disabled, a value of 2.5 is nice for sd3.5 medium
693    #[builder(default = "0.")]
694    slg_scale: f32,
695
696    /// Layers to skip for SLG steps: (default: \[7,8,9\])
697    #[builder(default = "vec![7, 8, 9]")]
698    skip_layer: Vec<i32>,
699
700    /// SLG enabling point: (default: 0.01)
701    #[builder(default = "0.01")]
702    skip_layer_start: f32,
703
704    /// SLG disabling point: (default: 0.2)
705    #[builder(default = "0.2")]
706    skip_layer_end: f32,
707
708    /// Disable auto resize of ref images
709    #[builder(default = "false")]
710    disable_auto_resize_ref_image: bool,
711}
712
713impl ConfigBuilder {
714    fn validate(&self) -> Result<(), ConfigBuilderError> {
715        self.validate_output_dir()
716    }
717
718    fn validate_output_dir(&self) -> Result<(), ConfigBuilderError> {
719        let is_dir = self.output.as_ref().is_some_and(|val| val.is_dir());
720        let multiple_items = self.batch_count.as_ref().is_some_and(|val| *val > 1);
721        if is_dir == multiple_items {
722            Ok(())
723        } else {
724            Err(ConfigBuilderError::ValidationError(
725                "When batch_count > 1, output should point to folder and vice versa".to_owned(),
726            ))
727        }
728    }
729}
730
731impl From<Config> for ConfigBuilder {
732    fn from(value: Config) -> Self {
733        let mut builder = ConfigBuilder::default();
734        builder
735            .pm_id_images_dir(value.pm_id_images_dir)
736            .init_img(value.init_img)
737            .control_image(value.control_image)
738            .output(value.output)
739            .prompt(value.prompt)
740            .negative_prompt(value.negative_prompt)
741            .cfg_scale(value.cfg_scale)
742            .strength(value.strength)
743            .pm_style_strength(value.pm_style_strength)
744            .control_strength(value.control_strength)
745            .height(value.height)
746            .width(value.width)
747            .sampling_method(value.sampling_method)
748            .steps(value.steps)
749            .seed(value.seed)
750            .batch_count(value.batch_count)
751            .clip_skip(value.clip_skip)
752            .slg_scale(value.slg_scale)
753            .skip_layer(value.skip_layer)
754            .skip_layer_start(value.skip_layer_start)
755            .skip_layer_end(value.skip_layer_end)
756            .canny(value.canny)
757            .disable_auto_resize_ref_image(value.disable_auto_resize_ref_image)
758            .preview_output(value.preview_output)
759            .preview_mode(value.preview_mode)
760            .preview_noisy(value.preview_noisy)
761            .preview_interval(value.preview_interval);
762        builder
763    }
764}
765
766#[derive(Debug, Clone, Default)]
767struct CLibString(CString);
768
769impl CLibString {
770    fn as_ptr(&self) -> *const c_char {
771        self.0.as_ptr()
772    }
773}
774
775impl From<&str> for CLibString {
776    fn from(value: &str) -> Self {
777        Self(CString::new(value).unwrap())
778    }
779}
780
781impl From<String> for CLibString {
782    fn from(value: String) -> Self {
783        Self(CString::new(value).unwrap())
784    }
785}
786
787#[derive(Debug, Clone, Default)]
788struct CLibPath(CString);
789
790impl CLibPath {
791    fn as_ptr(&self) -> *const c_char {
792        self.0.as_ptr()
793    }
794}
795
796impl From<PathBuf> for CLibPath {
797    fn from(value: PathBuf) -> Self {
798        Self(CString::new(value.to_str().unwrap_or_default()).unwrap())
799    }
800}
801
802impl From<&Path> for CLibPath {
803    fn from(value: &Path) -> Self {
804        Self(CString::new(value.to_str().unwrap_or_default()).unwrap())
805    }
806}
807
808impl From<&CLibPath> for PathBuf {
809    fn from(value: &CLibPath) -> Self {
810        PathBuf::from(value.0.to_str().unwrap())
811    }
812}
813
814fn output_files(path: &Path, prompt: &str, batch_size: i32) -> Vec<PathBuf> {
815    if batch_size == 1 {
816        vec![path.into()]
817    } else {
818        (1..=batch_size)
819            .map(|id| path.join(format!("{prompt}_{id}.png")))
820            .collect()
821    }
822}
823
824unsafe fn upscale(
825    upscale_repeats: i32,
826    upscaler_ctx: Option<*mut upscaler_ctx_t>,
827    data: sd_image_t,
828) -> Result<sd_image_t, DiffusionError> {
829    unsafe {
830        match upscaler_ctx {
831            Some(upscaler_ctx) => {
832                let upscale_factor = 4; // unused for RealESRGAN_x4plus_anime_6B.pth
833                let mut current_image = data;
834                for _ in 0..upscale_repeats {
835                    let upscaled_image =
836                        diffusion_rs_sys::upscale(upscaler_ctx, current_image, upscale_factor);
837
838                    if upscaled_image.data.is_null() {
839                        return Err(DiffusionError::Upscaler);
840                    }
841
842                    free(current_image.data as *mut c_void);
843                    current_image = upscaled_image;
844                }
845                Ok(current_image)
846            }
847            None => Ok(data),
848        }
849    }
850}
851
852/// Generate an image with a prompt
853pub fn gen_img(config: &Config, model_config: &mut ModelConfig) -> Result<(), DiffusionError> {
854    let prompt: CLibString = CLibString::from(config.prompt.as_str());
855    let files = output_files(&config.output, &config.prompt, config.batch_count);
856    unsafe {
857        let sd_ctx = model_config.diffusion_ctx(true);
858        let upscaler_ctx = model_config.upscaler_ctx();
859        let init_image = sd_image_t {
860            width: 0,
861            height: 0,
862            channel: 3,
863            data: null_mut(),
864        };
865        let mask_image = sd_image_t {
866            width: config.width as u32,
867            height: config.height as u32,
868            channel: 1,
869            data: null_mut(),
870        };
871        let mut layers = config.skip_layer.clone();
872        let guidance = sd_guidance_params_t {
873            txt_cfg: config.cfg_scale,
874            img_cfg: config.cfg_scale,
875            distilled_guidance: config.guidance,
876            slg: sd_slg_params_t {
877                layers: layers.as_mut_ptr(),
878                layer_count: config.skip_layer.len(),
879                layer_start: config.skip_layer_start,
880                layer_end: config.skip_layer_end,
881                scale: config.slg_scale,
882            },
883        };
884        let scheduler = if model_config.scheduler == Scheduler::SCHEDULER_COUNT {
885            sd_get_default_scheduler(sd_ctx)
886        } else {
887            model_config.scheduler
888        };
889        let sample_method = if config.sampling_method == SampleMethod::SAMPLE_METHOD_COUNT {
890            sd_get_default_sample_method(sd_ctx)
891        } else {
892            config.sampling_method
893        };
894        let sample_params = sd_sample_params_t {
895            guidance,
896            sample_method,
897            sample_steps: config.steps,
898            eta: config.eta,
899            scheduler,
900            shifted_timestep: model_config.timestep_shift,
901        };
902        let control_image = sd_image_t {
903            width: 0,
904            height: 0,
905            channel: 3,
906            data: null_mut(),
907        };
908        let vae_tiling_params = sd_tiling_params_t {
909            enabled: model_config.vae_tiling,
910            tile_size_x: model_config.vae_tile_size.0,
911            tile_size_y: model_config.vae_tile_size.1,
912            target_overlap: model_config.vae_tile_overlap,
913            rel_size_x: model_config.vae_relative_tile_size.0,
914            rel_size_y: model_config.vae_relative_tile_size.1,
915        };
916        let pm_params = sd_pm_params_t {
917            id_images: null_mut(),
918            id_images_count: 0,
919            id_embed_path: model_config.pm_id_embed_path.as_ptr(),
920            style_strength: config.pm_style_strength,
921        };
922
923        unsafe extern "C" fn save_preview_local(
924            _step: ::std::os::raw::c_int,
925            _frame_count: ::std::os::raw::c_int,
926            frames: *mut sd_image_t,
927            _is_noisy: bool,
928            data: *mut ::std::os::raw::c_void,
929        ) {
930            unsafe {
931                let path = &*data.cast::<PathBuf>();
932                let _ = save_img(*frames, path);
933            }
934        }
935
936        if config.preview_mode != PreviewType::PREVIEW_NONE {
937            let data = &config.preview_output as *const PathBuf;
938
939            sd_set_preview_callback(
940                Some(save_preview_local),
941                config.preview_mode,
942                config.preview_interval,
943                !config.preview_noisy,
944                config.preview_noisy,
945                data as *mut c_void,
946            );
947        }
948
949        let easy_cache = sd_easycache_params_t {
950            enabled: model_config.easy_cache,
951            reuse_threshold: model_config.easy_cache_reuse_threshold,
952            start_percent: model_config.easy_cache_start_percent,
953            end_percent: model_config.easy_cache_end_percent,
954        };
955
956        let sd_img_gen_params = sd_img_gen_params_t {
957            prompt: prompt.as_ptr(),
958            negative_prompt: config.negative_prompt.as_ptr(),
959            clip_skip: config.clip_skip as i32,
960            init_image,
961            ref_images: null_mut(),
962            ref_images_count: 0,
963            increase_ref_index: false,
964            mask_image,
965            width: config.width,
966            height: config.height,
967            sample_params,
968            strength: config.strength,
969            seed: config.seed,
970            batch_count: config.batch_count,
971            control_image,
972            control_strength: config.control_strength,
973            pm_params,
974            vae_tiling_params,
975            auto_resize_ref_image: config.disable_auto_resize_ref_image,
976            easycache: easy_cache,
977            loras: model_config.lora_models.loras_t.as_ptr(),
978            lora_count: model_config.lora_models.loras_t.len() as u32,
979        };
980        let slice = diffusion_rs_sys::generate_image(sd_ctx, &sd_img_gen_params);
981        let ret = {
982            if slice.is_null() {
983                return Err(DiffusionError::Forward);
984            }
985            for (img, path) in slice::from_raw_parts(slice, config.batch_count as usize)
986                .iter()
987                .zip(files)
988            {
989                match upscale(model_config.upscale_repeats, upscaler_ctx, *img) {
990                    Ok(img) => save_img(img, &path)?,
991                    Err(err) => {
992                        return Err(err);
993                    }
994                }
995            }
996            Ok(())
997        };
998        free(slice as *mut c_void);
999        ret
1000    }
1001}
1002
1003fn save_img(img: sd_image_t, path: &Path) -> Result<(), DiffusionError> {
1004    // Thx @wandbrandon
1005    let len = (img.width * img.height * img.channel) as usize;
1006    let buffer = unsafe { slice::from_raw_parts(img.data, len).to_vec() };
1007    let save_state = ImageBuffer::from_raw(img.width, img.height, buffer)
1008        .map(|img| RgbImage::from(img).save(path));
1009    if let Some(Err(err)) = save_state {
1010        return Err(DiffusionError::StoreImages(err));
1011    }
1012    Ok(())
1013}
1014
1015#[cfg(test)]
1016mod tests {
1017    use std::path::PathBuf;
1018
1019    use crate::{
1020        api::{ConfigBuilderError, ModelConfigBuilder},
1021        util::download_file_hf_hub,
1022    };
1023
1024    use super::{ConfigBuilder, gen_img};
1025
1026    #[test]
1027    fn test_required_args_txt2img() {
1028        assert!(ConfigBuilder::default().build().is_err());
1029        assert!(ModelConfigBuilder::default().build().is_err());
1030        ModelConfigBuilder::default()
1031            .model(PathBuf::from("./test.ckpt"))
1032            .build()
1033            .unwrap();
1034
1035        ConfigBuilder::default()
1036            .prompt("a lovely cat driving a sport car")
1037            .build()
1038            .unwrap();
1039
1040        assert!(matches!(
1041            ConfigBuilder::default()
1042                .prompt("a lovely cat driving a sport car")
1043                .batch_count(10)
1044                .build(),
1045            Err(ConfigBuilderError::ValidationError(_))
1046        ));
1047
1048        ConfigBuilder::default()
1049            .prompt("a lovely cat driving a sport car")
1050            .build()
1051            .unwrap();
1052
1053        ConfigBuilder::default()
1054            .prompt("a lovely duck drinking water from a bottle")
1055            .batch_count(2)
1056            .output(PathBuf::from("./"))
1057            .build()
1058            .unwrap();
1059    }
1060
1061    #[ignore]
1062    #[test]
1063    fn test_img_gen() {
1064        let model_path =
1065            download_file_hf_hub("CompVis/stable-diffusion-v-1-4-original", "sd-v1-4.ckpt")
1066                .unwrap();
1067
1068        let upscaler_path = download_file_hf_hub(
1069            "ximso/RealESRGAN_x4plus_anime_6B",
1070            "RealESRGAN_x4plus_anime_6B.pth",
1071        )
1072        .unwrap();
1073        let config = ConfigBuilder::default()
1074            .prompt("a lovely duck drinking water from a bottle")
1075            .output(PathBuf::from("./output_1.png"))
1076            .batch_count(1)
1077            .build()
1078            .unwrap();
1079        let mut model_config = ModelConfigBuilder::default()
1080            .model(model_path)
1081            .upscale_model(upscaler_path)
1082            .upscale_repeats(1)
1083            .build()
1084            .unwrap();
1085
1086        gen_img(&config, &mut model_config).unwrap();
1087        let config2 = ConfigBuilder::from(config.clone())
1088            .prompt("a lovely duck drinking water from a straw")
1089            .output(PathBuf::from("./output_2.png"))
1090            .build()
1091            .unwrap();
1092        gen_img(&config2, &mut model_config).unwrap();
1093
1094        let config3 = ConfigBuilder::from(config)
1095            .prompt("a lovely dog drinking water from a starbucks cup")
1096            .batch_count(2)
1097            .output(PathBuf::from("./"))
1098            .build()
1099            .unwrap();
1100
1101        gen_img(&config3, &mut model_config).unwrap();
1102    }
1103}