diffusion_rs/
api.rs

1use std::ffi::CString;
2use std::ffi::c_char;
3use std::ffi::c_void;
4use std::path::Path;
5use std::path::PathBuf;
6use std::ptr::null;
7use std::slice;
8
9use derive_builder::Builder;
10use diffusion_rs_sys::free_upscaler_ctx;
11use diffusion_rs_sys::new_upscaler_ctx;
12use diffusion_rs_sys::sd_image_t;
13use diffusion_rs_sys::upscaler_ctx_t;
14use image::ImageBuffer;
15use image::ImageError;
16use image::RgbImage;
17use libc::free;
18use thiserror::Error;
19
20use diffusion_rs_sys::free_sd_ctx;
21use diffusion_rs_sys::new_sd_ctx;
22use diffusion_rs_sys::sd_ctx_t;
23
24/// Specify the range function
25pub use diffusion_rs_sys::rng_type_t as RngFunction;
26
27/// Sampling methods
28pub use diffusion_rs_sys::sample_method_t as SampleMethod;
29
30/// Denoiser sigma schedule
31pub use diffusion_rs_sys::schedule_t as Schedule;
32
33/// Weight type
34pub use diffusion_rs_sys::sd_type_t as WeightType;
35
36#[non_exhaustive]
37#[derive(Error, Debug)]
38/// Error that can occurs while forwarding models
39pub enum DiffusionError {
40    #[error("The underling stablediffusion.cpp function returned NULL")]
41    Forward,
42    #[error(transparent)]
43    StoreImages(#[from] ImageError),
44    #[error("The underling upscaler model returned a NULL image")]
45    Upscaler,
46}
47
48#[repr(i32)]
49#[non_exhaustive]
50#[derive(Debug, Default, Copy, Clone, Hash, PartialEq, Eq)]
51/// Ignore the lower X layers of CLIP network
52pub enum ClipSkip {
53    /// Will be [ClipSkip::None] for SD1.x, [ClipSkip::OneLayer] for SD2.x
54    #[default]
55    Unspecified = 0,
56    None = 1,
57    OneLayer = 2,
58}
59
60#[derive(Builder, Debug, Clone)]
61#[builder(
62    setter(into, strip_option),
63    build_fn(error = "ConfigBuilderError", validate = "Self::validate")
64)]
65pub struct ModelConfig {
66    /// Number of threads to use during computation (default: 0).
67    /// If n_ threads <= 0, then threads will be set to the number of CPU physical cores.
68    #[builder(default = "num_cpus::get_physical() as i32", setter(custom))]
69    n_threads: i32,
70
71    /// Path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now
72    #[builder(default = "Default::default()")]
73    upscale_model: Option<CLibPath>,
74
75    /// Run the ESRGAN upscaler this many times (default 1)
76    #[builder(default = "0")]
77    upscale_repeats: i32,
78
79    /// Path to full model
80    #[builder(default = "Default::default()")]
81    model: CLibPath,
82
83    /// Path to the standalone diffusion model
84    #[builder(default = "Default::default()")]
85    diffusion_model: CLibPath,
86
87    /// path to the clip-l text encoder
88    #[builder(default = "Default::default()")]
89    clip_l: CLibPath,
90
91    /// path to the clip-g text encoder
92    #[builder(default = "Default::default()")]
93    clip_g: CLibPath,
94
95    /// Path to the t5xxl text encoder
96    #[builder(default = "Default::default()")]
97    t5xxl: CLibPath,
98
99    /// Path to vae
100    #[builder(default = "Default::default()")]
101    vae: CLibPath,
102
103    /// Path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)
104    #[builder(default = "Default::default()")]
105    taesd: CLibPath,
106
107    /// Path to control net model
108    #[builder(default = "Default::default()")]
109    control_net: CLibPath,
110
111    /// Path to embeddings
112    #[builder(default = "Default::default()")]
113    embeddings: CLibPath,
114
115    /// Path to PHOTOMAKER stacked id embeddings
116    #[builder(default = "Default::default()")]
117    stacked_id_embd: CLibPath,
118
119    /// Weight type. If not specified, the default is the type of the weight file
120    #[builder(default = "WeightType::SD_TYPE_COUNT")]
121    weight_type: WeightType,
122
123    /// Lora model directory
124    #[builder(default = "Default::default()", setter(custom))]
125    lora_model: CLibPath,
126
127    /// Suffix that needs to be added to prompt (e.g. lora model)
128    #[builder(default = "None", private)]
129    prompt_suffix: Option<String>,
130
131    /// Process vae in tiles to reduce memory usage (default: false)
132    #[builder(default = "false")]
133    vae_tiling: bool,
134
135    /// RNG (default: CUDA)
136    #[builder(default = "RngFunction::CUDA_RNG")]
137    rng: RngFunction,
138
139    /// Denoiser sigma schedule (default: DEFAULT)
140    #[builder(default = "Schedule::DEFAULT")]
141    schedule: Schedule,
142
143    /// Keep vae in cpu (for low vram) (default: false)
144    #[builder(default = "false")]
145    vae_on_cpu: bool,
146
147    /// keep clip in cpu (for low vram) (default: false)
148    #[builder(default = "false")]
149    clip_on_cpu: bool,
150
151    /// Keep controlnet in cpu (for low vram) (default: false)
152    #[builder(default = "false")]
153    control_net_cpu: bool,
154
155    /// Use flash attention in the diffusion model (for low vram).
156    /// Might lower quality, since it implies converting k and v to f16.
157    /// This might crash if it is not supported by the backend.
158    #[builder(default = "false")]
159    flash_attention: bool,
160
161    #[builder(default = "None", private)]
162    upscaler_ctx: Option<*mut upscaler_ctx_t>,
163
164    #[builder(default = "None", private)]
165    diffusion_ctx: Option<*mut sd_ctx_t>,
166}
167
168impl ModelConfigBuilder {
169    fn validate(&self) -> Result<(), ConfigBuilderError> {
170        self.validate_model()
171    }
172
173    fn validate_model(&self) -> Result<(), ConfigBuilderError> {
174        self.model
175            .as_ref()
176            .or(self.diffusion_model.as_ref())
177            .map(|_| ())
178            .ok_or(ConfigBuilderError::UninitializedField(
179                "Model OR DiffusionModel must be valorized",
180            ))
181    }
182
183    pub fn lora_model(&mut self, lora_model: &Path) -> &mut Self {
184        let folder = lora_model.parent().unwrap();
185        let file_name = lora_model.file_stem().unwrap().to_str().unwrap().to_owned();
186        self.prompt_suffix(format!("<lora:{file_name}:1>"));
187        self.lora_model = Some(folder.into());
188        self
189    }
190
191    pub fn n_threads(&mut self, value: i32) -> &mut Self {
192        self.n_threads = if value > 0 {
193            Some(value)
194        } else {
195            Some(num_cpus::get_physical() as i32)
196        };
197        self
198    }
199}
200
201impl ModelConfig {
202    unsafe fn upscaler_ctx(&mut self) -> Option<*mut upscaler_ctx_t> {
203        unsafe {
204            if self.upscale_model.is_none() || self.upscale_repeats == 0 {
205                None
206            } else {
207                if self.upscaler_ctx.is_none() {
208                    let upscaler = new_upscaler_ctx(
209                        self.upscale_model.as_ref().unwrap().as_ptr(),
210                        self.n_threads,
211                    );
212                    self.upscaler_ctx = Some(upscaler);
213                }
214                self.upscaler_ctx
215            }
216        }
217    }
218
219    unsafe fn diffusion_ctx(&mut self, vae_decode_only: bool) -> *mut sd_ctx_t {
220        unsafe {
221            if self.diffusion_ctx.is_none() {
222                let ctx = new_sd_ctx(
223                    self.model.as_ptr(),
224                    self.clip_l.as_ptr(),
225                    self.clip_g.as_ptr(),
226                    self.t5xxl.as_ptr(),
227                    self.diffusion_model.as_ptr(),
228                    self.vae.as_ptr(),
229                    self.taesd.as_ptr(),
230                    self.control_net.as_ptr(),
231                    self.lora_model.as_ptr(),
232                    self.embeddings.as_ptr(),
233                    self.stacked_id_embd.as_ptr(),
234                    vae_decode_only,
235                    self.vae_tiling,
236                    false,
237                    self.n_threads,
238                    self.weight_type,
239                    self.rng,
240                    self.schedule,
241                    self.clip_on_cpu,
242                    self.control_net_cpu,
243                    self.vae_on_cpu,
244                    self.flash_attention,
245                );
246                self.diffusion_ctx = Some(ctx)
247            }
248            self.diffusion_ctx.unwrap()
249        }
250    }
251}
252
253impl Drop for ModelConfig {
254    fn drop(&mut self) {
255        //Clean-up CTX section
256        unsafe {
257            if let Some(sd_ctx) = self.diffusion_ctx {
258                free_sd_ctx(sd_ctx);
259            }
260
261            if let Some(upscaler_ctx) = self.upscaler_ctx {
262                free_upscaler_ctx(upscaler_ctx);
263            }
264        }
265    }
266}
267
268#[derive(Builder, Debug, Clone)]
269#[builder(setter(into, strip_option), build_fn(validate = "Self::validate"))]
270/// Config struct common to all diffusion methods
271pub struct Config {
272    /// Path to PHOTOMAKER input id images dir
273    #[builder(default = "Default::default()")]
274    input_id_images: CLibPath,
275
276    /// Normalize PHOTOMAKER input id images
277    #[builder(default = "false")]
278    normalize_input: bool,
279
280    /// Path to the input image, required by img2img
281    #[builder(default = "Default::default()")]
282    init_img: CLibPath,
283
284    /// Path to image condition, control net
285    #[builder(default = "Default::default()")]
286    control_image: CLibPath,
287
288    /// Path to write result image to (default: ./output.png)
289    #[builder(default = "PathBuf::from(\"./output.png\")")]
290    output: PathBuf,
291
292    /// The prompt to render
293    prompt: String,
294
295    /// The negative prompt (default: "")
296    #[builder(default = "\"\".into()")]
297    negative_prompt: CLibString,
298
299    /// Unconditional guidance scale (default: 7.0)
300    #[builder(default = "7.0")]
301    cfg_scale: f32,
302
303    /// Guidance (default: 3.5)
304    #[builder(default = "3.5")]
305    guidance: f32,
306
307    /// Strength for noising/unnoising (default: 0.75)
308    #[builder(default = "0.75")]
309    strength: f32,
310
311    /// Strength for keeping input identity (default: 20%)
312    #[builder(default = "20.0")]
313    style_ratio: f32,
314
315    /// Strength to apply Control Net (default: 0.9)
316    /// 1.0 corresponds to full destruction of information in init
317    #[builder(default = "0.9")]
318    control_strength: f32,
319
320    /// Image height, in pixel space (default: 512)
321    #[builder(default = "512")]
322    height: i32,
323
324    /// Image width, in pixel space (default: 512)
325    #[builder(default = "512")]
326    width: i32,
327
328    /// Sampling-method (default: EULER_A)
329    #[builder(default = "SampleMethod::EULER_A")]
330    sampling_method: SampleMethod,
331
332    /// eta in DDIM, only for DDIM and TCD: (default: 0)
333    #[builder(default = "0.")]
334    eta: f32,
335
336    /// Number of sample steps (default: 20)
337    #[builder(default = "20")]
338    steps: i32,
339
340    /// RNG seed (default: 42, use random seed for < 0)
341    #[builder(default = "42")]
342    seed: i64,
343
344    /// Number of images to generate (default: 1)
345    #[builder(default = "1")]
346    batch_count: i32,
347
348    /// Ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)
349    /// <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x
350    #[builder(default = "ClipSkip::Unspecified")]
351    clip_skip: ClipSkip,
352
353    /// Apply canny preprocessor (edge detection) (default: false)
354    #[builder(default = "false")]
355    canny: bool,
356
357    /// skip layer guidance (SLG) scale, only for DiT models: (default: 0)
358    /// 0 means disabled, a value of 2.5 is nice for sd3.5 medium
359    #[builder(default = "0.")]
360    slg_scale: f32,
361
362    /// Layers to skip for SLG steps: (default: \[7,8,9\])
363    #[builder(default = "vec![7, 8, 9]")]
364    skip_layer: Vec<i32>,
365
366    /// SLG enabling point: (default: 0.01)
367    #[builder(default = "0.01")]
368    skip_layer_start: f32,
369
370    /// SLG disabling point: (default: 0.2)
371    #[builder(default = "0.2")]
372    skip_layer_end: f32,
373}
374
375impl ConfigBuilder {
376    fn validate(&self) -> Result<(), ConfigBuilderError> {
377        self.validate_output_dir()
378    }
379
380    fn validate_output_dir(&self) -> Result<(), ConfigBuilderError> {
381        let is_dir = self.output.as_ref().is_some_and(|val| val.is_dir());
382        let multiple_items = self.batch_count.as_ref().is_some_and(|val| *val > 1);
383        if is_dir == multiple_items {
384            Ok(())
385        } else {
386            Err(ConfigBuilderError::ValidationError(
387                "When batch_count > 1, ouput should point to folder and viceversa".to_owned(),
388            ))
389        }
390    }
391}
392
393impl From<Config> for ConfigBuilder {
394    fn from(value: Config) -> Self {
395        let mut builder = ConfigBuilder::default();
396        builder
397            .input_id_images(value.input_id_images)
398            .normalize_input(value.normalize_input)
399            .init_img(value.init_img)
400            .control_image(value.control_image)
401            .output(value.output)
402            .prompt(value.prompt)
403            .negative_prompt(value.negative_prompt)
404            .cfg_scale(value.cfg_scale)
405            .strength(value.strength)
406            .style_ratio(value.style_ratio)
407            .control_strength(value.control_strength)
408            .height(value.height)
409            .width(value.width)
410            .sampling_method(value.sampling_method)
411            .steps(value.steps)
412            .seed(value.seed)
413            .batch_count(value.batch_count)
414            .clip_skip(value.clip_skip)
415            .slg_scale(value.slg_scale)
416            .skip_layer(value.skip_layer)
417            .skip_layer_start(value.skip_layer_start)
418            .skip_layer_end(value.skip_layer_end)
419            .canny(value.canny);
420
421        builder
422    }
423}
424
425#[derive(Debug, Clone, Default)]
426struct CLibString(CString);
427
428impl CLibString {
429    fn as_ptr(&self) -> *const c_char {
430        self.0.as_ptr()
431    }
432}
433
434impl From<&str> for CLibString {
435    fn from(value: &str) -> Self {
436        Self(CString::new(value).unwrap())
437    }
438}
439
440impl From<String> for CLibString {
441    fn from(value: String) -> Self {
442        Self(CString::new(value).unwrap())
443    }
444}
445
446#[derive(Debug, Clone, Default)]
447struct CLibPath(CString);
448
449impl CLibPath {
450    fn as_ptr(&self) -> *const c_char {
451        self.0.as_ptr()
452    }
453}
454
455impl From<PathBuf> for CLibPath {
456    fn from(value: PathBuf) -> Self {
457        Self(CString::new(value.to_str().unwrap_or_default()).unwrap())
458    }
459}
460
461impl From<&Path> for CLibPath {
462    fn from(value: &Path) -> Self {
463        Self(CString::new(value.to_str().unwrap_or_default()).unwrap())
464    }
465}
466
467fn output_files(path: &Path, batch_size: i32) -> Vec<PathBuf> {
468    if batch_size == 1 {
469        vec![path.into()]
470    } else {
471        (1..=batch_size)
472            .map(|id| path.join(format!("output_{id}.png")))
473            .collect()
474    }
475}
476
477unsafe fn upscale(
478    upscale_repeats: i32,
479    upscaler_ctx: Option<*mut upscaler_ctx_t>,
480    data: sd_image_t,
481) -> Result<sd_image_t, DiffusionError> {
482    unsafe {
483        match upscaler_ctx {
484            Some(upscaler_ctx) => {
485                let upscale_factor = 4; // unused for RealESRGAN_x4plus_anime_6B.pth
486                let mut current_image = data;
487                for _ in 0..upscale_repeats {
488                    let upscaled_image =
489                        diffusion_rs_sys::upscale(upscaler_ctx, current_image, upscale_factor);
490
491                    if upscaled_image.data.is_null() {
492                        return Err(DiffusionError::Upscaler);
493                    }
494
495                    free(current_image.data as *mut c_void);
496                    current_image = upscaled_image;
497                }
498                Ok(current_image)
499            }
500            None => Ok(data),
501        }
502    }
503}
504
505/// Generate an image with a prompt
506pub fn txt2img(config: &mut Config, model_config: &mut ModelConfig) -> Result<(), DiffusionError> {
507    let prompt: CLibString = match &model_config.prompt_suffix {
508        Some(suffix) => format!("{} {suffix}", &config.prompt),
509        None => config.prompt.clone(),
510    }
511    .into();
512    let files = output_files(&config.output, config.batch_count);
513    unsafe {
514        let sd_ctx = model_config.diffusion_ctx(true);
515        let upscaler_ctx = model_config.upscaler_ctx();
516
517        let slice = diffusion_rs_sys::txt2img(
518            sd_ctx,
519            prompt.as_ptr(),
520            config.negative_prompt.as_ptr(),
521            config.clip_skip as i32,
522            config.cfg_scale,
523            config.guidance,
524            config.eta,
525            config.width,
526            config.height,
527            config.sampling_method,
528            config.steps,
529            config.seed,
530            config.batch_count,
531            null(),
532            config.control_strength,
533            config.style_ratio,
534            config.normalize_input,
535            config.input_id_images.as_ptr(),
536            config.skip_layer.as_mut_ptr(),
537            config.skip_layer.len(),
538            config.slg_scale,
539            config.skip_layer_start,
540            config.skip_layer_end,
541        );
542        if slice.is_null() {
543            return Err(DiffusionError::Forward);
544        }
545        for (img, path) in slice::from_raw_parts(slice, config.batch_count as usize)
546            .iter()
547            .zip(files)
548        {
549            match upscale(model_config.upscale_repeats, upscaler_ctx, *img) {
550                Ok(img) => {
551                    // Thx @wandbrandon
552                    let len = (img.width * img.height * img.channel) as usize;
553                    let buffer = slice::from_raw_parts(img.data, len).to_vec();
554                    let save_state = ImageBuffer::from_raw(img.width, img.height, buffer)
555                        .map(|img| RgbImage::from(img).save(path));
556                    if let Some(Err(err)) = save_state {
557                        return Err(DiffusionError::StoreImages(err));
558                    }
559                }
560                Err(err) => {
561                    return Err(err);
562                }
563            }
564        }
565
566        //Clean-up slice section
567        free(slice as *mut c_void);
568        Ok(())
569    }
570}
571
572#[cfg(test)]
573mod tests {
574    use std::path::PathBuf;
575
576    use crate::{
577        api::{ConfigBuilderError, ModelConfigBuilder},
578        util::download_file_hf_hub,
579    };
580
581    use super::{ConfigBuilder, txt2img};
582
583    #[test]
584    fn test_required_args_txt2img() {
585        assert!(ConfigBuilder::default().build().is_err());
586        assert!(ModelConfigBuilder::default().build().is_err());
587        ModelConfigBuilder::default()
588            .model(PathBuf::from("./test.ckpt"))
589            .build()
590            .unwrap();
591
592        ConfigBuilder::default()
593            .prompt("a lovely cat driving a sport car")
594            .build()
595            .unwrap();
596
597        assert!(matches!(
598            ConfigBuilder::default()
599                .prompt("a lovely cat driving a sport car")
600                .batch_count(10)
601                .build(),
602            Err(ConfigBuilderError::ValidationError(_))
603        ));
604
605        ConfigBuilder::default()
606            .prompt("a lovely cat driving a sport car")
607            .build()
608            .unwrap();
609
610        ConfigBuilder::default()
611            .prompt("a lovely duck drinking water from a bottle")
612            .batch_count(2)
613            .output(PathBuf::from("./"))
614            .build()
615            .unwrap();
616    }
617
618    #[ignore]
619    #[test]
620    fn test_txt2img() {
621        let model_path =
622            download_file_hf_hub("CompVis/stable-diffusion-v-1-4-original", "sd-v1-4.ckpt")
623                .unwrap();
624
625        let upscaler_path = download_file_hf_hub(
626            "ximso/RealESRGAN_x4plus_anime_6B",
627            "RealESRGAN_x4plus_anime_6B.pth",
628        )
629        .unwrap();
630        let mut config = ConfigBuilder::default()
631            .prompt("a lovely duck drinking water from a bottle")
632            .output(PathBuf::from("./output_1.png"))
633            .batch_count(1)
634            .build()
635            .unwrap();
636        let mut model_config = ModelConfigBuilder::default()
637            .model(model_path)
638            .upscale_model(upscaler_path)
639            .upscale_repeats(1)
640            .build()
641            .unwrap();
642
643        txt2img(&mut config, &mut model_config).unwrap();
644        let mut config2 = ConfigBuilder::from(config)
645            .prompt("a lovely duck drinking water from a straw")
646            .output(PathBuf::from("./output_2.png"))
647            .build()
648            .unwrap();
649        txt2img(&mut config2, &mut model_config).unwrap();
650    }
651}