1use std::ffi::CString;
2use std::ffi::c_char;
3use std::ffi::c_void;
4use std::path::Path;
5use std::path::PathBuf;
6use std::ptr::null_mut;
7use std::slice;
8
9use derive_builder::Builder;
10use diffusion_rs_sys::free_upscaler_ctx;
11use diffusion_rs_sys::new_upscaler_ctx;
12use diffusion_rs_sys::sd_ctx_params_t;
13use diffusion_rs_sys::sd_guidance_params_t;
14use diffusion_rs_sys::sd_image_t;
15use diffusion_rs_sys::sd_img_gen_params_t;
16use diffusion_rs_sys::sd_sample_params_t;
17use diffusion_rs_sys::sd_slg_params_t;
18use diffusion_rs_sys::upscaler_ctx_t;
19use image::ImageBuffer;
20use image::ImageError;
21use image::RgbImage;
22use libc::free;
23use thiserror::Error;
24
25use diffusion_rs_sys::free_sd_ctx;
26use diffusion_rs_sys::new_sd_ctx;
27use diffusion_rs_sys::sd_ctx_t;
28
29pub use diffusion_rs_sys::rng_type_t as RngFunction;
31
32pub use diffusion_rs_sys::sample_method_t as SampleMethod;
34
35pub use diffusion_rs_sys::scheduler_t as Scheduler;
37
38pub use diffusion_rs_sys::sd_type_t as WeightType;
40
41#[non_exhaustive]
42#[derive(Error, Debug)]
43pub enum DiffusionError {
45 #[error("The underling stablediffusion.cpp function returned NULL")]
46 Forward,
47 #[error(transparent)]
48 StoreImages(#[from] ImageError),
49 #[error("The underling upscaler model returned a NULL image")]
50 Upscaler,
51}
52
53#[repr(i32)]
54#[non_exhaustive]
55#[derive(Debug, Default, Copy, Clone, Hash, PartialEq, Eq)]
56pub enum ClipSkip {
58 #[default]
60 Unspecified = 0,
61 None = 1,
62 OneLayer = 2,
63}
64
65#[derive(Builder, Debug, Clone)]
66#[builder(
67 setter(into, strip_option),
68 build_fn(error = "ConfigBuilderError", validate = "Self::validate")
69)]
70pub struct ModelConfig {
71 #[builder(default = "num_cpus::get_physical() as i32", setter(custom))]
74 n_threads: i32,
75
76 #[builder(default = "false")]
78 offload_params_to_cpu: bool,
79
80 #[builder(default = "Default::default()")]
82 upscale_model: Option<CLibPath>,
83
84 #[builder(default = "0")]
86 upscale_repeats: i32,
87
88 #[builder(default = "Default::default()")]
90 model: CLibPath,
91
92 #[builder(default = "Default::default()")]
94 diffusion_model: CLibPath,
95
96 #[builder(default = "Default::default()")]
98 clip_l: CLibPath,
99
100 #[builder(default = "Default::default()")]
102 clip_g: CLibPath,
103
104 #[builder(default = "Default::default()")]
106 clip_vision: CLibPath,
107
108 #[builder(default = "Default::default()")]
110 t5xxl: CLibPath,
111
112 #[builder(default = "Default::default()")]
114 vae: CLibPath,
115
116 #[builder(default = "Default::default()")]
118 taesd: CLibPath,
119
120 #[builder(default = "Default::default()")]
122 control_net: CLibPath,
123
124 #[builder(default = "Default::default()")]
126 embeddings: CLibPath,
127
128 #[builder(default = "Default::default()")]
130 stacked_id_embd: CLibPath,
131
132 #[builder(default = "WeightType::SD_TYPE_COUNT")]
134 weight_type: WeightType,
135
136 #[builder(default = "Default::default()", setter(custom))]
138 lora_model: CLibPath,
139
140 #[builder(default = "Default::default()")]
142 high_noise_diffusion_model: CLibPath,
143
144 #[builder(default = "None", private)]
146 prompt_suffix: Option<String>,
147
148 #[builder(default = "false")]
150 vae_tiling: bool,
151
152 #[builder(default = "RngFunction::CUDA_RNG")]
154 rng: RngFunction,
155
156 #[builder(default = "Scheduler::DEFAULT")]
158 scheduler: Scheduler,
159
160 #[builder(default = "false")]
162 vae_on_cpu: bool,
163
164 #[builder(default = "false")]
166 clip_on_cpu: bool,
167
168 #[builder(default = "false")]
170 control_net_cpu: bool,
171
172 #[builder(default = "false")]
176 flash_attention: bool,
177
178 #[builder(default = "false")]
180 chroma_disable_dit_mask: bool,
181
182 #[builder(default = "false")]
184 chroma_enable_t5_mask: bool,
185
186 #[builder(default = "1")]
188 chroma_t5_mask_pad: i32,
189
190 #[builder(default = "false")]
193 diffusion_conv_direct: bool,
194
195 #[builder(default = "false")]
198 vae_conv_direct: bool,
199
200 #[builder(default = "f32::INFINITY")]
202 flow_shift: f32,
203
204 #[builder(default = "None", private)]
205 upscaler_ctx: Option<*mut upscaler_ctx_t>,
206
207 #[builder(default = "None", private)]
208 diffusion_ctx: Option<(*mut sd_ctx_t, sd_ctx_params_t)>,
209}
210
211impl ModelConfigBuilder {
212 fn validate(&self) -> Result<(), ConfigBuilderError> {
213 self.validate_model()
214 }
215
216 fn validate_model(&self) -> Result<(), ConfigBuilderError> {
217 self.model
218 .as_ref()
219 .or(self.diffusion_model.as_ref())
220 .map(|_| ())
221 .ok_or(ConfigBuilderError::UninitializedField(
222 "Model OR DiffusionModel must be valorized",
223 ))
224 }
225
226 pub fn lora_model(&mut self, lora_model: &Path) -> &mut Self {
227 let folder = lora_model.parent().unwrap();
228 let file_name = lora_model.file_stem().unwrap().to_str().unwrap().to_owned();
229 self.prompt_suffix(format!("<lora:{file_name}:1>"));
230 self.lora_model = Some(folder.into());
231 self
232 }
233
234 pub fn n_threads(&mut self, value: i32) -> &mut Self {
235 self.n_threads = if value > 0 {
236 Some(value)
237 } else {
238 Some(num_cpus::get_physical() as i32)
239 };
240 self
241 }
242}
243
244impl ModelConfig {
245 unsafe fn upscaler_ctx(&mut self) -> Option<*mut upscaler_ctx_t> {
246 unsafe {
247 if self.upscale_model.is_none() || self.upscale_repeats == 0 {
248 None
249 } else {
250 if self.upscaler_ctx.is_none() {
251 let upscaler = new_upscaler_ctx(
252 self.upscale_model.as_ref().unwrap().as_ptr(),
253 self.offload_params_to_cpu,
254 self.diffusion_conv_direct,
255 self.n_threads,
256 );
257 self.upscaler_ctx = Some(upscaler);
258 }
259 self.upscaler_ctx
260 }
261 }
262 }
263
264 unsafe fn diffusion_ctx(&mut self, vae_decode_only: bool) -> *mut sd_ctx_t {
265 unsafe {
266 if self.diffusion_ctx.is_none() {
267 let sd_ctx_params = sd_ctx_params_t {
268 model_path: self.model.as_ptr(),
269 clip_l_path: self.clip_l.as_ptr(),
270 clip_g_path: self.clip_g.as_ptr(),
271 clip_vision_path: self.clip_vision.as_ptr(),
272 high_noise_diffusion_model_path: self.high_noise_diffusion_model.as_ptr(),
273 t5xxl_path: self.t5xxl.as_ptr(),
274 diffusion_model_path: self.diffusion_model.as_ptr(),
275 vae_path: self.vae.as_ptr(),
276 taesd_path: self.taesd.as_ptr(),
277 control_net_path: self.control_net.as_ptr(),
278 lora_model_dir: self.lora_model.as_ptr(),
279 embedding_dir: self.embeddings.as_ptr(),
280 stacked_id_embed_dir: self.stacked_id_embd.as_ptr(),
281 vae_decode_only,
282 vae_tiling: self.vae_tiling,
283 free_params_immediately: false,
284 n_threads: self.n_threads,
285 wtype: self.weight_type,
286 rng_type: self.rng,
287 keep_clip_on_cpu: self.clip_on_cpu,
288 keep_control_net_on_cpu: self.control_net_cpu,
289 keep_vae_on_cpu: self.vae_on_cpu,
290 diffusion_flash_attn: self.flash_attention,
291 diffusion_conv_direct: self.diffusion_conv_direct,
292 chroma_use_dit_mask: !self.chroma_disable_dit_mask,
293 chroma_use_t5_mask: self.chroma_enable_t5_mask,
294 chroma_t5_mask_pad: self.chroma_t5_mask_pad,
295 vae_conv_direct: self.vae_conv_direct,
296 offload_params_to_cpu: self.offload_params_to_cpu,
297 flow_shift: self.flow_shift,
298 };
299 let ctx = new_sd_ctx(&sd_ctx_params);
300 self.diffusion_ctx = Some((ctx, sd_ctx_params))
301 }
302 self.diffusion_ctx.unwrap().0
303 }
304 }
305}
306
307impl Drop for ModelConfig {
308 fn drop(&mut self) {
309 unsafe {
311 if let Some((sd_ctx, _)) = self.diffusion_ctx {
312 free_sd_ctx(sd_ctx);
313 }
314
315 if let Some(upscaler_ctx) = self.upscaler_ctx {
316 free_upscaler_ctx(upscaler_ctx);
317 }
318 }
319 }
320}
321
322#[derive(Builder, Debug, Clone)]
323#[builder(setter(into, strip_option), build_fn(validate = "Self::validate"))]
324pub struct Config {
326 #[builder(default = "Default::default()")]
328 input_id_images: CLibPath,
329
330 #[builder(default = "false")]
332 normalize_input: bool,
333
334 #[builder(default = "Default::default()")]
336 init_img: CLibPath,
337
338 #[builder(default = "Default::default()")]
340 control_image: CLibPath,
341
342 #[builder(default = "PathBuf::from(\"./output.png\")")]
344 output: PathBuf,
345
346 prompt: String,
348
349 #[builder(default = "\"\".into()")]
351 negative_prompt: CLibString,
352
353 #[builder(default = "7.0")]
355 cfg_scale: f32,
356
357 #[builder(default = "3.5")]
359 guidance: f32,
360
361 #[builder(default = "0.75")]
363 strength: f32,
364
365 #[builder(default = "20.0")]
367 style_ratio: f32,
368
369 #[builder(default = "0.9")]
372 control_strength: f32,
373
374 #[builder(default = "512")]
376 height: i32,
377
378 #[builder(default = "512")]
380 width: i32,
381
382 #[builder(default = "SampleMethod::EULER_A")]
384 sampling_method: SampleMethod,
385
386 #[builder(default = "0.")]
388 eta: f32,
389
390 #[builder(default = "20")]
392 steps: i32,
393
394 #[builder(default = "42")]
396 seed: i64,
397
398 #[builder(default = "1")]
400 batch_count: i32,
401
402 #[builder(default = "ClipSkip::Unspecified")]
405 clip_skip: ClipSkip,
406
407 #[builder(default = "false")]
409 canny: bool,
410
411 #[builder(default = "0.")]
414 slg_scale: f32,
415
416 #[builder(default = "vec![7, 8, 9]")]
418 skip_layer: Vec<i32>,
419
420 #[builder(default = "0.01")]
422 skip_layer_start: f32,
423
424 #[builder(default = "0.2")]
426 skip_layer_end: f32,
427}
428
429impl ConfigBuilder {
430 fn validate(&self) -> Result<(), ConfigBuilderError> {
431 self.validate_output_dir()
432 }
433
434 fn validate_output_dir(&self) -> Result<(), ConfigBuilderError> {
435 let is_dir = self.output.as_ref().is_some_and(|val| val.is_dir());
436 let multiple_items = self.batch_count.as_ref().is_some_and(|val| *val > 1);
437 if is_dir == multiple_items {
438 Ok(())
439 } else {
440 Err(ConfigBuilderError::ValidationError(
441 "When batch_count > 1, output should point to folder and vice versa".to_owned(),
442 ))
443 }
444 }
445}
446
447impl From<Config> for ConfigBuilder {
448 fn from(value: Config) -> Self {
449 let mut builder = ConfigBuilder::default();
450 builder
451 .input_id_images(value.input_id_images)
452 .normalize_input(value.normalize_input)
453 .init_img(value.init_img)
454 .control_image(value.control_image)
455 .output(value.output)
456 .prompt(value.prompt)
457 .negative_prompt(value.negative_prompt)
458 .cfg_scale(value.cfg_scale)
459 .strength(value.strength)
460 .style_ratio(value.style_ratio)
461 .control_strength(value.control_strength)
462 .height(value.height)
463 .width(value.width)
464 .sampling_method(value.sampling_method)
465 .steps(value.steps)
466 .seed(value.seed)
467 .batch_count(value.batch_count)
468 .clip_skip(value.clip_skip)
469 .slg_scale(value.slg_scale)
470 .skip_layer(value.skip_layer)
471 .skip_layer_start(value.skip_layer_start)
472 .skip_layer_end(value.skip_layer_end)
473 .canny(value.canny);
474
475 builder
476 }
477}
478
479#[derive(Debug, Clone, Default)]
480struct CLibString(CString);
481
482impl CLibString {
483 fn as_ptr(&self) -> *const c_char {
484 self.0.as_ptr()
485 }
486}
487
488impl From<&str> for CLibString {
489 fn from(value: &str) -> Self {
490 Self(CString::new(value).unwrap())
491 }
492}
493
494impl From<String> for CLibString {
495 fn from(value: String) -> Self {
496 Self(CString::new(value).unwrap())
497 }
498}
499
500#[derive(Debug, Clone, Default)]
501struct CLibPath(CString);
502
503impl CLibPath {
504 fn as_ptr(&self) -> *const c_char {
505 self.0.as_ptr()
506 }
507}
508
509impl From<PathBuf> for CLibPath {
510 fn from(value: PathBuf) -> Self {
511 Self(CString::new(value.to_str().unwrap_or_default()).unwrap())
512 }
513}
514
515impl From<&Path> for CLibPath {
516 fn from(value: &Path) -> Self {
517 Self(CString::new(value.to_str().unwrap_or_default()).unwrap())
518 }
519}
520
521fn output_files(path: &Path, prompt: &str, batch_size: i32) -> Vec<PathBuf> {
522 if batch_size == 1 {
523 vec![path.into()]
524 } else {
525 (1..=batch_size)
526 .map(|id| path.join(format!("{prompt}_{id}.png")))
527 .collect()
528 }
529}
530
531unsafe fn upscale(
532 upscale_repeats: i32,
533 upscaler_ctx: Option<*mut upscaler_ctx_t>,
534 data: sd_image_t,
535) -> Result<sd_image_t, DiffusionError> {
536 unsafe {
537 match upscaler_ctx {
538 Some(upscaler_ctx) => {
539 let upscale_factor = 4; let mut current_image = data;
541 for _ in 0..upscale_repeats {
542 let upscaled_image =
543 diffusion_rs_sys::upscale(upscaler_ctx, current_image, upscale_factor);
544
545 if upscaled_image.data.is_null() {
546 return Err(DiffusionError::Upscaler);
547 }
548
549 free(current_image.data as *mut c_void);
550 current_image = upscaled_image;
551 }
552 Ok(current_image)
553 }
554 None => Ok(data),
555 }
556 }
557}
558
559pub fn gen_img(config: &mut Config, model_config: &mut ModelConfig) -> Result<(), DiffusionError> {
561 let prompt: CLibString = match &model_config.prompt_suffix {
562 Some(suffix) => format!("{} {suffix}", &config.prompt),
563 None => config.prompt.clone(),
564 }
565 .into();
566 let files = output_files(&config.output, &config.prompt, config.batch_count);
567 unsafe {
568 let sd_ctx = model_config.diffusion_ctx(true);
569 let upscaler_ctx = model_config.upscaler_ctx();
570 let init_image = sd_image_t {
571 width: 0,
572 height: 0,
573 channel: 3,
574 data: null_mut(),
575 };
576 let mask_image = sd_image_t {
577 width: config.width as u32,
578 height: config.height as u32,
579 channel: 1,
580 data: null_mut(),
581 };
582 let guidance = sd_guidance_params_t {
583 txt_cfg: config.cfg_scale,
584 img_cfg: config.cfg_scale,
585 distilled_guidance: config.guidance,
586 slg: sd_slg_params_t {
587 layers: config.skip_layer.as_mut_ptr(),
588 layer_count: config.skip_layer.len(),
589 layer_start: config.skip_layer_start,
590 layer_end: config.skip_layer_end,
591 scale: config.slg_scale,
592 },
593 };
594 let sample_params = sd_sample_params_t {
595 guidance,
596 sample_method: config.sampling_method,
597 sample_steps: config.steps,
598 eta: config.eta,
599 scheduler: model_config.scheduler,
600 };
601 let control_image = sd_image_t {
602 width: 0,
603 height: 0,
604 channel: 3,
605 data: null_mut(),
606 };
607
608 let sd_img_gen_params = sd_img_gen_params_t {
609 prompt: prompt.as_ptr(),
610 negative_prompt: config.negative_prompt.as_ptr(),
611 clip_skip: config.clip_skip as i32,
612 init_image,
613 ref_images: null_mut(),
614 ref_images_count: 0,
615 increase_ref_index: false,
616 mask_image,
617 width: config.width,
618 height: config.height,
619 sample_params,
620 strength: config.strength,
621 seed: config.seed,
622 batch_count: config.batch_count,
623 control_image,
624 control_strength: config.control_strength,
625 style_strength: config.style_ratio,
626 normalize_input: config.normalize_input,
627 input_id_images_path: config.input_id_images.as_ptr(),
628 };
629
630 let slice = diffusion_rs_sys::generate_image(sd_ctx, &sd_img_gen_params);
631 if slice.is_null() {
632 return Err(DiffusionError::Forward);
633 }
634 for (img, path) in slice::from_raw_parts(slice, config.batch_count as usize)
635 .iter()
636 .zip(files)
637 {
638 match upscale(model_config.upscale_repeats, upscaler_ctx, *img) {
639 Ok(img) => {
640 let len = (img.width * img.height * img.channel) as usize;
642 let buffer = slice::from_raw_parts(img.data, len).to_vec();
643 let save_state = ImageBuffer::from_raw(img.width, img.height, buffer)
644 .map(|img| RgbImage::from(img).save(path));
645 if let Some(Err(err)) = save_state {
646 return Err(DiffusionError::StoreImages(err));
647 }
648 }
649 Err(err) => {
650 return Err(err);
651 }
652 }
653 }
654
655 free(slice as *mut c_void);
657 Ok(())
658 }
659}
660
661#[cfg(test)]
662mod tests {
663 use std::path::PathBuf;
664
665 use crate::{
666 api::{ConfigBuilderError, ModelConfigBuilder},
667 util::download_file_hf_hub,
668 };
669
670 use super::{ConfigBuilder, gen_img};
671
672 #[test]
673 fn test_required_args_txt2img() {
674 assert!(ConfigBuilder::default().build().is_err());
675 assert!(ModelConfigBuilder::default().build().is_err());
676 ModelConfigBuilder::default()
677 .model(PathBuf::from("./test.ckpt"))
678 .build()
679 .unwrap();
680
681 ConfigBuilder::default()
682 .prompt("a lovely cat driving a sport car")
683 .build()
684 .unwrap();
685
686 assert!(matches!(
687 ConfigBuilder::default()
688 .prompt("a lovely cat driving a sport car")
689 .batch_count(10)
690 .build(),
691 Err(ConfigBuilderError::ValidationError(_))
692 ));
693
694 ConfigBuilder::default()
695 .prompt("a lovely cat driving a sport car")
696 .build()
697 .unwrap();
698
699 ConfigBuilder::default()
700 .prompt("a lovely duck drinking water from a bottle")
701 .batch_count(2)
702 .output(PathBuf::from("./"))
703 .build()
704 .unwrap();
705 }
706
707 #[ignore]
708 #[test]
709 fn test_img_gen() {
710 let model_path =
711 download_file_hf_hub("CompVis/stable-diffusion-v-1-4-original", "sd-v1-4.ckpt")
712 .unwrap();
713
714 let upscaler_path = download_file_hf_hub(
715 "ximso/RealESRGAN_x4plus_anime_6B",
716 "RealESRGAN_x4plus_anime_6B.pth",
717 )
718 .unwrap();
719 let mut config = ConfigBuilder::default()
720 .prompt("a lovely duck drinking water from a bottle")
721 .output(PathBuf::from("./output_1.png"))
722 .batch_count(1)
723 .build()
724 .unwrap();
725 let mut model_config = ModelConfigBuilder::default()
726 .model(model_path)
727 .upscale_model(upscaler_path)
728 .upscale_repeats(1)
729 .build()
730 .unwrap();
731
732 gen_img(&mut config, &mut model_config).unwrap();
733 let mut config2 = ConfigBuilder::from(config.clone())
734 .prompt("a lovely duck drinking water from a straw")
735 .output(PathBuf::from("./output_2.png"))
736 .build()
737 .unwrap();
738 gen_img(&mut config2, &mut model_config).unwrap();
739
740 let mut config3 = ConfigBuilder::from(config)
741 .prompt("a lovely dog drinking water from a starbucks cup")
742 .batch_count(2)
743 .output(PathBuf::from("./"))
744 .build()
745 .unwrap();
746
747 gen_img(&mut config3, &mut model_config).unwrap();
748 }
749}