1use std::ffi::CString;
2use std::ffi::c_char;
3use std::ffi::c_void;
4use std::path::Path;
5use std::path::PathBuf;
6use std::ptr::null;
7use std::slice;
8
9use derive_builder::Builder;
10use diffusion_rs_sys::free_upscaler_ctx;
11use diffusion_rs_sys::new_upscaler_ctx;
12use diffusion_rs_sys::sd_image_t;
13use diffusion_rs_sys::upscaler_ctx_t;
14use image::ImageBuffer;
15use image::ImageError;
16use image::RgbImage;
17use libc::free;
18use thiserror::Error;
19
20use diffusion_rs_sys::free_sd_ctx;
21use diffusion_rs_sys::new_sd_ctx;
22use diffusion_rs_sys::sd_ctx_t;
23
24pub use diffusion_rs_sys::rng_type_t as RngFunction;
26
27pub use diffusion_rs_sys::sample_method_t as SampleMethod;
29
30pub use diffusion_rs_sys::schedule_t as Schedule;
32
33pub use diffusion_rs_sys::sd_type_t as WeightType;
35
36#[non_exhaustive]
37#[derive(Error, Debug)]
38pub enum DiffusionError {
40 #[error("The underling stablediffusion.cpp function returned NULL")]
41 Forward,
42 #[error(transparent)]
43 StoreImages(#[from] ImageError),
44 #[error("The underling upscaler model returned a NULL image")]
45 Upscaler,
46}
47
48#[repr(i32)]
49#[non_exhaustive]
50#[derive(Debug, Default, Copy, Clone, Hash, PartialEq, Eq)]
51pub enum ClipSkip {
53 #[default]
55 Unspecified = 0,
56 None = 1,
57 OneLayer = 2,
58}
59
60#[derive(Builder, Debug, Clone)]
61#[builder(
62 setter(into, strip_option),
63 build_fn(error = "ConfigBuilderError", validate = "Self::validate")
64)]
65pub struct ModelConfig {
66 #[builder(default = "num_cpus::get_physical() as i32", setter(custom))]
69 n_threads: i32,
70
71 #[builder(default = "Default::default()")]
73 upscale_model: Option<CLibPath>,
74
75 #[builder(default = "0")]
77 upscale_repeats: i32,
78
79 #[builder(default = "Default::default()")]
81 model: CLibPath,
82
83 #[builder(default = "Default::default()")]
85 diffusion_model: CLibPath,
86
87 #[builder(default = "Default::default()")]
89 clip_l: CLibPath,
90
91 #[builder(default = "Default::default()")]
93 clip_g: CLibPath,
94
95 #[builder(default = "Default::default()")]
97 t5xxl: CLibPath,
98
99 #[builder(default = "Default::default()")]
101 vae: CLibPath,
102
103 #[builder(default = "Default::default()")]
105 taesd: CLibPath,
106
107 #[builder(default = "Default::default()")]
109 control_net: CLibPath,
110
111 #[builder(default = "Default::default()")]
113 embeddings: CLibPath,
114
115 #[builder(default = "Default::default()")]
117 stacked_id_embd: CLibPath,
118
119 #[builder(default = "WeightType::SD_TYPE_COUNT")]
121 weight_type: WeightType,
122
123 #[builder(default = "Default::default()", setter(custom))]
125 lora_model: CLibPath,
126
127 #[builder(default = "None", private)]
129 prompt_suffix: Option<String>,
130
131 #[builder(default = "false")]
133 vae_tiling: bool,
134
135 #[builder(default = "RngFunction::CUDA_RNG")]
137 rng: RngFunction,
138
139 #[builder(default = "Schedule::DEFAULT")]
141 schedule: Schedule,
142
143 #[builder(default = "false")]
145 vae_on_cpu: bool,
146
147 #[builder(default = "false")]
149 clip_on_cpu: bool,
150
151 #[builder(default = "false")]
153 control_net_cpu: bool,
154
155 #[builder(default = "false")]
159 flash_attention: bool,
160
161 #[builder(default = "true")]
162 chroma_use_dit_mask: bool,
163
164 #[builder(default = "false")]
165 chroma_use_t5_mask: bool,
166
167 #[builder(default = "1")]
168 chroma_t5_mask_pad: i32,
169
170 #[builder(default = "None", private)]
171 upscaler_ctx: Option<*mut upscaler_ctx_t>,
172
173 #[builder(default = "None", private)]
174 diffusion_ctx: Option<*mut sd_ctx_t>,
175}
176
177impl ModelConfigBuilder {
178 fn validate(&self) -> Result<(), ConfigBuilderError> {
179 self.validate_model()
180 }
181
182 fn validate_model(&self) -> Result<(), ConfigBuilderError> {
183 self.model
184 .as_ref()
185 .or(self.diffusion_model.as_ref())
186 .map(|_| ())
187 .ok_or(ConfigBuilderError::UninitializedField(
188 "Model OR DiffusionModel must be valorized",
189 ))
190 }
191
192 pub fn lora_model(&mut self, lora_model: &Path) -> &mut Self {
193 let folder = lora_model.parent().unwrap();
194 let file_name = lora_model.file_stem().unwrap().to_str().unwrap().to_owned();
195 self.prompt_suffix(format!("<lora:{file_name}:1>"));
196 self.lora_model = Some(folder.into());
197 self
198 }
199
200 pub fn n_threads(&mut self, value: i32) -> &mut Self {
201 self.n_threads = if value > 0 {
202 Some(value)
203 } else {
204 Some(num_cpus::get_physical() as i32)
205 };
206 self
207 }
208}
209
210impl ModelConfig {
211 unsafe fn upscaler_ctx(&mut self) -> Option<*mut upscaler_ctx_t> {
212 unsafe {
213 if self.upscale_model.is_none() || self.upscale_repeats == 0 {
214 None
215 } else {
216 if self.upscaler_ctx.is_none() {
217 let upscaler = new_upscaler_ctx(
218 self.upscale_model.as_ref().unwrap().as_ptr(),
219 self.n_threads,
220 );
221 self.upscaler_ctx = Some(upscaler);
222 }
223 self.upscaler_ctx
224 }
225 }
226 }
227
228 unsafe fn diffusion_ctx(&mut self, vae_decode_only: bool) -> *mut sd_ctx_t {
229 unsafe {
230 if self.diffusion_ctx.is_none() {
231 let ctx = new_sd_ctx(
232 self.model.as_ptr(),
233 self.clip_l.as_ptr(),
234 self.clip_g.as_ptr(),
235 self.t5xxl.as_ptr(),
236 self.diffusion_model.as_ptr(),
237 self.vae.as_ptr(),
238 self.taesd.as_ptr(),
239 self.control_net.as_ptr(),
240 self.lora_model.as_ptr(),
241 self.embeddings.as_ptr(),
242 self.stacked_id_embd.as_ptr(),
243 vae_decode_only,
244 self.vae_tiling,
245 false,
246 self.n_threads,
247 self.weight_type,
248 self.rng,
249 self.schedule,
250 self.clip_on_cpu,
251 self.control_net_cpu,
252 self.vae_on_cpu,
253 self.flash_attention,
254 self.chroma_use_dit_mask,
255 self.chroma_use_t5_mask,
256 self.chroma_t5_mask_pad,
257 );
258 self.diffusion_ctx = Some(ctx)
259 }
260 self.diffusion_ctx.unwrap()
261 }
262 }
263}
264
265impl Drop for ModelConfig {
266 fn drop(&mut self) {
267 unsafe {
269 if let Some(sd_ctx) = self.diffusion_ctx {
270 free_sd_ctx(sd_ctx);
271 }
272
273 if let Some(upscaler_ctx) = self.upscaler_ctx {
274 free_upscaler_ctx(upscaler_ctx);
275 }
276 }
277 }
278}
279
280#[derive(Builder, Debug, Clone)]
281#[builder(setter(into, strip_option), build_fn(validate = "Self::validate"))]
282pub struct Config {
284 #[builder(default = "Default::default()")]
286 input_id_images: CLibPath,
287
288 #[builder(default = "false")]
290 normalize_input: bool,
291
292 #[builder(default = "Default::default()")]
294 init_img: CLibPath,
295
296 #[builder(default = "Default::default()")]
298 control_image: CLibPath,
299
300 #[builder(default = "PathBuf::from(\"./output.png\")")]
302 output: PathBuf,
303
304 prompt: String,
306
307 #[builder(default = "\"\".into()")]
309 negative_prompt: CLibString,
310
311 #[builder(default = "7.0")]
313 cfg_scale: f32,
314
315 #[builder(default = "3.5")]
317 guidance: f32,
318
319 #[builder(default = "0.75")]
321 strength: f32,
322
323 #[builder(default = "20.0")]
325 style_ratio: f32,
326
327 #[builder(default = "0.9")]
330 control_strength: f32,
331
332 #[builder(default = "512")]
334 height: i32,
335
336 #[builder(default = "512")]
338 width: i32,
339
340 #[builder(default = "SampleMethod::EULER_A")]
342 sampling_method: SampleMethod,
343
344 #[builder(default = "0.")]
346 eta: f32,
347
348 #[builder(default = "20")]
350 steps: i32,
351
352 #[builder(default = "42")]
354 seed: i64,
355
356 #[builder(default = "1")]
358 batch_count: i32,
359
360 #[builder(default = "ClipSkip::Unspecified")]
363 clip_skip: ClipSkip,
364
365 #[builder(default = "false")]
367 canny: bool,
368
369 #[builder(default = "0.")]
372 slg_scale: f32,
373
374 #[builder(default = "vec![7, 8, 9]")]
376 skip_layer: Vec<i32>,
377
378 #[builder(default = "0.01")]
380 skip_layer_start: f32,
381
382 #[builder(default = "0.2")]
384 skip_layer_end: f32,
385}
386
387impl ConfigBuilder {
388 fn validate(&self) -> Result<(), ConfigBuilderError> {
389 self.validate_output_dir()
390 }
391
392 fn validate_output_dir(&self) -> Result<(), ConfigBuilderError> {
393 let is_dir = self.output.as_ref().is_some_and(|val| val.is_dir());
394 let multiple_items = self.batch_count.as_ref().is_some_and(|val| *val > 1);
395 if is_dir == multiple_items {
396 Ok(())
397 } else {
398 Err(ConfigBuilderError::ValidationError(
399 "When batch_count > 1, ouput should point to folder and viceversa".to_owned(),
400 ))
401 }
402 }
403}
404
405impl From<Config> for ConfigBuilder {
406 fn from(value: Config) -> Self {
407 let mut builder = ConfigBuilder::default();
408 builder
409 .input_id_images(value.input_id_images)
410 .normalize_input(value.normalize_input)
411 .init_img(value.init_img)
412 .control_image(value.control_image)
413 .output(value.output)
414 .prompt(value.prompt)
415 .negative_prompt(value.negative_prompt)
416 .cfg_scale(value.cfg_scale)
417 .strength(value.strength)
418 .style_ratio(value.style_ratio)
419 .control_strength(value.control_strength)
420 .height(value.height)
421 .width(value.width)
422 .sampling_method(value.sampling_method)
423 .steps(value.steps)
424 .seed(value.seed)
425 .batch_count(value.batch_count)
426 .clip_skip(value.clip_skip)
427 .slg_scale(value.slg_scale)
428 .skip_layer(value.skip_layer)
429 .skip_layer_start(value.skip_layer_start)
430 .skip_layer_end(value.skip_layer_end)
431 .canny(value.canny);
432
433 builder
434 }
435}
436
437#[derive(Debug, Clone, Default)]
438struct CLibString(CString);
439
440impl CLibString {
441 fn as_ptr(&self) -> *const c_char {
442 self.0.as_ptr()
443 }
444}
445
446impl From<&str> for CLibString {
447 fn from(value: &str) -> Self {
448 Self(CString::new(value).unwrap())
449 }
450}
451
452impl From<String> for CLibString {
453 fn from(value: String) -> Self {
454 Self(CString::new(value).unwrap())
455 }
456}
457
458#[derive(Debug, Clone, Default)]
459struct CLibPath(CString);
460
461impl CLibPath {
462 fn as_ptr(&self) -> *const c_char {
463 self.0.as_ptr()
464 }
465}
466
467impl From<PathBuf> for CLibPath {
468 fn from(value: PathBuf) -> Self {
469 Self(CString::new(value.to_str().unwrap_or_default()).unwrap())
470 }
471}
472
473impl From<&Path> for CLibPath {
474 fn from(value: &Path) -> Self {
475 Self(CString::new(value.to_str().unwrap_or_default()).unwrap())
476 }
477}
478
479fn output_files(path: &Path, batch_size: i32) -> Vec<PathBuf> {
480 if batch_size == 1 {
481 vec![path.into()]
482 } else {
483 (1..=batch_size)
484 .map(|id| path.join(format!("output_{id}.png")))
485 .collect()
486 }
487}
488
489unsafe fn upscale(
490 upscale_repeats: i32,
491 upscaler_ctx: Option<*mut upscaler_ctx_t>,
492 data: sd_image_t,
493) -> Result<sd_image_t, DiffusionError> {
494 unsafe {
495 match upscaler_ctx {
496 Some(upscaler_ctx) => {
497 let upscale_factor = 4; let mut current_image = data;
499 for _ in 0..upscale_repeats {
500 let upscaled_image =
501 diffusion_rs_sys::upscale(upscaler_ctx, current_image, upscale_factor);
502
503 if upscaled_image.data.is_null() {
504 return Err(DiffusionError::Upscaler);
505 }
506
507 free(current_image.data as *mut c_void);
508 current_image = upscaled_image;
509 }
510 Ok(current_image)
511 }
512 None => Ok(data),
513 }
514 }
515}
516
517pub fn txt2img(config: &mut Config, model_config: &mut ModelConfig) -> Result<(), DiffusionError> {
519 let prompt: CLibString = match &model_config.prompt_suffix {
520 Some(suffix) => format!("{} {suffix}", &config.prompt),
521 None => config.prompt.clone(),
522 }
523 .into();
524 let files = output_files(&config.output, config.batch_count);
525 unsafe {
526 let sd_ctx = model_config.diffusion_ctx(true);
527 let upscaler_ctx = model_config.upscaler_ctx();
528
529 let slice = diffusion_rs_sys::txt2img(
530 sd_ctx,
531 prompt.as_ptr(),
532 config.negative_prompt.as_ptr(),
533 config.clip_skip as i32,
534 config.cfg_scale,
535 config.guidance,
536 config.eta,
537 config.width,
538 config.height,
539 config.sampling_method,
540 config.steps,
541 config.seed,
542 config.batch_count,
543 null(),
544 config.control_strength,
545 config.style_ratio,
546 config.normalize_input,
547 config.input_id_images.as_ptr(),
548 config.skip_layer.as_mut_ptr(),
549 config.skip_layer.len(),
550 config.slg_scale,
551 config.skip_layer_start,
552 config.skip_layer_end,
553 );
554 if slice.is_null() {
555 return Err(DiffusionError::Forward);
556 }
557 for (img, path) in slice::from_raw_parts(slice, config.batch_count as usize)
558 .iter()
559 .zip(files)
560 {
561 match upscale(model_config.upscale_repeats, upscaler_ctx, *img) {
562 Ok(img) => {
563 let len = (img.width * img.height * img.channel) as usize;
565 let buffer = slice::from_raw_parts(img.data, len).to_vec();
566 let save_state = ImageBuffer::from_raw(img.width, img.height, buffer)
567 .map(|img| RgbImage::from(img).save(path));
568 if let Some(Err(err)) = save_state {
569 return Err(DiffusionError::StoreImages(err));
570 }
571 }
572 Err(err) => {
573 return Err(err);
574 }
575 }
576 }
577
578 free(slice as *mut c_void);
580 Ok(())
581 }
582}
583
584#[cfg(test)]
585mod tests {
586 use std::path::PathBuf;
587
588 use crate::{
589 api::{ConfigBuilderError, ModelConfigBuilder},
590 util::download_file_hf_hub,
591 };
592
593 use super::{ConfigBuilder, txt2img};
594
595 #[test]
596 fn test_required_args_txt2img() {
597 assert!(ConfigBuilder::default().build().is_err());
598 assert!(ModelConfigBuilder::default().build().is_err());
599 ModelConfigBuilder::default()
600 .model(PathBuf::from("./test.ckpt"))
601 .build()
602 .unwrap();
603
604 ConfigBuilder::default()
605 .prompt("a lovely cat driving a sport car")
606 .build()
607 .unwrap();
608
609 assert!(matches!(
610 ConfigBuilder::default()
611 .prompt("a lovely cat driving a sport car")
612 .batch_count(10)
613 .build(),
614 Err(ConfigBuilderError::ValidationError(_))
615 ));
616
617 ConfigBuilder::default()
618 .prompt("a lovely cat driving a sport car")
619 .build()
620 .unwrap();
621
622 ConfigBuilder::default()
623 .prompt("a lovely duck drinking water from a bottle")
624 .batch_count(2)
625 .output(PathBuf::from("./"))
626 .build()
627 .unwrap();
628 }
629
630 #[ignore]
631 #[test]
632 fn test_txt2img() {
633 let model_path =
634 download_file_hf_hub("CompVis/stable-diffusion-v-1-4-original", "sd-v1-4.ckpt")
635 .unwrap();
636
637 let upscaler_path = download_file_hf_hub(
638 "ximso/RealESRGAN_x4plus_anime_6B",
639 "RealESRGAN_x4plus_anime_6B.pth",
640 )
641 .unwrap();
642 let mut config = ConfigBuilder::default()
643 .prompt("a lovely duck drinking water from a bottle")
644 .output(PathBuf::from("./output_1.png"))
645 .batch_count(1)
646 .build()
647 .unwrap();
648 let mut model_config = ModelConfigBuilder::default()
649 .model(model_path)
650 .upscale_model(upscaler_path)
651 .upscale_repeats(1)
652 .build()
653 .unwrap();
654
655 txt2img(&mut config, &mut model_config).unwrap();
656 let mut config2 = ConfigBuilder::from(config)
657 .prompt("a lovely duck drinking water from a straw")
658 .output(PathBuf::from("./output_2.png"))
659 .build()
660 .unwrap();
661 txt2img(&mut config2, &mut model_config).unwrap();
662 }
663}