diffusion_rs/
api.rs

1use std::ffi::CString;
2use std::ffi::c_char;
3use std::ffi::c_void;
4use std::path::Path;
5use std::path::PathBuf;
6use std::ptr::null;
7use std::ptr::null_mut;
8use std::slice;
9
10use derive_builder::Builder;
11use diffusion_rs_sys::free_upscaler_ctx;
12use diffusion_rs_sys::new_upscaler_ctx;
13use diffusion_rs_sys::sd_ctx_params_t;
14use diffusion_rs_sys::sd_guidance_params_t;
15use diffusion_rs_sys::sd_image_t;
16use diffusion_rs_sys::sd_img_gen_params_t;
17use diffusion_rs_sys::sd_slg_params_t;
18use diffusion_rs_sys::upscaler_ctx_t;
19use image::ImageBuffer;
20use image::ImageError;
21use image::RgbImage;
22use libc::free;
23use thiserror::Error;
24
25use diffusion_rs_sys::free_sd_ctx;
26use diffusion_rs_sys::new_sd_ctx;
27use diffusion_rs_sys::sd_ctx_t;
28
29/// Specify the range function
30pub use diffusion_rs_sys::rng_type_t as RngFunction;
31
32/// Sampling methods
33pub use diffusion_rs_sys::sample_method_t as SampleMethod;
34
35/// Denoiser sigma schedule
36pub use diffusion_rs_sys::schedule_t as Schedule;
37
38/// Weight type
39pub use diffusion_rs_sys::sd_type_t as WeightType;
40
41#[non_exhaustive]
42#[derive(Error, Debug)]
43/// Error that can occurs while forwarding models
44pub enum DiffusionError {
45    #[error("The underling stablediffusion.cpp function returned NULL")]
46    Forward,
47    #[error(transparent)]
48    StoreImages(#[from] ImageError),
49    #[error("The underling upscaler model returned a NULL image")]
50    Upscaler,
51}
52
53#[repr(i32)]
54#[non_exhaustive]
55#[derive(Debug, Default, Copy, Clone, Hash, PartialEq, Eq)]
56/// Ignore the lower X layers of CLIP network
57pub enum ClipSkip {
58    /// Will be [ClipSkip::None] for SD1.x, [ClipSkip::OneLayer] for SD2.x
59    #[default]
60    Unspecified = 0,
61    None = 1,
62    OneLayer = 2,
63}
64
65#[derive(Builder, Debug, Clone)]
66#[builder(
67    setter(into, strip_option),
68    build_fn(error = "ConfigBuilderError", validate = "Self::validate")
69)]
70pub struct ModelConfig {
71    /// Number of threads to use during computation (default: 0).
72    /// If n_ threads <= 0, then threads will be set to the number of CPU physical cores.
73    #[builder(default = "num_cpus::get_physical() as i32", setter(custom))]
74    n_threads: i32,
75
76    /// Path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now
77    #[builder(default = "Default::default()")]
78    upscale_model: Option<CLibPath>,
79
80    /// Run the ESRGAN upscaler this many times (default 1)
81    #[builder(default = "0")]
82    upscale_repeats: i32,
83
84    /// Path to full model
85    #[builder(default = "Default::default()")]
86    model: CLibPath,
87
88    /// Path to the standalone diffusion model
89    #[builder(default = "Default::default()")]
90    diffusion_model: CLibPath,
91
92    /// path to the clip-l text encoder
93    #[builder(default = "Default::default()")]
94    clip_l: CLibPath,
95
96    /// path to the clip-g text encoder
97    #[builder(default = "Default::default()")]
98    clip_g: CLibPath,
99
100    /// Path to the t5xxl text encoder
101    #[builder(default = "Default::default()")]
102    t5xxl: CLibPath,
103
104    /// Path to vae
105    #[builder(default = "Default::default()")]
106    vae: CLibPath,
107
108    /// Path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)
109    #[builder(default = "Default::default()")]
110    taesd: CLibPath,
111
112    /// Path to control net model
113    #[builder(default = "Default::default()")]
114    control_net: CLibPath,
115
116    /// Path to embeddings
117    #[builder(default = "Default::default()")]
118    embeddings: CLibPath,
119
120    /// Path to PHOTOMAKER stacked id embeddings
121    #[builder(default = "Default::default()")]
122    stacked_id_embd: CLibPath,
123
124    /// Weight type. If not specified, the default is the type of the weight file
125    #[builder(default = "WeightType::SD_TYPE_COUNT")]
126    weight_type: WeightType,
127
128    /// Lora model directory
129    #[builder(default = "Default::default()", setter(custom))]
130    lora_model: CLibPath,
131
132    /// Suffix that needs to be added to prompt (e.g. lora model)
133    #[builder(default = "None", private)]
134    prompt_suffix: Option<String>,
135
136    /// Process vae in tiles to reduce memory usage (default: false)
137    #[builder(default = "false")]
138    vae_tiling: bool,
139
140    /// RNG (default: CUDA)
141    #[builder(default = "RngFunction::CUDA_RNG")]
142    rng: RngFunction,
143
144    /// Denoiser sigma schedule (default: DEFAULT)
145    #[builder(default = "Schedule::DEFAULT")]
146    schedule: Schedule,
147
148    /// Keep vae in cpu (for low vram) (default: false)
149    #[builder(default = "false")]
150    vae_on_cpu: bool,
151
152    /// keep clip in cpu (for low vram) (default: false)
153    #[builder(default = "false")]
154    clip_on_cpu: bool,
155
156    /// Keep controlnet in cpu (for low vram) (default: false)
157    #[builder(default = "false")]
158    control_net_cpu: bool,
159
160    /// Use flash attention in the diffusion model (for low vram).
161    /// Might lower quality, since it implies converting k and v to f16.
162    /// This might crash if it is not supported by the backend.
163    #[builder(default = "false")]
164    flash_attention: bool,
165
166    #[builder(default = "false")]
167    chroma_disable_dit_mask: bool,
168
169    #[builder(default = "false")]
170    chroma_enable_t5_mask: bool,
171
172    #[builder(default = "1")]
173    chroma_t5_mask_pad: i32,
174
175    #[builder(default = "None", private)]
176    upscaler_ctx: Option<*mut upscaler_ctx_t>,
177
178    #[builder(default = "None", private)]
179    diffusion_ctx: Option<(*mut sd_ctx_t, sd_ctx_params_t)>,
180}
181
182impl ModelConfigBuilder {
183    fn validate(&self) -> Result<(), ConfigBuilderError> {
184        self.validate_model()
185    }
186
187    fn validate_model(&self) -> Result<(), ConfigBuilderError> {
188        self.model
189            .as_ref()
190            .or(self.diffusion_model.as_ref())
191            .map(|_| ())
192            .ok_or(ConfigBuilderError::UninitializedField(
193                "Model OR DiffusionModel must be valorized",
194            ))
195    }
196
197    pub fn lora_model(&mut self, lora_model: &Path) -> &mut Self {
198        let folder = lora_model.parent().unwrap();
199        let file_name = lora_model.file_stem().unwrap().to_str().unwrap().to_owned();
200        self.prompt_suffix(format!("<lora:{file_name}:1>"));
201        self.lora_model = Some(folder.into());
202        self
203    }
204
205    pub fn n_threads(&mut self, value: i32) -> &mut Self {
206        self.n_threads = if value > 0 {
207            Some(value)
208        } else {
209            Some(num_cpus::get_physical() as i32)
210        };
211        self
212    }
213}
214
215impl ModelConfig {
216    unsafe fn upscaler_ctx(&mut self) -> Option<*mut upscaler_ctx_t> {
217        unsafe {
218            if self.upscale_model.is_none() || self.upscale_repeats == 0 {
219                None
220            } else {
221                if self.upscaler_ctx.is_none() {
222                    let upscaler = new_upscaler_ctx(
223                        self.upscale_model.as_ref().unwrap().as_ptr(),
224                        self.n_threads,
225                    );
226                    self.upscaler_ctx = Some(upscaler);
227                }
228                self.upscaler_ctx
229            }
230        }
231    }
232
233    unsafe fn diffusion_ctx(&mut self, vae_decode_only: bool) -> *mut sd_ctx_t {
234        unsafe {
235            if self.diffusion_ctx.is_none() {
236                let sd_ctx_params = sd_ctx_params_t {
237                    model_path: self.model.as_ptr(),
238                    clip_l_path: self.clip_l.as_ptr(),
239                    clip_g_path: self.clip_g.as_ptr(),
240                    t5xxl_path: self.t5xxl.as_ptr(),
241                    diffusion_model_path: self.diffusion_model.as_ptr(),
242                    vae_path: self.vae.as_ptr(),
243                    taesd_path: self.taesd.as_ptr(),
244                    control_net_path: self.control_net.as_ptr(),
245                    lora_model_dir: self.lora_model.as_ptr(),
246                    embedding_dir: self.embeddings.as_ptr(),
247                    stacked_id_embed_dir: self.stacked_id_embd.as_ptr(),
248                    vae_decode_only,
249                    vae_tiling: self.vae_tiling,
250                    free_params_immediately: false,
251                    n_threads: self.n_threads,
252                    wtype: self.weight_type,
253                    rng_type: self.rng,
254                    schedule: self.schedule,
255                    keep_clip_on_cpu: self.clip_on_cpu,
256                    keep_control_net_on_cpu: self.control_net_cpu,
257                    keep_vae_on_cpu: self.vae_on_cpu,
258                    diffusion_flash_attn: self.flash_attention,
259                    chroma_use_dit_mask: !self.chroma_disable_dit_mask,
260                    chroma_use_t5_mask: self.chroma_enable_t5_mask,
261                    chroma_t5_mask_pad: self.chroma_t5_mask_pad,
262                };
263                let ctx = new_sd_ctx(&sd_ctx_params);
264                self.diffusion_ctx = Some((ctx, sd_ctx_params))
265            }
266            self.diffusion_ctx.unwrap().0
267        }
268    }
269}
270
271impl Drop for ModelConfig {
272    fn drop(&mut self) {
273        //Clean-up CTX section
274        unsafe {
275            if let Some((sd_ctx, _)) = self.diffusion_ctx {
276                free_sd_ctx(sd_ctx);
277            }
278
279            if let Some(upscaler_ctx) = self.upscaler_ctx {
280                free_upscaler_ctx(upscaler_ctx);
281            }
282        }
283    }
284}
285
286#[derive(Builder, Debug, Clone)]
287#[builder(setter(into, strip_option), build_fn(validate = "Self::validate"))]
288/// Config struct common to all diffusion methods
289pub struct Config {
290    /// Path to PHOTOMAKER input id images dir
291    #[builder(default = "Default::default()")]
292    input_id_images: CLibPath,
293
294    /// Normalize PHOTOMAKER input id images
295    #[builder(default = "false")]
296    normalize_input: bool,
297
298    /// Path to the input image, required by img2img
299    #[builder(default = "Default::default()")]
300    init_img: CLibPath,
301
302    /// Path to image condition, control net
303    #[builder(default = "Default::default()")]
304    control_image: CLibPath,
305
306    /// Path to write result image to (default: ./output.png)
307    #[builder(default = "PathBuf::from(\"./output.png\")")]
308    output: PathBuf,
309
310    /// The prompt to render
311    prompt: String,
312
313    /// The negative prompt (default: "")
314    #[builder(default = "\"\".into()")]
315    negative_prompt: CLibString,
316
317    /// Unconditional guidance scale (default: 7.0)
318    #[builder(default = "7.0")]
319    cfg_scale: f32,
320
321    /// Min guidance scale (default: 1.0)
322    #[builder(default = "1.0")]
323    min_cfg_scale: f32,
324
325    /// Distilled guidance scale for models with guidance input (default: 3.5)
326    #[builder(default = "3.5")]
327    guidance: f32,
328
329    /// Strength for noising/unnoising (default: 0.75)
330    #[builder(default = "0.75")]
331    strength: f32,
332
333    /// Strength for keeping input identity (default: 20%)
334    #[builder(default = "20.0")]
335    style_ratio: f32,
336
337    /// Strength to apply Control Net (default: 0.9)
338    /// 1.0 corresponds to full destruction of information in init
339    #[builder(default = "0.9")]
340    control_strength: f32,
341
342    /// Image height, in pixel space (default: 512)
343    #[builder(default = "512")]
344    height: i32,
345
346    /// Image width, in pixel space (default: 512)
347    #[builder(default = "512")]
348    width: i32,
349
350    /// Sampling-method (default: EULER_A)
351    #[builder(default = "SampleMethod::EULER_A")]
352    sampling_method: SampleMethod,
353
354    /// eta in DDIM, only for DDIM and TCD: (default: 0)
355    #[builder(default = "0.")]
356    eta: f32,
357
358    /// Number of sample steps (default: 20)
359    #[builder(default = "20")]
360    steps: i32,
361
362    /// RNG seed (default: 42, use random seed for < 0)
363    #[builder(default = "42")]
364    seed: i64,
365
366    /// Number of images to generate (default: 1)
367    #[builder(default = "1")]
368    batch_count: i32,
369
370    /// Ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)
371    /// <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x
372    #[builder(default = "ClipSkip::Unspecified")]
373    clip_skip: ClipSkip,
374
375    /// Apply canny preprocessor (edge detection) (default: false)
376    #[builder(default = "false")]
377    canny: bool,
378
379    /// skip layer guidance (SLG) scale, only for DiT models: (default: 0)
380    /// 0 means disabled, a value of 2.5 is nice for sd3.5 medium
381    #[builder(default = "0.")]
382    slg_scale: f32,
383
384    /// Layers to skip for SLG steps: (default: \[7,8,9\])
385    #[builder(default = "vec![7, 8, 9]")]
386    skip_layer: Vec<i32>,
387
388    /// SLG enabling point: (default: 0.01)
389    #[builder(default = "0.01")]
390    skip_layer_start: f32,
391
392    /// SLG disabling point: (default: 0.2)
393    #[builder(default = "0.2")]
394    skip_layer_end: f32,
395}
396
397impl ConfigBuilder {
398    fn validate(&self) -> Result<(), ConfigBuilderError> {
399        self.validate_output_dir()
400    }
401
402    fn validate_output_dir(&self) -> Result<(), ConfigBuilderError> {
403        let is_dir = self.output.as_ref().is_some_and(|val| val.is_dir());
404        let multiple_items = self.batch_count.as_ref().is_some_and(|val| *val > 1);
405        if is_dir == multiple_items {
406            Ok(())
407        } else {
408            Err(ConfigBuilderError::ValidationError(
409                "When batch_count > 1, ouput should point to folder and viceversa".to_owned(),
410            ))
411        }
412    }
413}
414
415impl From<Config> for ConfigBuilder {
416    fn from(value: Config) -> Self {
417        let mut builder = ConfigBuilder::default();
418        builder
419            .input_id_images(value.input_id_images)
420            .normalize_input(value.normalize_input)
421            .init_img(value.init_img)
422            .control_image(value.control_image)
423            .output(value.output)
424            .prompt(value.prompt)
425            .negative_prompt(value.negative_prompt)
426            .cfg_scale(value.cfg_scale)
427            .min_cfg_scale(value.min_cfg_scale)
428            .strength(value.strength)
429            .style_ratio(value.style_ratio)
430            .control_strength(value.control_strength)
431            .height(value.height)
432            .width(value.width)
433            .sampling_method(value.sampling_method)
434            .steps(value.steps)
435            .seed(value.seed)
436            .batch_count(value.batch_count)
437            .clip_skip(value.clip_skip)
438            .slg_scale(value.slg_scale)
439            .skip_layer(value.skip_layer)
440            .skip_layer_start(value.skip_layer_start)
441            .skip_layer_end(value.skip_layer_end)
442            .canny(value.canny);
443
444        builder
445    }
446}
447
448#[derive(Debug, Clone, Default)]
449struct CLibString(CString);
450
451impl CLibString {
452    fn as_ptr(&self) -> *const c_char {
453        self.0.as_ptr()
454    }
455}
456
457impl From<&str> for CLibString {
458    fn from(value: &str) -> Self {
459        Self(CString::new(value).unwrap())
460    }
461}
462
463impl From<String> for CLibString {
464    fn from(value: String) -> Self {
465        Self(CString::new(value).unwrap())
466    }
467}
468
469#[derive(Debug, Clone, Default)]
470struct CLibPath(CString);
471
472impl CLibPath {
473    fn as_ptr(&self) -> *const c_char {
474        self.0.as_ptr()
475    }
476}
477
478impl From<PathBuf> for CLibPath {
479    fn from(value: PathBuf) -> Self {
480        Self(CString::new(value.to_str().unwrap_or_default()).unwrap())
481    }
482}
483
484impl From<&Path> for CLibPath {
485    fn from(value: &Path) -> Self {
486        Self(CString::new(value.to_str().unwrap_or_default()).unwrap())
487    }
488}
489
490fn output_files(path: &Path, batch_size: i32) -> Vec<PathBuf> {
491    if batch_size == 1 {
492        vec![path.into()]
493    } else {
494        (1..=batch_size)
495            .map(|id| path.join(format!("output_{id}.png")))
496            .collect()
497    }
498}
499
500unsafe fn upscale(
501    upscale_repeats: i32,
502    upscaler_ctx: Option<*mut upscaler_ctx_t>,
503    data: sd_image_t,
504) -> Result<sd_image_t, DiffusionError> {
505    unsafe {
506        match upscaler_ctx {
507            Some(upscaler_ctx) => {
508                let upscale_factor = 4; // unused for RealESRGAN_x4plus_anime_6B.pth
509                let mut current_image = data;
510                for _ in 0..upscale_repeats {
511                    let upscaled_image =
512                        diffusion_rs_sys::upscale(upscaler_ctx, current_image, upscale_factor);
513
514                    if upscaled_image.data.is_null() {
515                        return Err(DiffusionError::Upscaler);
516                    }
517
518                    free(current_image.data as *mut c_void);
519                    current_image = upscaled_image;
520                }
521                Ok(current_image)
522            }
523            None => Ok(data),
524        }
525    }
526}
527
528/// Generate an image with a prompt
529pub fn gen_img(config: &mut Config, model_config: &mut ModelConfig) -> Result<(), DiffusionError> {
530    let prompt: CLibString = match &model_config.prompt_suffix {
531        Some(suffix) => format!("{} {suffix}", &config.prompt),
532        None => config.prompt.clone(),
533    }
534    .into();
535    let files = output_files(&config.output, config.batch_count);
536    unsafe {
537        let sd_ctx = model_config.diffusion_ctx(true);
538        let upscaler_ctx = model_config.upscaler_ctx();
539        let init_image = sd_image_t {
540            width: 0,
541            height: 0,
542            channel: 3,
543            data: null_mut(),
544        };
545        let mask_image = sd_image_t {
546            width: config.width as u32,
547            height: config.height as u32,
548            channel: 1,
549            data: null_mut(),
550        };
551        let guidance = sd_guidance_params_t {
552            txt_cfg: config.cfg_scale,
553            img_cfg: config.cfg_scale,
554            min_cfg: config.min_cfg_scale,
555            distilled_guidance: config.guidance,
556            slg: sd_slg_params_t {
557                layers: config.skip_layer.as_mut_ptr(),
558                layer_count: config.skip_layer.len(),
559                layer_start: config.skip_layer_start,
560                layer_end: config.skip_layer_end,
561                scale: config.slg_scale,
562            },
563        };
564
565        let sd_img_gen_params = sd_img_gen_params_t {
566            prompt: prompt.as_ptr(),
567            negative_prompt: config.negative_prompt.as_ptr(),
568            clip_skip: config.clip_skip as i32,
569            guidance,
570            init_image,
571            ref_images: null_mut(),
572            ref_images_count: 0,
573            mask_image,
574            width: config.width,
575            height: config.height,
576            sample_method: config.sampling_method,
577            sample_steps: config.steps,
578            eta: config.eta,
579            strength: config.strength,
580            seed: config.seed,
581            batch_count: config.batch_count,
582            control_cond: null(),
583            control_strength: config.control_strength,
584            style_strength: config.style_ratio,
585            normalize_input: config.normalize_input,
586            input_id_images_path: config.input_id_images.as_ptr(),
587        };
588
589        let slice = diffusion_rs_sys::generate_image(sd_ctx, &sd_img_gen_params);
590        if slice.is_null() {
591            return Err(DiffusionError::Forward);
592        }
593        for (img, path) in slice::from_raw_parts(slice, config.batch_count as usize)
594            .iter()
595            .zip(files)
596        {
597            match upscale(model_config.upscale_repeats, upscaler_ctx, *img) {
598                Ok(img) => {
599                    // Thx @wandbrandon
600                    let len = (img.width * img.height * img.channel) as usize;
601                    let buffer = slice::from_raw_parts(img.data, len).to_vec();
602                    let save_state = ImageBuffer::from_raw(img.width, img.height, buffer)
603                        .map(|img| RgbImage::from(img).save(path));
604                    if let Some(Err(err)) = save_state {
605                        return Err(DiffusionError::StoreImages(err));
606                    }
607                }
608                Err(err) => {
609                    return Err(err);
610                }
611            }
612        }
613
614        //Clean-up slice section
615        free(slice as *mut c_void);
616        Ok(())
617    }
618}
619
620#[cfg(test)]
621mod tests {
622    use std::path::PathBuf;
623
624    use crate::{
625        api::{ConfigBuilderError, ModelConfigBuilder},
626        util::download_file_hf_hub,
627    };
628
629    use super::{ConfigBuilder, gen_img};
630
631    #[test]
632    fn test_required_args_txt2img() {
633        assert!(ConfigBuilder::default().build().is_err());
634        assert!(ModelConfigBuilder::default().build().is_err());
635        ModelConfigBuilder::default()
636            .model(PathBuf::from("./test.ckpt"))
637            .build()
638            .unwrap();
639
640        ConfigBuilder::default()
641            .prompt("a lovely cat driving a sport car")
642            .build()
643            .unwrap();
644
645        assert!(matches!(
646            ConfigBuilder::default()
647                .prompt("a lovely cat driving a sport car")
648                .batch_count(10)
649                .build(),
650            Err(ConfigBuilderError::ValidationError(_))
651        ));
652
653        ConfigBuilder::default()
654            .prompt("a lovely cat driving a sport car")
655            .build()
656            .unwrap();
657
658        ConfigBuilder::default()
659            .prompt("a lovely duck drinking water from a bottle")
660            .batch_count(2)
661            .output(PathBuf::from("./"))
662            .build()
663            .unwrap();
664    }
665
666    #[ignore]
667    #[test]
668    fn test_txt2img() {
669        let model_path =
670            download_file_hf_hub("CompVis/stable-diffusion-v-1-4-original", "sd-v1-4.ckpt")
671                .unwrap();
672
673        let upscaler_path = download_file_hf_hub(
674            "ximso/RealESRGAN_x4plus_anime_6B",
675            "RealESRGAN_x4plus_anime_6B.pth",
676        )
677        .unwrap();
678        let mut config = ConfigBuilder::default()
679            .prompt("a lovely duck drinking water from a bottle")
680            .output(PathBuf::from("./output_1.png"))
681            .batch_count(1)
682            .build()
683            .unwrap();
684        let mut model_config = ModelConfigBuilder::default()
685            .model(model_path)
686            .upscale_model(upscaler_path)
687            .upscale_repeats(1)
688            .build()
689            .unwrap();
690
691        gen_img(&mut config, &mut model_config).unwrap();
692        let mut config2 = ConfigBuilder::from(config)
693            .prompt("a lovely duck drinking water from a straw")
694            .output(PathBuf::from("./output_2.png"))
695            .build()
696            .unwrap();
697        gen_img(&mut config2, &mut model_config).unwrap();
698    }
699}