1use std::collections::HashMap;
2use std::ffi::CString;
3use std::ffi::c_char;
4use std::ffi::c_void;
5use std::path::Path;
6use std::path::PathBuf;
7use std::ptr::null;
8use std::ptr::null_mut;
9use std::slice;
10use std::str::FromStr;
11use std::sync::mpsc::Sender;
12
13use chrono::Local;
14use derive_builder::Builder;
15use diffusion_rs_sys::free_upscaler_ctx;
16use diffusion_rs_sys::generate_image;
17use diffusion_rs_sys::new_upscaler_ctx;
18use diffusion_rs_sys::sd_cache_mode_t;
19use diffusion_rs_sys::sd_cache_params_t;
20use diffusion_rs_sys::sd_ctx_params_t;
21use diffusion_rs_sys::sd_embedding_t;
22use diffusion_rs_sys::sd_get_default_sample_method;
23use diffusion_rs_sys::sd_get_default_scheduler;
24use diffusion_rs_sys::sd_guidance_params_t;
25use diffusion_rs_sys::sd_hires_params_t;
26use diffusion_rs_sys::sd_image_t;
27use diffusion_rs_sys::sd_img_gen_params_t;
28use diffusion_rs_sys::sd_img_gen_params_to_str;
29use diffusion_rs_sys::sd_lora_t;
30use diffusion_rs_sys::sd_pm_params_t;
31use diffusion_rs_sys::sd_sample_params_t;
32use diffusion_rs_sys::sd_set_preview_callback;
33use diffusion_rs_sys::sd_set_progress_callback;
34use diffusion_rs_sys::sd_slg_params_t;
35use diffusion_rs_sys::sd_tiling_params_t;
36use diffusion_rs_sys::upscaler_ctx_t;
37use image::ImageBuffer;
38use image::ImageError;
39use image::RgbImage;
40use libc::free;
41use little_exif::exif_tag::ExifTag;
42use little_exif::metadata::Metadata;
43use thiserror::Error;
44use walkdir::DirEntry;
45use walkdir::WalkDir;
46
47use diffusion_rs_sys::free_sd_ctx;
48use diffusion_rs_sys::new_sd_ctx;
49use diffusion_rs_sys::sd_ctx_t;
50
51pub use diffusion_rs_sys::rng_type_t as RngFunction;
53
54pub use diffusion_rs_sys::sample_method_t as SampleMethod;
56
57pub use diffusion_rs_sys::scheduler_t as Scheduler;
59
60pub use diffusion_rs_sys::prediction_t as Prediction;
62
63pub use diffusion_rs_sys::sd_type_t as WeightType;
65
66pub use diffusion_rs_sys::preview_t as PreviewType;
68
69pub use diffusion_rs_sys::lora_apply_mode_t as LoraModeType;
71
72pub use diffusion_rs_sys::sd_hires_upscaler_t as Upscaler;
74
75pub use diffusion_rs_sys::sd_vae_format_t as VaeFormat;
77
78static VALID_EXT: [&str; 3] = ["gguf", "safetensors", "pt"];
79
80#[allow(unused)]
81#[derive(Debug)]
82pub struct Progress {
84 step: i32,
85 steps: i32,
86 time: f32,
87}
88
89#[non_exhaustive]
90#[derive(Error, Debug)]
91pub enum DiffusionError {
93 #[error("The underling stablediffusion.cpp function returned NULL")]
94 Forward,
95 #[error(transparent)]
96 StoreImages(#[from] ImageError),
97 #[error(transparent)]
98 Io(#[from] std::io::Error),
99 #[error("The underling upscaler model returned a NULL image")]
100 Upscaler,
101}
102
103#[non_exhaustive]
104#[derive(Clone, Debug, strum::Display, strum::EnumIter)]
105#[strum(serialize_all = "lowercase")]
106pub enum BackendDevice {
108 CPU,
109 CUDA0,
110 VULKAN0,
111 METAL,
112 GPU,
113 AUTO,
114 DISK,
115 DEFAULT,
116}
117
118#[non_exhaustive]
119#[derive(Clone, Debug, strum::Display, strum::EnumIter, PartialEq, Eq, Hash)]
120#[strum(serialize_all = "lowercase")]
121pub enum Module {
123 Diffusion,
124 Model,
125 Unet,
126 Dit,
127 Te,
128 Clip,
129 Text,
130 Textencoder,
131 Textencoders,
132 Conditioner,
133 Cond,
134 Llm,
135 T5,
136 T5xxl,
137 ClipVision,
138 Vision,
139 Vae,
140 Firststage,
141 Autoencoder,
142 Tae,
143 Controlnet,
144 Control,
145 Photomaker,
146 PhotomakerId,
147 PmId,
148 Photo,
149 Upscaler,
150 Esrgan,
151 Hires,
152}
153
154#[repr(i32)]
155#[non_exhaustive]
156#[derive(Debug, Default, Copy, Clone, Hash, PartialEq, Eq)]
157pub enum ClipSkip {
159 #[default]
161 Unspecified = 0,
162 None = 1,
163 OneLayer = 2,
164}
165
166type EmbeddingsStorage = (PathBuf, Vec<(CLibString, CLibPath)>, Vec<sd_embedding_t>);
167type LoraStorage = Vec<(CLibPath, LoraSpec)>;
168
169#[derive(Default, Debug, Clone)]
171pub struct LoraSpec {
172 pub file_name: String,
173 pub is_high_noise: bool,
174 pub multiplier: f32,
175}
176
177#[derive(Builder, Debug, Clone)]
179pub struct SpectrumCacheParams {
180 #[builder(default = "0.40")]
182 w: f32,
183
184 #[builder(default = "3")]
186 m: i32,
187
188 #[builder(default = "1.0")]
190 lam: f32,
191
192 #[builder(default = "2")]
194 window: i32,
195
196 #[builder(default = "0.50")]
198 flex: f32,
199
200 #[builder(default = "4")]
202 warmup: i32,
203
204 #[builder(default = "0.9")]
206 stop: f32,
207}
208
209#[derive(Builder, Debug, Clone)]
211pub struct UCacheParams {
212 #[builder(default = "1.0")]
214 threshold: f32,
215
216 #[builder(default = "0.15")]
218 start: f32,
219
220 #[builder(default = "0.95")]
222 end: f32,
223
224 #[builder(default = "1.0")]
226 decay: f32,
227
228 #[builder(default = "true")]
230 relative: bool,
231
232 #[builder(default = "true")]
236 reset: bool,
237}
238
239#[derive(Builder, Debug, Clone)]
241pub struct EasyCacheParams {
242 #[builder(default = "0.2")]
244 threshold: f32,
245
246 #[builder(default = "0.15")]
248 start: f32,
249
250 #[builder(default = "0.95")]
252 end: f32,
253}
254
255#[derive(Builder, Debug, Clone)]
257pub struct DbCacheParams {
258 #[builder(default = "8")]
260 fn_blocks: i32,
261
262 #[builder(default = "0")]
264 bn_blocks: i32,
265
266 #[builder(default = "0.08")]
268 threshold: f32,
269
270 #[builder(default = "8")]
272 warmup: i32,
273
274 #[builder(default = "CLibString::default()")]
278 scm_mask: CLibString,
279
280 #[builder(default = "ScmPolicy::default()")]
282 scm_policy_dynamic: ScmPolicy,
283}
284
285#[derive(Debug, Default, Clone)]
287pub enum ScmPolicy {
288 Static,
290 #[default]
291 Dynamic,
293}
294
295#[derive(Builder, Debug, Clone)]
297pub struct HiresParams {
298 #[builder(default = "0")]
300 width: i32,
301 #[builder(default = "0")]
303 height: i32,
304 #[builder(default = "0")]
306 steps: i32,
307 #[builder(default = "128")]
309 upscale_tile_size: i32,
310 #[builder(default = "2.0")]
312 scale: f32,
313 #[builder(default = "0.7")]
315 denoising_strength: f32,
316 #[builder(default = "None")]
318 hires_sigmas: Option<Vec<f32>>,
319}
320
321#[derive(Builder, Debug)]
323#[builder(
324 setter(into, strip_option),
325 build_fn(error = "ConfigBuilderError", validate = "Self::validate")
326)]
327pub struct ModelConfig {
328 #[builder(default = "num_cpus::get_physical() as i32", setter(custom))]
331 n_threads: i32,
332
333 #[builder(default = "false")]
335 enable_mmap: bool,
336
337 #[builder(default = "-1.0")]
339 max_vram: f32,
340
341 #[builder(default = "Default::default()")]
343 upscale_model: Option<CLibPath>,
344
345 #[builder(default = "1")]
347 upscale_repeats: i32,
348
349 #[builder(default = "128")]
351 upscale_tile_size: i32,
352
353 #[builder(default = "Default::default()")]
355 model: CLibPath,
356
357 #[builder(default = "Default::default()")]
359 diffusion_model: CLibPath,
360
361 #[builder(default = "Default::default()")]
363 unconditional_diffusion_model: CLibPath,
364
365 #[builder(default = "Default::default()")]
367 llm: CLibPath,
368
369 #[builder(default = "Default::default()")]
371 llm_vision: CLibPath,
372
373 #[builder(default = "Default::default()")]
375 clip_l: CLibPath,
376
377 #[builder(default = "Default::default()")]
379 clip_g: CLibPath,
380
381 #[builder(default = "Default::default()")]
383 clip_vision: CLibPath,
384
385 #[builder(default = "Default::default()")]
387 t5xxl: CLibPath,
388
389 #[builder(default = "Default::default()")]
391 vae: CLibPath,
392
393 #[builder(default = "VaeFormat::SD_VAE_FORMAT_AUTO")]
395 vae_format: VaeFormat,
396
397 #[builder(default = "Default::default()")]
399 taesd: CLibPath,
400
401 #[builder(default = "Default::default()")]
403 control_net: CLibPath,
404
405 #[builder(default = "Default::default()", setter(custom))]
407 embeddings: EmbeddingsStorage,
408
409 #[builder(default = "Default::default()")]
411 photo_maker: CLibPath,
412
413 #[builder(default = "Default::default()")]
415 pm_id_embed_path: CLibPath,
416
417 #[builder(default = "WeightType::SD_TYPE_COUNT")]
419 weight_type: WeightType,
420
421 #[builder(default = "Default::default()", setter(custom))]
423 lora_models: LoraStorage,
424
425 #[builder(default = "Default::default()")]
427 high_noise_diffusion_model: CLibPath,
428
429 #[builder(default = "false")]
431 vae_tiling: bool,
432
433 #[builder(default = "(32,32)")]
435 vae_tile_size: (i32, i32),
436
437 #[builder(default = "(0.,0.)")]
439 vae_relative_tile_size: (f32, f32),
440
441 #[builder(default = "0.5")]
443 vae_tile_overlap: f32,
444
445 #[builder(default = "RngFunction::CUDA_RNG")]
447 rng: RngFunction,
448
449 #[builder(default = "RngFunction::RNG_TYPE_COUNT")]
451 sampler_rng_type: RngFunction,
452
453 #[builder(default = "Scheduler::SCHEDULER_COUNT")]
457 scheduler: Scheduler,
458
459 #[builder(default = "Default::default()")]
461 sigmas: Vec<f32>,
462
463 #[builder(default = "Prediction::PREDICTION_COUNT")]
465 prediction: Prediction,
466
467 #[builder(default = "false")]
470 diffusion_flash_attention: bool,
471
472 #[builder(default = "false")]
475 flash_attention: bool,
476
477 #[builder(default = "false")]
479 chroma_disable_dit_mask: bool,
480
481 #[builder(default = "false")]
483 chroma_enable_t5_mask: bool,
484
485 #[builder(default = "1")]
487 chroma_t5_mask_pad: i32,
488
489 #[builder(default = "false")]
491 use_qwen_image_zero_cond_true: bool,
492
493 #[builder(default = "false")]
496 diffusion_conv_direct: bool,
497
498 #[builder(default = "false")]
501 vae_conv_direct: bool,
502
503 #[builder(default = "false")]
505 force_sdxl_vae_conv_scale: bool,
506
507 #[builder(default = "f32::INFINITY")]
509 flow_shift: f32,
510
511 #[builder(default = "0")]
513 timestep_shift: i32,
514
515 #[builder(default = "false")]
517 taesd_preview_only: bool,
518
519 #[builder(default = "LoraModeType::LORA_APPLY_AUTO")]
521 lora_apply_mode: LoraModeType,
522
523 #[builder(default = "false")]
525 circular: bool,
526
527 #[builder(default = "false")]
529 circular_x: bool,
530
531 #[builder(default = "false")]
533 circular_y: bool,
534
535 #[builder(default = "Self::hires_init()", setter(custom))]
537 hires_params: (Upscaler, HiresParams, Option<CLibPath>),
538
539 #[builder(default = "CLibString::default()")]
541 extra_sample_params: CLibString,
542
543 #[builder(default = "(None, CLibString::default())", setter(custom))]
545 backend: (Option<HashMap<Module, BackendDevice>>, CLibString),
546
547 #[builder(default = "(None, CLibString::default())", setter(custom))]
549 params_backend: (Option<HashMap<Module, BackendDevice>>, CLibString),
550
551 #[builder(default = "CLibPath::default()")]
553 embeddings_connectors: CLibPath,
554
555 #[builder(default = "CLibPath::default()")]
557 audio_vae: CLibPath,
558
559 #[builder(default = "true")]
561 vae_temporal_tiling: bool,
562
563 #[builder(default = "(None, CLibString::default())", setter(custom))]
565 extra_tiling_args: (Option<HashMap<String, String>>, CLibString),
566
567 #[builder(default = "false")]
569 stream_layers: bool,
570
571 #[builder(default = "None", private)]
572 upscaler_ctx: Option<*mut upscaler_ctx_t>,
573
574 #[builder(default = "None", private)]
575 diffusion_ctx: Option<(*mut sd_ctx_t, sd_ctx_params_t)>,
576}
577
578impl ModelConfigBuilder {
579 fn validate(&self) -> Result<(), ConfigBuilderError> {
580 self.validate_model()
581 }
582
583 fn validate_model(&self) -> Result<(), ConfigBuilderError> {
584 self.model
585 .as_ref()
586 .or(self.diffusion_model.as_ref())
587 .map(|_| ())
588 .ok_or(ConfigBuilderError::UninitializedField(
589 "Model OR DiffusionModel must be valorized",
590 ))
591 }
592
593 fn filter_valid_extensions(path: &Path) -> impl Iterator<Item = DirEntry> {
594 WalkDir::new(path)
595 .into_iter()
596 .filter_map(|entry| entry.ok())
597 .filter(|entry| {
598 entry
599 .path()
600 .extension()
601 .and_then(|ext| ext.to_str())
602 .map(|ext_str| VALID_EXT.contains(&ext_str))
603 .unwrap_or(false)
604 })
605 }
606 fn build_single_lora_storage(
607 spec: &LoraSpec,
608 valid_loras: &HashMap<String, PathBuf>,
609 ) -> (CLibPath, LoraSpec) {
610 let path = valid_loras.get(&spec.file_name).unwrap().as_path();
611 let c_path = CLibPath::from(path);
612 (c_path, spec.clone())
613 }
614
615 pub fn embeddings(&mut self, embeddings_dir: &Path) -> &mut Self {
616 let data: Vec<(CLibString, CLibPath)> = Self::filter_valid_extensions(embeddings_dir)
617 .map(|entry| {
618 let file_stem = entry
619 .path()
620 .file_stem()
621 .and_then(|stem| stem.to_str())
622 .unwrap_or_default()
623 .to_owned();
624 (CLibString::from(file_stem), CLibPath::from(entry.path()))
625 })
626 .collect();
627 let data_pointer = data
628 .iter()
629 .map(|(name, path)| sd_embedding_t {
630 name: name.as_ptr(),
631 path: path.as_ptr(),
632 })
633 .collect();
634 self.embeddings = Some((embeddings_dir.to_path_buf(), data, data_pointer));
635 self
636 }
637
638 pub fn lora_models(&mut self, lora_model_dir: &Path, specs: Vec<LoraSpec>) -> &mut Self {
639 let valid_loras: HashMap<String, PathBuf> = Self::filter_valid_extensions(lora_model_dir)
640 .map(|entry| {
641 let path = entry.path();
642 (
643 path.file_stem()
644 .and_then(|stem| stem.to_str())
645 .unwrap_or_default()
646 .to_owned(),
647 path.to_path_buf(),
648 )
649 })
650 .collect();
651 let valid_lora_names: Vec<&String> = valid_loras.keys().collect();
652 let standard = specs
653 .iter()
654 .filter(|s| valid_lora_names.contains(&&s.file_name) && !s.is_high_noise)
655 .map(|s| Self::build_single_lora_storage(s, &valid_loras));
656 let high_noise = specs
657 .iter()
658 .filter(|s| valid_lora_names.contains(&&s.file_name) && s.is_high_noise)
659 .map(|s| Self::build_single_lora_storage(s, &valid_loras));
660
661 self.lora_models_internal(standard.chain(high_noise).collect())
662 }
663
664 fn lora_models_internal(&mut self, lora_storage: LoraStorage) -> &mut Self {
665 self.lora_models = Some(lora_storage);
666 self
667 }
668
669 pub fn n_threads(&mut self, value: i32) -> &mut Self {
670 self.n_threads = if value > 0 {
671 Some(value)
672 } else {
673 Some(num_cpus::get_physical() as i32)
674 };
675 self
676 }
677
678 pub fn hires_params(
679 &mut self,
680 upscaler: Upscaler,
681 params: HiresParams,
682 custom_model: Option<&Path>,
683 ) -> &mut Self {
684 if upscaler == Upscaler::SD_HIRES_UPSCALER_COUNT
685 || (upscaler == Upscaler::SD_HIRES_UPSCALER_MODEL && custom_model.is_none())
686 {
687 panic!("Invalid combination for {upscaler:?} and {custom_model:?}")
688 }
689 self.hires_params = Some((upscaler, params, custom_model.map(Into::into)));
690
691 self
692 }
693
694 fn hires_init() -> (Upscaler, HiresParams, Option<CLibPath>) {
695 (
696 Upscaler::SD_HIRES_UPSCALER_NONE,
697 HiresParamsBuilder::default().build().unwrap(),
698 None,
699 )
700 }
701
702 pub fn backend(&mut self, backend_map: HashMap<Module, BackendDevice>) -> &mut Self {
703 let backend_str = backend_map
704 .iter()
705 .map(|(key, value)| format!("{}={}", key, value))
706 .collect::<Vec<String>>()
707 .join(",");
708 self.backend = Some((Some(backend_map), CLibString::from(backend_str)));
709 self
710 }
711
712 pub fn params_backend(&mut self, backend_map: HashMap<Module, BackendDevice>) -> &mut Self {
713 let params_backend_str = backend_map
714 .iter()
715 .map(|(key, value)| format!("{}={}", key, value))
716 .collect::<Vec<String>>()
717 .join(",");
718 self.params_backend = Some((Some(backend_map), CLibString::from(params_backend_str)));
719 self
720 }
721
722 pub fn extra_tiling_args(
723 &mut self,
724 extra_tiling_args_map: HashMap<String, String>,
725 ) -> &mut Self {
726 let extra_tiling_args_str = extra_tiling_args_map
727 .iter()
728 .map(|(key, value)| format!("{}={}", key, value))
729 .collect::<Vec<String>>()
730 .join(",");
731 self.extra_tiling_args = Some((
732 Some(extra_tiling_args_map),
733 CLibString::from(extra_tiling_args_str),
734 ));
735 self
736 }
737}
738
739impl ModelConfig {
740 unsafe fn upscaler_ctx(&mut self) -> Option<*mut upscaler_ctx_t> {
741 unsafe {
742 if self.upscale_model.is_none() || self.upscale_repeats == 0 {
743 None
744 } else {
745 if self.upscaler_ctx.is_none() {
746 let upscaler = new_upscaler_ctx(
747 self.upscale_model.as_ref().unwrap().as_ptr(),
748 self.diffusion_conv_direct,
749 self.n_threads,
750 self.upscale_tile_size,
751 self.backend.1.as_ptr(),
752 self.params_backend.1.as_ptr(),
753 );
754 self.upscaler_ctx = Some(upscaler);
755 }
756 self.upscaler_ctx
757 }
758 }
759 }
760
761 unsafe fn diffusion_ctx(&mut self) -> *mut sd_ctx_t {
762 unsafe {
763 if let Some((sd_ctx, _)) = self.diffusion_ctx.as_ref() {
767 sd_set_progress_callback(None, null_mut());
768 free_sd_ctx(*sd_ctx);
769 self.diffusion_ctx = None;
770 }
771 if self.diffusion_ctx.is_none() {
772 let sd_ctx_params = sd_ctx_params_t {
773 model_path: self.model.as_ptr(),
774 llm_path: self.llm.as_ptr(),
775 llm_vision_path: self.llm_vision.as_ptr(),
776 clip_l_path: self.clip_l.as_ptr(),
777 clip_g_path: self.clip_g.as_ptr(),
778 clip_vision_path: self.clip_vision.as_ptr(),
779 high_noise_diffusion_model_path: self.high_noise_diffusion_model.as_ptr(),
780 t5xxl_path: self.t5xxl.as_ptr(),
781 diffusion_model_path: self.diffusion_model.as_ptr(),
782 uncond_diffusion_model_path: self.unconditional_diffusion_model.as_ptr(),
783 vae_path: self.vae.as_ptr(),
784 taesd_path: self.taesd.as_ptr(),
785 control_net_path: self.control_net.as_ptr(),
786 embeddings: self.embeddings.2.as_ptr(),
787 embedding_count: self.embeddings.1.len() as u32,
788 photo_maker_path: self.photo_maker.as_ptr(),
789 n_threads: self.n_threads,
790 wtype: self.weight_type,
791 rng_type: self.rng,
792 diffusion_flash_attn: self.diffusion_flash_attention,
793 flash_attn: self.flash_attention,
794 diffusion_conv_direct: self.diffusion_conv_direct,
795 chroma_use_dit_mask: !self.chroma_disable_dit_mask,
796 chroma_use_t5_mask: self.chroma_enable_t5_mask,
797 chroma_t5_mask_pad: self.chroma_t5_mask_pad,
798 vae_conv_direct: self.vae_conv_direct,
799 prediction: self.prediction,
800 force_sdxl_vae_conv_scale: self.force_sdxl_vae_conv_scale,
801 tae_preview_only: self.taesd_preview_only,
802 lora_apply_mode: self.lora_apply_mode,
803 tensor_type_rules: null_mut(),
804 sampler_rng_type: self.sampler_rng_type,
805 circular_x: self.circular || self.circular_x,
806 circular_y: self.circular || self.circular_y,
807 qwen_image_zero_cond_t: self.use_qwen_image_zero_cond_true,
808 enable_mmap: self.enable_mmap,
809 max_vram: self.max_vram,
810 backend: self.backend.1.as_ptr(),
811 params_backend: self.params_backend.1.as_ptr(),
812 embeddings_connectors_path: self.embeddings_connectors.as_ptr(),
813 audio_vae_path: self.audio_vae.as_ptr(),
814 vae_format: self.vae_format,
815 stream_layers: self.stream_layers,
816 };
817 let ctx = new_sd_ctx(&sd_ctx_params);
818 self.diffusion_ctx = Some((ctx, sd_ctx_params))
819 }
820 self.diffusion_ctx.unwrap().0
821 }
822 }
823}
824
825impl Drop for ModelConfig {
826 fn drop(&mut self) {
827 unsafe {
829 if let Some((sd_ctx, _)) = self.diffusion_ctx {
830 free_sd_ctx(sd_ctx);
831 }
832
833 if let Some(upscaler_ctx) = self.upscaler_ctx {
834 free_upscaler_ctx(upscaler_ctx);
835 }
836 }
837 }
838}
839
840impl From<&ModelConfig> for ModelConfigBuilder {
841 fn from(value: &ModelConfig) -> Self {
842 let mut builder = ModelConfigBuilder::default();
843 let hires_path = value
844 .hires_params
845 .2
846 .clone()
847 .map(|f| PathBuf::from_str(&f.0.into_string().unwrap()).unwrap());
848 builder
849 .n_threads(value.n_threads)
850 .max_vram(value.max_vram)
851 .upscale_repeats(value.upscale_repeats)
852 .model(value.model.clone())
853 .diffusion_model(value.diffusion_model.clone())
854 .unconditional_diffusion_model(value.unconditional_diffusion_model.clone())
855 .llm(value.llm.clone())
856 .llm_vision(value.llm_vision.clone())
857 .clip_l(value.clip_l.clone())
858 .clip_g(value.clip_g.clone())
859 .clip_vision(value.clip_vision.clone())
860 .t5xxl(value.t5xxl.clone())
861 .vae(value.vae.clone())
862 .taesd(value.taesd.clone())
863 .control_net(value.control_net.clone())
864 .embeddings(&value.embeddings.0)
865 .photo_maker(value.photo_maker.clone())
866 .pm_id_embed_path(value.pm_id_embed_path.clone())
867 .weight_type(value.weight_type)
868 .high_noise_diffusion_model(value.high_noise_diffusion_model.clone())
869 .vae_tiling(value.vae_tiling)
870 .vae_tile_size(value.vae_tile_size)
871 .vae_relative_tile_size(value.vae_relative_tile_size)
872 .vae_tile_overlap(value.vae_tile_overlap)
873 .rng(value.rng)
874 .sampler_rng_type(value.rng)
875 .scheduler(value.scheduler)
876 .sigmas(value.sigmas.clone())
877 .prediction(value.prediction)
878 .control_net(value.control_net.clone())
879 .flash_attention(value.flash_attention)
880 .chroma_disable_dit_mask(value.chroma_disable_dit_mask)
881 .chroma_enable_t5_mask(value.chroma_enable_t5_mask)
882 .chroma_t5_mask_pad(value.chroma_t5_mask_pad)
883 .diffusion_conv_direct(value.diffusion_conv_direct)
884 .vae_conv_direct(value.vae_conv_direct)
885 .force_sdxl_vae_conv_scale(value.force_sdxl_vae_conv_scale)
886 .flow_shift(value.flow_shift)
887 .timestep_shift(value.timestep_shift)
888 .taesd_preview_only(value.taesd_preview_only)
889 .lora_apply_mode(value.lora_apply_mode)
890 .circular(value.circular)
891 .circular_x(value.circular_x)
892 .circular_y(value.circular_y)
893 .use_qwen_image_zero_cond_true(value.use_qwen_image_zero_cond_true)
894 .hires_params(
895 value.hires_params.0,
896 value.hires_params.1.clone(),
897 hires_path.as_deref(),
898 )
899 .extra_sample_params(value.extra_sample_params.clone())
900 .backend(value.backend.0.clone().unwrap_or_default())
901 .params_backend(value.params_backend.0.clone().unwrap_or_default())
902 .extra_tiling_args(value.extra_tiling_args.0.clone().unwrap_or_default());
903
904 builder.lora_models_internal(value.lora_models.clone());
905
906 if let Some(model) = &value.upscale_model {
907 builder.upscale_model(model.clone());
908 }
909 builder
910 }
911}
912
913#[derive(Builder, Debug)]
914#[builder(setter(into, strip_option), build_fn(validate = "Self::validate"))]
915pub struct Config {
917 #[builder(default = "Default::default()")]
919 pm_id_images_dir: CLibPath,
920
921 #[builder(default = "Default::default()")]
923 init_img: PathBuf,
924
925 #[builder(default = "Default::default()")]
927 mask_img: PathBuf,
928
929 #[builder(default = "Default::default()")]
931 control_image: CLibPath,
932
933 #[builder(default = "Default::default()")]
935 ref_images: Vec<PathBuf>,
936
937 #[builder(default = "PathBuf::from(\"./output.png\")")]
939 output: PathBuf,
940
941 #[builder(default = "PathBuf::from(\"./preview_output.png\")")]
943 preview_output: PathBuf,
944
945 #[builder(default = "PreviewType::PREVIEW_NONE")]
947 preview_mode: PreviewType,
948
949 #[builder(default = "false")]
951 preview_noisy: bool,
952
953 #[builder(default = "1")]
955 preview_interval: i32,
956
957 prompt: String,
959
960 #[builder(default = "\"\".into()")]
962 negative_prompt: CLibString,
963
964 #[builder(default = "7.0")]
966 cfg_scale: f32,
967
968 #[builder(default = "3.5")]
970 guidance: f32,
971
972 #[builder(default = "0.75")]
974 strength: f32,
975
976 #[builder(default = "20.0")]
978 pm_style_strength: f32,
979
980 #[builder(default = "0.9")]
983 control_strength: f32,
984
985 #[builder(default = "512")]
987 height: i32,
988
989 #[builder(default = "512")]
991 width: i32,
992
993 #[builder(default = "SampleMethod::SAMPLE_METHOD_COUNT")]
997 sampling_method: SampleMethod,
998
999 #[builder(default = "0.")]
1001 eta: f32,
1002
1003 #[builder(default = "20")]
1005 steps: i32,
1006
1007 #[builder(default = "-1")]
1009 seed: i64,
1010
1011 #[builder(default = "1")]
1013 batch_count: i32,
1014
1015 #[builder(default = "ClipSkip::Unspecified")]
1018 clip_skip: ClipSkip,
1019
1020 #[builder(default = "false")]
1022 canny: bool,
1023
1024 #[builder(default = "0.")]
1027 slg_scale: f32,
1028
1029 #[builder(default = "vec![7, 8, 9]")]
1031 skip_layer: Vec<i32>,
1032
1033 #[builder(default = "0.01")]
1035 skip_layer_start: f32,
1036
1037 #[builder(default = "0.2")]
1039 skip_layer_end: f32,
1040
1041 #[builder(default = "false")]
1043 disable_auto_resize_ref_image: bool,
1044
1045 #[builder(default = "Self::cache_init()", private)]
1046 cache: (sd_cache_params_t, Option<CLibString>),
1047}
1048
1049impl ConfigBuilder {
1050 fn validate(&self) -> Result<(), ConfigBuilderError> {
1051 self.validate_output_dir()
1052 }
1053
1054 fn validate_output_dir(&self) -> Result<(), ConfigBuilderError> {
1055 let is_dir = self.output.as_ref().is_some_and(|val| val.is_dir());
1056 let multiple_items = self.batch_count.as_ref().is_some_and(|val| *val > 1);
1057 if is_dir == multiple_items {
1058 Ok(())
1059 } else {
1060 Err(ConfigBuilderError::ValidationError(
1061 "When batch_count > 1, output should point to folder and vice versa".to_owned(),
1062 ))
1063 }
1064 }
1065
1066 fn cache_init() -> (sd_cache_params_t, Option<CLibString>) {
1067 (
1068 sd_cache_params_t {
1069 mode: sd_cache_mode_t::SD_CACHE_DISABLED,
1070 reuse_threshold: 1.0,
1071 start_percent: 0.15,
1072 end_percent: 0.95,
1073 error_decay_rate: 1.0,
1074 use_relative_threshold: true,
1075 reset_error_on_compute: true,
1076 Fn_compute_blocks: 8,
1077 Bn_compute_blocks: 0,
1078 residual_diff_threshold: 0.08,
1079 max_warmup_steps: 8,
1080 max_cached_steps: -1,
1081 max_continuous_cached_steps: -1,
1082 taylorseer_n_derivatives: 1,
1083 taylorseer_skip_interval: 1,
1084 scm_mask: null(),
1085 scm_policy_dynamic: true,
1086 spectrum_w: 0.4,
1087 spectrum_m: 3,
1088 spectrum_lam: 1.0,
1089 spectrum_window_size: 2,
1090 spectrum_flex_window: 0.5,
1091 spectrum_warmup_steps: 4,
1092 spectrum_stop_percent: 0.9,
1093 },
1094 None,
1095 )
1096 }
1097
1098 pub fn no_caching(&mut self) -> &mut Self {
1099 let mut cache = Self::cache_init();
1100 cache.0.mode = sd_cache_mode_t::SD_CACHE_DISABLED;
1101 self.cache = Some(cache);
1102 self
1103 }
1104
1105 pub fn spectrum_caching(&mut self, params: SpectrumCacheParams) -> &mut Self {
1106 let (mut cache, mask) = Self::cache_init();
1107 cache.mode = sd_cache_mode_t::SD_CACHE_SPECTRUM;
1108 cache.spectrum_w = params.w;
1109 cache.spectrum_m = params.m;
1110 cache.spectrum_lam = params.lam;
1111 cache.spectrum_window_size = params.window;
1112 cache.spectrum_flex_window = params.flex;
1113 cache.spectrum_warmup_steps = params.warmup;
1114 cache.spectrum_stop_percent = params.stop;
1115 self.cache = Some((cache, mask));
1116 self
1117 }
1118
1119 pub fn ucache_caching(&mut self, params: UCacheParams) -> &mut Self {
1120 let (mut cache, mask) = Self::cache_init();
1121 cache.mode = sd_cache_mode_t::SD_CACHE_UCACHE;
1122 cache.reuse_threshold = params.threshold;
1123 cache.start_percent = params.start;
1124 cache.end_percent = params.end;
1125 cache.error_decay_rate = params.decay;
1126 cache.use_relative_threshold = params.relative;
1127 cache.reset_error_on_compute = params.reset;
1128 self.cache = Some((cache, mask));
1129 self
1130 }
1131
1132 pub fn easy_cache_caching(&mut self, params: EasyCacheParams) -> &mut Self {
1133 let (mut cache, mask) = Self::cache_init();
1134 cache.mode = sd_cache_mode_t::SD_CACHE_EASYCACHE;
1135 cache.reuse_threshold = params.threshold;
1136 cache.start_percent = params.start;
1137 cache.end_percent = params.end;
1138 self.cache = Some((cache, mask));
1139 self
1140 }
1141
1142 pub fn db_cache_caching(&mut self, params: DbCacheParams) -> &mut Self {
1143 let (mut cache, _) = Self::cache_init();
1144 cache.mode = sd_cache_mode_t::SD_CACHE_DBCACHE;
1145 cache.Fn_compute_blocks = params.fn_blocks;
1146 cache.Bn_compute_blocks = params.bn_blocks;
1147 cache.residual_diff_threshold = params.threshold;
1148 cache.max_warmup_steps = params.warmup;
1149 cache.scm_policy_dynamic = match params.scm_policy_dynamic {
1150 ScmPolicy::Static => false,
1151 ScmPolicy::Dynamic => true,
1152 };
1153 self.cache = Some((cache, Some(params.scm_mask)));
1154 self
1155 }
1156
1157 pub fn taylor_seer_caching(&mut self) -> &mut Self {
1158 let (mut cache, mask) = Self::cache_init();
1159 cache.mode = sd_cache_mode_t::SD_CACHE_TAYLORSEER;
1160 self.cache = Some((cache, mask));
1161 self
1162 }
1163
1164 pub fn cache_dit_caching(&mut self, params: DbCacheParams) -> &mut Self {
1165 self.db_cache_caching(params).cache.as_mut().unwrap().0.mode =
1166 sd_cache_mode_t::SD_CACHE_CACHE_DIT;
1167 self
1168 }
1169}
1170
1171impl From<&Config> for ConfigBuilder {
1172 fn from(value: &Config) -> Self {
1173 let mut builder = ConfigBuilder::default();
1174 builder
1175 .pm_id_images_dir(value.pm_id_images_dir.clone())
1176 .init_img(value.init_img.clone())
1177 .mask_img(value.mask_img.clone())
1178 .control_image(value.control_image.clone())
1179 .ref_images(value.ref_images.clone())
1180 .output(value.output.clone())
1181 .prompt(value.prompt.clone())
1182 .negative_prompt(value.negative_prompt.clone())
1183 .cfg_scale(value.cfg_scale)
1184 .strength(value.strength)
1185 .pm_style_strength(value.pm_style_strength)
1186 .control_strength(value.control_strength)
1187 .height(value.height)
1188 .width(value.width)
1189 .sampling_method(value.sampling_method)
1190 .steps(value.steps)
1191 .seed(value.seed)
1192 .batch_count(value.batch_count)
1193 .clip_skip(value.clip_skip)
1194 .slg_scale(value.slg_scale)
1195 .skip_layer(value.skip_layer.clone())
1196 .skip_layer_start(value.skip_layer_start)
1197 .skip_layer_end(value.skip_layer_end)
1198 .canny(value.canny)
1199 .disable_auto_resize_ref_image(value.disable_auto_resize_ref_image)
1200 .preview_output(value.preview_output.clone())
1201 .preview_mode(value.preview_mode)
1202 .preview_noisy(value.preview_noisy)
1203 .preview_interval(value.preview_interval)
1204 .cache(value.cache.clone());
1205 builder
1206 }
1207}
1208
1209#[derive(Debug, Clone, Default)]
1210struct CLibString(CString);
1211
1212impl CLibString {
1213 fn as_ptr(&self) -> *const c_char {
1214 self.0.as_ptr()
1215 }
1216}
1217
1218impl From<&str> for CLibString {
1219 fn from(value: &str) -> Self {
1220 Self(CString::new(value).unwrap())
1221 }
1222}
1223
1224impl From<String> for CLibString {
1225 fn from(value: String) -> Self {
1226 Self(CString::new(value).unwrap())
1227 }
1228}
1229
1230#[derive(Debug, Clone, Default)]
1231struct CLibPath(CString);
1232
1233impl CLibPath {
1234 fn as_ptr(&self) -> *const c_char {
1235 self.0.as_ptr()
1236 }
1237}
1238
1239impl From<PathBuf> for CLibPath {
1240 fn from(value: PathBuf) -> Self {
1241 Self(CString::new(value.to_str().unwrap_or_default()).unwrap())
1242 }
1243}
1244
1245impl From<&Path> for CLibPath {
1246 fn from(value: &Path) -> Self {
1247 Self(CString::new(value.to_str().unwrap_or_default()).unwrap())
1248 }
1249}
1250
1251impl From<&CLibPath> for PathBuf {
1252 fn from(value: &CLibPath) -> Self {
1253 PathBuf::from(value.0.to_str().unwrap())
1254 }
1255}
1256
1257fn output_files(path: &Path, batch_size: i32) -> Vec<PathBuf> {
1258 let date = Local::now().format("%Y.%m.%d-%H.%M.%S");
1259 if batch_size == 1 {
1260 vec![path.into()]
1261 } else {
1262 (1..=batch_size)
1263 .map(|id| path.join(format!("output_{date}_{id}.png")))
1264 .collect()
1265 }
1266}
1267
1268unsafe fn upscale(
1269 upscale_repeats: i32,
1270 upscaler_ctx: Option<*mut upscaler_ctx_t>,
1271 data: sd_image_t,
1272) -> Result<sd_image_t, DiffusionError> {
1273 unsafe {
1274 match upscaler_ctx {
1275 Some(upscaler_ctx) => {
1276 let upscale_factor = 4; let mut current_image = data;
1278 for _ in 0..upscale_repeats {
1279 let upscaled_image =
1280 diffusion_rs_sys::upscale(upscaler_ctx, current_image, upscale_factor);
1281
1282 if upscaled_image.data.is_null() {
1283 return Err(DiffusionError::Upscaler);
1284 }
1285
1286 free(current_image.data as *mut c_void);
1287 current_image = upscaled_image;
1288 }
1289 Ok(current_image)
1290 }
1291 None => Ok(data),
1292 }
1293 }
1294}
1295
1296pub fn gen_img_with_progress(
1298 config: &Config,
1299 model_config: &mut ModelConfig,
1300 sender: Sender<Progress>,
1301) -> Result<(), DiffusionError> {
1302 gen_img_maybe_progress(config, model_config, Some(sender))
1303}
1304
1305pub fn gen_img(config: &Config, model_config: &mut ModelConfig) -> Result<(), DiffusionError> {
1307 gen_img_maybe_progress(config, model_config, None)
1308}
1309
1310fn gen_img_maybe_progress(
1311 config: &Config,
1312 model_config: &mut ModelConfig,
1313 mut sender: Option<Sender<Progress>>,
1314) -> Result<(), DiffusionError> {
1315 let prompt: CLibString = CLibString::from(config.prompt.as_str());
1316 let files = output_files(&config.output, config.batch_count);
1317 unsafe {
1318 let has_init_image = config.init_img.exists();
1319 let has_mask_image = config.mask_img.exists();
1320
1321 let sd_ctx = model_config.diffusion_ctx();
1322 let upscaler_ctx = model_config.upscaler_ctx();
1323
1324 let mut init_image = sd_image_t {
1325 width: 0,
1326 height: 0,
1327 channel: 3,
1328 data: std::ptr::null_mut(),
1329 };
1330 let mut mask_image = sd_image_t {
1331 width: config.width as u32,
1332 height: config.height as u32,
1333 channel: 1,
1334 data: null_mut(),
1335 };
1336 let mut layers = config.skip_layer.clone();
1337 let guidance = sd_guidance_params_t {
1338 txt_cfg: config.cfg_scale,
1339 img_cfg: config.cfg_scale,
1340 distilled_guidance: config.guidance,
1341 slg: sd_slg_params_t {
1342 layers: layers.as_mut_ptr(),
1343 layer_count: config.skip_layer.len(),
1344 layer_start: config.skip_layer_start,
1345 layer_end: config.skip_layer_end,
1346 scale: config.slg_scale,
1347 },
1348 };
1349 let scheduler = if model_config.scheduler == Scheduler::SCHEDULER_COUNT {
1350 sd_get_default_scheduler(sd_ctx, config.sampling_method)
1351 } else {
1352 model_config.scheduler
1353 };
1354 let sample_method = if config.sampling_method == SampleMethod::SAMPLE_METHOD_COUNT {
1355 sd_get_default_sample_method(sd_ctx)
1356 } else {
1357 config.sampling_method
1358 };
1359 let sample_params = sd_sample_params_t {
1360 guidance,
1361 sample_method,
1362 sample_steps: config.steps,
1363 eta: config.eta,
1364 scheduler,
1365 shifted_timestep: model_config.timestep_shift,
1366 custom_sigmas: model_config.sigmas.as_mut_ptr(),
1367 custom_sigmas_count: model_config.sigmas.len() as i32,
1368 flow_shift: model_config.flow_shift,
1369 extra_sample_args: model_config.extra_sample_params.as_ptr(),
1370 };
1371 let control_image = sd_image_t {
1372 width: 0,
1373 height: 0,
1374 channel: 3,
1375 data: null_mut(),
1376 };
1377 let vae_tiling_params = sd_tiling_params_t {
1378 enabled: model_config.vae_tiling,
1379 tile_size_x: model_config.vae_tile_size.0,
1380 tile_size_y: model_config.vae_tile_size.1,
1381 target_overlap: model_config.vae_tile_overlap,
1382 rel_size_x: model_config.vae_relative_tile_size.0,
1383 rel_size_y: model_config.vae_relative_tile_size.1,
1384 temporal_tiling: model_config.vae_temporal_tiling,
1385 extra_tiling_args: model_config.extra_tiling_args.1.as_ptr(),
1386 };
1387 let pm_params = sd_pm_params_t {
1388 id_images: null_mut(),
1389 id_images_count: 0,
1390 id_embed_path: model_config.pm_id_embed_path.as_ptr(),
1391 style_strength: config.pm_style_strength,
1392 };
1393
1394 let mut image_buffer: Vec<u8> = Vec::new();
1396 let mut mask_buffer: Vec<u8> = Vec::new();
1397
1398 if has_init_image {
1399 let img = image::open(&config.init_img)?;
1400 image_buffer = img.to_rgb8().into_raw();
1401
1402 init_image = sd_image_t {
1403 width: img.width(),
1404 height: img.height(),
1405 channel: 3,
1406 data: image_buffer.as_mut_ptr(),
1407 }
1408 }
1409
1410 if has_mask_image {
1411 let img = image::open(&config.mask_img)?;
1412 mask_buffer = img.to_luma8().into_raw();
1414
1415 mask_image = sd_image_t {
1416 width: img.width(),
1417 height: img.height(),
1418 channel: 1,
1419 data: mask_buffer.as_mut_ptr(),
1420 }
1421 }
1422
1423 if !image_buffer.is_empty() && mask_buffer.is_empty() {
1427 let img: ImageBuffer<image::Luma<u8>, Vec<u8>> =
1428 ImageBuffer::from_pixel(init_image.width, init_image.height, image::Luma([255]));
1429 mask_buffer = img.into_raw();
1430 mask_image = sd_image_t {
1431 width: init_image.width,
1432 height: init_image.height,
1433 channel: 1,
1434 data: mask_buffer.as_mut_ptr(),
1435 }
1436 }
1437
1438 let mut ref_image_list = Vec::new();
1439 let mut ref_pixel_storage = Vec::new();
1440 for ref_path in &config.ref_images {
1441 if ref_path.exists() {
1442 let img = image::open(ref_path)?;
1443 let image_data = img.to_rgb8().into_raw();
1444
1445 ref_pixel_storage.push(image_data);
1446 let storage_ref = ref_pixel_storage.last_mut().unwrap();
1447 ref_image_list.push(sd_image_t {
1448 width: img.width(),
1449 height: img.height(),
1450 channel: 3,
1451 data: storage_ref.as_mut_ptr(),
1452 });
1453 }
1454 }
1455
1456 let num_ref_images = ref_image_list.len();
1457 let ref_image_ptr = if num_ref_images > 0 {
1458 ref_image_list.as_mut_ptr()
1459 } else {
1460 null_mut()
1461 };
1462
1463 unsafe extern "C" fn save_preview_local(
1464 _step: ::std::os::raw::c_int,
1465 _frame_count: ::std::os::raw::c_int,
1466 frames: *mut sd_image_t,
1467 _is_noisy: bool,
1468 data: *mut ::std::os::raw::c_void,
1469 ) {
1470 unsafe {
1471 let path = &*data.cast::<PathBuf>();
1472 let _ = save_img(*frames, path, None);
1473 }
1474 }
1475
1476 if config.preview_mode != PreviewType::PREVIEW_NONE {
1477 let data = &config.preview_output as *const PathBuf;
1478
1479 sd_set_preview_callback(
1480 Some(save_preview_local),
1481 config.preview_mode,
1482 config.preview_interval,
1483 !config.preview_noisy,
1484 config.preview_noisy,
1485 data as *mut c_void,
1486 );
1487 }
1488
1489 if sender.is_some() {
1490 unsafe extern "C" fn progress_callback(
1491 step: ::std::os::raw::c_int,
1492 steps: ::std::os::raw::c_int,
1493 time: f32,
1494 data: *mut ::std::os::raw::c_void,
1495 ) {
1496 unsafe {
1497 let sender = &*data.cast::<Option<Sender<Progress>>>();
1498
1499 if let Some(sender) = sender {
1500 let _ = sender.send(Progress { step, steps, time });
1501 }
1502 }
1503 }
1504 let sender_ptr: *mut c_void = &mut sender as *mut _ as *mut c_void;
1505 sd_set_progress_callback(Some(progress_callback), sender_ptr);
1506 }
1507
1508 let loras: Vec<sd_lora_t> = model_config
1509 .lora_models
1510 .iter()
1511 .map(|(c_path, spec)| sd_lora_t {
1512 is_high_noise: spec.is_high_noise,
1513 multiplier: spec.multiplier,
1514 path: c_path.as_ptr(),
1515 })
1516 .collect();
1517
1518 let mut cache = config.cache.0;
1519 if let Some(scm_mask) = &config.cache.1 {
1520 cache.scm_mask = scm_mask.as_ptr();
1521 }
1522
1523 let mut hires_path = null();
1524 let mut hires_sigmas = null_mut();
1525 let mut hires_sigmas_count = 0;
1526 if let Some(path) = &model_config.hires_params.2 {
1527 hires_path = path.as_ptr();
1528 }
1529 if let Some(sigmas) = &mut model_config.hires_params.1.hires_sigmas {
1530 hires_sigmas = sigmas.as_mut_ptr();
1531 hires_sigmas_count = sigmas.len() as i32;
1532 }
1533
1534 let hires = sd_hires_params_t {
1535 enabled: model_config.hires_params.0 != Upscaler::SD_HIRES_UPSCALER_NONE,
1536 upscaler: model_config.hires_params.0,
1537 model_path: hires_path,
1538 scale: model_config.hires_params.1.scale,
1539 target_width: model_config.hires_params.1.width,
1540 target_height: model_config.hires_params.1.height,
1541 steps: model_config.hires_params.1.steps,
1542 denoising_strength: model_config.hires_params.1.denoising_strength,
1543 upscale_tile_size: model_config.hires_params.1.upscale_tile_size,
1544 custom_sigmas: hires_sigmas,
1545 custom_sigmas_count: hires_sigmas_count,
1546 };
1547
1548 let sd_img_gen_params = sd_img_gen_params_t {
1549 prompt: prompt.as_ptr(),
1550 negative_prompt: config.negative_prompt.as_ptr(),
1551 clip_skip: config.clip_skip as i32,
1552 init_image,
1553 ref_images: ref_image_ptr,
1554 ref_images_count: num_ref_images as i32,
1555 increase_ref_index: false,
1556 mask_image,
1557 width: config.width,
1558 height: config.height,
1559 sample_params,
1560 strength: config.strength,
1561 seed: config.seed,
1562 batch_count: config.batch_count,
1563 control_image,
1564 control_strength: config.control_strength,
1565 pm_params,
1566 vae_tiling_params,
1567 auto_resize_ref_image: config.disable_auto_resize_ref_image,
1568 cache,
1569 loras: loras.as_ptr(),
1570 lora_count: loras.len() as u32,
1571 hires,
1572 };
1573
1574 let params_str = CString::from_raw(sd_img_gen_params_to_str(&sd_img_gen_params))
1575 .into_string()
1576 .unwrap();
1577
1578 let slice = generate_image(sd_ctx, &sd_img_gen_params);
1579 let ret = {
1580 if slice.is_null() {
1581 return Err(DiffusionError::Forward);
1582 }
1583 for (img, path) in slice::from_raw_parts(slice, config.batch_count as usize)
1584 .iter()
1585 .zip(files)
1586 {
1587 if img.data.is_null() {
1590 return Err(DiffusionError::Forward);
1591 }
1592 match upscale(model_config.upscale_repeats, upscaler_ctx, *img) {
1593 Ok(img) => save_img(img, &path, Some(¶ms_str))?,
1594 Err(err) => {
1595 return Err(err);
1596 }
1597 }
1598 }
1599 Ok(())
1600 };
1601 free(slice as *mut c_void);
1602 ret
1603 }
1604}
1605
1606fn save_img(img: sd_image_t, path: &Path, params: Option<&str>) -> Result<(), DiffusionError> {
1607 let len = (img.width * img.height * img.channel) as usize;
1609 let buffer = unsafe { slice::from_raw_parts(img.data, len).to_vec() };
1610 let save_state = ImageBuffer::from_raw(img.width, img.height, buffer).map(|img| {
1611 RgbImage::from(img)
1612 .save(path)
1613 .map_err(DiffusionError::StoreImages)
1614 });
1615 if let Some(Err(err)) = save_state {
1616 return Err(err);
1617 }
1618 if let Some(params) = params {
1619 let mut metadata = Metadata::new();
1620 metadata.set_tag(ExifTag::ImageDescription(params.to_string()));
1621 metadata.write_to_file(path)?;
1622 }
1623 Ok(())
1624}
1625
1626#[cfg(test)]
1627mod tests {
1628 use image::{DynamicImage, ImageBuffer, Rgba};
1629 use std::path::PathBuf;
1630
1631 use crate::{
1632 api::{ConfigBuilderError, ModelConfigBuilder},
1633 util::download_file_hf_hub,
1634 };
1635
1636 use super::{ConfigBuilder, gen_img};
1637
1638 #[test]
1639 fn test_required_args_txt2img() {
1640 assert!(ConfigBuilder::default().build().is_err());
1641 assert!(ModelConfigBuilder::default().build().is_err());
1642 ModelConfigBuilder::default()
1643 .model(PathBuf::from("./test.ckpt"))
1644 .build()
1645 .unwrap();
1646
1647 ConfigBuilder::default()
1648 .prompt("a lovely cat driving a sport car")
1649 .build()
1650 .unwrap();
1651
1652 assert!(matches!(
1653 ConfigBuilder::default()
1654 .prompt("a lovely cat driving a sport car")
1655 .batch_count(10)
1656 .build(),
1657 Err(ConfigBuilderError::ValidationError(_))
1658 ));
1659
1660 ConfigBuilder::default()
1661 .prompt("a lovely cat driving a sport car")
1662 .build()
1663 .unwrap();
1664
1665 ConfigBuilder::default()
1666 .prompt("a lovely duck drinking water from a bottle")
1667 .batch_count(2)
1668 .output(PathBuf::from("./"))
1669 .build()
1670 .unwrap();
1671 }
1672
1673 #[ignore]
1674 #[test]
1675 fn test_img2img_gen() {
1676 let model_path =
1677 download_file_hf_hub("CompVis/stable-diffusion-v-1-4-original", "sd-v1-4.ckpt")
1678 .unwrap();
1679 let gen_img_output = "./output_img.png";
1680 let config = ConfigBuilder::default()
1681 .prompt("A high quality 3d texture")
1682 .output(PathBuf::from(gen_img_output))
1683 .batch_count(1)
1684 .build()
1685 .unwrap();
1686
1687 let mut model_config = ModelConfigBuilder::default()
1688 .model(model_path)
1689 .build()
1690 .unwrap();
1691
1692 gen_img(&config, &mut model_config).unwrap();
1693
1694 let mut cond = ImageBuffer::new(512, 512);
1696 for (x, y, pixel) in cond.enumerate_pixels_mut() {
1697 let r = (x as f32 / 512.0 * 255.0) as u8;
1698 let g = (y as f32 / 512.0 * 255.0) as u8;
1699 let b = 127;
1700 *pixel = Rgba([r, g, b, 255]);
1701 }
1702 let cond_path = "test_cond_image.png";
1703 DynamicImage::ImageRgba8(cond)
1704 .save(cond_path)
1705 .expect("Failed to save reference image");
1706
1707 let refine_prompt = "PBR texture map, matching the lighting and micro-detail density of the reference image.";
1709 let img2img_config = ConfigBuilder::default()
1710 .prompt(refine_prompt)
1711 .output(PathBuf::from("./output_img_ref.png"))
1712 .ref_images(vec![PathBuf::from(cond_path)])
1713 .init_img(PathBuf::from(gen_img_output))
1714 .batch_count(1)
1715 .build()
1716 .unwrap();
1717 gen_img(&img2img_config, &mut model_config).unwrap();
1718
1719 gen_img(&config, &mut model_config).unwrap();
1721 }
1722
1723 #[ignore]
1724 #[test]
1725 fn test_img_gen() {
1726 let model_path =
1727 download_file_hf_hub("CompVis/stable-diffusion-v-1-4-original", "sd-v1-4.ckpt")
1728 .unwrap();
1729
1730 let upscaler_path = download_file_hf_hub(
1731 "ximso/RealESRGAN_x4plus_anime_6B",
1732 "RealESRGAN_x4plus_anime_6B.pth",
1733 )
1734 .unwrap();
1735 let config = ConfigBuilder::default()
1736 .prompt("a lovely duck drinking water from a bottle")
1737 .output(PathBuf::from("./output_1.png"))
1738 .batch_count(1)
1739 .build()
1740 .unwrap();
1741 let mut model_config = ModelConfigBuilder::default()
1742 .model(model_path)
1743 .upscale_model(upscaler_path)
1744 .upscale_repeats(1)
1745 .build()
1746 .unwrap();
1747
1748 gen_img(&config, &mut model_config).unwrap();
1749 let config2 = ConfigBuilder::from(&config)
1750 .prompt("a lovely duck drinking water from a straw")
1751 .output(PathBuf::from("./output_2.png"))
1752 .build()
1753 .unwrap();
1754 gen_img(&config2, &mut model_config).unwrap();
1755
1756 let config3 = ConfigBuilder::from(&config)
1757 .prompt("a lovely dog drinking water from a starbucks cup")
1758 .batch_count(2)
1759 .output(PathBuf::from("./"))
1760 .build()
1761 .unwrap();
1762
1763 gen_img(&config3, &mut model_config).unwrap();
1764 }
1765}