1use std::cmp::max;
2use std::collections::HashMap;
3use std::ffi::CString;
4use std::ffi::c_char;
5use std::ffi::c_void;
6use std::fmt::Display;
7use std::path::Path;
8use std::path::PathBuf;
9use std::ptr::null;
10use std::ptr::null_mut;
11use std::slice;
12
13use derive_builder::Builder;
14use diffusion_rs_sys::free_upscaler_ctx;
15use diffusion_rs_sys::new_upscaler_ctx;
16use diffusion_rs_sys::sd_cache_mode_t;
17use diffusion_rs_sys::sd_cache_params_t;
18use diffusion_rs_sys::sd_ctx_params_t;
19use diffusion_rs_sys::sd_embedding_t;
20use diffusion_rs_sys::sd_get_default_sample_method;
21use diffusion_rs_sys::sd_get_default_scheduler;
22use diffusion_rs_sys::sd_guidance_params_t;
23use diffusion_rs_sys::sd_image_t;
24use diffusion_rs_sys::sd_img_gen_params_t;
25use diffusion_rs_sys::sd_lora_t;
26use diffusion_rs_sys::sd_pm_params_t;
27use diffusion_rs_sys::sd_sample_params_t;
28use diffusion_rs_sys::sd_set_preview_callback;
29use diffusion_rs_sys::sd_slg_params_t;
30use diffusion_rs_sys::sd_tiling_params_t;
31use diffusion_rs_sys::upscaler_ctx_t;
32use image::ImageBuffer;
33use image::ImageError;
34use image::RgbImage;
35use libc::free;
36use thiserror::Error;
37use walkdir::DirEntry;
38use walkdir::WalkDir;
39
40use diffusion_rs_sys::free_sd_ctx;
41use diffusion_rs_sys::new_sd_ctx;
42use diffusion_rs_sys::sd_ctx_t;
43
44pub use diffusion_rs_sys::rng_type_t as RngFunction;
46
47pub use diffusion_rs_sys::sample_method_t as SampleMethod;
49
50pub use diffusion_rs_sys::scheduler_t as Scheduler;
52
53pub use diffusion_rs_sys::prediction_t as Prediction;
55
56pub use diffusion_rs_sys::sd_type_t as WeightType;
58
59pub use diffusion_rs_sys::preview_t as PreviewType;
61
62pub use diffusion_rs_sys::lora_apply_mode_t as LoraModeType;
64
65static VALID_EXT: [&str; 3] = ["pt", "safetensors", "gguf"];
66
67#[non_exhaustive]
68#[derive(Error, Debug)]
69pub enum DiffusionError {
71 #[error("The underling stablediffusion.cpp function returned NULL")]
72 Forward,
73 #[error(transparent)]
74 StoreImages(#[from] ImageError),
75 #[error("The underling upscaler model returned a NULL image")]
76 Upscaler,
77}
78
79#[repr(i32)]
80#[non_exhaustive]
81#[derive(Debug, Default, Copy, Clone, Hash, PartialEq, Eq)]
82pub enum ClipSkip {
84 #[default]
86 Unspecified = 0,
87 None = 1,
88 OneLayer = 2,
89}
90
91type EmbeddingsStorage = (PathBuf, Vec<(CLibString, CLibPath)>, Vec<sd_embedding_t>);
92
93#[derive(Default, Debug, Clone)]
94struct LoraStorage {
95 lora_model_dir: CLibPath,
96 data: Vec<(CLibPath, String, f32)>,
97 loras_t: Vec<sd_lora_t>,
98}
99
100#[derive(Default, Debug, Clone)]
102pub struct LoraSpec {
103 pub file_name: String,
104 pub is_high_noise: bool,
105 pub multiplier: f32,
106}
107
108#[derive(Builder, Debug, Clone)]
110pub struct UCacheParams {
111 #[builder(default = "1.0")]
113 threshold: f32,
114
115 #[builder(default = "0.15")]
117 start: f32,
118
119 #[builder(default = "0.95")]
121 end: f32,
122
123 #[builder(default = "1.0")]
125 decay: f32,
126
127 #[builder(default = "true")]
129 relative: bool,
130
131 #[builder(default = "true")]
135 reset: bool,
136}
137
138#[derive(Builder, Debug, Clone)]
140pub struct EasyCacheParams {
141 #[builder(default = "0.2")]
143 threshold: f32,
144
145 #[builder(default = "0.15")]
147 start: f32,
148
149 #[builder(default = "0.95")]
151 end: f32,
152}
153
154#[derive(Builder, Debug, Clone)]
156pub struct DbCacheParams {
157 #[builder(default = "8")]
159 fn_blocks: i32,
160
161 #[builder(default = "0")]
163 bn_blocks: i32,
164
165 #[builder(default = "0.08")]
167 threshold: f32,
168
169 #[builder(default = "8")]
171 warmup: i32,
172
173 scm_mask: ScmPreset,
175
176 #[builder(default = "ScmPolicy::default()")]
178 scm_policy_dynamic: ScmPolicy,
179}
180
181#[derive(Debug, Default, Clone)]
183pub enum ScmPolicy {
184 Static,
186 #[default]
187 Dynamic,
189}
190
191#[derive(Debug, Default, Clone)]
193pub enum ScmPreset {
194 Slow,
195 #[default]
196 Medium,
197 Fast,
198 Ultra,
199 Custom(String),
202}
203
204impl ScmPreset {
205 fn to_vec_string(&self, steps: i32) -> String {
206 match self {
207 ScmPreset::Slow => ScmPresetBins {
208 compute_bins: vec![8, 3, 3, 2, 1, 1],
209 cache_bins: vec![1, 2, 2, 2, 3],
210 steps,
211 }
212 .to_string(),
213 ScmPreset::Medium => ScmPresetBins {
214 compute_bins: vec![6, 2, 2, 2, 2, 1],
215 cache_bins: vec![1, 3, 3, 3, 3],
216 steps,
217 }
218 .to_string(),
219 ScmPreset::Fast => ScmPresetBins {
220 compute_bins: vec![6, 1, 1, 1, 1, 1],
221 cache_bins: vec![1, 3, 4, 5, 4],
222 steps,
223 }
224 .to_string(),
225 ScmPreset::Ultra => ScmPresetBins {
226 compute_bins: vec![4, 1, 1, 1, 1],
227 cache_bins: vec![2, 5, 6, 7],
228 steps,
229 }
230 .to_string(),
231 ScmPreset::Custom(s) => s.clone(),
232 }
233 }
234}
235
236#[derive(Debug, Clone)]
237struct ScmPresetBins {
238 compute_bins: Vec<i32>,
239 cache_bins: Vec<i32>,
240 steps: i32,
241}
242
243impl ScmPresetBins {
244 fn maybe_scale(&self) -> ScmPresetBins {
245 if self.steps == 28 || self.steps <= 0 {
246 return self.clone();
247 }
248 self.scale()
249 }
250
251 fn scale(&self) -> ScmPresetBins {
252 let scale = self.steps as f32 / 28.0;
253 let scaled_compute_bins = self
254 .compute_bins
255 .iter()
256 .map(|b| max(1, (*b as f32 * scale * 0.5) as i32))
257 .collect();
258 let scaled_cached_bins = self
259 .cache_bins
260 .iter()
261 .map(|b| max(1, (*b as f32 * scale * 0.5) as i32))
262 .collect();
263 ScmPresetBins {
264 compute_bins: scaled_compute_bins,
265 cache_bins: scaled_cached_bins,
266 steps: self.steps,
267 }
268 }
269
270 fn generate_vec_mask(&self) -> Vec<i32> {
271 let mut mask = Vec::new();
272 let mut c_idx = 0;
273 let mut cache_idx = 0;
274
275 while mask.len() < self.steps as usize {
276 if c_idx < self.compute_bins.len() {
277 let compute_count = self.compute_bins[c_idx];
278 for _ in 0..compute_count {
279 if mask.len() < self.steps as usize {
280 mask.push(1);
281 }
282 }
283 c_idx += 1;
284 }
285 if cache_idx < self.cache_bins.len() {
286 let cache_count = self.cache_bins[c_idx];
287 for _ in 0..cache_count {
288 if mask.len() < self.steps as usize {
289 mask.push(0);
290 }
291 }
292 cache_idx += 1;
293 }
294 if c_idx >= self.compute_bins.len() && cache_idx >= self.cache_bins.len() {
295 break;
296 }
297 }
298 if let Some(last) = mask.last_mut() {
299 *last = 1;
300 }
301 mask
302 }
303}
304
305impl Display for ScmPresetBins {
306 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
307 let mask: String = self
308 .maybe_scale()
309 .generate_vec_mask()
310 .iter()
311 .map(|x| x.to_string())
312 .collect::<Vec<_>>()
313 .join(",");
314 write!(f, "{mask}")
315 }
316}
317
318#[derive(Builder, Debug, Clone)]
320#[builder(
321 setter(into, strip_option),
322 build_fn(error = "ConfigBuilderError", validate = "Self::validate")
323)]
324pub struct ModelConfig {
325 #[builder(default = "num_cpus::get_physical() as i32", setter(custom))]
328 n_threads: i32,
329
330 #[builder(default = "false")]
332 offload_params_to_cpu: bool,
333
334 #[builder(default = "Default::default()")]
336 upscale_model: Option<CLibPath>,
337
338 #[builder(default = "1")]
340 upscale_repeats: i32,
341
342 #[builder(default = "128")]
344 upscale_tile_size: i32,
345
346 #[builder(default = "Default::default()")]
348 model: CLibPath,
349
350 #[builder(default = "Default::default()")]
352 diffusion_model: CLibPath,
353
354 #[builder(default = "Default::default()")]
356 llm: CLibPath,
357
358 #[builder(default = "Default::default()")]
360 llm_vision: CLibPath,
361
362 #[builder(default = "Default::default()")]
364 clip_l: CLibPath,
365
366 #[builder(default = "Default::default()")]
368 clip_g: CLibPath,
369
370 #[builder(default = "Default::default()")]
372 clip_vision: CLibPath,
373
374 #[builder(default = "Default::default()")]
376 t5xxl: CLibPath,
377
378 #[builder(default = "Default::default()")]
380 vae: CLibPath,
381
382 #[builder(default = "Default::default()")]
384 taesd: CLibPath,
385
386 #[builder(default = "Default::default()")]
388 control_net: CLibPath,
389
390 #[builder(default = "Default::default()", setter(custom))]
392 embeddings: EmbeddingsStorage,
393
394 #[builder(default = "Default::default()")]
396 photo_maker: CLibPath,
397
398 #[builder(default = "Default::default()")]
400 pm_id_embed_path: CLibPath,
401
402 #[builder(default = "WeightType::SD_TYPE_COUNT")]
404 weight_type: WeightType,
405
406 #[builder(default = "Default::default()", setter(custom))]
408 lora_models: LoraStorage,
409
410 #[builder(default = "Default::default()")]
412 high_noise_diffusion_model: CLibPath,
413
414 #[builder(default = "false")]
416 vae_tiling: bool,
417
418 #[builder(default = "(32,32)")]
420 vae_tile_size: (i32, i32),
421
422 #[builder(default = "(0.,0.)")]
424 vae_relative_tile_size: (f32, f32),
425
426 #[builder(default = "0.5")]
428 vae_tile_overlap: f32,
429
430 #[builder(default = "RngFunction::CUDA_RNG")]
432 rng: RngFunction,
433
434 #[builder(default = "RngFunction::RNG_TYPE_COUNT")]
436 sampler_rng_type: RngFunction,
437
438 #[builder(default = "Scheduler::SCHEDULER_COUNT")]
442 scheduler: Scheduler,
443
444 #[builder(default = "Default::default()")]
446 sigmas: Vec<f32>,
447
448 #[builder(default = "Prediction::PREDICTION_COUNT")]
450 prediction: Prediction,
451
452 #[builder(default = "false")]
454 vae_on_cpu: bool,
455
456 #[builder(default = "false")]
458 clip_on_cpu: bool,
459
460 #[builder(default = "false")]
462 control_net_cpu: bool,
463
464 #[builder(default = "false")]
467 flash_attention: bool,
468
469 #[builder(default = "false")]
471 chroma_disable_dit_mask: bool,
472
473 #[builder(default = "false")]
475 chroma_enable_t5_mask: bool,
476
477 #[builder(default = "1")]
479 chroma_t5_mask_pad: i32,
480
481 #[builder(default = "false")]
483 use_qwen_image_zero_cond_true: bool,
484
485 #[builder(default = "false")]
488 diffusion_conv_direct: bool,
489
490 #[builder(default = "false")]
493 vae_conv_direct: bool,
494
495 #[builder(default = "false")]
497 force_sdxl_vae_conv_scale: bool,
498
499 #[builder(default = "f32::INFINITY")]
501 flow_shift: f32,
502
503 #[builder(default = "0")]
505 timestep_shift: i32,
506
507 #[builder(default = "false")]
509 taesd_preview_only: bool,
510
511 #[builder(default = "LoraModeType::LORA_APPLY_AUTO")]
513 lora_apply_mode: LoraModeType,
514
515 #[builder(default = "false")]
517 circular: bool,
518
519 #[builder(default = "false")]
521 circular_x: bool,
522
523 #[builder(default = "false")]
525 circular_y: bool,
526
527 #[builder(default = "None", private)]
528 upscaler_ctx: Option<*mut upscaler_ctx_t>,
529
530 #[builder(default = "None", private)]
531 diffusion_ctx: Option<(*mut sd_ctx_t, sd_ctx_params_t)>,
532}
533
534impl ModelConfigBuilder {
535 fn validate(&self) -> Result<(), ConfigBuilderError> {
536 self.validate_model()
537 }
538
539 fn validate_model(&self) -> Result<(), ConfigBuilderError> {
540 self.model
541 .as_ref()
542 .or(self.diffusion_model.as_ref())
543 .map(|_| ())
544 .ok_or(ConfigBuilderError::UninitializedField(
545 "Model OR DiffusionModel must be valorized",
546 ))
547 }
548
549 fn filter_valid_extensions(&self, path: &Path) -> impl Iterator<Item = DirEntry> {
550 WalkDir::new(path)
551 .into_iter()
552 .filter_map(|entry| entry.ok())
553 .filter(|entry| {
554 entry
555 .path()
556 .extension()
557 .and_then(|ext| ext.to_str())
558 .map(|ext_str| VALID_EXT.contains(&ext_str))
559 .unwrap_or(false)
560 })
561 }
562
563 fn build_single_lora_storage(
564 spec: &LoraSpec,
565 is_high_noise: bool,
566 valid_loras: &HashMap<String, PathBuf>,
567 ) -> ((CLibPath, String, f32), sd_lora_t) {
568 let path = valid_loras.get(&spec.file_name).unwrap().as_path();
569 let c_path = CLibPath::from(path);
570 let lora = sd_lora_t {
571 is_high_noise,
572 multiplier: spec.multiplier,
573 path: c_path.as_ptr(),
574 };
575 let data = (c_path, spec.file_name.clone(), spec.multiplier);
576 (data, lora)
577 }
578
579 pub fn embeddings(&mut self, embeddings_dir: &Path) -> &mut Self {
580 let data: Vec<(CLibString, CLibPath)> = self
581 .filter_valid_extensions(embeddings_dir)
582 .map(|entry| {
583 let file_stem = entry
584 .path()
585 .file_stem()
586 .and_then(|stem| stem.to_str())
587 .unwrap_or_default()
588 .to_owned();
589 (CLibString::from(file_stem), CLibPath::from(entry.path()))
590 })
591 .collect();
592 let data_pointer = data
593 .iter()
594 .map(|(name, path)| sd_embedding_t {
595 name: name.as_ptr(),
596 path: path.as_ptr(),
597 })
598 .collect();
599 self.embeddings = Some((embeddings_dir.to_path_buf(), data, data_pointer));
600 self
601 }
602
603 pub fn lora_models(&mut self, lora_model_dir: &Path, specs: Vec<LoraSpec>) -> &mut Self {
604 let valid_loras: HashMap<String, PathBuf> = self
605 .filter_valid_extensions(lora_model_dir)
606 .map(|entry| {
607 let path = entry.path();
608 (
609 path.file_stem()
610 .and_then(|stem| stem.to_str())
611 .unwrap_or_default()
612 .to_owned(),
613 path.to_path_buf(),
614 )
615 })
616 .collect();
617 let valid_lora_names: Vec<&String> = valid_loras.keys().collect();
618 let standard = specs
619 .iter()
620 .filter(|s| valid_lora_names.contains(&&s.file_name) && !s.is_high_noise)
621 .map(|s| Self::build_single_lora_storage(s, false, &valid_loras));
622 let high_noise = specs
623 .iter()
624 .filter(|s| valid_lora_names.contains(&&s.file_name) && s.is_high_noise)
625 .map(|s| Self::build_single_lora_storage(s, true, &valid_loras));
626
627 let mut data = Vec::new();
628 let mut loras_t = Vec::new();
629 for lora in standard.chain(high_noise) {
630 data.push(lora.0);
631 loras_t.push(lora.1);
632 }
633
634 self.lora_models = Some(LoraStorage {
635 lora_model_dir: lora_model_dir.into(),
636 data,
637 loras_t,
638 });
639 self
640 }
641
642 pub fn n_threads(&mut self, value: i32) -> &mut Self {
643 self.n_threads = if value > 0 {
644 Some(value)
645 } else {
646 Some(num_cpus::get_physical() as i32)
647 };
648 self
649 }
650}
651
652impl ModelConfig {
653 unsafe fn upscaler_ctx(&mut self) -> Option<*mut upscaler_ctx_t> {
654 unsafe {
655 if self.upscale_model.is_none() || self.upscale_repeats == 0 {
656 None
657 } else {
658 if self.upscaler_ctx.is_none() {
659 let upscaler = new_upscaler_ctx(
660 self.upscale_model.as_ref().unwrap().as_ptr(),
661 self.offload_params_to_cpu,
662 self.diffusion_conv_direct,
663 self.n_threads,
664 self.upscale_tile_size,
665 );
666 self.upscaler_ctx = Some(upscaler);
667 }
668 self.upscaler_ctx
669 }
670 }
671 }
672
673 unsafe fn diffusion_ctx(&mut self, vae_decode_only: bool) -> *mut sd_ctx_t {
674 unsafe {
675 if self.diffusion_ctx.is_none() {
676 let sd_ctx_params = sd_ctx_params_t {
677 model_path: self.model.as_ptr(),
678 llm_path: self.llm.as_ptr(),
679 llm_vision_path: self.llm_vision.as_ptr(),
680 clip_l_path: self.clip_l.as_ptr(),
681 clip_g_path: self.clip_g.as_ptr(),
682 clip_vision_path: self.clip_vision.as_ptr(),
683 high_noise_diffusion_model_path: self.high_noise_diffusion_model.as_ptr(),
684 t5xxl_path: self.t5xxl.as_ptr(),
685 diffusion_model_path: self.diffusion_model.as_ptr(),
686 vae_path: self.vae.as_ptr(),
687 taesd_path: self.taesd.as_ptr(),
688 control_net_path: self.control_net.as_ptr(),
689 embeddings: self.embeddings.2.as_ptr(),
690 embedding_count: self.embeddings.1.len() as u32,
691 photo_maker_path: self.photo_maker.as_ptr(),
692 vae_decode_only,
693 free_params_immediately: false,
694 n_threads: self.n_threads,
695 wtype: self.weight_type,
696 rng_type: self.rng,
697 keep_clip_on_cpu: self.clip_on_cpu,
698 keep_control_net_on_cpu: self.control_net_cpu,
699 keep_vae_on_cpu: self.vae_on_cpu,
700 diffusion_flash_attn: self.flash_attention,
701 diffusion_conv_direct: self.diffusion_conv_direct,
702 chroma_use_dit_mask: !self.chroma_disable_dit_mask,
703 chroma_use_t5_mask: self.chroma_enable_t5_mask,
704 chroma_t5_mask_pad: self.chroma_t5_mask_pad,
705 vae_conv_direct: self.vae_conv_direct,
706 offload_params_to_cpu: self.offload_params_to_cpu,
707 flow_shift: self.flow_shift,
708 prediction: self.prediction,
709 force_sdxl_vae_conv_scale: self.force_sdxl_vae_conv_scale,
710 tae_preview_only: self.taesd_preview_only,
711 lora_apply_mode: self.lora_apply_mode,
712 tensor_type_rules: null_mut(),
713 sampler_rng_type: self.sampler_rng_type,
714 circular_x: self.circular || self.circular_x,
715 circular_y: self.circular || self.circular_y,
716 qwen_image_zero_cond_t: self.use_qwen_image_zero_cond_true,
717 };
718 let ctx = new_sd_ctx(&sd_ctx_params);
719 self.diffusion_ctx = Some((ctx, sd_ctx_params))
720 }
721 self.diffusion_ctx.unwrap().0
722 }
723 }
724}
725
726impl Drop for ModelConfig {
727 fn drop(&mut self) {
728 unsafe {
730 if let Some((sd_ctx, _)) = self.diffusion_ctx {
731 free_sd_ctx(sd_ctx);
732 }
733
734 if let Some(upscaler_ctx) = self.upscaler_ctx {
735 free_upscaler_ctx(upscaler_ctx);
736 }
737 }
738 }
739}
740
741impl From<ModelConfig> for ModelConfigBuilder {
742 fn from(value: ModelConfig) -> Self {
743 let mut builder = ModelConfigBuilder::default();
744 builder
745 .n_threads(value.n_threads)
746 .offload_params_to_cpu(value.offload_params_to_cpu)
747 .upscale_repeats(value.upscale_repeats)
748 .model(value.model.clone())
749 .diffusion_model(value.diffusion_model.clone())
750 .llm(value.llm.clone())
751 .llm_vision(value.llm_vision.clone())
752 .clip_l(value.clip_l.clone())
753 .clip_g(value.clip_g.clone())
754 .clip_vision(value.clip_vision.clone())
755 .t5xxl(value.t5xxl.clone())
756 .vae(value.vae.clone())
757 .taesd(value.taesd.clone())
758 .control_net(value.control_net.clone())
759 .embeddings(&value.embeddings.0)
760 .photo_maker(value.photo_maker.clone())
761 .pm_id_embed_path(value.pm_id_embed_path.clone())
762 .weight_type(value.weight_type)
763 .high_noise_diffusion_model(value.high_noise_diffusion_model.clone())
764 .vae_tiling(value.vae_tiling)
765 .vae_tile_size(value.vae_tile_size)
766 .vae_relative_tile_size(value.vae_relative_tile_size)
767 .vae_tile_overlap(value.vae_tile_overlap)
768 .rng(value.rng)
769 .sampler_rng_type(value.rng)
770 .scheduler(value.scheduler)
771 .sigmas(value.sigmas.clone())
772 .prediction(value.prediction)
773 .vae_on_cpu(value.vae_on_cpu)
774 .clip_on_cpu(value.clip_on_cpu)
775 .control_net(value.control_net.clone())
776 .control_net_cpu(value.control_net_cpu)
777 .flash_attention(value.flash_attention)
778 .chroma_disable_dit_mask(value.chroma_disable_dit_mask)
779 .chroma_enable_t5_mask(value.chroma_enable_t5_mask)
780 .chroma_t5_mask_pad(value.chroma_t5_mask_pad)
781 .diffusion_conv_direct(value.diffusion_conv_direct)
782 .vae_conv_direct(value.vae_conv_direct)
783 .force_sdxl_vae_conv_scale(value.force_sdxl_vae_conv_scale)
784 .flow_shift(value.flow_shift)
785 .timestep_shift(value.timestep_shift)
786 .taesd_preview_only(value.taesd_preview_only)
787 .lora_apply_mode(value.lora_apply_mode)
788 .circular(value.circular)
789 .circular_x(value.circular_x)
790 .circular_y(value.circular_y)
791 .use_qwen_image_zero_cond_true(value.use_qwen_image_zero_cond_true);
792
793 let lora_model_dir = Into::<PathBuf>::into(&value.lora_models.lora_model_dir);
794 let lora_specs = value
795 .lora_models
796 .data
797 .iter()
798 .map(|(_, name, multiplier)| LoraSpec {
799 file_name: name.clone(),
800 is_high_noise: false,
801 multiplier: *multiplier,
802 })
803 .collect();
804
805 builder.lora_models(&lora_model_dir, lora_specs);
806
807 if let Some(model) = &value.upscale_model {
808 builder.upscale_model(model.clone());
809 }
810 builder
811 }
812}
813
814#[derive(Builder, Debug, Clone)]
815#[builder(setter(into, strip_option), build_fn(validate = "Self::validate"))]
816pub struct Config {
818 #[builder(default = "Default::default()")]
820 pm_id_images_dir: CLibPath,
821
822 #[builder(default = "Default::default()")]
824 init_img: CLibPath,
825
826 #[builder(default = "Default::default()")]
828 control_image: CLibPath,
829
830 #[builder(default = "PathBuf::from(\"./output.png\")")]
832 output: PathBuf,
833
834 #[builder(default = "PathBuf::from(\"./preview_output.png\")")]
836 preview_output: PathBuf,
837
838 #[builder(default = "PreviewType::PREVIEW_NONE")]
840 preview_mode: PreviewType,
841
842 #[builder(default = "false")]
844 preview_noisy: bool,
845
846 #[builder(default = "1")]
848 preview_interval: i32,
849
850 prompt: String,
852
853 #[builder(default = "\"\".into()")]
855 negative_prompt: CLibString,
856
857 #[builder(default = "7.0")]
859 cfg_scale: f32,
860
861 #[builder(default = "3.5")]
863 guidance: f32,
864
865 #[builder(default = "0.75")]
867 strength: f32,
868
869 #[builder(default = "20.0")]
871 pm_style_strength: f32,
872
873 #[builder(default = "0.9")]
876 control_strength: f32,
877
878 #[builder(default = "512")]
880 height: i32,
881
882 #[builder(default = "512")]
884 width: i32,
885
886 #[builder(default = "SampleMethod::SAMPLE_METHOD_COUNT")]
890 sampling_method: SampleMethod,
891
892 #[builder(default = "0.")]
894 eta: f32,
895
896 #[builder(default = "20")]
898 steps: i32,
899
900 #[builder(default = "42")]
902 seed: i64,
903
904 #[builder(default = "1")]
906 batch_count: i32,
907
908 #[builder(default = "ClipSkip::Unspecified")]
911 clip_skip: ClipSkip,
912
913 #[builder(default = "false")]
915 canny: bool,
916
917 #[builder(default = "0.")]
920 slg_scale: f32,
921
922 #[builder(default = "vec![7, 8, 9]")]
924 skip_layer: Vec<i32>,
925
926 #[builder(default = "0.01")]
928 skip_layer_start: f32,
929
930 #[builder(default = "0.2")]
932 skip_layer_end: f32,
933
934 #[builder(default = "false")]
936 disable_auto_resize_ref_image: bool,
937
938 #[builder(default = "Self::cache_init()", private)]
939 cache: sd_cache_params_t,
940
941 #[builder(default = "CLibString::default()", private)]
942 scm_mask: CLibString,
943}
944
945impl ConfigBuilder {
946 fn validate(&self) -> Result<(), ConfigBuilderError> {
947 self.validate_output_dir()
948 }
949
950 fn validate_output_dir(&self) -> Result<(), ConfigBuilderError> {
951 let is_dir = self.output.as_ref().is_some_and(|val| val.is_dir());
952 let multiple_items = self.batch_count.as_ref().is_some_and(|val| *val > 1);
953 if is_dir == multiple_items {
954 Ok(())
955 } else {
956 Err(ConfigBuilderError::ValidationError(
957 "When batch_count > 1, output should point to folder and vice versa".to_owned(),
958 ))
959 }
960 }
961
962 fn cache_init() -> sd_cache_params_t {
963 sd_cache_params_t {
964 mode: sd_cache_mode_t::SD_CACHE_DISABLED,
965 reuse_threshold: 1.0,
966 start_percent: 0.15,
967 end_percent: 0.95,
968 error_decay_rate: 1.0,
969 use_relative_threshold: true,
970 reset_error_on_compute: true,
971 Fn_compute_blocks: 8,
972 Bn_compute_blocks: 0,
973 residual_diff_threshold: 0.08,
974 max_warmup_steps: 8,
975 max_cached_steps: -1,
976 max_continuous_cached_steps: -1,
977 taylorseer_n_derivatives: 1,
978 taylorseer_skip_interval: 1,
979 scm_mask: null(),
980 scm_policy_dynamic: true,
981 }
982 }
983
984 pub fn no_caching(&mut self) -> &mut Self {
985 let mut cache = Self::cache_init();
986 cache.mode = sd_cache_mode_t::SD_CACHE_DISABLED;
987 self.cache = Some(cache);
988 self
989 }
990
991 pub fn ucache_caching(&mut self, params: UCacheParams) -> &mut Self {
992 let mut cache = Self::cache_init();
993 cache.mode = sd_cache_mode_t::SD_CACHE_UCACHE;
994 cache.reuse_threshold = params.threshold;
995 cache.start_percent = params.start;
996 cache.end_percent = params.end;
997 cache.error_decay_rate = params.decay;
998 cache.use_relative_threshold = params.relative;
999 cache.reset_error_on_compute = params.reset;
1000 self.cache = Some(cache);
1001 self
1002 }
1003
1004 pub fn easy_cache_caching(&mut self, params: EasyCacheParams) -> &mut Self {
1005 let mut cache = Self::cache_init();
1006 cache.mode = sd_cache_mode_t::SD_CACHE_EASYCACHE;
1007 cache.reuse_threshold = params.threshold;
1008 cache.start_percent = params.start;
1009 cache.end_percent = params.end;
1010 self.cache = Some(cache);
1011 self
1012 }
1013
1014 pub fn db_cache_caching(&mut self, params: DbCacheParams) -> &mut Self {
1015 let mut cache = Self::cache_init();
1016 cache.mode = sd_cache_mode_t::SD_CACHE_DBCACHE;
1017 cache.Fn_compute_blocks = params.fn_blocks;
1018 cache.Bn_compute_blocks = params.bn_blocks;
1019 cache.residual_diff_threshold = params.threshold;
1020 cache.max_warmup_steps = params.warmup;
1021 cache.scm_policy_dynamic = match params.scm_policy_dynamic {
1022 ScmPolicy::Static => false,
1023 ScmPolicy::Dynamic => true,
1024 };
1025 self.scm_mask = Some(CLibString::from(
1026 params
1027 .scm_mask
1028 .to_vec_string(self.steps.unwrap_or_default()),
1029 ));
1030 cache.scm_mask = self.scm_mask.as_ref().unwrap().as_ptr();
1031
1032 self.cache = Some(cache);
1033 self
1034 }
1035
1036 pub fn taylor_seer_caching(&mut self) -> &mut Self {
1037 let mut cache = Self::cache_init();
1038 cache.mode = sd_cache_mode_t::SD_CACHE_TAYLORSEER;
1039 self.cache = Some(cache);
1040 self
1041 }
1042
1043 pub fn cache_dit_caching(&mut self, params: DbCacheParams) -> &mut Self {
1044 self.db_cache_caching(params).cache.unwrap().mode = sd_cache_mode_t::SD_CACHE_CACHE_DIT;
1045 self
1046 }
1047}
1048
1049impl From<Config> for ConfigBuilder {
1050 fn from(value: Config) -> Self {
1051 let mut builder = ConfigBuilder::default();
1052 let mut cache = value.cache;
1053 let scm_mask = value.scm_mask.clone();
1054 cache.scm_mask = scm_mask.as_ptr();
1055 builder
1056 .pm_id_images_dir(value.pm_id_images_dir)
1057 .init_img(value.init_img)
1058 .control_image(value.control_image)
1059 .output(value.output)
1060 .prompt(value.prompt)
1061 .negative_prompt(value.negative_prompt)
1062 .cfg_scale(value.cfg_scale)
1063 .strength(value.strength)
1064 .pm_style_strength(value.pm_style_strength)
1065 .control_strength(value.control_strength)
1066 .height(value.height)
1067 .width(value.width)
1068 .sampling_method(value.sampling_method)
1069 .steps(value.steps)
1070 .seed(value.seed)
1071 .batch_count(value.batch_count)
1072 .clip_skip(value.clip_skip)
1073 .slg_scale(value.slg_scale)
1074 .skip_layer(value.skip_layer)
1075 .skip_layer_start(value.skip_layer_start)
1076 .skip_layer_end(value.skip_layer_end)
1077 .canny(value.canny)
1078 .disable_auto_resize_ref_image(value.disable_auto_resize_ref_image)
1079 .preview_output(value.preview_output)
1080 .preview_mode(value.preview_mode)
1081 .preview_noisy(value.preview_noisy)
1082 .preview_interval(value.preview_interval)
1083 .cache(cache)
1084 .scm_mask(scm_mask);
1085 builder
1086 }
1087}
1088
1089#[derive(Debug, Clone, Default)]
1090struct CLibString(CString);
1091
1092impl CLibString {
1093 fn as_ptr(&self) -> *const c_char {
1094 self.0.as_ptr()
1095 }
1096}
1097
1098impl From<&str> for CLibString {
1099 fn from(value: &str) -> Self {
1100 Self(CString::new(value).unwrap())
1101 }
1102}
1103
1104impl From<String> for CLibString {
1105 fn from(value: String) -> Self {
1106 Self(CString::new(value).unwrap())
1107 }
1108}
1109
1110#[derive(Debug, Clone, Default)]
1111struct CLibPath(CString);
1112
1113impl CLibPath {
1114 fn as_ptr(&self) -> *const c_char {
1115 self.0.as_ptr()
1116 }
1117}
1118
1119impl From<PathBuf> for CLibPath {
1120 fn from(value: PathBuf) -> Self {
1121 Self(CString::new(value.to_str().unwrap_or_default()).unwrap())
1122 }
1123}
1124
1125impl From<&Path> for CLibPath {
1126 fn from(value: &Path) -> Self {
1127 Self(CString::new(value.to_str().unwrap_or_default()).unwrap())
1128 }
1129}
1130
1131impl From<&CLibPath> for PathBuf {
1132 fn from(value: &CLibPath) -> Self {
1133 PathBuf::from(value.0.to_str().unwrap())
1134 }
1135}
1136
1137fn output_files(path: &Path, prompt: &str, batch_size: i32) -> Vec<PathBuf> {
1138 if batch_size == 1 {
1139 vec![path.into()]
1140 } else {
1141 (1..=batch_size)
1142 .map(|id| path.join(format!("{prompt}_{id}.png")))
1143 .collect()
1144 }
1145}
1146
1147unsafe fn upscale(
1148 upscale_repeats: i32,
1149 upscaler_ctx: Option<*mut upscaler_ctx_t>,
1150 data: sd_image_t,
1151) -> Result<sd_image_t, DiffusionError> {
1152 unsafe {
1153 match upscaler_ctx {
1154 Some(upscaler_ctx) => {
1155 let upscale_factor = 4; let mut current_image = data;
1157 for _ in 0..upscale_repeats {
1158 let upscaled_image =
1159 diffusion_rs_sys::upscale(upscaler_ctx, current_image, upscale_factor);
1160
1161 if upscaled_image.data.is_null() {
1162 return Err(DiffusionError::Upscaler);
1163 }
1164
1165 free(current_image.data as *mut c_void);
1166 current_image = upscaled_image;
1167 }
1168 Ok(current_image)
1169 }
1170 None => Ok(data),
1171 }
1172 }
1173}
1174
1175pub fn gen_img(config: &Config, model_config: &mut ModelConfig) -> Result<(), DiffusionError> {
1177 let prompt: CLibString = CLibString::from(config.prompt.as_str());
1178 let files = output_files(&config.output, &config.prompt, config.batch_count);
1179 unsafe {
1180 let sd_ctx = model_config.diffusion_ctx(true);
1181 let upscaler_ctx = model_config.upscaler_ctx();
1182 let init_image = sd_image_t {
1183 width: 0,
1184 height: 0,
1185 channel: 3,
1186 data: null_mut(),
1187 };
1188 let mask_image = sd_image_t {
1189 width: config.width as u32,
1190 height: config.height as u32,
1191 channel: 1,
1192 data: null_mut(),
1193 };
1194 let mut layers = config.skip_layer.clone();
1195 let guidance = sd_guidance_params_t {
1196 txt_cfg: config.cfg_scale,
1197 img_cfg: config.cfg_scale,
1198 distilled_guidance: config.guidance,
1199 slg: sd_slg_params_t {
1200 layers: layers.as_mut_ptr(),
1201 layer_count: config.skip_layer.len(),
1202 layer_start: config.skip_layer_start,
1203 layer_end: config.skip_layer_end,
1204 scale: config.slg_scale,
1205 },
1206 };
1207 let scheduler = if model_config.scheduler == Scheduler::SCHEDULER_COUNT {
1208 sd_get_default_scheduler(sd_ctx, config.sampling_method)
1209 } else {
1210 model_config.scheduler
1211 };
1212 let sample_method = if config.sampling_method == SampleMethod::SAMPLE_METHOD_COUNT {
1213 sd_get_default_sample_method(sd_ctx)
1214 } else {
1215 config.sampling_method
1216 };
1217 let sample_params = sd_sample_params_t {
1218 guidance,
1219 sample_method,
1220 sample_steps: config.steps,
1221 eta: config.eta,
1222 scheduler,
1223 shifted_timestep: model_config.timestep_shift,
1224 custom_sigmas: model_config.sigmas.as_mut_ptr(),
1225 custom_sigmas_count: model_config.sigmas.len() as i32,
1226 };
1227 let control_image = sd_image_t {
1228 width: 0,
1229 height: 0,
1230 channel: 3,
1231 data: null_mut(),
1232 };
1233 let vae_tiling_params = sd_tiling_params_t {
1234 enabled: model_config.vae_tiling,
1235 tile_size_x: model_config.vae_tile_size.0,
1236 tile_size_y: model_config.vae_tile_size.1,
1237 target_overlap: model_config.vae_tile_overlap,
1238 rel_size_x: model_config.vae_relative_tile_size.0,
1239 rel_size_y: model_config.vae_relative_tile_size.1,
1240 };
1241 let pm_params = sd_pm_params_t {
1242 id_images: null_mut(),
1243 id_images_count: 0,
1244 id_embed_path: model_config.pm_id_embed_path.as_ptr(),
1245 style_strength: config.pm_style_strength,
1246 };
1247
1248 unsafe extern "C" fn save_preview_local(
1249 _step: ::std::os::raw::c_int,
1250 _frame_count: ::std::os::raw::c_int,
1251 frames: *mut sd_image_t,
1252 _is_noisy: bool,
1253 data: *mut ::std::os::raw::c_void,
1254 ) {
1255 unsafe {
1256 let path = &*data.cast::<PathBuf>();
1257 let _ = save_img(*frames, path);
1258 }
1259 }
1260
1261 if config.preview_mode != PreviewType::PREVIEW_NONE {
1262 let data = &config.preview_output as *const PathBuf;
1263
1264 sd_set_preview_callback(
1265 Some(save_preview_local),
1266 config.preview_mode,
1267 config.preview_interval,
1268 !config.preview_noisy,
1269 config.preview_noisy,
1270 data as *mut c_void,
1271 );
1272 }
1273
1274 let sd_img_gen_params = sd_img_gen_params_t {
1275 prompt: prompt.as_ptr(),
1276 negative_prompt: config.negative_prompt.as_ptr(),
1277 clip_skip: config.clip_skip as i32,
1278 init_image,
1279 ref_images: null_mut(),
1280 ref_images_count: 0,
1281 increase_ref_index: false,
1282 mask_image,
1283 width: config.width,
1284 height: config.height,
1285 sample_params,
1286 strength: config.strength,
1287 seed: config.seed,
1288 batch_count: config.batch_count,
1289 control_image,
1290 control_strength: config.control_strength,
1291 pm_params,
1292 vae_tiling_params,
1293 auto_resize_ref_image: config.disable_auto_resize_ref_image,
1294 cache: config.cache,
1295 loras: model_config.lora_models.loras_t.as_ptr(),
1296 lora_count: model_config.lora_models.loras_t.len() as u32,
1297 };
1298 let slice = diffusion_rs_sys::generate_image(sd_ctx, &sd_img_gen_params);
1299 let ret = {
1300 if slice.is_null() {
1301 return Err(DiffusionError::Forward);
1302 }
1303 for (img, path) in slice::from_raw_parts(slice, config.batch_count as usize)
1304 .iter()
1305 .zip(files)
1306 {
1307 match upscale(model_config.upscale_repeats, upscaler_ctx, *img) {
1308 Ok(img) => save_img(img, &path)?,
1309 Err(err) => {
1310 return Err(err);
1311 }
1312 }
1313 }
1314 Ok(())
1315 };
1316 free(slice as *mut c_void);
1317 ret
1318 }
1319}
1320
1321fn save_img(img: sd_image_t, path: &Path) -> Result<(), DiffusionError> {
1322 let len = (img.width * img.height * img.channel) as usize;
1324 let buffer = unsafe { slice::from_raw_parts(img.data, len).to_vec() };
1325 let save_state = ImageBuffer::from_raw(img.width, img.height, buffer)
1326 .map(|img| RgbImage::from(img).save(path));
1327 if let Some(Err(err)) = save_state {
1328 return Err(DiffusionError::StoreImages(err));
1329 }
1330 Ok(())
1331}
1332
1333#[cfg(test)]
1334mod tests {
1335 use std::path::PathBuf;
1336
1337 use crate::{
1338 api::{ConfigBuilderError, ModelConfigBuilder},
1339 util::download_file_hf_hub,
1340 };
1341
1342 use super::{ConfigBuilder, gen_img};
1343
1344 #[test]
1345 fn test_required_args_txt2img() {
1346 assert!(ConfigBuilder::default().build().is_err());
1347 assert!(ModelConfigBuilder::default().build().is_err());
1348 ModelConfigBuilder::default()
1349 .model(PathBuf::from("./test.ckpt"))
1350 .build()
1351 .unwrap();
1352
1353 ConfigBuilder::default()
1354 .prompt("a lovely cat driving a sport car")
1355 .build()
1356 .unwrap();
1357
1358 assert!(matches!(
1359 ConfigBuilder::default()
1360 .prompt("a lovely cat driving a sport car")
1361 .batch_count(10)
1362 .build(),
1363 Err(ConfigBuilderError::ValidationError(_))
1364 ));
1365
1366 ConfigBuilder::default()
1367 .prompt("a lovely cat driving a sport car")
1368 .build()
1369 .unwrap();
1370
1371 ConfigBuilder::default()
1372 .prompt("a lovely duck drinking water from a bottle")
1373 .batch_count(2)
1374 .output(PathBuf::from("./"))
1375 .build()
1376 .unwrap();
1377 }
1378
1379 #[ignore]
1380 #[test]
1381 fn test_img_gen() {
1382 let model_path =
1383 download_file_hf_hub("CompVis/stable-diffusion-v-1-4-original", "sd-v1-4.ckpt")
1384 .unwrap();
1385
1386 let upscaler_path = download_file_hf_hub(
1387 "ximso/RealESRGAN_x4plus_anime_6B",
1388 "RealESRGAN_x4plus_anime_6B.pth",
1389 )
1390 .unwrap();
1391 let config = ConfigBuilder::default()
1392 .prompt("a lovely duck drinking water from a bottle")
1393 .output(PathBuf::from("./output_1.png"))
1394 .batch_count(1)
1395 .build()
1396 .unwrap();
1397 let mut model_config = ModelConfigBuilder::default()
1398 .model(model_path)
1399 .upscale_model(upscaler_path)
1400 .upscale_repeats(1)
1401 .build()
1402 .unwrap();
1403
1404 gen_img(&config, &mut model_config).unwrap();
1405 let config2 = ConfigBuilder::from(config.clone())
1406 .prompt("a lovely duck drinking water from a straw")
1407 .output(PathBuf::from("./output_2.png"))
1408 .build()
1409 .unwrap();
1410 gen_img(&config2, &mut model_config).unwrap();
1411
1412 let config3 = ConfigBuilder::from(config)
1413 .prompt("a lovely dog drinking water from a starbucks cup")
1414 .batch_count(2)
1415 .output(PathBuf::from("./"))
1416 .build()
1417 .unwrap();
1418
1419 gen_img(&config3, &mut model_config).unwrap();
1420 }
1421}