1use std::ffi::CString;
2use std::ffi::c_char;
3use std::ffi::c_void;
4use std::path::Path;
5use std::path::PathBuf;
6use std::ptr::null;
7use std::ptr::null_mut;
8use std::slice;
9
10use derive_builder::Builder;
11use diffusion_rs_sys::free_upscaler_ctx;
12use diffusion_rs_sys::new_upscaler_ctx;
13use diffusion_rs_sys::sd_ctx_params_t;
14use diffusion_rs_sys::sd_guidance_params_t;
15use diffusion_rs_sys::sd_image_t;
16use diffusion_rs_sys::sd_img_gen_params_t;
17use diffusion_rs_sys::sd_slg_params_t;
18use diffusion_rs_sys::upscaler_ctx_t;
19use image::ImageBuffer;
20use image::ImageError;
21use image::RgbImage;
22use libc::free;
23use thiserror::Error;
24
25use diffusion_rs_sys::free_sd_ctx;
26use diffusion_rs_sys::new_sd_ctx;
27use diffusion_rs_sys::sd_ctx_t;
28
29pub use diffusion_rs_sys::rng_type_t as RngFunction;
31
32pub use diffusion_rs_sys::sample_method_t as SampleMethod;
34
35pub use diffusion_rs_sys::schedule_t as Schedule;
37
38pub use diffusion_rs_sys::sd_type_t as WeightType;
40
41#[non_exhaustive]
42#[derive(Error, Debug)]
43pub enum DiffusionError {
45 #[error("The underling stablediffusion.cpp function returned NULL")]
46 Forward,
47 #[error(transparent)]
48 StoreImages(#[from] ImageError),
49 #[error("The underling upscaler model returned a NULL image")]
50 Upscaler,
51}
52
53#[repr(i32)]
54#[non_exhaustive]
55#[derive(Debug, Default, Copy, Clone, Hash, PartialEq, Eq)]
56pub enum ClipSkip {
58 #[default]
60 Unspecified = 0,
61 None = 1,
62 OneLayer = 2,
63}
64
65#[derive(Builder, Debug, Clone)]
66#[builder(
67 setter(into, strip_option),
68 build_fn(error = "ConfigBuilderError", validate = "Self::validate")
69)]
70pub struct ModelConfig {
71 #[builder(default = "num_cpus::get_physical() as i32", setter(custom))]
74 n_threads: i32,
75
76 #[builder(default = "Default::default()")]
78 upscale_model: Option<CLibPath>,
79
80 #[builder(default = "0")]
82 upscale_repeats: i32,
83
84 #[builder(default = "Default::default()")]
86 model: CLibPath,
87
88 #[builder(default = "Default::default()")]
90 diffusion_model: CLibPath,
91
92 #[builder(default = "Default::default()")]
94 clip_l: CLibPath,
95
96 #[builder(default = "Default::default()")]
98 clip_g: CLibPath,
99
100 #[builder(default = "Default::default()")]
102 t5xxl: CLibPath,
103
104 #[builder(default = "Default::default()")]
106 vae: CLibPath,
107
108 #[builder(default = "Default::default()")]
110 taesd: CLibPath,
111
112 #[builder(default = "Default::default()")]
114 control_net: CLibPath,
115
116 #[builder(default = "Default::default()")]
118 embeddings: CLibPath,
119
120 #[builder(default = "Default::default()")]
122 stacked_id_embd: CLibPath,
123
124 #[builder(default = "WeightType::SD_TYPE_COUNT")]
126 weight_type: WeightType,
127
128 #[builder(default = "Default::default()", setter(custom))]
130 lora_model: CLibPath,
131
132 #[builder(default = "None", private)]
134 prompt_suffix: Option<String>,
135
136 #[builder(default = "false")]
138 vae_tiling: bool,
139
140 #[builder(default = "RngFunction::CUDA_RNG")]
142 rng: RngFunction,
143
144 #[builder(default = "Schedule::DEFAULT")]
146 schedule: Schedule,
147
148 #[builder(default = "false")]
150 vae_on_cpu: bool,
151
152 #[builder(default = "false")]
154 clip_on_cpu: bool,
155
156 #[builder(default = "false")]
158 control_net_cpu: bool,
159
160 #[builder(default = "false")]
164 flash_attention: bool,
165
166 #[builder(default = "false")]
167 chroma_disable_dit_mask: bool,
168
169 #[builder(default = "false")]
170 chroma_enable_t5_mask: bool,
171
172 #[builder(default = "1")]
173 chroma_t5_mask_pad: i32,
174
175 #[builder(default = "None", private)]
176 upscaler_ctx: Option<*mut upscaler_ctx_t>,
177
178 #[builder(default = "None", private)]
179 diffusion_ctx: Option<(*mut sd_ctx_t, sd_ctx_params_t)>,
180}
181
182impl ModelConfigBuilder {
183 fn validate(&self) -> Result<(), ConfigBuilderError> {
184 self.validate_model()
185 }
186
187 fn validate_model(&self) -> Result<(), ConfigBuilderError> {
188 self.model
189 .as_ref()
190 .or(self.diffusion_model.as_ref())
191 .map(|_| ())
192 .ok_or(ConfigBuilderError::UninitializedField(
193 "Model OR DiffusionModel must be valorized",
194 ))
195 }
196
197 pub fn lora_model(&mut self, lora_model: &Path) -> &mut Self {
198 let folder = lora_model.parent().unwrap();
199 let file_name = lora_model.file_stem().unwrap().to_str().unwrap().to_owned();
200 self.prompt_suffix(format!("<lora:{file_name}:1>"));
201 self.lora_model = Some(folder.into());
202 self
203 }
204
205 pub fn n_threads(&mut self, value: i32) -> &mut Self {
206 self.n_threads = if value > 0 {
207 Some(value)
208 } else {
209 Some(num_cpus::get_physical() as i32)
210 };
211 self
212 }
213}
214
215impl ModelConfig {
216 unsafe fn upscaler_ctx(&mut self) -> Option<*mut upscaler_ctx_t> {
217 unsafe {
218 if self.upscale_model.is_none() || self.upscale_repeats == 0 {
219 None
220 } else {
221 if self.upscaler_ctx.is_none() {
222 let upscaler = new_upscaler_ctx(
223 self.upscale_model.as_ref().unwrap().as_ptr(),
224 self.n_threads,
225 );
226 self.upscaler_ctx = Some(upscaler);
227 }
228 self.upscaler_ctx
229 }
230 }
231 }
232
233 unsafe fn diffusion_ctx(&mut self, vae_decode_only: bool) -> *mut sd_ctx_t {
234 unsafe {
235 if self.diffusion_ctx.is_none() {
236 let sd_ctx_params = sd_ctx_params_t {
237 model_path: self.model.as_ptr(),
238 clip_l_path: self.clip_l.as_ptr(),
239 clip_g_path: self.clip_g.as_ptr(),
240 t5xxl_path: self.t5xxl.as_ptr(),
241 diffusion_model_path: self.diffusion_model.as_ptr(),
242 vae_path: self.vae.as_ptr(),
243 taesd_path: self.taesd.as_ptr(),
244 control_net_path: self.control_net.as_ptr(),
245 lora_model_dir: self.lora_model.as_ptr(),
246 embedding_dir: self.embeddings.as_ptr(),
247 stacked_id_embed_dir: self.stacked_id_embd.as_ptr(),
248 vae_decode_only,
249 vae_tiling: self.vae_tiling,
250 free_params_immediately: false,
251 n_threads: self.n_threads,
252 wtype: self.weight_type,
253 rng_type: self.rng,
254 schedule: self.schedule,
255 keep_clip_on_cpu: self.clip_on_cpu,
256 keep_control_net_on_cpu: self.control_net_cpu,
257 keep_vae_on_cpu: self.vae_on_cpu,
258 diffusion_flash_attn: self.flash_attention,
259 chroma_use_dit_mask: !self.chroma_disable_dit_mask,
260 chroma_use_t5_mask: self.chroma_enable_t5_mask,
261 chroma_t5_mask_pad: self.chroma_t5_mask_pad,
262 };
263 let ctx = new_sd_ctx(&sd_ctx_params);
264 self.diffusion_ctx = Some((ctx, sd_ctx_params))
265 }
266 self.diffusion_ctx.unwrap().0
267 }
268 }
269}
270
271impl Drop for ModelConfig {
272 fn drop(&mut self) {
273 unsafe {
275 if let Some((sd_ctx, _)) = self.diffusion_ctx {
276 free_sd_ctx(sd_ctx);
277 }
278
279 if let Some(upscaler_ctx) = self.upscaler_ctx {
280 free_upscaler_ctx(upscaler_ctx);
281 }
282 }
283 }
284}
285
286#[derive(Builder, Debug, Clone)]
287#[builder(setter(into, strip_option), build_fn(validate = "Self::validate"))]
288pub struct Config {
290 #[builder(default = "Default::default()")]
292 input_id_images: CLibPath,
293
294 #[builder(default = "false")]
296 normalize_input: bool,
297
298 #[builder(default = "Default::default()")]
300 init_img: CLibPath,
301
302 #[builder(default = "Default::default()")]
304 control_image: CLibPath,
305
306 #[builder(default = "PathBuf::from(\"./output.png\")")]
308 output: PathBuf,
309
310 prompt: String,
312
313 #[builder(default = "\"\".into()")]
315 negative_prompt: CLibString,
316
317 #[builder(default = "7.0")]
319 cfg_scale: f32,
320
321 #[builder(default = "1.0")]
323 min_cfg_scale: f32,
324
325 #[builder(default = "3.5")]
327 guidance: f32,
328
329 #[builder(default = "0.75")]
331 strength: f32,
332
333 #[builder(default = "20.0")]
335 style_ratio: f32,
336
337 #[builder(default = "0.9")]
340 control_strength: f32,
341
342 #[builder(default = "512")]
344 height: i32,
345
346 #[builder(default = "512")]
348 width: i32,
349
350 #[builder(default = "SampleMethod::EULER_A")]
352 sampling_method: SampleMethod,
353
354 #[builder(default = "0.")]
356 eta: f32,
357
358 #[builder(default = "20")]
360 steps: i32,
361
362 #[builder(default = "42")]
364 seed: i64,
365
366 #[builder(default = "1")]
368 batch_count: i32,
369
370 #[builder(default = "ClipSkip::Unspecified")]
373 clip_skip: ClipSkip,
374
375 #[builder(default = "false")]
377 canny: bool,
378
379 #[builder(default = "0.")]
382 slg_scale: f32,
383
384 #[builder(default = "vec![7, 8, 9]")]
386 skip_layer: Vec<i32>,
387
388 #[builder(default = "0.01")]
390 skip_layer_start: f32,
391
392 #[builder(default = "0.2")]
394 skip_layer_end: f32,
395}
396
397impl ConfigBuilder {
398 fn validate(&self) -> Result<(), ConfigBuilderError> {
399 self.validate_output_dir()
400 }
401
402 fn validate_output_dir(&self) -> Result<(), ConfigBuilderError> {
403 let is_dir = self.output.as_ref().is_some_and(|val| val.is_dir());
404 let multiple_items = self.batch_count.as_ref().is_some_and(|val| *val > 1);
405 if is_dir == multiple_items {
406 Ok(())
407 } else {
408 Err(ConfigBuilderError::ValidationError(
409 "When batch_count > 1, ouput should point to folder and viceversa".to_owned(),
410 ))
411 }
412 }
413}
414
415impl From<Config> for ConfigBuilder {
416 fn from(value: Config) -> Self {
417 let mut builder = ConfigBuilder::default();
418 builder
419 .input_id_images(value.input_id_images)
420 .normalize_input(value.normalize_input)
421 .init_img(value.init_img)
422 .control_image(value.control_image)
423 .output(value.output)
424 .prompt(value.prompt)
425 .negative_prompt(value.negative_prompt)
426 .cfg_scale(value.cfg_scale)
427 .min_cfg_scale(value.min_cfg_scale)
428 .strength(value.strength)
429 .style_ratio(value.style_ratio)
430 .control_strength(value.control_strength)
431 .height(value.height)
432 .width(value.width)
433 .sampling_method(value.sampling_method)
434 .steps(value.steps)
435 .seed(value.seed)
436 .batch_count(value.batch_count)
437 .clip_skip(value.clip_skip)
438 .slg_scale(value.slg_scale)
439 .skip_layer(value.skip_layer)
440 .skip_layer_start(value.skip_layer_start)
441 .skip_layer_end(value.skip_layer_end)
442 .canny(value.canny);
443
444 builder
445 }
446}
447
448#[derive(Debug, Clone, Default)]
449struct CLibString(CString);
450
451impl CLibString {
452 fn as_ptr(&self) -> *const c_char {
453 self.0.as_ptr()
454 }
455}
456
457impl From<&str> for CLibString {
458 fn from(value: &str) -> Self {
459 Self(CString::new(value).unwrap())
460 }
461}
462
463impl From<String> for CLibString {
464 fn from(value: String) -> Self {
465 Self(CString::new(value).unwrap())
466 }
467}
468
469#[derive(Debug, Clone, Default)]
470struct CLibPath(CString);
471
472impl CLibPath {
473 fn as_ptr(&self) -> *const c_char {
474 self.0.as_ptr()
475 }
476}
477
478impl From<PathBuf> for CLibPath {
479 fn from(value: PathBuf) -> Self {
480 Self(CString::new(value.to_str().unwrap_or_default()).unwrap())
481 }
482}
483
484impl From<&Path> for CLibPath {
485 fn from(value: &Path) -> Self {
486 Self(CString::new(value.to_str().unwrap_or_default()).unwrap())
487 }
488}
489
490fn output_files(path: &Path, batch_size: i32) -> Vec<PathBuf> {
491 if batch_size == 1 {
492 vec![path.into()]
493 } else {
494 (1..=batch_size)
495 .map(|id| path.join(format!("output_{id}.png")))
496 .collect()
497 }
498}
499
500unsafe fn upscale(
501 upscale_repeats: i32,
502 upscaler_ctx: Option<*mut upscaler_ctx_t>,
503 data: sd_image_t,
504) -> Result<sd_image_t, DiffusionError> {
505 unsafe {
506 match upscaler_ctx {
507 Some(upscaler_ctx) => {
508 let upscale_factor = 4; let mut current_image = data;
510 for _ in 0..upscale_repeats {
511 let upscaled_image =
512 diffusion_rs_sys::upscale(upscaler_ctx, current_image, upscale_factor);
513
514 if upscaled_image.data.is_null() {
515 return Err(DiffusionError::Upscaler);
516 }
517
518 free(current_image.data as *mut c_void);
519 current_image = upscaled_image;
520 }
521 Ok(current_image)
522 }
523 None => Ok(data),
524 }
525 }
526}
527
528pub fn gen_img(config: &mut Config, model_config: &mut ModelConfig) -> Result<(), DiffusionError> {
530 let prompt: CLibString = match &model_config.prompt_suffix {
531 Some(suffix) => format!("{} {suffix}", &config.prompt),
532 None => config.prompt.clone(),
533 }
534 .into();
535 let files = output_files(&config.output, config.batch_count);
536 unsafe {
537 let sd_ctx = model_config.diffusion_ctx(true);
538 let upscaler_ctx = model_config.upscaler_ctx();
539 let init_image = sd_image_t {
540 width: 0,
541 height: 0,
542 channel: 3,
543 data: null_mut(),
544 };
545 let mask_image = sd_image_t {
546 width: config.width as u32,
547 height: config.height as u32,
548 channel: 1,
549 data: null_mut(),
550 };
551 let guidance = sd_guidance_params_t {
552 txt_cfg: config.cfg_scale,
553 img_cfg: config.cfg_scale,
554 min_cfg: config.min_cfg_scale,
555 distilled_guidance: config.guidance,
556 slg: sd_slg_params_t {
557 layers: config.skip_layer.as_mut_ptr(),
558 layer_count: config.skip_layer.len(),
559 layer_start: config.skip_layer_start,
560 layer_end: config.skip_layer_end,
561 scale: config.slg_scale,
562 },
563 };
564
565 let sd_img_gen_params = sd_img_gen_params_t {
566 prompt: prompt.as_ptr(),
567 negative_prompt: config.negative_prompt.as_ptr(),
568 clip_skip: config.clip_skip as i32,
569 guidance,
570 init_image,
571 ref_images: null_mut(),
572 ref_images_count: 0,
573 mask_image,
574 width: config.width,
575 height: config.height,
576 sample_method: config.sampling_method,
577 sample_steps: config.steps,
578 eta: config.eta,
579 strength: config.strength,
580 seed: config.seed,
581 batch_count: config.batch_count,
582 control_cond: null(),
583 control_strength: config.control_strength,
584 style_strength: config.style_ratio,
585 normalize_input: config.normalize_input,
586 input_id_images_path: config.input_id_images.as_ptr(),
587 };
588
589 let slice = diffusion_rs_sys::generate_image(sd_ctx, &sd_img_gen_params);
590 if slice.is_null() {
591 return Err(DiffusionError::Forward);
592 }
593 for (img, path) in slice::from_raw_parts(slice, config.batch_count as usize)
594 .iter()
595 .zip(files)
596 {
597 match upscale(model_config.upscale_repeats, upscaler_ctx, *img) {
598 Ok(img) => {
599 let len = (img.width * img.height * img.channel) as usize;
601 let buffer = slice::from_raw_parts(img.data, len).to_vec();
602 let save_state = ImageBuffer::from_raw(img.width, img.height, buffer)
603 .map(|img| RgbImage::from(img).save(path));
604 if let Some(Err(err)) = save_state {
605 return Err(DiffusionError::StoreImages(err));
606 }
607 }
608 Err(err) => {
609 return Err(err);
610 }
611 }
612 }
613
614 free(slice as *mut c_void);
616 Ok(())
617 }
618}
619
620#[cfg(test)]
621mod tests {
622 use std::path::PathBuf;
623
624 use crate::{
625 api::{ConfigBuilderError, ModelConfigBuilder},
626 util::download_file_hf_hub,
627 };
628
629 use super::{ConfigBuilder, gen_img};
630
631 #[test]
632 fn test_required_args_txt2img() {
633 assert!(ConfigBuilder::default().build().is_err());
634 assert!(ModelConfigBuilder::default().build().is_err());
635 ModelConfigBuilder::default()
636 .model(PathBuf::from("./test.ckpt"))
637 .build()
638 .unwrap();
639
640 ConfigBuilder::default()
641 .prompt("a lovely cat driving a sport car")
642 .build()
643 .unwrap();
644
645 assert!(matches!(
646 ConfigBuilder::default()
647 .prompt("a lovely cat driving a sport car")
648 .batch_count(10)
649 .build(),
650 Err(ConfigBuilderError::ValidationError(_))
651 ));
652
653 ConfigBuilder::default()
654 .prompt("a lovely cat driving a sport car")
655 .build()
656 .unwrap();
657
658 ConfigBuilder::default()
659 .prompt("a lovely duck drinking water from a bottle")
660 .batch_count(2)
661 .output(PathBuf::from("./"))
662 .build()
663 .unwrap();
664 }
665
666 #[ignore]
667 #[test]
668 fn test_txt2img() {
669 let model_path =
670 download_file_hf_hub("CompVis/stable-diffusion-v-1-4-original", "sd-v1-4.ckpt")
671 .unwrap();
672
673 let upscaler_path = download_file_hf_hub(
674 "ximso/RealESRGAN_x4plus_anime_6B",
675 "RealESRGAN_x4plus_anime_6B.pth",
676 )
677 .unwrap();
678 let mut config = ConfigBuilder::default()
679 .prompt("a lovely duck drinking water from a bottle")
680 .output(PathBuf::from("./output_1.png"))
681 .batch_count(1)
682 .build()
683 .unwrap();
684 let mut model_config = ModelConfigBuilder::default()
685 .model(model_path)
686 .upscale_model(upscaler_path)
687 .upscale_repeats(1)
688 .build()
689 .unwrap();
690
691 gen_img(&mut config, &mut model_config).unwrap();
692 let mut config2 = ConfigBuilder::from(config)
693 .prompt("a lovely duck drinking water from a straw")
694 .output(PathBuf::from("./output_2.png"))
695 .build()
696 .unwrap();
697 gen_img(&mut config2, &mut model_config).unwrap();
698 }
699}