1use std::cmp::max;
2use std::collections::HashMap;
3use std::ffi::CString;
4use std::ffi::c_char;
5use std::ffi::c_void;
6use std::fmt::Display;
7use std::path::Path;
8use std::path::PathBuf;
9use std::ptr::null;
10use std::ptr::null_mut;
11use std::slice;
12use std::sync::mpsc::Sender;
13
14use chrono::Local;
15use derive_builder::Builder;
16use diffusion_rs_sys::free_upscaler_ctx;
17use diffusion_rs_sys::generate_image;
18use diffusion_rs_sys::new_upscaler_ctx;
19use diffusion_rs_sys::sd_cache_mode_t;
20use diffusion_rs_sys::sd_cache_params_t;
21use diffusion_rs_sys::sd_ctx_params_t;
22use diffusion_rs_sys::sd_embedding_t;
23use diffusion_rs_sys::sd_get_default_sample_method;
24use diffusion_rs_sys::sd_get_default_scheduler;
25use diffusion_rs_sys::sd_guidance_params_t;
26use diffusion_rs_sys::sd_image_t;
27use diffusion_rs_sys::sd_img_gen_params_t;
28use diffusion_rs_sys::sd_img_gen_params_to_str;
29use diffusion_rs_sys::sd_lora_t;
30use diffusion_rs_sys::sd_pm_params_t;
31use diffusion_rs_sys::sd_sample_params_t;
32use diffusion_rs_sys::sd_set_preview_callback;
33use diffusion_rs_sys::sd_set_progress_callback;
34use diffusion_rs_sys::sd_slg_params_t;
35use diffusion_rs_sys::sd_tiling_params_t;
36use diffusion_rs_sys::upscaler_ctx_t;
37use image::ImageBuffer;
38use image::ImageError;
39use image::RgbImage;
40use libc::free;
41use little_exif::exif_tag::ExifTag;
42use little_exif::metadata::Metadata;
43use thiserror::Error;
44use walkdir::DirEntry;
45use walkdir::WalkDir;
46
47use diffusion_rs_sys::free_sd_ctx;
48use diffusion_rs_sys::new_sd_ctx;
49use diffusion_rs_sys::sd_ctx_t;
50
51pub use diffusion_rs_sys::rng_type_t as RngFunction;
53
54pub use diffusion_rs_sys::sample_method_t as SampleMethod;
56
57pub use diffusion_rs_sys::scheduler_t as Scheduler;
59
60pub use diffusion_rs_sys::prediction_t as Prediction;
62
63pub use diffusion_rs_sys::sd_type_t as WeightType;
65
66pub use diffusion_rs_sys::preview_t as PreviewType;
68
69pub use diffusion_rs_sys::lora_apply_mode_t as LoraModeType;
71
72static VALID_EXT: [&str; 3] = ["gguf", "safetensors", "pt"];
73
74#[allow(unused)]
75#[derive(Debug)]
76pub struct Progress {
78 step: i32,
79 steps: i32,
80 time: f32,
81}
82
83#[non_exhaustive]
84#[derive(Error, Debug)]
85pub enum DiffusionError {
87 #[error("The underling stablediffusion.cpp function returned NULL")]
88 Forward,
89 #[error(transparent)]
90 StoreImages(#[from] ImageError),
91 #[error(transparent)]
92 Io(#[from] std::io::Error),
93 #[error("The underling upscaler model returned a NULL image")]
94 Upscaler,
95}
96
97#[repr(i32)]
98#[non_exhaustive]
99#[derive(Debug, Default, Copy, Clone, Hash, PartialEq, Eq)]
100pub enum ClipSkip {
102 #[default]
104 Unspecified = 0,
105 None = 1,
106 OneLayer = 2,
107}
108
109type EmbeddingsStorage = (PathBuf, Vec<(CLibString, CLibPath)>, Vec<sd_embedding_t>);
110
111#[derive(Default, Debug, Clone)]
112struct LoraStorage {
113 lora_model_dir: CLibPath,
114 data: Vec<(CLibPath, String, f32)>,
115 loras_t: Vec<sd_lora_t>,
116}
117
118#[derive(Default, Debug, Clone)]
120pub struct LoraSpec {
121 pub file_name: String,
122 pub is_high_noise: bool,
123 pub multiplier: f32,
124}
125
126#[derive(Builder, Debug, Clone)]
128pub struct UCacheParams {
129 #[builder(default = "1.0")]
131 threshold: f32,
132
133 #[builder(default = "0.15")]
135 start: f32,
136
137 #[builder(default = "0.95")]
139 end: f32,
140
141 #[builder(default = "1.0")]
143 decay: f32,
144
145 #[builder(default = "true")]
147 relative: bool,
148
149 #[builder(default = "true")]
153 reset: bool,
154}
155
156#[derive(Builder, Debug, Clone)]
158pub struct EasyCacheParams {
159 #[builder(default = "0.2")]
161 threshold: f32,
162
163 #[builder(default = "0.15")]
165 start: f32,
166
167 #[builder(default = "0.95")]
169 end: f32,
170}
171
172#[derive(Builder, Debug, Clone)]
174pub struct DbCacheParams {
175 #[builder(default = "8")]
177 fn_blocks: i32,
178
179 #[builder(default = "0")]
181 bn_blocks: i32,
182
183 #[builder(default = "0.08")]
185 threshold: f32,
186
187 #[builder(default = "8")]
189 warmup: i32,
190
191 scm_mask: ScmPreset,
193
194 #[builder(default = "ScmPolicy::default()")]
196 scm_policy_dynamic: ScmPolicy,
197}
198
199#[derive(Debug, Default, Clone)]
201pub enum ScmPolicy {
202 Static,
204 #[default]
205 Dynamic,
207}
208
209#[derive(Debug, Default, Clone)]
211pub enum ScmPreset {
212 Slow,
213 #[default]
214 Medium,
215 Fast,
216 Ultra,
217 Custom(String),
220}
221
222impl ScmPreset {
223 fn to_vec_string(&self, steps: i32) -> String {
224 match self {
225 ScmPreset::Slow => ScmPresetBins {
226 compute_bins: vec![8, 3, 3, 2, 1, 1],
227 cache_bins: vec![1, 2, 2, 2, 3],
228 steps,
229 }
230 .to_string(),
231 ScmPreset::Medium => ScmPresetBins {
232 compute_bins: vec![6, 2, 2, 2, 2, 1],
233 cache_bins: vec![1, 3, 3, 3, 3],
234 steps,
235 }
236 .to_string(),
237 ScmPreset::Fast => ScmPresetBins {
238 compute_bins: vec![6, 1, 1, 1, 1, 1],
239 cache_bins: vec![1, 3, 4, 5, 4],
240 steps,
241 }
242 .to_string(),
243 ScmPreset::Ultra => ScmPresetBins {
244 compute_bins: vec![4, 1, 1, 1, 1],
245 cache_bins: vec![2, 5, 6, 7],
246 steps,
247 }
248 .to_string(),
249 ScmPreset::Custom(s) => s.clone(),
250 }
251 }
252}
253
254#[derive(Debug, Clone)]
255struct ScmPresetBins {
256 compute_bins: Vec<i32>,
257 cache_bins: Vec<i32>,
258 steps: i32,
259}
260
261impl ScmPresetBins {
262 fn maybe_scale(&self) -> ScmPresetBins {
263 if self.steps == 28 || self.steps <= 0 {
264 return self.clone();
265 }
266 self.scale()
267 }
268
269 fn scale(&self) -> ScmPresetBins {
270 let scale = self.steps as f32 / 28.0;
271 let scaled_compute_bins = self
272 .compute_bins
273 .iter()
274 .map(|b| max(1, (*b as f32 * scale * 0.5) as i32))
275 .collect();
276 let scaled_cached_bins = self
277 .cache_bins
278 .iter()
279 .map(|b| max(1, (*b as f32 * scale * 0.5) as i32))
280 .collect();
281 ScmPresetBins {
282 compute_bins: scaled_compute_bins,
283 cache_bins: scaled_cached_bins,
284 steps: self.steps,
285 }
286 }
287
288 fn generate_vec_mask(&self) -> Vec<i32> {
289 let mut mask = Vec::new();
290 let mut c_idx = 0;
291 let mut cache_idx = 0;
292
293 while mask.len() < self.steps as usize {
294 if c_idx < self.compute_bins.len() {
295 let compute_count = self.compute_bins[c_idx];
296 for _ in 0..compute_count {
297 if mask.len() < self.steps as usize {
298 mask.push(1);
299 }
300 }
301 c_idx += 1;
302 }
303 if cache_idx < self.cache_bins.len() {
304 let cache_count = self.cache_bins[c_idx];
305 for _ in 0..cache_count {
306 if mask.len() < self.steps as usize {
307 mask.push(0);
308 }
309 }
310 cache_idx += 1;
311 }
312 if c_idx >= self.compute_bins.len() && cache_idx >= self.cache_bins.len() {
313 break;
314 }
315 }
316 if let Some(last) = mask.last_mut() {
317 *last = 1;
318 }
319 mask
320 }
321}
322
323impl Display for ScmPresetBins {
324 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
325 let mask: String = self
326 .maybe_scale()
327 .generate_vec_mask()
328 .iter()
329 .map(|x| x.to_string())
330 .collect::<Vec<_>>()
331 .join(",");
332 write!(f, "{mask}")
333 }
334}
335
336#[derive(Builder, Debug, Clone)]
338#[builder(
339 setter(into, strip_option),
340 build_fn(error = "ConfigBuilderError", validate = "Self::validate")
341)]
342pub struct ModelConfig {
343 #[builder(default = "num_cpus::get_physical() as i32", setter(custom))]
346 n_threads: i32,
347
348 #[builder(default = "false")]
350 enable_mmap: bool,
351
352 #[builder(default = "false")]
354 offload_params_to_cpu: bool,
355
356 #[builder(default = "Default::default()")]
358 upscale_model: Option<CLibPath>,
359
360 #[builder(default = "1")]
362 upscale_repeats: i32,
363
364 #[builder(default = "128")]
366 upscale_tile_size: i32,
367
368 #[builder(default = "Default::default()")]
370 model: CLibPath,
371
372 #[builder(default = "Default::default()")]
374 diffusion_model: CLibPath,
375
376 #[builder(default = "Default::default()")]
378 llm: CLibPath,
379
380 #[builder(default = "Default::default()")]
382 llm_vision: CLibPath,
383
384 #[builder(default = "Default::default()")]
386 clip_l: CLibPath,
387
388 #[builder(default = "Default::default()")]
390 clip_g: CLibPath,
391
392 #[builder(default = "Default::default()")]
394 clip_vision: CLibPath,
395
396 #[builder(default = "Default::default()")]
398 t5xxl: CLibPath,
399
400 #[builder(default = "Default::default()")]
402 vae: CLibPath,
403
404 #[builder(default = "Default::default()")]
406 taesd: CLibPath,
407
408 #[builder(default = "Default::default()")]
410 control_net: CLibPath,
411
412 #[builder(default = "Default::default()", setter(custom))]
414 embeddings: EmbeddingsStorage,
415
416 #[builder(default = "Default::default()")]
418 photo_maker: CLibPath,
419
420 #[builder(default = "Default::default()")]
422 pm_id_embed_path: CLibPath,
423
424 #[builder(default = "WeightType::SD_TYPE_COUNT")]
426 weight_type: WeightType,
427
428 #[builder(default = "Default::default()", setter(custom))]
430 lora_models: LoraStorage,
431
432 #[builder(default = "Default::default()")]
434 high_noise_diffusion_model: CLibPath,
435
436 #[builder(default = "false")]
438 vae_tiling: bool,
439
440 #[builder(default = "(32,32)")]
442 vae_tile_size: (i32, i32),
443
444 #[builder(default = "(0.,0.)")]
446 vae_relative_tile_size: (f32, f32),
447
448 #[builder(default = "0.5")]
450 vae_tile_overlap: f32,
451
452 #[builder(default = "RngFunction::CUDA_RNG")]
454 rng: RngFunction,
455
456 #[builder(default = "RngFunction::RNG_TYPE_COUNT")]
458 sampler_rng_type: RngFunction,
459
460 #[builder(default = "Scheduler::SCHEDULER_COUNT")]
464 scheduler: Scheduler,
465
466 #[builder(default = "Default::default()")]
468 sigmas: Vec<f32>,
469
470 #[builder(default = "Prediction::PREDICTION_COUNT")]
472 prediction: Prediction,
473
474 #[builder(default = "false")]
476 vae_on_cpu: bool,
477
478 #[builder(default = "false")]
480 clip_on_cpu: bool,
481
482 #[builder(default = "false")]
484 control_net_cpu: bool,
485
486 #[builder(default = "false")]
489 flash_attention: bool,
490
491 #[builder(default = "false")]
493 chroma_disable_dit_mask: bool,
494
495 #[builder(default = "false")]
497 chroma_enable_t5_mask: bool,
498
499 #[builder(default = "1")]
501 chroma_t5_mask_pad: i32,
502
503 #[builder(default = "false")]
505 use_qwen_image_zero_cond_true: bool,
506
507 #[builder(default = "false")]
510 diffusion_conv_direct: bool,
511
512 #[builder(default = "false")]
515 vae_conv_direct: bool,
516
517 #[builder(default = "false")]
519 force_sdxl_vae_conv_scale: bool,
520
521 #[builder(default = "f32::INFINITY")]
523 flow_shift: f32,
524
525 #[builder(default = "0")]
527 timestep_shift: i32,
528
529 #[builder(default = "false")]
531 taesd_preview_only: bool,
532
533 #[builder(default = "LoraModeType::LORA_APPLY_AUTO")]
535 lora_apply_mode: LoraModeType,
536
537 #[builder(default = "false")]
539 circular: bool,
540
541 #[builder(default = "false")]
543 circular_x: bool,
544
545 #[builder(default = "false")]
547 circular_y: bool,
548
549 #[builder(default = "None", private)]
550 upscaler_ctx: Option<*mut upscaler_ctx_t>,
551
552 #[builder(default = "None", private)]
553 diffusion_ctx: Option<(*mut sd_ctx_t, sd_ctx_params_t)>,
554}
555
556impl ModelConfigBuilder {
557 fn validate(&self) -> Result<(), ConfigBuilderError> {
558 self.validate_model()
559 }
560
561 fn validate_model(&self) -> Result<(), ConfigBuilderError> {
562 self.model
563 .as_ref()
564 .or(self.diffusion_model.as_ref())
565 .map(|_| ())
566 .ok_or(ConfigBuilderError::UninitializedField(
567 "Model OR DiffusionModel must be valorized",
568 ))
569 }
570
571 fn filter_valid_extensions(&self, path: &Path) -> impl Iterator<Item = DirEntry> {
572 WalkDir::new(path)
573 .into_iter()
574 .filter_map(|entry| entry.ok())
575 .filter(|entry| {
576 entry
577 .path()
578 .extension()
579 .and_then(|ext| ext.to_str())
580 .map(|ext_str| VALID_EXT.contains(&ext_str))
581 .unwrap_or(false)
582 })
583 }
584
585 fn build_single_lora_storage(
586 spec: &LoraSpec,
587 is_high_noise: bool,
588 valid_loras: &HashMap<String, PathBuf>,
589 ) -> ((CLibPath, String, f32), sd_lora_t) {
590 let path = valid_loras.get(&spec.file_name).unwrap().as_path();
591 let c_path = CLibPath::from(path);
592 let lora = sd_lora_t {
593 is_high_noise,
594 multiplier: spec.multiplier,
595 path: c_path.as_ptr(),
596 };
597 let data = (c_path, spec.file_name.clone(), spec.multiplier);
598 (data, lora)
599 }
600
601 pub fn embeddings(&mut self, embeddings_dir: &Path) -> &mut Self {
602 let data: Vec<(CLibString, CLibPath)> = self
603 .filter_valid_extensions(embeddings_dir)
604 .map(|entry| {
605 let file_stem = entry
606 .path()
607 .file_stem()
608 .and_then(|stem| stem.to_str())
609 .unwrap_or_default()
610 .to_owned();
611 (CLibString::from(file_stem), CLibPath::from(entry.path()))
612 })
613 .collect();
614 let data_pointer = data
615 .iter()
616 .map(|(name, path)| sd_embedding_t {
617 name: name.as_ptr(),
618 path: path.as_ptr(),
619 })
620 .collect();
621 self.embeddings = Some((embeddings_dir.to_path_buf(), data, data_pointer));
622 self
623 }
624
625 pub fn lora_models(&mut self, lora_model_dir: &Path, specs: Vec<LoraSpec>) -> &mut Self {
626 let valid_loras: HashMap<String, PathBuf> = self
627 .filter_valid_extensions(lora_model_dir)
628 .map(|entry| {
629 let path = entry.path();
630 (
631 path.file_stem()
632 .and_then(|stem| stem.to_str())
633 .unwrap_or_default()
634 .to_owned(),
635 path.to_path_buf(),
636 )
637 })
638 .collect();
639 let valid_lora_names: Vec<&String> = valid_loras.keys().collect();
640 let standard = specs
641 .iter()
642 .filter(|s| valid_lora_names.contains(&&s.file_name) && !s.is_high_noise)
643 .map(|s| Self::build_single_lora_storage(s, false, &valid_loras));
644 let high_noise = specs
645 .iter()
646 .filter(|s| valid_lora_names.contains(&&s.file_name) && s.is_high_noise)
647 .map(|s| Self::build_single_lora_storage(s, true, &valid_loras));
648
649 let mut data = Vec::new();
650 let mut loras_t = Vec::new();
651 for lora in standard.chain(high_noise) {
652 data.push(lora.0);
653 loras_t.push(lora.1);
654 }
655
656 self.lora_models = Some(LoraStorage {
657 lora_model_dir: lora_model_dir.into(),
658 data,
659 loras_t,
660 });
661 self
662 }
663
664 pub fn n_threads(&mut self, value: i32) -> &mut Self {
665 self.n_threads = if value > 0 {
666 Some(value)
667 } else {
668 Some(num_cpus::get_physical() as i32)
669 };
670 self
671 }
672}
673
674impl ModelConfig {
675 unsafe fn upscaler_ctx(&mut self) -> Option<*mut upscaler_ctx_t> {
676 unsafe {
677 if self.upscale_model.is_none() || self.upscale_repeats == 0 {
678 None
679 } else {
680 if self.upscaler_ctx.is_none() {
681 let upscaler = new_upscaler_ctx(
682 self.upscale_model.as_ref().unwrap().as_ptr(),
683 self.offload_params_to_cpu,
684 self.diffusion_conv_direct,
685 self.n_threads,
686 self.upscale_tile_size,
687 );
688 self.upscaler_ctx = Some(upscaler);
689 }
690 self.upscaler_ctx
691 }
692 }
693 }
694
695 unsafe fn diffusion_ctx(&mut self, vae_decode_only: bool) -> *mut sd_ctx_t {
696 unsafe {
697 if self.diffusion_ctx.is_none() {
698 let sd_ctx_params = sd_ctx_params_t {
699 model_path: self.model.as_ptr(),
700 llm_path: self.llm.as_ptr(),
701 llm_vision_path: self.llm_vision.as_ptr(),
702 clip_l_path: self.clip_l.as_ptr(),
703 clip_g_path: self.clip_g.as_ptr(),
704 clip_vision_path: self.clip_vision.as_ptr(),
705 high_noise_diffusion_model_path: self.high_noise_diffusion_model.as_ptr(),
706 t5xxl_path: self.t5xxl.as_ptr(),
707 diffusion_model_path: self.diffusion_model.as_ptr(),
708 vae_path: self.vae.as_ptr(),
709 taesd_path: self.taesd.as_ptr(),
710 control_net_path: self.control_net.as_ptr(),
711 embeddings: self.embeddings.2.as_ptr(),
712 embedding_count: self.embeddings.1.len() as u32,
713 photo_maker_path: self.photo_maker.as_ptr(),
714 vae_decode_only,
715 free_params_immediately: false,
716 n_threads: self.n_threads,
717 wtype: self.weight_type,
718 rng_type: self.rng,
719 keep_clip_on_cpu: self.clip_on_cpu,
720 keep_control_net_on_cpu: self.control_net_cpu,
721 keep_vae_on_cpu: self.vae_on_cpu,
722 diffusion_flash_attn: self.flash_attention,
723 diffusion_conv_direct: self.diffusion_conv_direct,
724 chroma_use_dit_mask: !self.chroma_disable_dit_mask,
725 chroma_use_t5_mask: self.chroma_enable_t5_mask,
726 chroma_t5_mask_pad: self.chroma_t5_mask_pad,
727 vae_conv_direct: self.vae_conv_direct,
728 offload_params_to_cpu: self.offload_params_to_cpu,
729 flow_shift: self.flow_shift,
730 prediction: self.prediction,
731 force_sdxl_vae_conv_scale: self.force_sdxl_vae_conv_scale,
732 tae_preview_only: self.taesd_preview_only,
733 lora_apply_mode: self.lora_apply_mode,
734 tensor_type_rules: null_mut(),
735 sampler_rng_type: self.sampler_rng_type,
736 circular_x: self.circular || self.circular_x,
737 circular_y: self.circular || self.circular_y,
738 qwen_image_zero_cond_t: self.use_qwen_image_zero_cond_true,
739 enable_mmap: self.enable_mmap,
740 };
741 let ctx = new_sd_ctx(&sd_ctx_params);
742 self.diffusion_ctx = Some((ctx, sd_ctx_params))
743 }
744 self.diffusion_ctx.unwrap().0
745 }
746 }
747}
748
749impl Drop for ModelConfig {
750 fn drop(&mut self) {
751 unsafe {
753 if let Some((sd_ctx, _)) = self.diffusion_ctx {
754 free_sd_ctx(sd_ctx);
755 }
756
757 if let Some(upscaler_ctx) = self.upscaler_ctx {
758 free_upscaler_ctx(upscaler_ctx);
759 }
760 }
761 }
762}
763
764impl From<ModelConfig> for ModelConfigBuilder {
765 fn from(value: ModelConfig) -> Self {
766 let mut builder = ModelConfigBuilder::default();
767 builder
768 .n_threads(value.n_threads)
769 .offload_params_to_cpu(value.offload_params_to_cpu)
770 .upscale_repeats(value.upscale_repeats)
771 .model(value.model.clone())
772 .diffusion_model(value.diffusion_model.clone())
773 .llm(value.llm.clone())
774 .llm_vision(value.llm_vision.clone())
775 .clip_l(value.clip_l.clone())
776 .clip_g(value.clip_g.clone())
777 .clip_vision(value.clip_vision.clone())
778 .t5xxl(value.t5xxl.clone())
779 .vae(value.vae.clone())
780 .taesd(value.taesd.clone())
781 .control_net(value.control_net.clone())
782 .embeddings(&value.embeddings.0)
783 .photo_maker(value.photo_maker.clone())
784 .pm_id_embed_path(value.pm_id_embed_path.clone())
785 .weight_type(value.weight_type)
786 .high_noise_diffusion_model(value.high_noise_diffusion_model.clone())
787 .vae_tiling(value.vae_tiling)
788 .vae_tile_size(value.vae_tile_size)
789 .vae_relative_tile_size(value.vae_relative_tile_size)
790 .vae_tile_overlap(value.vae_tile_overlap)
791 .rng(value.rng)
792 .sampler_rng_type(value.rng)
793 .scheduler(value.scheduler)
794 .sigmas(value.sigmas.clone())
795 .prediction(value.prediction)
796 .vae_on_cpu(value.vae_on_cpu)
797 .clip_on_cpu(value.clip_on_cpu)
798 .control_net(value.control_net.clone())
799 .control_net_cpu(value.control_net_cpu)
800 .flash_attention(value.flash_attention)
801 .chroma_disable_dit_mask(value.chroma_disable_dit_mask)
802 .chroma_enable_t5_mask(value.chroma_enable_t5_mask)
803 .chroma_t5_mask_pad(value.chroma_t5_mask_pad)
804 .diffusion_conv_direct(value.diffusion_conv_direct)
805 .vae_conv_direct(value.vae_conv_direct)
806 .force_sdxl_vae_conv_scale(value.force_sdxl_vae_conv_scale)
807 .flow_shift(value.flow_shift)
808 .timestep_shift(value.timestep_shift)
809 .taesd_preview_only(value.taesd_preview_only)
810 .lora_apply_mode(value.lora_apply_mode)
811 .circular(value.circular)
812 .circular_x(value.circular_x)
813 .circular_y(value.circular_y)
814 .use_qwen_image_zero_cond_true(value.use_qwen_image_zero_cond_true);
815
816 let lora_model_dir = Into::<PathBuf>::into(&value.lora_models.lora_model_dir);
817 let lora_specs = value
818 .lora_models
819 .data
820 .iter()
821 .map(|(_, name, multiplier)| LoraSpec {
822 file_name: name.clone(),
823 is_high_noise: false,
824 multiplier: *multiplier,
825 })
826 .collect();
827
828 builder.lora_models(&lora_model_dir, lora_specs);
829
830 if let Some(model) = &value.upscale_model {
831 builder.upscale_model(model.clone());
832 }
833 builder
834 }
835}
836
837#[derive(Builder, Debug, Clone)]
838#[builder(setter(into, strip_option), build_fn(validate = "Self::validate"))]
839pub struct Config {
841 #[builder(default = "Default::default()")]
843 pm_id_images_dir: CLibPath,
844
845 #[builder(default = "Default::default()")]
847 init_img: CLibPath,
848
849 #[builder(default = "Default::default()")]
851 control_image: CLibPath,
852
853 #[builder(default = "PathBuf::from(\"./output.png\")")]
855 output: PathBuf,
856
857 #[builder(default = "PathBuf::from(\"./preview_output.png\")")]
859 preview_output: PathBuf,
860
861 #[builder(default = "PreviewType::PREVIEW_NONE")]
863 preview_mode: PreviewType,
864
865 #[builder(default = "false")]
867 preview_noisy: bool,
868
869 #[builder(default = "1")]
871 preview_interval: i32,
872
873 prompt: String,
875
876 #[builder(default = "\"\".into()")]
878 negative_prompt: CLibString,
879
880 #[builder(default = "7.0")]
882 cfg_scale: f32,
883
884 #[builder(default = "3.5")]
886 guidance: f32,
887
888 #[builder(default = "0.75")]
890 strength: f32,
891
892 #[builder(default = "20.0")]
894 pm_style_strength: f32,
895
896 #[builder(default = "0.9")]
899 control_strength: f32,
900
901 #[builder(default = "512")]
903 height: i32,
904
905 #[builder(default = "512")]
907 width: i32,
908
909 #[builder(default = "SampleMethod::SAMPLE_METHOD_COUNT")]
913 sampling_method: SampleMethod,
914
915 #[builder(default = "0.")]
917 eta: f32,
918
919 #[builder(default = "20")]
921 steps: i32,
922
923 #[builder(default = "42")]
925 seed: i64,
926
927 #[builder(default = "1")]
929 batch_count: i32,
930
931 #[builder(default = "ClipSkip::Unspecified")]
934 clip_skip: ClipSkip,
935
936 #[builder(default = "false")]
938 canny: bool,
939
940 #[builder(default = "0.")]
943 slg_scale: f32,
944
945 #[builder(default = "vec![7, 8, 9]")]
947 skip_layer: Vec<i32>,
948
949 #[builder(default = "0.01")]
951 skip_layer_start: f32,
952
953 #[builder(default = "0.2")]
955 skip_layer_end: f32,
956
957 #[builder(default = "false")]
959 disable_auto_resize_ref_image: bool,
960
961 #[builder(default = "Self::cache_init()", private)]
962 cache: sd_cache_params_t,
963
964 #[builder(default = "CLibString::default()", private)]
965 scm_mask: CLibString,
966}
967
968impl ConfigBuilder {
969 fn validate(&self) -> Result<(), ConfigBuilderError> {
970 self.validate_output_dir()
971 }
972
973 fn validate_output_dir(&self) -> Result<(), ConfigBuilderError> {
974 let is_dir = self.output.as_ref().is_some_and(|val| val.is_dir());
975 let multiple_items = self.batch_count.as_ref().is_some_and(|val| *val > 1);
976 if is_dir == multiple_items {
977 Ok(())
978 } else {
979 Err(ConfigBuilderError::ValidationError(
980 "When batch_count > 1, output should point to folder and vice versa".to_owned(),
981 ))
982 }
983 }
984
985 fn cache_init() -> sd_cache_params_t {
986 sd_cache_params_t {
987 mode: sd_cache_mode_t::SD_CACHE_DISABLED,
988 reuse_threshold: 1.0,
989 start_percent: 0.15,
990 end_percent: 0.95,
991 error_decay_rate: 1.0,
992 use_relative_threshold: true,
993 reset_error_on_compute: true,
994 Fn_compute_blocks: 8,
995 Bn_compute_blocks: 0,
996 residual_diff_threshold: 0.08,
997 max_warmup_steps: 8,
998 max_cached_steps: -1,
999 max_continuous_cached_steps: -1,
1000 taylorseer_n_derivatives: 1,
1001 taylorseer_skip_interval: 1,
1002 scm_mask: null(),
1003 scm_policy_dynamic: true,
1004 }
1005 }
1006
1007 pub fn no_caching(&mut self) -> &mut Self {
1008 let mut cache = Self::cache_init();
1009 cache.mode = sd_cache_mode_t::SD_CACHE_DISABLED;
1010 self.cache = Some(cache);
1011 self
1012 }
1013
1014 pub fn ucache_caching(&mut self, params: UCacheParams) -> &mut Self {
1015 let mut cache = Self::cache_init();
1016 cache.mode = sd_cache_mode_t::SD_CACHE_UCACHE;
1017 cache.reuse_threshold = params.threshold;
1018 cache.start_percent = params.start;
1019 cache.end_percent = params.end;
1020 cache.error_decay_rate = params.decay;
1021 cache.use_relative_threshold = params.relative;
1022 cache.reset_error_on_compute = params.reset;
1023 self.cache = Some(cache);
1024 self
1025 }
1026
1027 pub fn easy_cache_caching(&mut self, params: EasyCacheParams) -> &mut Self {
1028 let mut cache = Self::cache_init();
1029 cache.mode = sd_cache_mode_t::SD_CACHE_EASYCACHE;
1030 cache.reuse_threshold = params.threshold;
1031 cache.start_percent = params.start;
1032 cache.end_percent = params.end;
1033 self.cache = Some(cache);
1034 self
1035 }
1036
1037 pub fn db_cache_caching(&mut self, params: DbCacheParams) -> &mut Self {
1038 let mut cache = Self::cache_init();
1039 cache.mode = sd_cache_mode_t::SD_CACHE_DBCACHE;
1040 cache.Fn_compute_blocks = params.fn_blocks;
1041 cache.Bn_compute_blocks = params.bn_blocks;
1042 cache.residual_diff_threshold = params.threshold;
1043 cache.max_warmup_steps = params.warmup;
1044 cache.scm_policy_dynamic = match params.scm_policy_dynamic {
1045 ScmPolicy::Static => false,
1046 ScmPolicy::Dynamic => true,
1047 };
1048 self.scm_mask = Some(CLibString::from(
1049 params
1050 .scm_mask
1051 .to_vec_string(self.steps.unwrap_or_default()),
1052 ));
1053 cache.scm_mask = self.scm_mask.as_ref().unwrap().as_ptr();
1054
1055 self.cache = Some(cache);
1056 self
1057 }
1058
1059 pub fn taylor_seer_caching(&mut self) -> &mut Self {
1060 let mut cache = Self::cache_init();
1061 cache.mode = sd_cache_mode_t::SD_CACHE_TAYLORSEER;
1062 self.cache = Some(cache);
1063 self
1064 }
1065
1066 pub fn cache_dit_caching(&mut self, params: DbCacheParams) -> &mut Self {
1067 self.db_cache_caching(params).cache.unwrap().mode = sd_cache_mode_t::SD_CACHE_CACHE_DIT;
1068 self
1069 }
1070}
1071
1072impl From<Config> for ConfigBuilder {
1073 fn from(value: Config) -> Self {
1074 let mut builder = ConfigBuilder::default();
1075 let mut cache = value.cache;
1076 let scm_mask = value.scm_mask.clone();
1077 cache.scm_mask = scm_mask.as_ptr();
1078 builder
1079 .pm_id_images_dir(value.pm_id_images_dir)
1080 .init_img(value.init_img)
1081 .control_image(value.control_image)
1082 .output(value.output)
1083 .prompt(value.prompt)
1084 .negative_prompt(value.negative_prompt)
1085 .cfg_scale(value.cfg_scale)
1086 .strength(value.strength)
1087 .pm_style_strength(value.pm_style_strength)
1088 .control_strength(value.control_strength)
1089 .height(value.height)
1090 .width(value.width)
1091 .sampling_method(value.sampling_method)
1092 .steps(value.steps)
1093 .seed(value.seed)
1094 .batch_count(value.batch_count)
1095 .clip_skip(value.clip_skip)
1096 .slg_scale(value.slg_scale)
1097 .skip_layer(value.skip_layer)
1098 .skip_layer_start(value.skip_layer_start)
1099 .skip_layer_end(value.skip_layer_end)
1100 .canny(value.canny)
1101 .disable_auto_resize_ref_image(value.disable_auto_resize_ref_image)
1102 .preview_output(value.preview_output)
1103 .preview_mode(value.preview_mode)
1104 .preview_noisy(value.preview_noisy)
1105 .preview_interval(value.preview_interval)
1106 .cache(cache)
1107 .scm_mask(scm_mask);
1108 builder
1109 }
1110}
1111
1112#[derive(Debug, Clone, Default)]
1113struct CLibString(CString);
1114
1115impl CLibString {
1116 fn as_ptr(&self) -> *const c_char {
1117 self.0.as_ptr()
1118 }
1119}
1120
1121impl From<&str> for CLibString {
1122 fn from(value: &str) -> Self {
1123 Self(CString::new(value).unwrap())
1124 }
1125}
1126
1127impl From<String> for CLibString {
1128 fn from(value: String) -> Self {
1129 Self(CString::new(value).unwrap())
1130 }
1131}
1132
1133#[derive(Debug, Clone, Default)]
1134struct CLibPath(CString);
1135
1136impl CLibPath {
1137 fn as_ptr(&self) -> *const c_char {
1138 self.0.as_ptr()
1139 }
1140}
1141
1142impl From<PathBuf> for CLibPath {
1143 fn from(value: PathBuf) -> Self {
1144 Self(CString::new(value.to_str().unwrap_or_default()).unwrap())
1145 }
1146}
1147
1148impl From<&Path> for CLibPath {
1149 fn from(value: &Path) -> Self {
1150 Self(CString::new(value.to_str().unwrap_or_default()).unwrap())
1151 }
1152}
1153
1154impl From<&CLibPath> for PathBuf {
1155 fn from(value: &CLibPath) -> Self {
1156 PathBuf::from(value.0.to_str().unwrap())
1157 }
1158}
1159
1160fn output_files(path: &Path, batch_size: i32) -> Vec<PathBuf> {
1161 let date = Local::now().format("%Y.%m.%d-%H.%M.%S");
1162 if batch_size == 1 {
1163 vec![path.into()]
1164 } else {
1165 (1..=batch_size)
1166 .map(|id| path.join(format!("output_{date}_{id}.png")))
1167 .collect()
1168 }
1169}
1170
1171unsafe fn upscale(
1172 upscale_repeats: i32,
1173 upscaler_ctx: Option<*mut upscaler_ctx_t>,
1174 data: sd_image_t,
1175) -> Result<sd_image_t, DiffusionError> {
1176 unsafe {
1177 match upscaler_ctx {
1178 Some(upscaler_ctx) => {
1179 let upscale_factor = 4; let mut current_image = data;
1181 for _ in 0..upscale_repeats {
1182 let upscaled_image =
1183 diffusion_rs_sys::upscale(upscaler_ctx, current_image, upscale_factor);
1184
1185 if upscaled_image.data.is_null() {
1186 return Err(DiffusionError::Upscaler);
1187 }
1188
1189 free(current_image.data as *mut c_void);
1190 current_image = upscaled_image;
1191 }
1192 Ok(current_image)
1193 }
1194 None => Ok(data),
1195 }
1196 }
1197}
1198
1199pub fn gen_img_with_progress(
1201 config: &Config,
1202 model_config: &mut ModelConfig,
1203 sender: Sender<Progress>,
1204) -> Result<(), DiffusionError> {
1205 gen_img_maybe_progress(config, model_config, Some(sender))
1206}
1207
1208pub fn gen_img(config: &Config, model_config: &mut ModelConfig) -> Result<(), DiffusionError> {
1210 gen_img_maybe_progress(config, model_config, None)
1211}
1212
1213fn gen_img_maybe_progress(
1214 config: &Config,
1215 model_config: &mut ModelConfig,
1216 mut sender: Option<Sender<Progress>>,
1217) -> Result<(), DiffusionError> {
1218 let prompt: CLibString = CLibString::from(config.prompt.as_str());
1219 let files = output_files(&config.output, config.batch_count);
1220 unsafe {
1221 let sd_ctx = model_config.diffusion_ctx(true);
1222 let upscaler_ctx = model_config.upscaler_ctx();
1223 let init_image = sd_image_t {
1224 width: 0,
1225 height: 0,
1226 channel: 3,
1227 data: null_mut(),
1228 };
1229 let mask_image = sd_image_t {
1230 width: config.width as u32,
1231 height: config.height as u32,
1232 channel: 1,
1233 data: null_mut(),
1234 };
1235 let mut layers = config.skip_layer.clone();
1236 let guidance = sd_guidance_params_t {
1237 txt_cfg: config.cfg_scale,
1238 img_cfg: config.cfg_scale,
1239 distilled_guidance: config.guidance,
1240 slg: sd_slg_params_t {
1241 layers: layers.as_mut_ptr(),
1242 layer_count: config.skip_layer.len(),
1243 layer_start: config.skip_layer_start,
1244 layer_end: config.skip_layer_end,
1245 scale: config.slg_scale,
1246 },
1247 };
1248 let scheduler = if model_config.scheduler == Scheduler::SCHEDULER_COUNT {
1249 sd_get_default_scheduler(sd_ctx, config.sampling_method)
1250 } else {
1251 model_config.scheduler
1252 };
1253 let sample_method = if config.sampling_method == SampleMethod::SAMPLE_METHOD_COUNT {
1254 sd_get_default_sample_method(sd_ctx)
1255 } else {
1256 config.sampling_method
1257 };
1258 let sample_params = sd_sample_params_t {
1259 guidance,
1260 sample_method,
1261 sample_steps: config.steps,
1262 eta: config.eta,
1263 scheduler,
1264 shifted_timestep: model_config.timestep_shift,
1265 custom_sigmas: model_config.sigmas.as_mut_ptr(),
1266 custom_sigmas_count: model_config.sigmas.len() as i32,
1267 };
1268 let control_image = sd_image_t {
1269 width: 0,
1270 height: 0,
1271 channel: 3,
1272 data: null_mut(),
1273 };
1274 let vae_tiling_params = sd_tiling_params_t {
1275 enabled: model_config.vae_tiling,
1276 tile_size_x: model_config.vae_tile_size.0,
1277 tile_size_y: model_config.vae_tile_size.1,
1278 target_overlap: model_config.vae_tile_overlap,
1279 rel_size_x: model_config.vae_relative_tile_size.0,
1280 rel_size_y: model_config.vae_relative_tile_size.1,
1281 };
1282 let pm_params = sd_pm_params_t {
1283 id_images: null_mut(),
1284 id_images_count: 0,
1285 id_embed_path: model_config.pm_id_embed_path.as_ptr(),
1286 style_strength: config.pm_style_strength,
1287 };
1288
1289 unsafe extern "C" fn save_preview_local(
1290 _step: ::std::os::raw::c_int,
1291 _frame_count: ::std::os::raw::c_int,
1292 frames: *mut sd_image_t,
1293 _is_noisy: bool,
1294 data: *mut ::std::os::raw::c_void,
1295 ) {
1296 unsafe {
1297 let path = &*data.cast::<PathBuf>();
1298 let _ = save_img(*frames, path, None);
1299 }
1300 }
1301
1302 if config.preview_mode != PreviewType::PREVIEW_NONE {
1303 let data = &config.preview_output as *const PathBuf;
1304
1305 sd_set_preview_callback(
1306 Some(save_preview_local),
1307 config.preview_mode,
1308 config.preview_interval,
1309 !config.preview_noisy,
1310 config.preview_noisy,
1311 data as *mut c_void,
1312 );
1313 }
1314
1315 if sender.is_some() {
1316 unsafe extern "C" fn progress_callback(
1317 step: ::std::os::raw::c_int,
1318 steps: ::std::os::raw::c_int,
1319 time: f32,
1320 data: *mut ::std::os::raw::c_void,
1321 ) {
1322 unsafe {
1323 let sender = &*data.cast::<Option<Sender<Progress>>>();
1324
1325 if let Some(sender) = sender {
1326 let _ = sender.send(Progress { step, steps, time });
1327 }
1328 }
1329 }
1330 let sender_ptr: *mut c_void = &mut sender as *mut _ as *mut c_void;
1331 sd_set_progress_callback(Some(progress_callback), sender_ptr);
1332 }
1333
1334 let sd_img_gen_params = sd_img_gen_params_t {
1335 prompt: prompt.as_ptr(),
1336 negative_prompt: config.negative_prompt.as_ptr(),
1337 clip_skip: config.clip_skip as i32,
1338 init_image,
1339 ref_images: null_mut(),
1340 ref_images_count: 0,
1341 increase_ref_index: false,
1342 mask_image,
1343 width: config.width,
1344 height: config.height,
1345 sample_params,
1346 strength: config.strength,
1347 seed: config.seed,
1348 batch_count: config.batch_count,
1349 control_image,
1350 control_strength: config.control_strength,
1351 pm_params,
1352 vae_tiling_params,
1353 auto_resize_ref_image: config.disable_auto_resize_ref_image,
1354 cache: config.cache,
1355 loras: model_config.lora_models.loras_t.as_ptr(),
1356 lora_count: model_config.lora_models.loras_t.len() as u32,
1357 };
1358
1359 let params_str = CString::from_raw(sd_img_gen_params_to_str(&sd_img_gen_params))
1360 .into_string()
1361 .unwrap();
1362
1363 let slice = generate_image(sd_ctx, &sd_img_gen_params);
1364 let ret = {
1365 if slice.is_null() {
1366 return Err(DiffusionError::Forward);
1367 }
1368 for (img, path) in slice::from_raw_parts(slice, config.batch_count as usize)
1369 .iter()
1370 .zip(files)
1371 {
1372 match upscale(model_config.upscale_repeats, upscaler_ctx, *img) {
1373 Ok(img) => save_img(img, &path, Some(¶ms_str))?,
1374 Err(err) => {
1375 return Err(err);
1376 }
1377 }
1378 }
1379 Ok(())
1380 };
1381 free(slice as *mut c_void);
1382 ret
1383 }
1384}
1385
1386fn save_img(img: sd_image_t, path: &Path, params: Option<&str>) -> Result<(), DiffusionError> {
1387 let len = (img.width * img.height * img.channel) as usize;
1389 let buffer = unsafe { slice::from_raw_parts(img.data, len).to_vec() };
1390 let save_state = ImageBuffer::from_raw(img.width, img.height, buffer).map(|img| {
1391 RgbImage::from(img)
1392 .save(path)
1393 .map_err(DiffusionError::StoreImages)
1394 });
1395 if let Some(Err(err)) = save_state {
1396 return Err(err);
1397 }
1398 if let Some(params) = params {
1399 let mut metadata = Metadata::new();
1400 metadata.set_tag(ExifTag::ImageDescription(params.to_string()));
1401 metadata.write_to_file(path)?;
1402 }
1403 Ok(())
1404}
1405
1406#[cfg(test)]
1407mod tests {
1408 use std::path::PathBuf;
1409
1410 use crate::{
1411 api::{ConfigBuilderError, ModelConfigBuilder},
1412 util::download_file_hf_hub,
1413 };
1414
1415 use super::{ConfigBuilder, gen_img};
1416
1417 #[test]
1418 fn test_required_args_txt2img() {
1419 assert!(ConfigBuilder::default().build().is_err());
1420 assert!(ModelConfigBuilder::default().build().is_err());
1421 ModelConfigBuilder::default()
1422 .model(PathBuf::from("./test.ckpt"))
1423 .build()
1424 .unwrap();
1425
1426 ConfigBuilder::default()
1427 .prompt("a lovely cat driving a sport car")
1428 .build()
1429 .unwrap();
1430
1431 assert!(matches!(
1432 ConfigBuilder::default()
1433 .prompt("a lovely cat driving a sport car")
1434 .batch_count(10)
1435 .build(),
1436 Err(ConfigBuilderError::ValidationError(_))
1437 ));
1438
1439 ConfigBuilder::default()
1440 .prompt("a lovely cat driving a sport car")
1441 .build()
1442 .unwrap();
1443
1444 ConfigBuilder::default()
1445 .prompt("a lovely duck drinking water from a bottle")
1446 .batch_count(2)
1447 .output(PathBuf::from("./"))
1448 .build()
1449 .unwrap();
1450 }
1451
1452 #[ignore]
1453 #[test]
1454 fn test_img_gen() {
1455 let model_path =
1456 download_file_hf_hub("CompVis/stable-diffusion-v-1-4-original", "sd-v1-4.ckpt")
1457 .unwrap();
1458
1459 let upscaler_path = download_file_hf_hub(
1460 "ximso/RealESRGAN_x4plus_anime_6B",
1461 "RealESRGAN_x4plus_anime_6B.pth",
1462 )
1463 .unwrap();
1464 let config = ConfigBuilder::default()
1465 .prompt("a lovely duck drinking water from a bottle")
1466 .output(PathBuf::from("./output_1.png"))
1467 .batch_count(1)
1468 .build()
1469 .unwrap();
1470 let mut model_config = ModelConfigBuilder::default()
1471 .model(model_path)
1472 .upscale_model(upscaler_path)
1473 .upscale_repeats(1)
1474 .build()
1475 .unwrap();
1476
1477 gen_img(&config, &mut model_config).unwrap();
1478 let config2 = ConfigBuilder::from(config.clone())
1479 .prompt("a lovely duck drinking water from a straw")
1480 .output(PathBuf::from("./output_2.png"))
1481 .build()
1482 .unwrap();
1483 gen_img(&config2, &mut model_config).unwrap();
1484
1485 let config3 = ConfigBuilder::from(config)
1486 .prompt("a lovely dog drinking water from a starbucks cup")
1487 .batch_count(2)
1488 .output(PathBuf::from("./"))
1489 .build()
1490 .unwrap();
1491
1492 gen_img(&config3, &mut model_config).unwrap();
1493 }
1494}