1use std::ffi::CString;
2use std::ffi::c_char;
3use std::ffi::c_void;
4use std::path::Path;
5use std::path::PathBuf;
6use std::ptr::null;
7use std::slice;
8
9use derive_builder::Builder;
10use diffusion_rs_sys::free_upscaler_ctx;
11use diffusion_rs_sys::new_upscaler_ctx;
12use diffusion_rs_sys::sd_image_t;
13use diffusion_rs_sys::upscaler_ctx_t;
14use image::ImageBuffer;
15use image::ImageError;
16use image::RgbImage;
17use libc::free;
18use thiserror::Error;
19
20use diffusion_rs_sys::free_sd_ctx;
21use diffusion_rs_sys::new_sd_ctx;
22use diffusion_rs_sys::sd_ctx_t;
23
24pub use diffusion_rs_sys::rng_type_t as RngFunction;
26
27pub use diffusion_rs_sys::sample_method_t as SampleMethod;
29
30pub use diffusion_rs_sys::schedule_t as Schedule;
32
33pub use diffusion_rs_sys::sd_type_t as WeightType;
35
36#[non_exhaustive]
37#[derive(Error, Debug)]
38pub enum DiffusionError {
40 #[error("The underling stablediffusion.cpp function returned NULL")]
41 Forward,
42 #[error(transparent)]
43 StoreImages(#[from] ImageError),
44 #[error("The underling upscaler model returned a NULL image")]
45 Upscaler,
46}
47
48#[repr(i32)]
49#[non_exhaustive]
50#[derive(Debug, Default, Copy, Clone, Hash, PartialEq, Eq)]
51pub enum ClipSkip {
53 #[default]
55 Unspecified = 0,
56 None = 1,
57 OneLayer = 2,
58}
59
60#[derive(Builder, Debug, Clone)]
61#[builder(
62 setter(into, strip_option),
63 build_fn(error = "ConfigBuilderError", validate = "Self::validate")
64)]
65pub struct ModelConfig {
66 #[builder(default = "num_cpus::get_physical() as i32", setter(custom))]
69 n_threads: i32,
70
71 #[builder(default = "Default::default()")]
73 upscale_model: Option<CLibPath>,
74
75 #[builder(default = "0")]
77 upscale_repeats: i32,
78
79 #[builder(default = "Default::default()")]
81 model: CLibPath,
82
83 #[builder(default = "Default::default()")]
85 diffusion_model: CLibPath,
86
87 #[builder(default = "Default::default()")]
89 clip_l: CLibPath,
90
91 #[builder(default = "Default::default()")]
93 clip_g: CLibPath,
94
95 #[builder(default = "Default::default()")]
97 t5xxl: CLibPath,
98
99 #[builder(default = "Default::default()")]
101 vae: CLibPath,
102
103 #[builder(default = "Default::default()")]
105 taesd: CLibPath,
106
107 #[builder(default = "Default::default()")]
109 control_net: CLibPath,
110
111 #[builder(default = "Default::default()")]
113 embeddings: CLibPath,
114
115 #[builder(default = "Default::default()")]
117 stacked_id_embd: CLibPath,
118
119 #[builder(default = "WeightType::SD_TYPE_COUNT")]
121 weight_type: WeightType,
122
123 #[builder(default = "Default::default()", setter(custom))]
125 lora_model: CLibPath,
126
127 #[builder(default = "None", private)]
129 prompt_suffix: Option<String>,
130
131 #[builder(default = "false")]
133 vae_tiling: bool,
134
135 #[builder(default = "RngFunction::CUDA_RNG")]
137 rng: RngFunction,
138
139 #[builder(default = "Schedule::DEFAULT")]
141 schedule: Schedule,
142
143 #[builder(default = "false")]
145 vae_on_cpu: bool,
146
147 #[builder(default = "false")]
149 clip_on_cpu: bool,
150
151 #[builder(default = "false")]
153 control_net_cpu: bool,
154
155 #[builder(default = "false")]
159 flash_attention: bool,
160
161 #[builder(default = "None", private)]
162 upscaler_ctx: Option<*mut upscaler_ctx_t>,
163
164 #[builder(default = "None", private)]
165 diffusion_ctx: Option<*mut sd_ctx_t>,
166}
167
168impl ModelConfigBuilder {
169 fn validate(&self) -> Result<(), ConfigBuilderError> {
170 self.validate_model()
171 }
172
173 fn validate_model(&self) -> Result<(), ConfigBuilderError> {
174 self.model
175 .as_ref()
176 .or(self.diffusion_model.as_ref())
177 .map(|_| ())
178 .ok_or(ConfigBuilderError::UninitializedField(
179 "Model OR DiffusionModel must be valorized",
180 ))
181 }
182
183 pub fn lora_model(&mut self, lora_model: &Path) -> &mut Self {
184 let folder = lora_model.parent().unwrap();
185 let file_name = lora_model.file_stem().unwrap().to_str().unwrap().to_owned();
186 self.prompt_suffix(format!("<lora:{file_name}:1>"));
187 self.lora_model = Some(folder.into());
188 self
189 }
190
191 pub fn n_threads(&mut self, value: i32) -> &mut Self {
192 self.n_threads = if value > 0 {
193 Some(value)
194 } else {
195 Some(num_cpus::get_physical() as i32)
196 };
197 self
198 }
199}
200
201impl ModelConfig {
202 unsafe fn upscaler_ctx(&mut self) -> Option<*mut upscaler_ctx_t> {
203 unsafe {
204 if self.upscale_model.is_none() || self.upscale_repeats == 0 {
205 None
206 } else {
207 if self.upscaler_ctx.is_none() {
208 let upscaler = new_upscaler_ctx(
209 self.upscale_model.as_ref().unwrap().as_ptr(),
210 self.n_threads,
211 );
212 self.upscaler_ctx = Some(upscaler);
213 }
214 self.upscaler_ctx
215 }
216 }
217 }
218
219 unsafe fn diffusion_ctx(&mut self, vae_decode_only: bool) -> *mut sd_ctx_t {
220 unsafe {
221 if self.diffusion_ctx.is_none() {
222 let ctx = new_sd_ctx(
223 self.model.as_ptr(),
224 self.clip_l.as_ptr(),
225 self.clip_g.as_ptr(),
226 self.t5xxl.as_ptr(),
227 self.diffusion_model.as_ptr(),
228 self.vae.as_ptr(),
229 self.taesd.as_ptr(),
230 self.control_net.as_ptr(),
231 self.lora_model.as_ptr(),
232 self.embeddings.as_ptr(),
233 self.stacked_id_embd.as_ptr(),
234 vae_decode_only,
235 self.vae_tiling,
236 false,
237 self.n_threads,
238 self.weight_type,
239 self.rng,
240 self.schedule,
241 self.clip_on_cpu,
242 self.control_net_cpu,
243 self.vae_on_cpu,
244 self.flash_attention,
245 );
246 self.diffusion_ctx = Some(ctx)
247 }
248 self.diffusion_ctx.unwrap()
249 }
250 }
251}
252
253impl Drop for ModelConfig {
254 fn drop(&mut self) {
255 unsafe {
257 if let Some(sd_ctx) = self.diffusion_ctx {
258 free_sd_ctx(sd_ctx);
259 }
260
261 if let Some(upscaler_ctx) = self.upscaler_ctx {
262 free_upscaler_ctx(upscaler_ctx);
263 }
264 }
265 }
266}
267
268#[derive(Builder, Debug, Clone)]
269#[builder(setter(into, strip_option), build_fn(validate = "Self::validate"))]
270pub struct Config {
272 #[builder(default = "Default::default()")]
274 input_id_images: CLibPath,
275
276 #[builder(default = "false")]
278 normalize_input: bool,
279
280 #[builder(default = "Default::default()")]
282 init_img: CLibPath,
283
284 #[builder(default = "Default::default()")]
286 control_image: CLibPath,
287
288 #[builder(default = "PathBuf::from(\"./output.png\")")]
290 output: PathBuf,
291
292 prompt: String,
294
295 #[builder(default = "\"\".into()")]
297 negative_prompt: CLibString,
298
299 #[builder(default = "7.0")]
301 cfg_scale: f32,
302
303 #[builder(default = "3.5")]
305 guidance: f32,
306
307 #[builder(default = "0.75")]
309 strength: f32,
310
311 #[builder(default = "20.0")]
313 style_ratio: f32,
314
315 #[builder(default = "0.9")]
318 control_strength: f32,
319
320 #[builder(default = "512")]
322 height: i32,
323
324 #[builder(default = "512")]
326 width: i32,
327
328 #[builder(default = "SampleMethod::EULER_A")]
330 sampling_method: SampleMethod,
331
332 #[builder(default = "0.")]
334 eta: f32,
335
336 #[builder(default = "20")]
338 steps: i32,
339
340 #[builder(default = "42")]
342 seed: i64,
343
344 #[builder(default = "1")]
346 batch_count: i32,
347
348 #[builder(default = "ClipSkip::Unspecified")]
351 clip_skip: ClipSkip,
352
353 #[builder(default = "false")]
355 canny: bool,
356
357 #[builder(default = "0.")]
360 slg_scale: f32,
361
362 #[builder(default = "vec![7, 8, 9]")]
364 skip_layer: Vec<i32>,
365
366 #[builder(default = "0.01")]
368 skip_layer_start: f32,
369
370 #[builder(default = "0.2")]
372 skip_layer_end: f32,
373}
374
375impl ConfigBuilder {
376 fn validate(&self) -> Result<(), ConfigBuilderError> {
377 self.validate_output_dir()
378 }
379
380 fn validate_output_dir(&self) -> Result<(), ConfigBuilderError> {
381 let is_dir = self.output.as_ref().is_some_and(|val| val.is_dir());
382 let multiple_items = self.batch_count.as_ref().is_some_and(|val| *val > 1);
383 if is_dir == multiple_items {
384 Ok(())
385 } else {
386 Err(ConfigBuilderError::ValidationError(
387 "When batch_count > 1, ouput should point to folder and viceversa".to_owned(),
388 ))
389 }
390 }
391}
392
393impl From<Config> for ConfigBuilder {
394 fn from(value: Config) -> Self {
395 let mut builder = ConfigBuilder::default();
396 builder
397 .input_id_images(value.input_id_images)
398 .normalize_input(value.normalize_input)
399 .init_img(value.init_img)
400 .control_image(value.control_image)
401 .output(value.output)
402 .prompt(value.prompt)
403 .negative_prompt(value.negative_prompt)
404 .cfg_scale(value.cfg_scale)
405 .strength(value.strength)
406 .style_ratio(value.style_ratio)
407 .control_strength(value.control_strength)
408 .height(value.height)
409 .width(value.width)
410 .sampling_method(value.sampling_method)
411 .steps(value.steps)
412 .seed(value.seed)
413 .batch_count(value.batch_count)
414 .clip_skip(value.clip_skip)
415 .slg_scale(value.slg_scale)
416 .skip_layer(value.skip_layer)
417 .skip_layer_start(value.skip_layer_start)
418 .skip_layer_end(value.skip_layer_end)
419 .canny(value.canny);
420
421 builder
422 }
423}
424
425#[derive(Debug, Clone, Default)]
426struct CLibString(CString);
427
428impl CLibString {
429 fn as_ptr(&self) -> *const c_char {
430 self.0.as_ptr()
431 }
432}
433
434impl From<&str> for CLibString {
435 fn from(value: &str) -> Self {
436 Self(CString::new(value).unwrap())
437 }
438}
439
440impl From<String> for CLibString {
441 fn from(value: String) -> Self {
442 Self(CString::new(value).unwrap())
443 }
444}
445
446#[derive(Debug, Clone, Default)]
447struct CLibPath(CString);
448
449impl CLibPath {
450 fn as_ptr(&self) -> *const c_char {
451 self.0.as_ptr()
452 }
453}
454
455impl From<PathBuf> for CLibPath {
456 fn from(value: PathBuf) -> Self {
457 Self(CString::new(value.to_str().unwrap_or_default()).unwrap())
458 }
459}
460
461impl From<&Path> for CLibPath {
462 fn from(value: &Path) -> Self {
463 Self(CString::new(value.to_str().unwrap_or_default()).unwrap())
464 }
465}
466
467fn output_files(path: &Path, batch_size: i32) -> Vec<PathBuf> {
468 if batch_size == 1 {
469 vec![path.into()]
470 } else {
471 (1..=batch_size)
472 .map(|id| path.join(format!("output_{id}.png")))
473 .collect()
474 }
475}
476
477unsafe fn upscale(
478 upscale_repeats: i32,
479 upscaler_ctx: Option<*mut upscaler_ctx_t>,
480 data: sd_image_t,
481) -> Result<sd_image_t, DiffusionError> {
482 unsafe {
483 match upscaler_ctx {
484 Some(upscaler_ctx) => {
485 let upscale_factor = 4; let mut current_image = data;
487 for _ in 0..upscale_repeats {
488 let upscaled_image =
489 diffusion_rs_sys::upscale(upscaler_ctx, current_image, upscale_factor);
490
491 if upscaled_image.data.is_null() {
492 return Err(DiffusionError::Upscaler);
493 }
494
495 free(current_image.data as *mut c_void);
496 current_image = upscaled_image;
497 }
498 Ok(current_image)
499 }
500 None => Ok(data),
501 }
502 }
503}
504
505pub fn txt2img(config: &mut Config, model_config: &mut ModelConfig) -> Result<(), DiffusionError> {
507 let prompt: CLibString = match &model_config.prompt_suffix {
508 Some(suffix) => format!("{} {suffix}", &config.prompt),
509 None => config.prompt.clone(),
510 }
511 .into();
512 let files = output_files(&config.output, config.batch_count);
513 unsafe {
514 let sd_ctx = model_config.diffusion_ctx(true);
515 let upscaler_ctx = model_config.upscaler_ctx();
516
517 let slice = diffusion_rs_sys::txt2img(
518 sd_ctx,
519 prompt.as_ptr(),
520 config.negative_prompt.as_ptr(),
521 config.clip_skip as i32,
522 config.cfg_scale,
523 config.guidance,
524 config.eta,
525 config.width,
526 config.height,
527 config.sampling_method,
528 config.steps,
529 config.seed,
530 config.batch_count,
531 null(),
532 config.control_strength,
533 config.style_ratio,
534 config.normalize_input,
535 config.input_id_images.as_ptr(),
536 config.skip_layer.as_mut_ptr(),
537 config.skip_layer.len(),
538 config.slg_scale,
539 config.skip_layer_start,
540 config.skip_layer_end,
541 );
542 if slice.is_null() {
543 return Err(DiffusionError::Forward);
544 }
545 for (img, path) in slice::from_raw_parts(slice, config.batch_count as usize)
546 .iter()
547 .zip(files)
548 {
549 match upscale(model_config.upscale_repeats, upscaler_ctx, *img) {
550 Ok(img) => {
551 let len = (img.width * img.height * img.channel) as usize;
553 let buffer = slice::from_raw_parts(img.data, len).to_vec();
554 let save_state = ImageBuffer::from_raw(img.width, img.height, buffer)
555 .map(|img| RgbImage::from(img).save(path));
556 if let Some(Err(err)) = save_state {
557 return Err(DiffusionError::StoreImages(err));
558 }
559 }
560 Err(err) => {
561 return Err(err);
562 }
563 }
564 }
565
566 free(slice as *mut c_void);
568 Ok(())
569 }
570}
571
572#[cfg(test)]
573mod tests {
574 use std::path::PathBuf;
575
576 use crate::{
577 api::{ConfigBuilderError, ModelConfigBuilder},
578 util::download_file_hf_hub,
579 };
580
581 use super::{ConfigBuilder, txt2img};
582
583 #[test]
584 fn test_required_args_txt2img() {
585 assert!(ConfigBuilder::default().build().is_err());
586 assert!(ModelConfigBuilder::default().build().is_err());
587 ModelConfigBuilder::default()
588 .model(PathBuf::from("./test.ckpt"))
589 .build()
590 .unwrap();
591
592 ConfigBuilder::default()
593 .prompt("a lovely cat driving a sport car")
594 .build()
595 .unwrap();
596
597 assert!(matches!(
598 ConfigBuilder::default()
599 .prompt("a lovely cat driving a sport car")
600 .batch_count(10)
601 .build(),
602 Err(ConfigBuilderError::ValidationError(_))
603 ));
604
605 ConfigBuilder::default()
606 .prompt("a lovely cat driving a sport car")
607 .build()
608 .unwrap();
609
610 ConfigBuilder::default()
611 .prompt("a lovely duck drinking water from a bottle")
612 .batch_count(2)
613 .output(PathBuf::from("./"))
614 .build()
615 .unwrap();
616 }
617
618 #[ignore]
619 #[test]
620 fn test_txt2img() {
621 let model_path =
622 download_file_hf_hub("CompVis/stable-diffusion-v-1-4-original", "sd-v1-4.ckpt")
623 .unwrap();
624
625 let upscaler_path = download_file_hf_hub(
626 "ximso/RealESRGAN_x4plus_anime_6B",
627 "RealESRGAN_x4plus_anime_6B.pth",
628 )
629 .unwrap();
630 let mut config = ConfigBuilder::default()
631 .prompt("a lovely duck drinking water from a bottle")
632 .output(PathBuf::from("./output_1.png"))
633 .batch_count(1)
634 .build()
635 .unwrap();
636 let mut model_config = ModelConfigBuilder::default()
637 .model(model_path)
638 .upscale_model(upscaler_path)
639 .upscale_repeats(1)
640 .build()
641 .unwrap();
642
643 txt2img(&mut config, &mut model_config).unwrap();
644 let mut config2 = ConfigBuilder::from(config)
645 .prompt("a lovely duck drinking water from a straw")
646 .output(PathBuf::from("./output_2.png"))
647 .build()
648 .unwrap();
649 txt2img(&mut config2, &mut model_config).unwrap();
650 }
651}