diffusion_rs/
api.rs

1use std::ffi::CString;
2use std::ffi::c_char;
3use std::ffi::c_void;
4use std::path::Path;
5use std::path::PathBuf;
6use std::ptr::null;
7use std::slice;
8
9use derive_builder::Builder;
10use diffusion_rs_sys::free_upscaler_ctx;
11use diffusion_rs_sys::new_upscaler_ctx;
12use diffusion_rs_sys::sd_image_t;
13use diffusion_rs_sys::upscaler_ctx_t;
14use image::ImageBuffer;
15use image::ImageError;
16use image::RgbImage;
17use libc::free;
18use thiserror::Error;
19
20use diffusion_rs_sys::free_sd_ctx;
21use diffusion_rs_sys::new_sd_ctx;
22use diffusion_rs_sys::sd_ctx_t;
23
24/// Specify the range function
25pub use diffusion_rs_sys::rng_type_t as RngFunction;
26
27/// Sampling methods
28pub use diffusion_rs_sys::sample_method_t as SampleMethod;
29
30/// Denoiser sigma schedule
31pub use diffusion_rs_sys::schedule_t as Schedule;
32
33/// Weight type
34pub use diffusion_rs_sys::sd_type_t as WeightType;
35
36#[non_exhaustive]
37#[derive(Error, Debug)]
38/// Error that can occurs while forwarding models
39pub enum DiffusionError {
40    #[error("The underling stablediffusion.cpp function returned NULL")]
41    Forward,
42    #[error(transparent)]
43    StoreImages(#[from] ImageError),
44    #[error("The underling upscaler model returned a NULL image")]
45    Upscaler,
46}
47
48#[repr(i32)]
49#[non_exhaustive]
50#[derive(Debug, Default, Copy, Clone, Hash, PartialEq, Eq)]
51/// Ignore the lower X layers of CLIP network
52pub enum ClipSkip {
53    /// Will be [ClipSkip::None] for SD1.x, [ClipSkip::OneLayer] for SD2.x
54    #[default]
55    Unspecified = 0,
56    None = 1,
57    OneLayer = 2,
58}
59
60#[derive(Builder, Debug, Clone)]
61#[builder(
62    setter(into, strip_option),
63    build_fn(error = "ConfigBuilderError", validate = "Self::validate")
64)]
65pub struct ModelConfig {
66    /// Number of threads to use during computation (default: 0).
67    /// If n_ threads <= 0, then threads will be set to the number of CPU physical cores.
68    #[builder(default = "num_cpus::get_physical() as i32", setter(custom))]
69    n_threads: i32,
70
71    /// Path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now
72    #[builder(default = "Default::default()")]
73    upscale_model: Option<CLibPath>,
74
75    /// Run the ESRGAN upscaler this many times (default 1)
76    #[builder(default = "0")]
77    upscale_repeats: i32,
78
79    /// Path to full model
80    #[builder(default = "Default::default()")]
81    model: CLibPath,
82
83    /// Path to the standalone diffusion model
84    #[builder(default = "Default::default()")]
85    diffusion_model: CLibPath,
86
87    /// path to the clip-l text encoder
88    #[builder(default = "Default::default()")]
89    clip_l: CLibPath,
90
91    /// path to the clip-g text encoder
92    #[builder(default = "Default::default()")]
93    clip_g: CLibPath,
94
95    /// Path to the t5xxl text encoder
96    #[builder(default = "Default::default()")]
97    t5xxl: CLibPath,
98
99    /// Path to vae
100    #[builder(default = "Default::default()")]
101    vae: CLibPath,
102
103    /// Path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)
104    #[builder(default = "Default::default()")]
105    taesd: CLibPath,
106
107    /// Path to control net model
108    #[builder(default = "Default::default()")]
109    control_net: CLibPath,
110
111    /// Path to embeddings
112    #[builder(default = "Default::default()")]
113    embeddings: CLibPath,
114
115    /// Path to PHOTOMAKER stacked id embeddings
116    #[builder(default = "Default::default()")]
117    stacked_id_embd: CLibPath,
118
119    /// Weight type. If not specified, the default is the type of the weight file
120    #[builder(default = "WeightType::SD_TYPE_COUNT")]
121    weight_type: WeightType,
122
123    /// Lora model directory
124    #[builder(default = "Default::default()", setter(custom))]
125    lora_model: CLibPath,
126
127    /// Suffix that needs to be added to prompt (e.g. lora model)
128    #[builder(default = "None", private)]
129    prompt_suffix: Option<String>,
130
131    /// Process vae in tiles to reduce memory usage (default: false)
132    #[builder(default = "false")]
133    vae_tiling: bool,
134
135    /// RNG (default: CUDA)
136    #[builder(default = "RngFunction::CUDA_RNG")]
137    rng: RngFunction,
138
139    /// Denoiser sigma schedule (default: DEFAULT)
140    #[builder(default = "Schedule::DEFAULT")]
141    schedule: Schedule,
142
143    /// Keep vae in cpu (for low vram) (default: false)
144    #[builder(default = "false")]
145    vae_on_cpu: bool,
146
147    /// keep clip in cpu (for low vram) (default: false)
148    #[builder(default = "false")]
149    clip_on_cpu: bool,
150
151    /// Keep controlnet in cpu (for low vram) (default: false)
152    #[builder(default = "false")]
153    control_net_cpu: bool,
154
155    /// Use flash attention in the diffusion model (for low vram).
156    /// Might lower quality, since it implies converting k and v to f16.
157    /// This might crash if it is not supported by the backend.
158    #[builder(default = "false")]
159    flash_attention: bool,
160
161    #[builder(default = "true")]
162    chroma_use_dit_mask: bool,
163
164    #[builder(default = "false")]
165    chroma_use_t5_mask: bool,
166
167    #[builder(default = "1")]
168    chroma_t5_mask_pad: i32,
169
170    #[builder(default = "None", private)]
171    upscaler_ctx: Option<*mut upscaler_ctx_t>,
172
173    #[builder(default = "None", private)]
174    diffusion_ctx: Option<*mut sd_ctx_t>,
175}
176
177impl ModelConfigBuilder {
178    fn validate(&self) -> Result<(), ConfigBuilderError> {
179        self.validate_model()
180    }
181
182    fn validate_model(&self) -> Result<(), ConfigBuilderError> {
183        self.model
184            .as_ref()
185            .or(self.diffusion_model.as_ref())
186            .map(|_| ())
187            .ok_or(ConfigBuilderError::UninitializedField(
188                "Model OR DiffusionModel must be valorized",
189            ))
190    }
191
192    pub fn lora_model(&mut self, lora_model: &Path) -> &mut Self {
193        let folder = lora_model.parent().unwrap();
194        let file_name = lora_model.file_stem().unwrap().to_str().unwrap().to_owned();
195        self.prompt_suffix(format!("<lora:{file_name}:1>"));
196        self.lora_model = Some(folder.into());
197        self
198    }
199
200    pub fn n_threads(&mut self, value: i32) -> &mut Self {
201        self.n_threads = if value > 0 {
202            Some(value)
203        } else {
204            Some(num_cpus::get_physical() as i32)
205        };
206        self
207    }
208}
209
210impl ModelConfig {
211    unsafe fn upscaler_ctx(&mut self) -> Option<*mut upscaler_ctx_t> {
212        unsafe {
213            if self.upscale_model.is_none() || self.upscale_repeats == 0 {
214                None
215            } else {
216                if self.upscaler_ctx.is_none() {
217                    let upscaler = new_upscaler_ctx(
218                        self.upscale_model.as_ref().unwrap().as_ptr(),
219                        self.n_threads,
220                    );
221                    self.upscaler_ctx = Some(upscaler);
222                }
223                self.upscaler_ctx
224            }
225        }
226    }
227
228    unsafe fn diffusion_ctx(&mut self, vae_decode_only: bool) -> *mut sd_ctx_t {
229        unsafe {
230            if self.diffusion_ctx.is_none() {
231                let ctx = new_sd_ctx(
232                    self.model.as_ptr(),
233                    self.clip_l.as_ptr(),
234                    self.clip_g.as_ptr(),
235                    self.t5xxl.as_ptr(),
236                    self.diffusion_model.as_ptr(),
237                    self.vae.as_ptr(),
238                    self.taesd.as_ptr(),
239                    self.control_net.as_ptr(),
240                    self.lora_model.as_ptr(),
241                    self.embeddings.as_ptr(),
242                    self.stacked_id_embd.as_ptr(),
243                    vae_decode_only,
244                    self.vae_tiling,
245                    false,
246                    self.n_threads,
247                    self.weight_type,
248                    self.rng,
249                    self.schedule,
250                    self.clip_on_cpu,
251                    self.control_net_cpu,
252                    self.vae_on_cpu,
253                    self.flash_attention,
254                    self.chroma_use_dit_mask,
255                    self.chroma_use_t5_mask,
256                    self.chroma_t5_mask_pad,
257                );
258                self.diffusion_ctx = Some(ctx)
259            }
260            self.diffusion_ctx.unwrap()
261        }
262    }
263}
264
265impl Drop for ModelConfig {
266    fn drop(&mut self) {
267        //Clean-up CTX section
268        unsafe {
269            if let Some(sd_ctx) = self.diffusion_ctx {
270                free_sd_ctx(sd_ctx);
271            }
272
273            if let Some(upscaler_ctx) = self.upscaler_ctx {
274                free_upscaler_ctx(upscaler_ctx);
275            }
276        }
277    }
278}
279
280#[derive(Builder, Debug, Clone)]
281#[builder(setter(into, strip_option), build_fn(validate = "Self::validate"))]
282/// Config struct common to all diffusion methods
283pub struct Config {
284    /// Path to PHOTOMAKER input id images dir
285    #[builder(default = "Default::default()")]
286    input_id_images: CLibPath,
287
288    /// Normalize PHOTOMAKER input id images
289    #[builder(default = "false")]
290    normalize_input: bool,
291
292    /// Path to the input image, required by img2img
293    #[builder(default = "Default::default()")]
294    init_img: CLibPath,
295
296    /// Path to image condition, control net
297    #[builder(default = "Default::default()")]
298    control_image: CLibPath,
299
300    /// Path to write result image to (default: ./output.png)
301    #[builder(default = "PathBuf::from(\"./output.png\")")]
302    output: PathBuf,
303
304    /// The prompt to render
305    prompt: String,
306
307    /// The negative prompt (default: "")
308    #[builder(default = "\"\".into()")]
309    negative_prompt: CLibString,
310
311    /// Unconditional guidance scale (default: 7.0)
312    #[builder(default = "7.0")]
313    cfg_scale: f32,
314
315    /// Guidance (default: 3.5)
316    #[builder(default = "3.5")]
317    guidance: f32,
318
319    /// Strength for noising/unnoising (default: 0.75)
320    #[builder(default = "0.75")]
321    strength: f32,
322
323    /// Strength for keeping input identity (default: 20%)
324    #[builder(default = "20.0")]
325    style_ratio: f32,
326
327    /// Strength to apply Control Net (default: 0.9)
328    /// 1.0 corresponds to full destruction of information in init
329    #[builder(default = "0.9")]
330    control_strength: f32,
331
332    /// Image height, in pixel space (default: 512)
333    #[builder(default = "512")]
334    height: i32,
335
336    /// Image width, in pixel space (default: 512)
337    #[builder(default = "512")]
338    width: i32,
339
340    /// Sampling-method (default: EULER_A)
341    #[builder(default = "SampleMethod::EULER_A")]
342    sampling_method: SampleMethod,
343
344    /// eta in DDIM, only for DDIM and TCD: (default: 0)
345    #[builder(default = "0.")]
346    eta: f32,
347
348    /// Number of sample steps (default: 20)
349    #[builder(default = "20")]
350    steps: i32,
351
352    /// RNG seed (default: 42, use random seed for < 0)
353    #[builder(default = "42")]
354    seed: i64,
355
356    /// Number of images to generate (default: 1)
357    #[builder(default = "1")]
358    batch_count: i32,
359
360    /// Ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)
361    /// <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x
362    #[builder(default = "ClipSkip::Unspecified")]
363    clip_skip: ClipSkip,
364
365    /// Apply canny preprocessor (edge detection) (default: false)
366    #[builder(default = "false")]
367    canny: bool,
368
369    /// skip layer guidance (SLG) scale, only for DiT models: (default: 0)
370    /// 0 means disabled, a value of 2.5 is nice for sd3.5 medium
371    #[builder(default = "0.")]
372    slg_scale: f32,
373
374    /// Layers to skip for SLG steps: (default: \[7,8,9\])
375    #[builder(default = "vec![7, 8, 9]")]
376    skip_layer: Vec<i32>,
377
378    /// SLG enabling point: (default: 0.01)
379    #[builder(default = "0.01")]
380    skip_layer_start: f32,
381
382    /// SLG disabling point: (default: 0.2)
383    #[builder(default = "0.2")]
384    skip_layer_end: f32,
385}
386
387impl ConfigBuilder {
388    fn validate(&self) -> Result<(), ConfigBuilderError> {
389        self.validate_output_dir()
390    }
391
392    fn validate_output_dir(&self) -> Result<(), ConfigBuilderError> {
393        let is_dir = self.output.as_ref().is_some_and(|val| val.is_dir());
394        let multiple_items = self.batch_count.as_ref().is_some_and(|val| *val > 1);
395        if is_dir == multiple_items {
396            Ok(())
397        } else {
398            Err(ConfigBuilderError::ValidationError(
399                "When batch_count > 1, ouput should point to folder and viceversa".to_owned(),
400            ))
401        }
402    }
403}
404
405impl From<Config> for ConfigBuilder {
406    fn from(value: Config) -> Self {
407        let mut builder = ConfigBuilder::default();
408        builder
409            .input_id_images(value.input_id_images)
410            .normalize_input(value.normalize_input)
411            .init_img(value.init_img)
412            .control_image(value.control_image)
413            .output(value.output)
414            .prompt(value.prompt)
415            .negative_prompt(value.negative_prompt)
416            .cfg_scale(value.cfg_scale)
417            .strength(value.strength)
418            .style_ratio(value.style_ratio)
419            .control_strength(value.control_strength)
420            .height(value.height)
421            .width(value.width)
422            .sampling_method(value.sampling_method)
423            .steps(value.steps)
424            .seed(value.seed)
425            .batch_count(value.batch_count)
426            .clip_skip(value.clip_skip)
427            .slg_scale(value.slg_scale)
428            .skip_layer(value.skip_layer)
429            .skip_layer_start(value.skip_layer_start)
430            .skip_layer_end(value.skip_layer_end)
431            .canny(value.canny);
432
433        builder
434    }
435}
436
437#[derive(Debug, Clone, Default)]
438struct CLibString(CString);
439
440impl CLibString {
441    fn as_ptr(&self) -> *const c_char {
442        self.0.as_ptr()
443    }
444}
445
446impl From<&str> for CLibString {
447    fn from(value: &str) -> Self {
448        Self(CString::new(value).unwrap())
449    }
450}
451
452impl From<String> for CLibString {
453    fn from(value: String) -> Self {
454        Self(CString::new(value).unwrap())
455    }
456}
457
458#[derive(Debug, Clone, Default)]
459struct CLibPath(CString);
460
461impl CLibPath {
462    fn as_ptr(&self) -> *const c_char {
463        self.0.as_ptr()
464    }
465}
466
467impl From<PathBuf> for CLibPath {
468    fn from(value: PathBuf) -> Self {
469        Self(CString::new(value.to_str().unwrap_or_default()).unwrap())
470    }
471}
472
473impl From<&Path> for CLibPath {
474    fn from(value: &Path) -> Self {
475        Self(CString::new(value.to_str().unwrap_or_default()).unwrap())
476    }
477}
478
479fn output_files(path: &Path, batch_size: i32) -> Vec<PathBuf> {
480    if batch_size == 1 {
481        vec![path.into()]
482    } else {
483        (1..=batch_size)
484            .map(|id| path.join(format!("output_{id}.png")))
485            .collect()
486    }
487}
488
489unsafe fn upscale(
490    upscale_repeats: i32,
491    upscaler_ctx: Option<*mut upscaler_ctx_t>,
492    data: sd_image_t,
493) -> Result<sd_image_t, DiffusionError> {
494    unsafe {
495        match upscaler_ctx {
496            Some(upscaler_ctx) => {
497                let upscale_factor = 4; // unused for RealESRGAN_x4plus_anime_6B.pth
498                let mut current_image = data;
499                for _ in 0..upscale_repeats {
500                    let upscaled_image =
501                        diffusion_rs_sys::upscale(upscaler_ctx, current_image, upscale_factor);
502
503                    if upscaled_image.data.is_null() {
504                        return Err(DiffusionError::Upscaler);
505                    }
506
507                    free(current_image.data as *mut c_void);
508                    current_image = upscaled_image;
509                }
510                Ok(current_image)
511            }
512            None => Ok(data),
513        }
514    }
515}
516
517/// Generate an image with a prompt
518pub fn txt2img(config: &mut Config, model_config: &mut ModelConfig) -> Result<(), DiffusionError> {
519    let prompt: CLibString = match &model_config.prompt_suffix {
520        Some(suffix) => format!("{} {suffix}", &config.prompt),
521        None => config.prompt.clone(),
522    }
523    .into();
524    let files = output_files(&config.output, config.batch_count);
525    unsafe {
526        let sd_ctx = model_config.diffusion_ctx(true);
527        let upscaler_ctx = model_config.upscaler_ctx();
528
529        let slice = diffusion_rs_sys::txt2img(
530            sd_ctx,
531            prompt.as_ptr(),
532            config.negative_prompt.as_ptr(),
533            config.clip_skip as i32,
534            config.cfg_scale,
535            config.guidance,
536            config.eta,
537            config.width,
538            config.height,
539            config.sampling_method,
540            config.steps,
541            config.seed,
542            config.batch_count,
543            null(),
544            config.control_strength,
545            config.style_ratio,
546            config.normalize_input,
547            config.input_id_images.as_ptr(),
548            config.skip_layer.as_mut_ptr(),
549            config.skip_layer.len(),
550            config.slg_scale,
551            config.skip_layer_start,
552            config.skip_layer_end,
553        );
554        if slice.is_null() {
555            return Err(DiffusionError::Forward);
556        }
557        for (img, path) in slice::from_raw_parts(slice, config.batch_count as usize)
558            .iter()
559            .zip(files)
560        {
561            match upscale(model_config.upscale_repeats, upscaler_ctx, *img) {
562                Ok(img) => {
563                    // Thx @wandbrandon
564                    let len = (img.width * img.height * img.channel) as usize;
565                    let buffer = slice::from_raw_parts(img.data, len).to_vec();
566                    let save_state = ImageBuffer::from_raw(img.width, img.height, buffer)
567                        .map(|img| RgbImage::from(img).save(path));
568                    if let Some(Err(err)) = save_state {
569                        return Err(DiffusionError::StoreImages(err));
570                    }
571                }
572                Err(err) => {
573                    return Err(err);
574                }
575            }
576        }
577
578        //Clean-up slice section
579        free(slice as *mut c_void);
580        Ok(())
581    }
582}
583
584#[cfg(test)]
585mod tests {
586    use std::path::PathBuf;
587
588    use crate::{
589        api::{ConfigBuilderError, ModelConfigBuilder},
590        util::download_file_hf_hub,
591    };
592
593    use super::{ConfigBuilder, txt2img};
594
595    #[test]
596    fn test_required_args_txt2img() {
597        assert!(ConfigBuilder::default().build().is_err());
598        assert!(ModelConfigBuilder::default().build().is_err());
599        ModelConfigBuilder::default()
600            .model(PathBuf::from("./test.ckpt"))
601            .build()
602            .unwrap();
603
604        ConfigBuilder::default()
605            .prompt("a lovely cat driving a sport car")
606            .build()
607            .unwrap();
608
609        assert!(matches!(
610            ConfigBuilder::default()
611                .prompt("a lovely cat driving a sport car")
612                .batch_count(10)
613                .build(),
614            Err(ConfigBuilderError::ValidationError(_))
615        ));
616
617        ConfigBuilder::default()
618            .prompt("a lovely cat driving a sport car")
619            .build()
620            .unwrap();
621
622        ConfigBuilder::default()
623            .prompt("a lovely duck drinking water from a bottle")
624            .batch_count(2)
625            .output(PathBuf::from("./"))
626            .build()
627            .unwrap();
628    }
629
630    #[ignore]
631    #[test]
632    fn test_txt2img() {
633        let model_path =
634            download_file_hf_hub("CompVis/stable-diffusion-v-1-4-original", "sd-v1-4.ckpt")
635                .unwrap();
636
637        let upscaler_path = download_file_hf_hub(
638            "ximso/RealESRGAN_x4plus_anime_6B",
639            "RealESRGAN_x4plus_anime_6B.pth",
640        )
641        .unwrap();
642        let mut config = ConfigBuilder::default()
643            .prompt("a lovely duck drinking water from a bottle")
644            .output(PathBuf::from("./output_1.png"))
645            .batch_count(1)
646            .build()
647            .unwrap();
648        let mut model_config = ModelConfigBuilder::default()
649            .model(model_path)
650            .upscale_model(upscaler_path)
651            .upscale_repeats(1)
652            .build()
653            .unwrap();
654
655        txt2img(&mut config, &mut model_config).unwrap();
656        let mut config2 = ConfigBuilder::from(config)
657            .prompt("a lovely duck drinking water from a straw")
658            .output(PathBuf::from("./output_2.png"))
659            .build()
660            .unwrap();
661        txt2img(&mut config2, &mut model_config).unwrap();
662    }
663}