1use std::cmp::max;
2use std::collections::HashMap;
3use std::ffi::CString;
4use std::ffi::c_char;
5use std::ffi::c_void;
6use std::fmt::Display;
7use std::path::Path;
8use std::path::PathBuf;
9use std::ptr::null;
10use std::ptr::null_mut;
11use std::slice;
12
13use chrono::Local;
14use derive_builder::Builder;
15use diffusion_rs_sys::free_upscaler_ctx;
16use diffusion_rs_sys::new_upscaler_ctx;
17use diffusion_rs_sys::sd_cache_mode_t;
18use diffusion_rs_sys::sd_cache_params_t;
19use diffusion_rs_sys::sd_ctx_params_t;
20use diffusion_rs_sys::sd_embedding_t;
21use diffusion_rs_sys::sd_get_default_sample_method;
22use diffusion_rs_sys::sd_get_default_scheduler;
23use diffusion_rs_sys::sd_guidance_params_t;
24use diffusion_rs_sys::sd_image_t;
25use diffusion_rs_sys::sd_img_gen_params_t;
26use diffusion_rs_sys::sd_lora_t;
27use diffusion_rs_sys::sd_pm_params_t;
28use diffusion_rs_sys::sd_sample_params_t;
29use diffusion_rs_sys::sd_set_preview_callback;
30use diffusion_rs_sys::sd_slg_params_t;
31use diffusion_rs_sys::sd_tiling_params_t;
32use diffusion_rs_sys::upscaler_ctx_t;
33use image::ImageBuffer;
34use image::ImageError;
35use image::RgbImage;
36use libc::free;
37use thiserror::Error;
38use walkdir::DirEntry;
39use walkdir::WalkDir;
40
41use diffusion_rs_sys::free_sd_ctx;
42use diffusion_rs_sys::new_sd_ctx;
43use diffusion_rs_sys::sd_ctx_t;
44
45pub use diffusion_rs_sys::rng_type_t as RngFunction;
47
48pub use diffusion_rs_sys::sample_method_t as SampleMethod;
50
51pub use diffusion_rs_sys::scheduler_t as Scheduler;
53
54pub use diffusion_rs_sys::prediction_t as Prediction;
56
57pub use diffusion_rs_sys::sd_type_t as WeightType;
59
60pub use diffusion_rs_sys::preview_t as PreviewType;
62
63pub use diffusion_rs_sys::lora_apply_mode_t as LoraModeType;
65
66static VALID_EXT: [&str; 3] = ["gguf", "safetensors", "pt"];
67
68#[non_exhaustive]
69#[derive(Error, Debug)]
70pub enum DiffusionError {
72 #[error("The underling stablediffusion.cpp function returned NULL")]
73 Forward,
74 #[error(transparent)]
75 StoreImages(#[from] ImageError),
76 #[error("The underling upscaler model returned a NULL image")]
77 Upscaler,
78}
79
80#[repr(i32)]
81#[non_exhaustive]
82#[derive(Debug, Default, Copy, Clone, Hash, PartialEq, Eq)]
83pub enum ClipSkip {
85 #[default]
87 Unspecified = 0,
88 None = 1,
89 OneLayer = 2,
90}
91
92type EmbeddingsStorage = (PathBuf, Vec<(CLibString, CLibPath)>, Vec<sd_embedding_t>);
93
94#[derive(Default, Debug, Clone)]
95struct LoraStorage {
96 lora_model_dir: CLibPath,
97 data: Vec<(CLibPath, String, f32)>,
98 loras_t: Vec<sd_lora_t>,
99}
100
101#[derive(Default, Debug, Clone)]
103pub struct LoraSpec {
104 pub file_name: String,
105 pub is_high_noise: bool,
106 pub multiplier: f32,
107}
108
109#[derive(Builder, Debug, Clone)]
111pub struct UCacheParams {
112 #[builder(default = "1.0")]
114 threshold: f32,
115
116 #[builder(default = "0.15")]
118 start: f32,
119
120 #[builder(default = "0.95")]
122 end: f32,
123
124 #[builder(default = "1.0")]
126 decay: f32,
127
128 #[builder(default = "true")]
130 relative: bool,
131
132 #[builder(default = "true")]
136 reset: bool,
137}
138
139#[derive(Builder, Debug, Clone)]
141pub struct EasyCacheParams {
142 #[builder(default = "0.2")]
144 threshold: f32,
145
146 #[builder(default = "0.15")]
148 start: f32,
149
150 #[builder(default = "0.95")]
152 end: f32,
153}
154
155#[derive(Builder, Debug, Clone)]
157pub struct DbCacheParams {
158 #[builder(default = "8")]
160 fn_blocks: i32,
161
162 #[builder(default = "0")]
164 bn_blocks: i32,
165
166 #[builder(default = "0.08")]
168 threshold: f32,
169
170 #[builder(default = "8")]
172 warmup: i32,
173
174 scm_mask: ScmPreset,
176
177 #[builder(default = "ScmPolicy::default()")]
179 scm_policy_dynamic: ScmPolicy,
180}
181
182#[derive(Debug, Default, Clone)]
184pub enum ScmPolicy {
185 Static,
187 #[default]
188 Dynamic,
190}
191
192#[derive(Debug, Default, Clone)]
194pub enum ScmPreset {
195 Slow,
196 #[default]
197 Medium,
198 Fast,
199 Ultra,
200 Custom(String),
203}
204
205impl ScmPreset {
206 fn to_vec_string(&self, steps: i32) -> String {
207 match self {
208 ScmPreset::Slow => ScmPresetBins {
209 compute_bins: vec![8, 3, 3, 2, 1, 1],
210 cache_bins: vec![1, 2, 2, 2, 3],
211 steps,
212 }
213 .to_string(),
214 ScmPreset::Medium => ScmPresetBins {
215 compute_bins: vec![6, 2, 2, 2, 2, 1],
216 cache_bins: vec![1, 3, 3, 3, 3],
217 steps,
218 }
219 .to_string(),
220 ScmPreset::Fast => ScmPresetBins {
221 compute_bins: vec![6, 1, 1, 1, 1, 1],
222 cache_bins: vec![1, 3, 4, 5, 4],
223 steps,
224 }
225 .to_string(),
226 ScmPreset::Ultra => ScmPresetBins {
227 compute_bins: vec![4, 1, 1, 1, 1],
228 cache_bins: vec![2, 5, 6, 7],
229 steps,
230 }
231 .to_string(),
232 ScmPreset::Custom(s) => s.clone(),
233 }
234 }
235}
236
237#[derive(Debug, Clone)]
238struct ScmPresetBins {
239 compute_bins: Vec<i32>,
240 cache_bins: Vec<i32>,
241 steps: i32,
242}
243
244impl ScmPresetBins {
245 fn maybe_scale(&self) -> ScmPresetBins {
246 if self.steps == 28 || self.steps <= 0 {
247 return self.clone();
248 }
249 self.scale()
250 }
251
252 fn scale(&self) -> ScmPresetBins {
253 let scale = self.steps as f32 / 28.0;
254 let scaled_compute_bins = self
255 .compute_bins
256 .iter()
257 .map(|b| max(1, (*b as f32 * scale * 0.5) as i32))
258 .collect();
259 let scaled_cached_bins = self
260 .cache_bins
261 .iter()
262 .map(|b| max(1, (*b as f32 * scale * 0.5) as i32))
263 .collect();
264 ScmPresetBins {
265 compute_bins: scaled_compute_bins,
266 cache_bins: scaled_cached_bins,
267 steps: self.steps,
268 }
269 }
270
271 fn generate_vec_mask(&self) -> Vec<i32> {
272 let mut mask = Vec::new();
273 let mut c_idx = 0;
274 let mut cache_idx = 0;
275
276 while mask.len() < self.steps as usize {
277 if c_idx < self.compute_bins.len() {
278 let compute_count = self.compute_bins[c_idx];
279 for _ in 0..compute_count {
280 if mask.len() < self.steps as usize {
281 mask.push(1);
282 }
283 }
284 c_idx += 1;
285 }
286 if cache_idx < self.cache_bins.len() {
287 let cache_count = self.cache_bins[c_idx];
288 for _ in 0..cache_count {
289 if mask.len() < self.steps as usize {
290 mask.push(0);
291 }
292 }
293 cache_idx += 1;
294 }
295 if c_idx >= self.compute_bins.len() && cache_idx >= self.cache_bins.len() {
296 break;
297 }
298 }
299 if let Some(last) = mask.last_mut() {
300 *last = 1;
301 }
302 mask
303 }
304}
305
306impl Display for ScmPresetBins {
307 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
308 let mask: String = self
309 .maybe_scale()
310 .generate_vec_mask()
311 .iter()
312 .map(|x| x.to_string())
313 .collect::<Vec<_>>()
314 .join(",");
315 write!(f, "{mask}")
316 }
317}
318
319#[derive(Builder, Debug, Clone)]
321#[builder(
322 setter(into, strip_option),
323 build_fn(error = "ConfigBuilderError", validate = "Self::validate")
324)]
325pub struct ModelConfig {
326 #[builder(default = "num_cpus::get_physical() as i32", setter(custom))]
329 n_threads: i32,
330
331 #[builder(default = "false")]
333 enable_mmap: bool,
334
335 #[builder(default = "false")]
337 offload_params_to_cpu: bool,
338
339 #[builder(default = "Default::default()")]
341 upscale_model: Option<CLibPath>,
342
343 #[builder(default = "1")]
345 upscale_repeats: i32,
346
347 #[builder(default = "128")]
349 upscale_tile_size: i32,
350
351 #[builder(default = "Default::default()")]
353 model: CLibPath,
354
355 #[builder(default = "Default::default()")]
357 diffusion_model: CLibPath,
358
359 #[builder(default = "Default::default()")]
361 llm: CLibPath,
362
363 #[builder(default = "Default::default()")]
365 llm_vision: CLibPath,
366
367 #[builder(default = "Default::default()")]
369 clip_l: CLibPath,
370
371 #[builder(default = "Default::default()")]
373 clip_g: CLibPath,
374
375 #[builder(default = "Default::default()")]
377 clip_vision: CLibPath,
378
379 #[builder(default = "Default::default()")]
381 t5xxl: CLibPath,
382
383 #[builder(default = "Default::default()")]
385 vae: CLibPath,
386
387 #[builder(default = "Default::default()")]
389 taesd: CLibPath,
390
391 #[builder(default = "Default::default()")]
393 control_net: CLibPath,
394
395 #[builder(default = "Default::default()", setter(custom))]
397 embeddings: EmbeddingsStorage,
398
399 #[builder(default = "Default::default()")]
401 photo_maker: CLibPath,
402
403 #[builder(default = "Default::default()")]
405 pm_id_embed_path: CLibPath,
406
407 #[builder(default = "WeightType::SD_TYPE_COUNT")]
409 weight_type: WeightType,
410
411 #[builder(default = "Default::default()", setter(custom))]
413 lora_models: LoraStorage,
414
415 #[builder(default = "Default::default()")]
417 high_noise_diffusion_model: CLibPath,
418
419 #[builder(default = "false")]
421 vae_tiling: bool,
422
423 #[builder(default = "(32,32)")]
425 vae_tile_size: (i32, i32),
426
427 #[builder(default = "(0.,0.)")]
429 vae_relative_tile_size: (f32, f32),
430
431 #[builder(default = "0.5")]
433 vae_tile_overlap: f32,
434
435 #[builder(default = "RngFunction::CUDA_RNG")]
437 rng: RngFunction,
438
439 #[builder(default = "RngFunction::RNG_TYPE_COUNT")]
441 sampler_rng_type: RngFunction,
442
443 #[builder(default = "Scheduler::SCHEDULER_COUNT")]
447 scheduler: Scheduler,
448
449 #[builder(default = "Default::default()")]
451 sigmas: Vec<f32>,
452
453 #[builder(default = "Prediction::PREDICTION_COUNT")]
455 prediction: Prediction,
456
457 #[builder(default = "false")]
459 vae_on_cpu: bool,
460
461 #[builder(default = "false")]
463 clip_on_cpu: bool,
464
465 #[builder(default = "false")]
467 control_net_cpu: bool,
468
469 #[builder(default = "false")]
472 flash_attention: bool,
473
474 #[builder(default = "false")]
476 chroma_disable_dit_mask: bool,
477
478 #[builder(default = "false")]
480 chroma_enable_t5_mask: bool,
481
482 #[builder(default = "1")]
484 chroma_t5_mask_pad: i32,
485
486 #[builder(default = "false")]
488 use_qwen_image_zero_cond_true: bool,
489
490 #[builder(default = "false")]
493 diffusion_conv_direct: bool,
494
495 #[builder(default = "false")]
498 vae_conv_direct: bool,
499
500 #[builder(default = "false")]
502 force_sdxl_vae_conv_scale: bool,
503
504 #[builder(default = "f32::INFINITY")]
506 flow_shift: f32,
507
508 #[builder(default = "0")]
510 timestep_shift: i32,
511
512 #[builder(default = "false")]
514 taesd_preview_only: bool,
515
516 #[builder(default = "LoraModeType::LORA_APPLY_AUTO")]
518 lora_apply_mode: LoraModeType,
519
520 #[builder(default = "false")]
522 circular: bool,
523
524 #[builder(default = "false")]
526 circular_x: bool,
527
528 #[builder(default = "false")]
530 circular_y: bool,
531
532 #[builder(default = "None", private)]
533 upscaler_ctx: Option<*mut upscaler_ctx_t>,
534
535 #[builder(default = "None", private)]
536 diffusion_ctx: Option<(*mut sd_ctx_t, sd_ctx_params_t)>,
537}
538
539impl ModelConfigBuilder {
540 fn validate(&self) -> Result<(), ConfigBuilderError> {
541 self.validate_model()
542 }
543
544 fn validate_model(&self) -> Result<(), ConfigBuilderError> {
545 self.model
546 .as_ref()
547 .or(self.diffusion_model.as_ref())
548 .map(|_| ())
549 .ok_or(ConfigBuilderError::UninitializedField(
550 "Model OR DiffusionModel must be valorized",
551 ))
552 }
553
554 fn filter_valid_extensions(&self, path: &Path) -> impl Iterator<Item = DirEntry> {
555 WalkDir::new(path)
556 .into_iter()
557 .filter_map(|entry| entry.ok())
558 .filter(|entry| {
559 entry
560 .path()
561 .extension()
562 .and_then(|ext| ext.to_str())
563 .map(|ext_str| VALID_EXT.contains(&ext_str))
564 .unwrap_or(false)
565 })
566 }
567
568 fn build_single_lora_storage(
569 spec: &LoraSpec,
570 is_high_noise: bool,
571 valid_loras: &HashMap<String, PathBuf>,
572 ) -> ((CLibPath, String, f32), sd_lora_t) {
573 let path = valid_loras.get(&spec.file_name).unwrap().as_path();
574 let c_path = CLibPath::from(path);
575 let lora = sd_lora_t {
576 is_high_noise,
577 multiplier: spec.multiplier,
578 path: c_path.as_ptr(),
579 };
580 let data = (c_path, spec.file_name.clone(), spec.multiplier);
581 (data, lora)
582 }
583
584 pub fn embeddings(&mut self, embeddings_dir: &Path) -> &mut Self {
585 let data: Vec<(CLibString, CLibPath)> = self
586 .filter_valid_extensions(embeddings_dir)
587 .map(|entry| {
588 let file_stem = entry
589 .path()
590 .file_stem()
591 .and_then(|stem| stem.to_str())
592 .unwrap_or_default()
593 .to_owned();
594 (CLibString::from(file_stem), CLibPath::from(entry.path()))
595 })
596 .collect();
597 let data_pointer = data
598 .iter()
599 .map(|(name, path)| sd_embedding_t {
600 name: name.as_ptr(),
601 path: path.as_ptr(),
602 })
603 .collect();
604 self.embeddings = Some((embeddings_dir.to_path_buf(), data, data_pointer));
605 self
606 }
607
608 pub fn lora_models(&mut self, lora_model_dir: &Path, specs: Vec<LoraSpec>) -> &mut Self {
609 let valid_loras: HashMap<String, PathBuf> = self
610 .filter_valid_extensions(lora_model_dir)
611 .map(|entry| {
612 let path = entry.path();
613 (
614 path.file_stem()
615 .and_then(|stem| stem.to_str())
616 .unwrap_or_default()
617 .to_owned(),
618 path.to_path_buf(),
619 )
620 })
621 .collect();
622 let valid_lora_names: Vec<&String> = valid_loras.keys().collect();
623 let standard = specs
624 .iter()
625 .filter(|s| valid_lora_names.contains(&&s.file_name) && !s.is_high_noise)
626 .map(|s| Self::build_single_lora_storage(s, false, &valid_loras));
627 let high_noise = specs
628 .iter()
629 .filter(|s| valid_lora_names.contains(&&s.file_name) && s.is_high_noise)
630 .map(|s| Self::build_single_lora_storage(s, true, &valid_loras));
631
632 let mut data = Vec::new();
633 let mut loras_t = Vec::new();
634 for lora in standard.chain(high_noise) {
635 data.push(lora.0);
636 loras_t.push(lora.1);
637 }
638
639 self.lora_models = Some(LoraStorage {
640 lora_model_dir: lora_model_dir.into(),
641 data,
642 loras_t,
643 });
644 self
645 }
646
647 pub fn n_threads(&mut self, value: i32) -> &mut Self {
648 self.n_threads = if value > 0 {
649 Some(value)
650 } else {
651 Some(num_cpus::get_physical() as i32)
652 };
653 self
654 }
655}
656
657impl ModelConfig {
658 unsafe fn upscaler_ctx(&mut self) -> Option<*mut upscaler_ctx_t> {
659 unsafe {
660 if self.upscale_model.is_none() || self.upscale_repeats == 0 {
661 None
662 } else {
663 if self.upscaler_ctx.is_none() {
664 let upscaler = new_upscaler_ctx(
665 self.upscale_model.as_ref().unwrap().as_ptr(),
666 self.offload_params_to_cpu,
667 self.diffusion_conv_direct,
668 self.n_threads,
669 self.upscale_tile_size,
670 );
671 self.upscaler_ctx = Some(upscaler);
672 }
673 self.upscaler_ctx
674 }
675 }
676 }
677
678 unsafe fn diffusion_ctx(&mut self, vae_decode_only: bool) -> *mut sd_ctx_t {
679 unsafe {
680 if self.diffusion_ctx.is_none() {
681 let sd_ctx_params = sd_ctx_params_t {
682 model_path: self.model.as_ptr(),
683 llm_path: self.llm.as_ptr(),
684 llm_vision_path: self.llm_vision.as_ptr(),
685 clip_l_path: self.clip_l.as_ptr(),
686 clip_g_path: self.clip_g.as_ptr(),
687 clip_vision_path: self.clip_vision.as_ptr(),
688 high_noise_diffusion_model_path: self.high_noise_diffusion_model.as_ptr(),
689 t5xxl_path: self.t5xxl.as_ptr(),
690 diffusion_model_path: self.diffusion_model.as_ptr(),
691 vae_path: self.vae.as_ptr(),
692 taesd_path: self.taesd.as_ptr(),
693 control_net_path: self.control_net.as_ptr(),
694 embeddings: self.embeddings.2.as_ptr(),
695 embedding_count: self.embeddings.1.len() as u32,
696 photo_maker_path: self.photo_maker.as_ptr(),
697 vae_decode_only,
698 free_params_immediately: false,
699 n_threads: self.n_threads,
700 wtype: self.weight_type,
701 rng_type: self.rng,
702 keep_clip_on_cpu: self.clip_on_cpu,
703 keep_control_net_on_cpu: self.control_net_cpu,
704 keep_vae_on_cpu: self.vae_on_cpu,
705 diffusion_flash_attn: self.flash_attention,
706 diffusion_conv_direct: self.diffusion_conv_direct,
707 chroma_use_dit_mask: !self.chroma_disable_dit_mask,
708 chroma_use_t5_mask: self.chroma_enable_t5_mask,
709 chroma_t5_mask_pad: self.chroma_t5_mask_pad,
710 vae_conv_direct: self.vae_conv_direct,
711 offload_params_to_cpu: self.offload_params_to_cpu,
712 flow_shift: self.flow_shift,
713 prediction: self.prediction,
714 force_sdxl_vae_conv_scale: self.force_sdxl_vae_conv_scale,
715 tae_preview_only: self.taesd_preview_only,
716 lora_apply_mode: self.lora_apply_mode,
717 tensor_type_rules: null_mut(),
718 sampler_rng_type: self.sampler_rng_type,
719 circular_x: self.circular || self.circular_x,
720 circular_y: self.circular || self.circular_y,
721 qwen_image_zero_cond_t: self.use_qwen_image_zero_cond_true,
722 enable_mmap: self.enable_mmap,
723 };
724 let ctx = new_sd_ctx(&sd_ctx_params);
725 self.diffusion_ctx = Some((ctx, sd_ctx_params))
726 }
727 self.diffusion_ctx.unwrap().0
728 }
729 }
730}
731
732impl Drop for ModelConfig {
733 fn drop(&mut self) {
734 unsafe {
736 if let Some((sd_ctx, _)) = self.diffusion_ctx {
737 free_sd_ctx(sd_ctx);
738 }
739
740 if let Some(upscaler_ctx) = self.upscaler_ctx {
741 free_upscaler_ctx(upscaler_ctx);
742 }
743 }
744 }
745}
746
747impl From<ModelConfig> for ModelConfigBuilder {
748 fn from(value: ModelConfig) -> Self {
749 let mut builder = ModelConfigBuilder::default();
750 builder
751 .n_threads(value.n_threads)
752 .offload_params_to_cpu(value.offload_params_to_cpu)
753 .upscale_repeats(value.upscale_repeats)
754 .model(value.model.clone())
755 .diffusion_model(value.diffusion_model.clone())
756 .llm(value.llm.clone())
757 .llm_vision(value.llm_vision.clone())
758 .clip_l(value.clip_l.clone())
759 .clip_g(value.clip_g.clone())
760 .clip_vision(value.clip_vision.clone())
761 .t5xxl(value.t5xxl.clone())
762 .vae(value.vae.clone())
763 .taesd(value.taesd.clone())
764 .control_net(value.control_net.clone())
765 .embeddings(&value.embeddings.0)
766 .photo_maker(value.photo_maker.clone())
767 .pm_id_embed_path(value.pm_id_embed_path.clone())
768 .weight_type(value.weight_type)
769 .high_noise_diffusion_model(value.high_noise_diffusion_model.clone())
770 .vae_tiling(value.vae_tiling)
771 .vae_tile_size(value.vae_tile_size)
772 .vae_relative_tile_size(value.vae_relative_tile_size)
773 .vae_tile_overlap(value.vae_tile_overlap)
774 .rng(value.rng)
775 .sampler_rng_type(value.rng)
776 .scheduler(value.scheduler)
777 .sigmas(value.sigmas.clone())
778 .prediction(value.prediction)
779 .vae_on_cpu(value.vae_on_cpu)
780 .clip_on_cpu(value.clip_on_cpu)
781 .control_net(value.control_net.clone())
782 .control_net_cpu(value.control_net_cpu)
783 .flash_attention(value.flash_attention)
784 .chroma_disable_dit_mask(value.chroma_disable_dit_mask)
785 .chroma_enable_t5_mask(value.chroma_enable_t5_mask)
786 .chroma_t5_mask_pad(value.chroma_t5_mask_pad)
787 .diffusion_conv_direct(value.diffusion_conv_direct)
788 .vae_conv_direct(value.vae_conv_direct)
789 .force_sdxl_vae_conv_scale(value.force_sdxl_vae_conv_scale)
790 .flow_shift(value.flow_shift)
791 .timestep_shift(value.timestep_shift)
792 .taesd_preview_only(value.taesd_preview_only)
793 .lora_apply_mode(value.lora_apply_mode)
794 .circular(value.circular)
795 .circular_x(value.circular_x)
796 .circular_y(value.circular_y)
797 .use_qwen_image_zero_cond_true(value.use_qwen_image_zero_cond_true);
798
799 let lora_model_dir = Into::<PathBuf>::into(&value.lora_models.lora_model_dir);
800 let lora_specs = value
801 .lora_models
802 .data
803 .iter()
804 .map(|(_, name, multiplier)| LoraSpec {
805 file_name: name.clone(),
806 is_high_noise: false,
807 multiplier: *multiplier,
808 })
809 .collect();
810
811 builder.lora_models(&lora_model_dir, lora_specs);
812
813 if let Some(model) = &value.upscale_model {
814 builder.upscale_model(model.clone());
815 }
816 builder
817 }
818}
819
820#[derive(Builder, Debug, Clone)]
821#[builder(setter(into, strip_option), build_fn(validate = "Self::validate"))]
822pub struct Config {
824 #[builder(default = "Default::default()")]
826 pm_id_images_dir: CLibPath,
827
828 #[builder(default = "Default::default()")]
830 init_img: CLibPath,
831
832 #[builder(default = "Default::default()")]
834 control_image: CLibPath,
835
836 #[builder(default = "PathBuf::from(\"./output.png\")")]
838 output: PathBuf,
839
840 #[builder(default = "PathBuf::from(\"./preview_output.png\")")]
842 preview_output: PathBuf,
843
844 #[builder(default = "PreviewType::PREVIEW_NONE")]
846 preview_mode: PreviewType,
847
848 #[builder(default = "false")]
850 preview_noisy: bool,
851
852 #[builder(default = "1")]
854 preview_interval: i32,
855
856 prompt: String,
858
859 #[builder(default = "\"\".into()")]
861 negative_prompt: CLibString,
862
863 #[builder(default = "7.0")]
865 cfg_scale: f32,
866
867 #[builder(default = "3.5")]
869 guidance: f32,
870
871 #[builder(default = "0.75")]
873 strength: f32,
874
875 #[builder(default = "20.0")]
877 pm_style_strength: f32,
878
879 #[builder(default = "0.9")]
882 control_strength: f32,
883
884 #[builder(default = "512")]
886 height: i32,
887
888 #[builder(default = "512")]
890 width: i32,
891
892 #[builder(default = "SampleMethod::SAMPLE_METHOD_COUNT")]
896 sampling_method: SampleMethod,
897
898 #[builder(default = "0.")]
900 eta: f32,
901
902 #[builder(default = "20")]
904 steps: i32,
905
906 #[builder(default = "42")]
908 seed: i64,
909
910 #[builder(default = "1")]
912 batch_count: i32,
913
914 #[builder(default = "ClipSkip::Unspecified")]
917 clip_skip: ClipSkip,
918
919 #[builder(default = "false")]
921 canny: bool,
922
923 #[builder(default = "0.")]
926 slg_scale: f32,
927
928 #[builder(default = "vec![7, 8, 9]")]
930 skip_layer: Vec<i32>,
931
932 #[builder(default = "0.01")]
934 skip_layer_start: f32,
935
936 #[builder(default = "0.2")]
938 skip_layer_end: f32,
939
940 #[builder(default = "false")]
942 disable_auto_resize_ref_image: bool,
943
944 #[builder(default = "Self::cache_init()", private)]
945 cache: sd_cache_params_t,
946
947 #[builder(default = "CLibString::default()", private)]
948 scm_mask: CLibString,
949}
950
951impl ConfigBuilder {
952 fn validate(&self) -> Result<(), ConfigBuilderError> {
953 self.validate_output_dir()
954 }
955
956 fn validate_output_dir(&self) -> Result<(), ConfigBuilderError> {
957 let is_dir = self.output.as_ref().is_some_and(|val| val.is_dir());
958 let multiple_items = self.batch_count.as_ref().is_some_and(|val| *val > 1);
959 if is_dir == multiple_items {
960 Ok(())
961 } else {
962 Err(ConfigBuilderError::ValidationError(
963 "When batch_count > 1, output should point to folder and vice versa".to_owned(),
964 ))
965 }
966 }
967
968 fn cache_init() -> sd_cache_params_t {
969 sd_cache_params_t {
970 mode: sd_cache_mode_t::SD_CACHE_DISABLED,
971 reuse_threshold: 1.0,
972 start_percent: 0.15,
973 end_percent: 0.95,
974 error_decay_rate: 1.0,
975 use_relative_threshold: true,
976 reset_error_on_compute: true,
977 Fn_compute_blocks: 8,
978 Bn_compute_blocks: 0,
979 residual_diff_threshold: 0.08,
980 max_warmup_steps: 8,
981 max_cached_steps: -1,
982 max_continuous_cached_steps: -1,
983 taylorseer_n_derivatives: 1,
984 taylorseer_skip_interval: 1,
985 scm_mask: null(),
986 scm_policy_dynamic: true,
987 }
988 }
989
990 pub fn no_caching(&mut self) -> &mut Self {
991 let mut cache = Self::cache_init();
992 cache.mode = sd_cache_mode_t::SD_CACHE_DISABLED;
993 self.cache = Some(cache);
994 self
995 }
996
997 pub fn ucache_caching(&mut self, params: UCacheParams) -> &mut Self {
998 let mut cache = Self::cache_init();
999 cache.mode = sd_cache_mode_t::SD_CACHE_UCACHE;
1000 cache.reuse_threshold = params.threshold;
1001 cache.start_percent = params.start;
1002 cache.end_percent = params.end;
1003 cache.error_decay_rate = params.decay;
1004 cache.use_relative_threshold = params.relative;
1005 cache.reset_error_on_compute = params.reset;
1006 self.cache = Some(cache);
1007 self
1008 }
1009
1010 pub fn easy_cache_caching(&mut self, params: EasyCacheParams) -> &mut Self {
1011 let mut cache = Self::cache_init();
1012 cache.mode = sd_cache_mode_t::SD_CACHE_EASYCACHE;
1013 cache.reuse_threshold = params.threshold;
1014 cache.start_percent = params.start;
1015 cache.end_percent = params.end;
1016 self.cache = Some(cache);
1017 self
1018 }
1019
1020 pub fn db_cache_caching(&mut self, params: DbCacheParams) -> &mut Self {
1021 let mut cache = Self::cache_init();
1022 cache.mode = sd_cache_mode_t::SD_CACHE_DBCACHE;
1023 cache.Fn_compute_blocks = params.fn_blocks;
1024 cache.Bn_compute_blocks = params.bn_blocks;
1025 cache.residual_diff_threshold = params.threshold;
1026 cache.max_warmup_steps = params.warmup;
1027 cache.scm_policy_dynamic = match params.scm_policy_dynamic {
1028 ScmPolicy::Static => false,
1029 ScmPolicy::Dynamic => true,
1030 };
1031 self.scm_mask = Some(CLibString::from(
1032 params
1033 .scm_mask
1034 .to_vec_string(self.steps.unwrap_or_default()),
1035 ));
1036 cache.scm_mask = self.scm_mask.as_ref().unwrap().as_ptr();
1037
1038 self.cache = Some(cache);
1039 self
1040 }
1041
1042 pub fn taylor_seer_caching(&mut self) -> &mut Self {
1043 let mut cache = Self::cache_init();
1044 cache.mode = sd_cache_mode_t::SD_CACHE_TAYLORSEER;
1045 self.cache = Some(cache);
1046 self
1047 }
1048
1049 pub fn cache_dit_caching(&mut self, params: DbCacheParams) -> &mut Self {
1050 self.db_cache_caching(params).cache.unwrap().mode = sd_cache_mode_t::SD_CACHE_CACHE_DIT;
1051 self
1052 }
1053}
1054
1055impl From<Config> for ConfigBuilder {
1056 fn from(value: Config) -> Self {
1057 let mut builder = ConfigBuilder::default();
1058 let mut cache = value.cache;
1059 let scm_mask = value.scm_mask.clone();
1060 cache.scm_mask = scm_mask.as_ptr();
1061 builder
1062 .pm_id_images_dir(value.pm_id_images_dir)
1063 .init_img(value.init_img)
1064 .control_image(value.control_image)
1065 .output(value.output)
1066 .prompt(value.prompt)
1067 .negative_prompt(value.negative_prompt)
1068 .cfg_scale(value.cfg_scale)
1069 .strength(value.strength)
1070 .pm_style_strength(value.pm_style_strength)
1071 .control_strength(value.control_strength)
1072 .height(value.height)
1073 .width(value.width)
1074 .sampling_method(value.sampling_method)
1075 .steps(value.steps)
1076 .seed(value.seed)
1077 .batch_count(value.batch_count)
1078 .clip_skip(value.clip_skip)
1079 .slg_scale(value.slg_scale)
1080 .skip_layer(value.skip_layer)
1081 .skip_layer_start(value.skip_layer_start)
1082 .skip_layer_end(value.skip_layer_end)
1083 .canny(value.canny)
1084 .disable_auto_resize_ref_image(value.disable_auto_resize_ref_image)
1085 .preview_output(value.preview_output)
1086 .preview_mode(value.preview_mode)
1087 .preview_noisy(value.preview_noisy)
1088 .preview_interval(value.preview_interval)
1089 .cache(cache)
1090 .scm_mask(scm_mask);
1091 builder
1092 }
1093}
1094
1095#[derive(Debug, Clone, Default)]
1096struct CLibString(CString);
1097
1098impl CLibString {
1099 fn as_ptr(&self) -> *const c_char {
1100 self.0.as_ptr()
1101 }
1102}
1103
1104impl From<&str> for CLibString {
1105 fn from(value: &str) -> Self {
1106 Self(CString::new(value).unwrap())
1107 }
1108}
1109
1110impl From<String> for CLibString {
1111 fn from(value: String) -> Self {
1112 Self(CString::new(value).unwrap())
1113 }
1114}
1115
1116#[derive(Debug, Clone, Default)]
1117struct CLibPath(CString);
1118
1119impl CLibPath {
1120 fn as_ptr(&self) -> *const c_char {
1121 self.0.as_ptr()
1122 }
1123}
1124
1125impl From<PathBuf> for CLibPath {
1126 fn from(value: PathBuf) -> Self {
1127 Self(CString::new(value.to_str().unwrap_or_default()).unwrap())
1128 }
1129}
1130
1131impl From<&Path> for CLibPath {
1132 fn from(value: &Path) -> Self {
1133 Self(CString::new(value.to_str().unwrap_or_default()).unwrap())
1134 }
1135}
1136
1137impl From<&CLibPath> for PathBuf {
1138 fn from(value: &CLibPath) -> Self {
1139 PathBuf::from(value.0.to_str().unwrap())
1140 }
1141}
1142
1143fn output_files(path: &Path, batch_size: i32) -> Vec<PathBuf> {
1144 let date = Local::now().format("%Y.%m.%d-%H.%M.%S");
1145 if batch_size == 1 {
1146 vec![path.into()]
1147 } else {
1148 (1..=batch_size)
1149 .map(|id| path.join(format!("output_{date}_{id}.png")))
1150 .collect()
1151 }
1152}
1153
1154unsafe fn upscale(
1155 upscale_repeats: i32,
1156 upscaler_ctx: Option<*mut upscaler_ctx_t>,
1157 data: sd_image_t,
1158) -> Result<sd_image_t, DiffusionError> {
1159 unsafe {
1160 match upscaler_ctx {
1161 Some(upscaler_ctx) => {
1162 let upscale_factor = 4; let mut current_image = data;
1164 for _ in 0..upscale_repeats {
1165 let upscaled_image =
1166 diffusion_rs_sys::upscale(upscaler_ctx, current_image, upscale_factor);
1167
1168 if upscaled_image.data.is_null() {
1169 return Err(DiffusionError::Upscaler);
1170 }
1171
1172 free(current_image.data as *mut c_void);
1173 current_image = upscaled_image;
1174 }
1175 Ok(current_image)
1176 }
1177 None => Ok(data),
1178 }
1179 }
1180}
1181
1182pub fn gen_img(config: &Config, model_config: &mut ModelConfig) -> Result<(), DiffusionError> {
1184 let prompt: CLibString = CLibString::from(config.prompt.as_str());
1185 let files = output_files(&config.output, config.batch_count);
1186 unsafe {
1187 let sd_ctx = model_config.diffusion_ctx(true);
1188 let upscaler_ctx = model_config.upscaler_ctx();
1189 let init_image = sd_image_t {
1190 width: 0,
1191 height: 0,
1192 channel: 3,
1193 data: null_mut(),
1194 };
1195 let mask_image = sd_image_t {
1196 width: config.width as u32,
1197 height: config.height as u32,
1198 channel: 1,
1199 data: null_mut(),
1200 };
1201 let mut layers = config.skip_layer.clone();
1202 let guidance = sd_guidance_params_t {
1203 txt_cfg: config.cfg_scale,
1204 img_cfg: config.cfg_scale,
1205 distilled_guidance: config.guidance,
1206 slg: sd_slg_params_t {
1207 layers: layers.as_mut_ptr(),
1208 layer_count: config.skip_layer.len(),
1209 layer_start: config.skip_layer_start,
1210 layer_end: config.skip_layer_end,
1211 scale: config.slg_scale,
1212 },
1213 };
1214 let scheduler = if model_config.scheduler == Scheduler::SCHEDULER_COUNT {
1215 sd_get_default_scheduler(sd_ctx, config.sampling_method)
1216 } else {
1217 model_config.scheduler
1218 };
1219 let sample_method = if config.sampling_method == SampleMethod::SAMPLE_METHOD_COUNT {
1220 sd_get_default_sample_method(sd_ctx)
1221 } else {
1222 config.sampling_method
1223 };
1224 let sample_params = sd_sample_params_t {
1225 guidance,
1226 sample_method,
1227 sample_steps: config.steps,
1228 eta: config.eta,
1229 scheduler,
1230 shifted_timestep: model_config.timestep_shift,
1231 custom_sigmas: model_config.sigmas.as_mut_ptr(),
1232 custom_sigmas_count: model_config.sigmas.len() as i32,
1233 };
1234 let control_image = sd_image_t {
1235 width: 0,
1236 height: 0,
1237 channel: 3,
1238 data: null_mut(),
1239 };
1240 let vae_tiling_params = sd_tiling_params_t {
1241 enabled: model_config.vae_tiling,
1242 tile_size_x: model_config.vae_tile_size.0,
1243 tile_size_y: model_config.vae_tile_size.1,
1244 target_overlap: model_config.vae_tile_overlap,
1245 rel_size_x: model_config.vae_relative_tile_size.0,
1246 rel_size_y: model_config.vae_relative_tile_size.1,
1247 };
1248 let pm_params = sd_pm_params_t {
1249 id_images: null_mut(),
1250 id_images_count: 0,
1251 id_embed_path: model_config.pm_id_embed_path.as_ptr(),
1252 style_strength: config.pm_style_strength,
1253 };
1254
1255 unsafe extern "C" fn save_preview_local(
1256 _step: ::std::os::raw::c_int,
1257 _frame_count: ::std::os::raw::c_int,
1258 frames: *mut sd_image_t,
1259 _is_noisy: bool,
1260 data: *mut ::std::os::raw::c_void,
1261 ) {
1262 unsafe {
1263 let path = &*data.cast::<PathBuf>();
1264 let _ = save_img(*frames, path);
1265 }
1266 }
1267
1268 if config.preview_mode != PreviewType::PREVIEW_NONE {
1269 let data = &config.preview_output as *const PathBuf;
1270
1271 sd_set_preview_callback(
1272 Some(save_preview_local),
1273 config.preview_mode,
1274 config.preview_interval,
1275 !config.preview_noisy,
1276 config.preview_noisy,
1277 data as *mut c_void,
1278 );
1279 }
1280
1281 let sd_img_gen_params = sd_img_gen_params_t {
1282 prompt: prompt.as_ptr(),
1283 negative_prompt: config.negative_prompt.as_ptr(),
1284 clip_skip: config.clip_skip as i32,
1285 init_image,
1286 ref_images: null_mut(),
1287 ref_images_count: 0,
1288 increase_ref_index: false,
1289 mask_image,
1290 width: config.width,
1291 height: config.height,
1292 sample_params,
1293 strength: config.strength,
1294 seed: config.seed,
1295 batch_count: config.batch_count,
1296 control_image,
1297 control_strength: config.control_strength,
1298 pm_params,
1299 vae_tiling_params,
1300 auto_resize_ref_image: config.disable_auto_resize_ref_image,
1301 cache: config.cache,
1302 loras: model_config.lora_models.loras_t.as_ptr(),
1303 lora_count: model_config.lora_models.loras_t.len() as u32,
1304 };
1305 let slice = diffusion_rs_sys::generate_image(sd_ctx, &sd_img_gen_params);
1306 let ret = {
1307 if slice.is_null() {
1308 return Err(DiffusionError::Forward);
1309 }
1310 for (img, path) in slice::from_raw_parts(slice, config.batch_count as usize)
1311 .iter()
1312 .zip(files)
1313 {
1314 match upscale(model_config.upscale_repeats, upscaler_ctx, *img) {
1315 Ok(img) => save_img(img, &path)?,
1316 Err(err) => {
1317 return Err(err);
1318 }
1319 }
1320 }
1321 Ok(())
1322 };
1323 free(slice as *mut c_void);
1324 ret
1325 }
1326}
1327
1328fn save_img(img: sd_image_t, path: &Path) -> Result<(), DiffusionError> {
1329 let len = (img.width * img.height * img.channel) as usize;
1331 let buffer = unsafe { slice::from_raw_parts(img.data, len).to_vec() };
1332 let save_state = ImageBuffer::from_raw(img.width, img.height, buffer)
1333 .map(|img| RgbImage::from(img).save(path));
1334 if let Some(Err(err)) = save_state {
1335 return Err(DiffusionError::StoreImages(err));
1336 }
1337 Ok(())
1338}
1339
1340#[cfg(test)]
1341mod tests {
1342 use std::path::PathBuf;
1343
1344 use crate::{
1345 api::{ConfigBuilderError, ModelConfigBuilder},
1346 util::download_file_hf_hub,
1347 };
1348
1349 use super::{ConfigBuilder, gen_img};
1350
1351 #[test]
1352 fn test_required_args_txt2img() {
1353 assert!(ConfigBuilder::default().build().is_err());
1354 assert!(ModelConfigBuilder::default().build().is_err());
1355 ModelConfigBuilder::default()
1356 .model(PathBuf::from("./test.ckpt"))
1357 .build()
1358 .unwrap();
1359
1360 ConfigBuilder::default()
1361 .prompt("a lovely cat driving a sport car")
1362 .build()
1363 .unwrap();
1364
1365 assert!(matches!(
1366 ConfigBuilder::default()
1367 .prompt("a lovely cat driving a sport car")
1368 .batch_count(10)
1369 .build(),
1370 Err(ConfigBuilderError::ValidationError(_))
1371 ));
1372
1373 ConfigBuilder::default()
1374 .prompt("a lovely cat driving a sport car")
1375 .build()
1376 .unwrap();
1377
1378 ConfigBuilder::default()
1379 .prompt("a lovely duck drinking water from a bottle")
1380 .batch_count(2)
1381 .output(PathBuf::from("./"))
1382 .build()
1383 .unwrap();
1384 }
1385
1386 #[ignore]
1387 #[test]
1388 fn test_img_gen() {
1389 let model_path =
1390 download_file_hf_hub("CompVis/stable-diffusion-v-1-4-original", "sd-v1-4.ckpt")
1391 .unwrap();
1392
1393 let upscaler_path = download_file_hf_hub(
1394 "ximso/RealESRGAN_x4plus_anime_6B",
1395 "RealESRGAN_x4plus_anime_6B.pth",
1396 )
1397 .unwrap();
1398 let config = ConfigBuilder::default()
1399 .prompt("a lovely duck drinking water from a bottle")
1400 .output(PathBuf::from("./output_1.png"))
1401 .batch_count(1)
1402 .build()
1403 .unwrap();
1404 let mut model_config = ModelConfigBuilder::default()
1405 .model(model_path)
1406 .upscale_model(upscaler_path)
1407 .upscale_repeats(1)
1408 .build()
1409 .unwrap();
1410
1411 gen_img(&config, &mut model_config).unwrap();
1412 let config2 = ConfigBuilder::from(config.clone())
1413 .prompt("a lovely duck drinking water from a straw")
1414 .output(PathBuf::from("./output_2.png"))
1415 .build()
1416 .unwrap();
1417 gen_img(&config2, &mut model_config).unwrap();
1418
1419 let config3 = ConfigBuilder::from(config)
1420 .prompt("a lovely dog drinking water from a starbucks cup")
1421 .batch_count(2)
1422 .output(PathBuf::from("./"))
1423 .build()
1424 .unwrap();
1425
1426 gen_img(&config3, &mut model_config).unwrap();
1427 }
1428}