1use std::cmp::max;
2use std::collections::HashMap;
3use std::ffi::CString;
4use std::ffi::c_char;
5use std::ffi::c_void;
6use std::fmt::Display;
7use std::path::Path;
8use std::path::PathBuf;
9use std::ptr::null;
10use std::ptr::null_mut;
11use std::slice;
12use std::sync::mpsc::Sender;
13
14use chrono::Local;
15use derive_builder::Builder;
16use diffusion_rs_sys::free_upscaler_ctx;
17use diffusion_rs_sys::generate_image;
18use diffusion_rs_sys::new_upscaler_ctx;
19use diffusion_rs_sys::sd_cache_mode_t;
20use diffusion_rs_sys::sd_cache_params_t;
21use diffusion_rs_sys::sd_ctx_params_t;
22use diffusion_rs_sys::sd_embedding_t;
23use diffusion_rs_sys::sd_get_default_sample_method;
24use diffusion_rs_sys::sd_get_default_scheduler;
25use diffusion_rs_sys::sd_guidance_params_t;
26use diffusion_rs_sys::sd_image_t;
27use diffusion_rs_sys::sd_img_gen_params_t;
28use diffusion_rs_sys::sd_img_gen_params_to_str;
29use diffusion_rs_sys::sd_lora_t;
30use diffusion_rs_sys::sd_pm_params_t;
31use diffusion_rs_sys::sd_sample_params_t;
32use diffusion_rs_sys::sd_set_preview_callback;
33use diffusion_rs_sys::sd_set_progress_callback;
34use diffusion_rs_sys::sd_slg_params_t;
35use diffusion_rs_sys::sd_tiling_params_t;
36use diffusion_rs_sys::upscaler_ctx_t;
37use image::ImageBuffer;
38use image::ImageError;
39use image::RgbImage;
40use libc::free;
41use little_exif::exif_tag::ExifTag;
42use little_exif::metadata::Metadata;
43use thiserror::Error;
44use walkdir::DirEntry;
45use walkdir::WalkDir;
46
47use diffusion_rs_sys::free_sd_ctx;
48use diffusion_rs_sys::new_sd_ctx;
49use diffusion_rs_sys::sd_ctx_t;
50
51pub use diffusion_rs_sys::rng_type_t as RngFunction;
53
54pub use diffusion_rs_sys::sample_method_t as SampleMethod;
56
57pub use diffusion_rs_sys::scheduler_t as Scheduler;
59
60pub use diffusion_rs_sys::prediction_t as Prediction;
62
63pub use diffusion_rs_sys::sd_type_t as WeightType;
65
66pub use diffusion_rs_sys::preview_t as PreviewType;
68
69pub use diffusion_rs_sys::lora_apply_mode_t as LoraModeType;
71
72static VALID_EXT: [&str; 3] = ["gguf", "safetensors", "pt"];
73
74#[allow(unused)]
75#[derive(Debug)]
76pub struct Progress {
78 step: i32,
79 steps: i32,
80 time: f32,
81}
82
83#[non_exhaustive]
84#[derive(Error, Debug)]
85pub enum DiffusionError {
87 #[error("The underling stablediffusion.cpp function returned NULL")]
88 Forward,
89 #[error(transparent)]
90 StoreImages(#[from] ImageError),
91 #[error(transparent)]
92 Io(#[from] std::io::Error),
93 #[error("The underling upscaler model returned a NULL image")]
94 Upscaler,
95}
96
97#[repr(i32)]
98#[non_exhaustive]
99#[derive(Debug, Default, Copy, Clone, Hash, PartialEq, Eq)]
100pub enum ClipSkip {
102 #[default]
104 Unspecified = 0,
105 None = 1,
106 OneLayer = 2,
107}
108
109type EmbeddingsStorage = (PathBuf, Vec<(CLibString, CLibPath)>, Vec<sd_embedding_t>);
110type LoraStorage = Vec<(CLibPath, LoraSpec)>;
111
112#[derive(Default, Debug, Clone)]
114pub struct LoraSpec {
115 pub file_name: String,
116 pub is_high_noise: bool,
117 pub multiplier: f32,
118}
119
120#[derive(Builder, Debug, Clone)]
122pub struct UCacheParams {
123 #[builder(default = "1.0")]
125 threshold: f32,
126
127 #[builder(default = "0.15")]
129 start: f32,
130
131 #[builder(default = "0.95")]
133 end: f32,
134
135 #[builder(default = "1.0")]
137 decay: f32,
138
139 #[builder(default = "true")]
141 relative: bool,
142
143 #[builder(default = "true")]
147 reset: bool,
148}
149
150#[derive(Builder, Debug, Clone)]
152pub struct EasyCacheParams {
153 #[builder(default = "0.2")]
155 threshold: f32,
156
157 #[builder(default = "0.15")]
159 start: f32,
160
161 #[builder(default = "0.95")]
163 end: f32,
164}
165
166#[derive(Builder, Debug, Clone)]
168pub struct DbCacheParams {
169 #[builder(default = "8")]
171 fn_blocks: i32,
172
173 #[builder(default = "0")]
175 bn_blocks: i32,
176
177 #[builder(default = "0.08")]
179 threshold: f32,
180
181 #[builder(default = "8")]
183 warmup: i32,
184
185 scm_mask: ScmPreset,
187
188 #[builder(default = "ScmPolicy::default()")]
190 scm_policy_dynamic: ScmPolicy,
191}
192
193#[derive(Debug, Default, Clone)]
195pub enum ScmPolicy {
196 Static,
198 #[default]
199 Dynamic,
201}
202
203#[derive(Debug, Default, Clone)]
205pub enum ScmPreset {
206 Slow,
207 #[default]
208 Medium,
209 Fast,
210 Ultra,
211 Custom(String),
214}
215
216impl ScmPreset {
217 fn to_vec_string(&self, steps: i32) -> String {
218 match self {
219 ScmPreset::Slow => ScmPresetBins {
220 compute_bins: vec![8, 3, 3, 2, 1, 1],
221 cache_bins: vec![1, 2, 2, 2, 3],
222 steps,
223 }
224 .to_string(),
225 ScmPreset::Medium => ScmPresetBins {
226 compute_bins: vec![6, 2, 2, 2, 2, 1],
227 cache_bins: vec![1, 3, 3, 3, 3],
228 steps,
229 }
230 .to_string(),
231 ScmPreset::Fast => ScmPresetBins {
232 compute_bins: vec![6, 1, 1, 1, 1, 1],
233 cache_bins: vec![1, 3, 4, 5, 4],
234 steps,
235 }
236 .to_string(),
237 ScmPreset::Ultra => ScmPresetBins {
238 compute_bins: vec![4, 1, 1, 1, 1],
239 cache_bins: vec![2, 5, 6, 7],
240 steps,
241 }
242 .to_string(),
243 ScmPreset::Custom(s) => s.clone(),
244 }
245 }
246}
247
248#[derive(Debug, Clone)]
249struct ScmPresetBins {
250 compute_bins: Vec<i32>,
251 cache_bins: Vec<i32>,
252 steps: i32,
253}
254
255impl ScmPresetBins {
256 fn maybe_scale(&self) -> ScmPresetBins {
257 if self.steps == 28 || self.steps <= 0 {
258 return self.clone();
259 }
260 self.scale()
261 }
262
263 fn scale(&self) -> ScmPresetBins {
264 let scale = self.steps as f32 / 28.0;
265 let scaled_compute_bins = self
266 .compute_bins
267 .iter()
268 .map(|b| max(1, (*b as f32 * scale * 0.5) as i32))
269 .collect();
270 let scaled_cached_bins = self
271 .cache_bins
272 .iter()
273 .map(|b| max(1, (*b as f32 * scale * 0.5) as i32))
274 .collect();
275 ScmPresetBins {
276 compute_bins: scaled_compute_bins,
277 cache_bins: scaled_cached_bins,
278 steps: self.steps,
279 }
280 }
281
282 fn generate_vec_mask(&self) -> Vec<i32> {
283 let mut mask = Vec::new();
284 let mut c_idx = 0;
285 let mut cache_idx = 0;
286
287 while mask.len() < self.steps as usize {
288 if c_idx < self.compute_bins.len() {
289 let compute_count = self.compute_bins[c_idx];
290 for _ in 0..compute_count {
291 if mask.len() < self.steps as usize {
292 mask.push(1);
293 }
294 }
295 c_idx += 1;
296 }
297 if cache_idx < self.cache_bins.len() {
298 let cache_count = self.cache_bins[c_idx];
299 for _ in 0..cache_count {
300 if mask.len() < self.steps as usize {
301 mask.push(0);
302 }
303 }
304 cache_idx += 1;
305 }
306 if c_idx >= self.compute_bins.len() && cache_idx >= self.cache_bins.len() {
307 break;
308 }
309 }
310 if let Some(last) = mask.last_mut() {
311 *last = 1;
312 }
313 mask
314 }
315}
316
317impl Display for ScmPresetBins {
318 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
319 let mask: String = self
320 .maybe_scale()
321 .generate_vec_mask()
322 .iter()
323 .map(|x| x.to_string())
324 .collect::<Vec<_>>()
325 .join(",");
326 write!(f, "{mask}")
327 }
328}
329
330#[derive(Builder, Debug, Clone)]
332#[builder(
333 setter(into, strip_option),
334 build_fn(error = "ConfigBuilderError", validate = "Self::validate")
335)]
336pub struct ModelConfig {
337 #[builder(default = "num_cpus::get_physical() as i32", setter(custom))]
340 n_threads: i32,
341
342 #[builder(default = "false")]
344 enable_mmap: bool,
345
346 #[builder(default = "false")]
348 offload_params_to_cpu: bool,
349
350 #[builder(default = "Default::default()")]
352 upscale_model: Option<CLibPath>,
353
354 #[builder(default = "1")]
356 upscale_repeats: i32,
357
358 #[builder(default = "128")]
360 upscale_tile_size: i32,
361
362 #[builder(default = "Default::default()")]
364 model: CLibPath,
365
366 #[builder(default = "Default::default()")]
368 diffusion_model: CLibPath,
369
370 #[builder(default = "Default::default()")]
372 llm: CLibPath,
373
374 #[builder(default = "Default::default()")]
376 llm_vision: CLibPath,
377
378 #[builder(default = "Default::default()")]
380 clip_l: CLibPath,
381
382 #[builder(default = "Default::default()")]
384 clip_g: CLibPath,
385
386 #[builder(default = "Default::default()")]
388 clip_vision: CLibPath,
389
390 #[builder(default = "Default::default()")]
392 t5xxl: CLibPath,
393
394 #[builder(default = "Default::default()")]
396 vae: CLibPath,
397
398 #[builder(default = "Default::default()")]
400 taesd: CLibPath,
401
402 #[builder(default = "Default::default()")]
404 control_net: CLibPath,
405
406 #[builder(default = "Default::default()", setter(custom))]
408 embeddings: EmbeddingsStorage,
409
410 #[builder(default = "Default::default()")]
412 photo_maker: CLibPath,
413
414 #[builder(default = "Default::default()")]
416 pm_id_embed_path: CLibPath,
417
418 #[builder(default = "WeightType::SD_TYPE_COUNT")]
420 weight_type: WeightType,
421
422 #[builder(default = "Default::default()", setter(custom))]
424 lora_models: LoraStorage,
425
426 #[builder(default = "Default::default()")]
428 high_noise_diffusion_model: CLibPath,
429
430 #[builder(default = "false")]
432 vae_tiling: bool,
433
434 #[builder(default = "(32,32)")]
436 vae_tile_size: (i32, i32),
437
438 #[builder(default = "(0.,0.)")]
440 vae_relative_tile_size: (f32, f32),
441
442 #[builder(default = "0.5")]
444 vae_tile_overlap: f32,
445
446 #[builder(default = "RngFunction::CUDA_RNG")]
448 rng: RngFunction,
449
450 #[builder(default = "RngFunction::RNG_TYPE_COUNT")]
452 sampler_rng_type: RngFunction,
453
454 #[builder(default = "Scheduler::SCHEDULER_COUNT")]
458 scheduler: Scheduler,
459
460 #[builder(default = "Default::default()")]
462 sigmas: Vec<f32>,
463
464 #[builder(default = "Prediction::PREDICTION_COUNT")]
466 prediction: Prediction,
467
468 #[builder(default = "false")]
470 vae_on_cpu: bool,
471
472 #[builder(default = "false")]
474 clip_on_cpu: bool,
475
476 #[builder(default = "false")]
478 control_net_cpu: bool,
479
480 #[builder(default = "false")]
483 diffusion_flash_attention: bool,
484
485 #[builder(default = "false")]
488 flash_attention: bool,
489
490 #[builder(default = "false")]
492 chroma_disable_dit_mask: bool,
493
494 #[builder(default = "false")]
496 chroma_enable_t5_mask: bool,
497
498 #[builder(default = "1")]
500 chroma_t5_mask_pad: i32,
501
502 #[builder(default = "false")]
504 use_qwen_image_zero_cond_true: bool,
505
506 #[builder(default = "false")]
509 diffusion_conv_direct: bool,
510
511 #[builder(default = "false")]
514 vae_conv_direct: bool,
515
516 #[builder(default = "false")]
518 force_sdxl_vae_conv_scale: bool,
519
520 #[builder(default = "f32::INFINITY")]
522 flow_shift: f32,
523
524 #[builder(default = "0")]
526 timestep_shift: i32,
527
528 #[builder(default = "false")]
530 taesd_preview_only: bool,
531
532 #[builder(default = "LoraModeType::LORA_APPLY_AUTO")]
534 lora_apply_mode: LoraModeType,
535
536 #[builder(default = "false")]
538 circular: bool,
539
540 #[builder(default = "false")]
542 circular_x: bool,
543
544 #[builder(default = "false")]
546 circular_y: bool,
547
548 #[builder(default = "None", private)]
549 upscaler_ctx: Option<*mut upscaler_ctx_t>,
550
551 #[builder(default = "None", private)]
552 diffusion_ctx: Option<(*mut sd_ctx_t, sd_ctx_params_t)>,
553}
554
555impl ModelConfigBuilder {
556 fn validate(&self) -> Result<(), ConfigBuilderError> {
557 self.validate_model()
558 }
559
560 fn validate_model(&self) -> Result<(), ConfigBuilderError> {
561 self.model
562 .as_ref()
563 .or(self.diffusion_model.as_ref())
564 .map(|_| ())
565 .ok_or(ConfigBuilderError::UninitializedField(
566 "Model OR DiffusionModel must be valorized",
567 ))
568 }
569
570 fn filter_valid_extensions(path: &Path) -> impl Iterator<Item = DirEntry> {
571 WalkDir::new(path)
572 .into_iter()
573 .filter_map(|entry| entry.ok())
574 .filter(|entry| {
575 entry
576 .path()
577 .extension()
578 .and_then(|ext| ext.to_str())
579 .map(|ext_str| VALID_EXT.contains(&ext_str))
580 .unwrap_or(false)
581 })
582 }
583 fn build_single_lora_storage(
584 spec: &LoraSpec,
585 valid_loras: &HashMap<String, PathBuf>,
586 ) -> (CLibPath, LoraSpec) {
587 let path = valid_loras.get(&spec.file_name).unwrap().as_path();
588 let c_path = CLibPath::from(path);
589 (c_path, spec.clone())
590 }
591
592 pub fn embeddings(&mut self, embeddings_dir: &Path) -> &mut Self {
593 let data: Vec<(CLibString, CLibPath)> = Self::filter_valid_extensions(embeddings_dir)
594 .map(|entry| {
595 let file_stem = entry
596 .path()
597 .file_stem()
598 .and_then(|stem| stem.to_str())
599 .unwrap_or_default()
600 .to_owned();
601 (CLibString::from(file_stem), CLibPath::from(entry.path()))
602 })
603 .collect();
604 let data_pointer = data
605 .iter()
606 .map(|(name, path)| sd_embedding_t {
607 name: name.as_ptr(),
608 path: path.as_ptr(),
609 })
610 .collect();
611 self.embeddings = Some((embeddings_dir.to_path_buf(), data, data_pointer));
612 self
613 }
614
615 pub fn lora_models(&mut self, lora_model_dir: &Path, specs: Vec<LoraSpec>) -> &mut Self {
616 let valid_loras: HashMap<String, PathBuf> = Self::filter_valid_extensions(lora_model_dir)
617 .map(|entry| {
618 let path = entry.path();
619 (
620 path.file_stem()
621 .and_then(|stem| stem.to_str())
622 .unwrap_or_default()
623 .to_owned(),
624 path.to_path_buf(),
625 )
626 })
627 .collect();
628 let valid_lora_names: Vec<&String> = valid_loras.keys().collect();
629 let standard = specs
630 .iter()
631 .filter(|s| valid_lora_names.contains(&&s.file_name) && !s.is_high_noise)
632 .map(|s| Self::build_single_lora_storage(s, &valid_loras));
633 let high_noise = specs
634 .iter()
635 .filter(|s| valid_lora_names.contains(&&s.file_name) && s.is_high_noise)
636 .map(|s| Self::build_single_lora_storage(s, &valid_loras));
637
638 self.lora_models_internal(standard.chain(high_noise).collect())
639 }
640
641 fn lora_models_internal(&mut self, lora_storage: LoraStorage) -> &mut Self {
642 self.lora_models = Some(lora_storage);
643 self
644 }
645
646 pub fn n_threads(&mut self, value: i32) -> &mut Self {
647 self.n_threads = if value > 0 {
648 Some(value)
649 } else {
650 Some(num_cpus::get_physical() as i32)
651 };
652 self
653 }
654}
655
656impl ModelConfig {
657 unsafe fn upscaler_ctx(&mut self) -> Option<*mut upscaler_ctx_t> {
658 unsafe {
659 if self.upscale_model.is_none() || self.upscale_repeats == 0 {
660 None
661 } else {
662 if self.upscaler_ctx.is_none() {
663 let upscaler = new_upscaler_ctx(
664 self.upscale_model.as_ref().unwrap().as_ptr(),
665 self.offload_params_to_cpu,
666 self.diffusion_conv_direct,
667 self.n_threads,
668 self.upscale_tile_size,
669 );
670 self.upscaler_ctx = Some(upscaler);
671 }
672 self.upscaler_ctx
673 }
674 }
675 }
676
677 unsafe fn diffusion_ctx(&mut self, vae_decode_only: bool) -> *mut sd_ctx_t {
678 unsafe {
679 if let Some((sd_ctx, sd_ctx_params)) = self.diffusion_ctx.as_ref()
683 && sd_ctx_params.vae_decode_only != vae_decode_only
684 {
685 sd_set_progress_callback(None, null_mut());
686 free_sd_ctx(*sd_ctx);
687 self.diffusion_ctx = None;
688 }
689 if self.diffusion_ctx.is_none() {
690 let sd_ctx_params = sd_ctx_params_t {
691 model_path: self.model.as_ptr(),
692 llm_path: self.llm.as_ptr(),
693 llm_vision_path: self.llm_vision.as_ptr(),
694 clip_l_path: self.clip_l.as_ptr(),
695 clip_g_path: self.clip_g.as_ptr(),
696 clip_vision_path: self.clip_vision.as_ptr(),
697 high_noise_diffusion_model_path: self.high_noise_diffusion_model.as_ptr(),
698 t5xxl_path: self.t5xxl.as_ptr(),
699 diffusion_model_path: self.diffusion_model.as_ptr(),
700 vae_path: self.vae.as_ptr(),
701 taesd_path: self.taesd.as_ptr(),
702 control_net_path: self.control_net.as_ptr(),
703 embeddings: self.embeddings.2.as_ptr(),
704 embedding_count: self.embeddings.1.len() as u32,
705 photo_maker_path: self.photo_maker.as_ptr(),
706 vae_decode_only,
707 free_params_immediately: false,
708 n_threads: self.n_threads,
709 wtype: self.weight_type,
710 rng_type: self.rng,
711 keep_clip_on_cpu: self.clip_on_cpu,
712 keep_control_net_on_cpu: self.control_net_cpu,
713 keep_vae_on_cpu: self.vae_on_cpu,
714 diffusion_flash_attn: self.diffusion_flash_attention,
715 flash_attn: self.flash_attention,
716 diffusion_conv_direct: self.diffusion_conv_direct,
717 chroma_use_dit_mask: !self.chroma_disable_dit_mask,
718 chroma_use_t5_mask: self.chroma_enable_t5_mask,
719 chroma_t5_mask_pad: self.chroma_t5_mask_pad,
720 vae_conv_direct: self.vae_conv_direct,
721 offload_params_to_cpu: self.offload_params_to_cpu,
722 prediction: self.prediction,
723 force_sdxl_vae_conv_scale: self.force_sdxl_vae_conv_scale,
724 tae_preview_only: self.taesd_preview_only,
725 lora_apply_mode: self.lora_apply_mode,
726 tensor_type_rules: null_mut(),
727 sampler_rng_type: self.sampler_rng_type,
728 circular_x: self.circular || self.circular_x,
729 circular_y: self.circular || self.circular_y,
730 qwen_image_zero_cond_t: self.use_qwen_image_zero_cond_true,
731 enable_mmap: self.enable_mmap,
732 };
733 let ctx = new_sd_ctx(&sd_ctx_params);
734 self.diffusion_ctx = Some((ctx, sd_ctx_params))
735 }
736 self.diffusion_ctx.unwrap().0
737 }
738 }
739}
740
741impl Drop for ModelConfig {
742 fn drop(&mut self) {
743 unsafe {
745 if let Some((sd_ctx, _)) = self.diffusion_ctx {
746 free_sd_ctx(sd_ctx);
747 }
748
749 if let Some(upscaler_ctx) = self.upscaler_ctx {
750 free_upscaler_ctx(upscaler_ctx);
751 }
752 }
753 }
754}
755
756impl From<ModelConfig> for ModelConfigBuilder {
757 fn from(value: ModelConfig) -> Self {
758 let mut builder = ModelConfigBuilder::default();
759 builder
760 .n_threads(value.n_threads)
761 .offload_params_to_cpu(value.offload_params_to_cpu)
762 .upscale_repeats(value.upscale_repeats)
763 .model(value.model.clone())
764 .diffusion_model(value.diffusion_model.clone())
765 .llm(value.llm.clone())
766 .llm_vision(value.llm_vision.clone())
767 .clip_l(value.clip_l.clone())
768 .clip_g(value.clip_g.clone())
769 .clip_vision(value.clip_vision.clone())
770 .t5xxl(value.t5xxl.clone())
771 .vae(value.vae.clone())
772 .taesd(value.taesd.clone())
773 .control_net(value.control_net.clone())
774 .embeddings(&value.embeddings.0)
775 .photo_maker(value.photo_maker.clone())
776 .pm_id_embed_path(value.pm_id_embed_path.clone())
777 .weight_type(value.weight_type)
778 .high_noise_diffusion_model(value.high_noise_diffusion_model.clone())
779 .vae_tiling(value.vae_tiling)
780 .vae_tile_size(value.vae_tile_size)
781 .vae_relative_tile_size(value.vae_relative_tile_size)
782 .vae_tile_overlap(value.vae_tile_overlap)
783 .rng(value.rng)
784 .sampler_rng_type(value.rng)
785 .scheduler(value.scheduler)
786 .sigmas(value.sigmas.clone())
787 .prediction(value.prediction)
788 .vae_on_cpu(value.vae_on_cpu)
789 .clip_on_cpu(value.clip_on_cpu)
790 .control_net(value.control_net.clone())
791 .control_net_cpu(value.control_net_cpu)
792 .flash_attention(value.flash_attention)
793 .chroma_disable_dit_mask(value.chroma_disable_dit_mask)
794 .chroma_enable_t5_mask(value.chroma_enable_t5_mask)
795 .chroma_t5_mask_pad(value.chroma_t5_mask_pad)
796 .diffusion_conv_direct(value.diffusion_conv_direct)
797 .vae_conv_direct(value.vae_conv_direct)
798 .force_sdxl_vae_conv_scale(value.force_sdxl_vae_conv_scale)
799 .flow_shift(value.flow_shift)
800 .timestep_shift(value.timestep_shift)
801 .taesd_preview_only(value.taesd_preview_only)
802 .lora_apply_mode(value.lora_apply_mode)
803 .circular(value.circular)
804 .circular_x(value.circular_x)
805 .circular_y(value.circular_y)
806 .use_qwen_image_zero_cond_true(value.use_qwen_image_zero_cond_true);
807
808 builder.lora_models_internal(value.lora_models.clone());
809
810 if let Some(model) = &value.upscale_model {
811 builder.upscale_model(model.clone());
812 }
813 builder
814 }
815}
816
817#[derive(Builder, Debug, Clone)]
818#[builder(setter(into, strip_option), build_fn(validate = "Self::validate"))]
819pub struct Config {
821 #[builder(default = "Default::default()")]
823 pm_id_images_dir: CLibPath,
824
825 #[builder(default = "Default::default()")]
827 init_img: PathBuf,
828
829 #[builder(default = "Default::default()")]
831 mask_img: PathBuf,
832
833 #[builder(default = "Default::default()")]
835 control_image: CLibPath,
836
837 #[builder(default = "Default::default()")]
839 ref_images: Vec<PathBuf>,
840
841 #[builder(default = "PathBuf::from(\"./output.png\")")]
843 output: PathBuf,
844
845 #[builder(default = "PathBuf::from(\"./preview_output.png\")")]
847 preview_output: PathBuf,
848
849 #[builder(default = "PreviewType::PREVIEW_NONE")]
851 preview_mode: PreviewType,
852
853 #[builder(default = "false")]
855 preview_noisy: bool,
856
857 #[builder(default = "1")]
859 preview_interval: i32,
860
861 prompt: String,
863
864 #[builder(default = "\"\".into()")]
866 negative_prompt: CLibString,
867
868 #[builder(default = "7.0")]
870 cfg_scale: f32,
871
872 #[builder(default = "3.5")]
874 guidance: f32,
875
876 #[builder(default = "0.75")]
878 strength: f32,
879
880 #[builder(default = "20.0")]
882 pm_style_strength: f32,
883
884 #[builder(default = "0.9")]
887 control_strength: f32,
888
889 #[builder(default = "512")]
891 height: i32,
892
893 #[builder(default = "512")]
895 width: i32,
896
897 #[builder(default = "SampleMethod::SAMPLE_METHOD_COUNT")]
901 sampling_method: SampleMethod,
902
903 #[builder(default = "0.")]
905 eta: f32,
906
907 #[builder(default = "20")]
909 steps: i32,
910
911 #[builder(default = "42")]
913 seed: i64,
914
915 #[builder(default = "1")]
917 batch_count: i32,
918
919 #[builder(default = "ClipSkip::Unspecified")]
922 clip_skip: ClipSkip,
923
924 #[builder(default = "false")]
926 canny: bool,
927
928 #[builder(default = "0.")]
931 slg_scale: f32,
932
933 #[builder(default = "vec![7, 8, 9]")]
935 skip_layer: Vec<i32>,
936
937 #[builder(default = "0.01")]
939 skip_layer_start: f32,
940
941 #[builder(default = "0.2")]
943 skip_layer_end: f32,
944
945 #[builder(default = "false")]
947 disable_auto_resize_ref_image: bool,
948
949 #[builder(default = "Self::cache_init()", private)]
950 cache: sd_cache_params_t,
951
952 #[builder(default = "CLibString::default()", private)]
953 scm_mask: CLibString,
954}
955
956impl ConfigBuilder {
957 fn validate(&self) -> Result<(), ConfigBuilderError> {
958 self.validate_output_dir()
959 }
960
961 fn validate_output_dir(&self) -> Result<(), ConfigBuilderError> {
962 let is_dir = self.output.as_ref().is_some_and(|val| val.is_dir());
963 let multiple_items = self.batch_count.as_ref().is_some_and(|val| *val > 1);
964 if is_dir == multiple_items {
965 Ok(())
966 } else {
967 Err(ConfigBuilderError::ValidationError(
968 "When batch_count > 1, output should point to folder and vice versa".to_owned(),
969 ))
970 }
971 }
972
973 fn cache_init() -> sd_cache_params_t {
974 sd_cache_params_t {
975 mode: sd_cache_mode_t::SD_CACHE_DISABLED,
976 reuse_threshold: 1.0,
977 start_percent: 0.15,
978 end_percent: 0.95,
979 error_decay_rate: 1.0,
980 use_relative_threshold: true,
981 reset_error_on_compute: true,
982 Fn_compute_blocks: 8,
983 Bn_compute_blocks: 0,
984 residual_diff_threshold: 0.08,
985 max_warmup_steps: 8,
986 max_cached_steps: -1,
987 max_continuous_cached_steps: -1,
988 taylorseer_n_derivatives: 1,
989 taylorseer_skip_interval: 1,
990 scm_mask: null(),
991 scm_policy_dynamic: true,
992 }
993 }
994
995 pub fn no_caching(&mut self) -> &mut Self {
996 let mut cache = Self::cache_init();
997 cache.mode = sd_cache_mode_t::SD_CACHE_DISABLED;
998 self.cache = Some(cache);
999 self
1000 }
1001
1002 pub fn ucache_caching(&mut self, params: UCacheParams) -> &mut Self {
1003 let mut cache = Self::cache_init();
1004 cache.mode = sd_cache_mode_t::SD_CACHE_UCACHE;
1005 cache.reuse_threshold = params.threshold;
1006 cache.start_percent = params.start;
1007 cache.end_percent = params.end;
1008 cache.error_decay_rate = params.decay;
1009 cache.use_relative_threshold = params.relative;
1010 cache.reset_error_on_compute = params.reset;
1011 self.cache = Some(cache);
1012 self
1013 }
1014
1015 pub fn easy_cache_caching(&mut self, params: EasyCacheParams) -> &mut Self {
1016 let mut cache = Self::cache_init();
1017 cache.mode = sd_cache_mode_t::SD_CACHE_EASYCACHE;
1018 cache.reuse_threshold = params.threshold;
1019 cache.start_percent = params.start;
1020 cache.end_percent = params.end;
1021 self.cache = Some(cache);
1022 self
1023 }
1024
1025 pub fn db_cache_caching(&mut self, params: DbCacheParams) -> &mut Self {
1026 let mut cache = Self::cache_init();
1027 cache.mode = sd_cache_mode_t::SD_CACHE_DBCACHE;
1028 cache.Fn_compute_blocks = params.fn_blocks;
1029 cache.Bn_compute_blocks = params.bn_blocks;
1030 cache.residual_diff_threshold = params.threshold;
1031 cache.max_warmup_steps = params.warmup;
1032 cache.scm_policy_dynamic = match params.scm_policy_dynamic {
1033 ScmPolicy::Static => false,
1034 ScmPolicy::Dynamic => true,
1035 };
1036 self.scm_mask = Some(CLibString::from(
1037 params
1038 .scm_mask
1039 .to_vec_string(self.steps.unwrap_or_default()),
1040 ));
1041 cache.scm_mask = self.scm_mask.as_ref().unwrap().as_ptr();
1042
1043 self.cache = Some(cache);
1044 self
1045 }
1046
1047 pub fn taylor_seer_caching(&mut self) -> &mut Self {
1048 let mut cache = Self::cache_init();
1049 cache.mode = sd_cache_mode_t::SD_CACHE_TAYLORSEER;
1050 self.cache = Some(cache);
1051 self
1052 }
1053
1054 pub fn cache_dit_caching(&mut self, params: DbCacheParams) -> &mut Self {
1055 self.db_cache_caching(params).cache.unwrap().mode = sd_cache_mode_t::SD_CACHE_CACHE_DIT;
1056 self
1057 }
1058}
1059
1060impl From<Config> for ConfigBuilder {
1061 fn from(value: Config) -> Self {
1062 let mut builder = ConfigBuilder::default();
1063 let mut cache = value.cache;
1064 let scm_mask = value.scm_mask.clone();
1065 cache.scm_mask = scm_mask.as_ptr();
1066 builder
1067 .pm_id_images_dir(value.pm_id_images_dir)
1068 .init_img(value.init_img)
1069 .mask_img(value.mask_img)
1070 .control_image(value.control_image)
1071 .ref_images(value.ref_images)
1072 .output(value.output)
1073 .prompt(value.prompt)
1074 .negative_prompt(value.negative_prompt)
1075 .cfg_scale(value.cfg_scale)
1076 .strength(value.strength)
1077 .pm_style_strength(value.pm_style_strength)
1078 .control_strength(value.control_strength)
1079 .height(value.height)
1080 .width(value.width)
1081 .sampling_method(value.sampling_method)
1082 .steps(value.steps)
1083 .seed(value.seed)
1084 .batch_count(value.batch_count)
1085 .clip_skip(value.clip_skip)
1086 .slg_scale(value.slg_scale)
1087 .skip_layer(value.skip_layer)
1088 .skip_layer_start(value.skip_layer_start)
1089 .skip_layer_end(value.skip_layer_end)
1090 .canny(value.canny)
1091 .disable_auto_resize_ref_image(value.disable_auto_resize_ref_image)
1092 .preview_output(value.preview_output)
1093 .preview_mode(value.preview_mode)
1094 .preview_noisy(value.preview_noisy)
1095 .preview_interval(value.preview_interval)
1096 .cache(cache)
1097 .scm_mask(scm_mask);
1098 builder
1099 }
1100}
1101
1102#[derive(Debug, Clone, Default)]
1103struct CLibString(CString);
1104
1105impl CLibString {
1106 fn as_ptr(&self) -> *const c_char {
1107 self.0.as_ptr()
1108 }
1109}
1110
1111impl From<&str> for CLibString {
1112 fn from(value: &str) -> Self {
1113 Self(CString::new(value).unwrap())
1114 }
1115}
1116
1117impl From<String> for CLibString {
1118 fn from(value: String) -> Self {
1119 Self(CString::new(value).unwrap())
1120 }
1121}
1122
1123#[derive(Debug, Clone, Default)]
1124struct CLibPath(CString);
1125
1126impl CLibPath {
1127 fn as_ptr(&self) -> *const c_char {
1128 self.0.as_ptr()
1129 }
1130}
1131
1132impl From<PathBuf> for CLibPath {
1133 fn from(value: PathBuf) -> Self {
1134 Self(CString::new(value.to_str().unwrap_or_default()).unwrap())
1135 }
1136}
1137
1138impl From<&Path> for CLibPath {
1139 fn from(value: &Path) -> Self {
1140 Self(CString::new(value.to_str().unwrap_or_default()).unwrap())
1141 }
1142}
1143
1144impl From<&CLibPath> for PathBuf {
1145 fn from(value: &CLibPath) -> Self {
1146 PathBuf::from(value.0.to_str().unwrap())
1147 }
1148}
1149
1150fn output_files(path: &Path, batch_size: i32) -> Vec<PathBuf> {
1151 let date = Local::now().format("%Y.%m.%d-%H.%M.%S");
1152 if batch_size == 1 {
1153 vec![path.into()]
1154 } else {
1155 (1..=batch_size)
1156 .map(|id| path.join(format!("output_{date}_{id}.png")))
1157 .collect()
1158 }
1159}
1160
1161unsafe fn upscale(
1162 upscale_repeats: i32,
1163 upscaler_ctx: Option<*mut upscaler_ctx_t>,
1164 data: sd_image_t,
1165) -> Result<sd_image_t, DiffusionError> {
1166 unsafe {
1167 match upscaler_ctx {
1168 Some(upscaler_ctx) => {
1169 let upscale_factor = 4; let mut current_image = data;
1171 for _ in 0..upscale_repeats {
1172 let upscaled_image =
1173 diffusion_rs_sys::upscale(upscaler_ctx, current_image, upscale_factor);
1174
1175 if upscaled_image.data.is_null() {
1176 return Err(DiffusionError::Upscaler);
1177 }
1178
1179 free(current_image.data as *mut c_void);
1180 current_image = upscaled_image;
1181 }
1182 Ok(current_image)
1183 }
1184 None => Ok(data),
1185 }
1186 }
1187}
1188
1189pub fn gen_img_with_progress(
1191 config: &Config,
1192 model_config: &mut ModelConfig,
1193 sender: Sender<Progress>,
1194) -> Result<(), DiffusionError> {
1195 gen_img_maybe_progress(config, model_config, Some(sender))
1196}
1197
1198pub fn gen_img(config: &Config, model_config: &mut ModelConfig) -> Result<(), DiffusionError> {
1200 gen_img_maybe_progress(config, model_config, None)
1201}
1202
1203fn gen_img_maybe_progress(
1204 config: &Config,
1205 model_config: &mut ModelConfig,
1206 mut sender: Option<Sender<Progress>>,
1207) -> Result<(), DiffusionError> {
1208 let prompt: CLibString = CLibString::from(config.prompt.as_str());
1209 let files = output_files(&config.output, config.batch_count);
1210 unsafe {
1211 let has_init_image = config.init_img.exists();
1212 let has_mask_image = config.mask_img.exists();
1213
1214 let is_decode_only = !has_init_image;
1215 let sd_ctx = model_config.diffusion_ctx(is_decode_only);
1216 let upscaler_ctx = model_config.upscaler_ctx();
1217
1218 let mut init_image = sd_image_t {
1219 width: 0,
1220 height: 0,
1221 channel: 3,
1222 data: std::ptr::null_mut(),
1223 };
1224 let mut mask_image = sd_image_t {
1225 width: config.width as u32,
1226 height: config.height as u32,
1227 channel: 1,
1228 data: null_mut(),
1229 };
1230 let mut layers = config.skip_layer.clone();
1231 let guidance = sd_guidance_params_t {
1232 txt_cfg: config.cfg_scale,
1233 img_cfg: config.cfg_scale,
1234 distilled_guidance: config.guidance,
1235 slg: sd_slg_params_t {
1236 layers: layers.as_mut_ptr(),
1237 layer_count: config.skip_layer.len(),
1238 layer_start: config.skip_layer_start,
1239 layer_end: config.skip_layer_end,
1240 scale: config.slg_scale,
1241 },
1242 };
1243 let scheduler = if model_config.scheduler == Scheduler::SCHEDULER_COUNT {
1244 sd_get_default_scheduler(sd_ctx, config.sampling_method)
1245 } else {
1246 model_config.scheduler
1247 };
1248 let sample_method = if config.sampling_method == SampleMethod::SAMPLE_METHOD_COUNT {
1249 sd_get_default_sample_method(sd_ctx)
1250 } else {
1251 config.sampling_method
1252 };
1253 let sample_params = sd_sample_params_t {
1254 guidance,
1255 sample_method,
1256 sample_steps: config.steps,
1257 eta: config.eta,
1258 scheduler,
1259 shifted_timestep: model_config.timestep_shift,
1260 custom_sigmas: model_config.sigmas.as_mut_ptr(),
1261 custom_sigmas_count: model_config.sigmas.len() as i32,
1262 flow_shift: model_config.flow_shift,
1263 };
1264 let control_image = sd_image_t {
1265 width: 0,
1266 height: 0,
1267 channel: 3,
1268 data: null_mut(),
1269 };
1270 let vae_tiling_params = sd_tiling_params_t {
1271 enabled: model_config.vae_tiling,
1272 tile_size_x: model_config.vae_tile_size.0,
1273 tile_size_y: model_config.vae_tile_size.1,
1274 target_overlap: model_config.vae_tile_overlap,
1275 rel_size_x: model_config.vae_relative_tile_size.0,
1276 rel_size_y: model_config.vae_relative_tile_size.1,
1277 };
1278 let pm_params = sd_pm_params_t {
1279 id_images: null_mut(),
1280 id_images_count: 0,
1281 id_embed_path: model_config.pm_id_embed_path.as_ptr(),
1282 style_strength: config.pm_style_strength,
1283 };
1284
1285 let mut image_buffer: Vec<u8> = Vec::new();
1287 let mut mask_buffer: Vec<u8> = Vec::new();
1288
1289 if has_init_image {
1290 let img = image::open(&config.init_img)?;
1291 image_buffer = img.to_rgb8().into_raw();
1292
1293 init_image = sd_image_t {
1294 width: img.width(),
1295 height: img.height(),
1296 channel: 3,
1297 data: image_buffer.as_mut_ptr(),
1298 }
1299 }
1300
1301 if has_mask_image {
1302 let img = image::open(&config.mask_img)?;
1303 mask_buffer = img.to_luma8().into_raw();
1305
1306 mask_image = sd_image_t {
1307 width: img.width(),
1308 height: img.height(),
1309 channel: 1,
1310 data: mask_buffer.as_mut_ptr(),
1311 }
1312 }
1313
1314 if !image_buffer.is_empty() && mask_buffer.is_empty() {
1318 let img: ImageBuffer<image::Luma<u8>, Vec<u8>> =
1319 ImageBuffer::from_pixel(init_image.width, init_image.height, image::Luma([255]));
1320 mask_buffer = img.into_raw();
1321 mask_image = sd_image_t {
1322 width: init_image.width,
1323 height: init_image.height,
1324 channel: 1,
1325 data: mask_buffer.as_mut_ptr(),
1326 }
1327 }
1328
1329 let mut ref_image_list = Vec::new();
1330 let mut ref_pixel_storage = Vec::new();
1331 for ref_path in &config.ref_images {
1332 if ref_path.exists() {
1333 let img = image::open(ref_path)?;
1334 let image_data = img.to_rgb8().into_raw();
1335
1336 ref_pixel_storage.push(image_data);
1337 let storage_ref = ref_pixel_storage.last_mut().unwrap();
1338 ref_image_list.push(sd_image_t {
1339 width: img.width(),
1340 height: img.height(),
1341 channel: 3,
1342 data: storage_ref.as_mut_ptr(),
1343 });
1344 }
1345 }
1346
1347 let num_ref_images = ref_image_list.len();
1348 let ref_image_ptr = if num_ref_images > 0 {
1349 ref_image_list.as_mut_ptr()
1350 } else {
1351 null_mut()
1352 };
1353
1354 unsafe extern "C" fn save_preview_local(
1355 _step: ::std::os::raw::c_int,
1356 _frame_count: ::std::os::raw::c_int,
1357 frames: *mut sd_image_t,
1358 _is_noisy: bool,
1359 data: *mut ::std::os::raw::c_void,
1360 ) {
1361 unsafe {
1362 let path = &*data.cast::<PathBuf>();
1363 let _ = save_img(*frames, path, None);
1364 }
1365 }
1366
1367 if config.preview_mode != PreviewType::PREVIEW_NONE {
1368 let data = &config.preview_output as *const PathBuf;
1369
1370 sd_set_preview_callback(
1371 Some(save_preview_local),
1372 config.preview_mode,
1373 config.preview_interval,
1374 !config.preview_noisy,
1375 config.preview_noisy,
1376 data as *mut c_void,
1377 );
1378 }
1379
1380 if sender.is_some() {
1381 unsafe extern "C" fn progress_callback(
1382 step: ::std::os::raw::c_int,
1383 steps: ::std::os::raw::c_int,
1384 time: f32,
1385 data: *mut ::std::os::raw::c_void,
1386 ) {
1387 unsafe {
1388 let sender = &*data.cast::<Option<Sender<Progress>>>();
1389
1390 if let Some(sender) = sender {
1391 let _ = sender.send(Progress { step, steps, time });
1392 }
1393 }
1394 }
1395 let sender_ptr: *mut c_void = &mut sender as *mut _ as *mut c_void;
1396 sd_set_progress_callback(Some(progress_callback), sender_ptr);
1397 }
1398
1399 let loras: Vec<sd_lora_t> = model_config
1400 .lora_models
1401 .iter()
1402 .map(|(c_path, spec)| sd_lora_t {
1403 is_high_noise: spec.is_high_noise,
1404 multiplier: spec.multiplier,
1405 path: c_path.as_ptr(),
1406 })
1407 .collect();
1408
1409 let sd_img_gen_params = sd_img_gen_params_t {
1410 prompt: prompt.as_ptr(),
1411 negative_prompt: config.negative_prompt.as_ptr(),
1412 clip_skip: config.clip_skip as i32,
1413 init_image,
1414 ref_images: ref_image_ptr,
1415 ref_images_count: num_ref_images as i32,
1416 increase_ref_index: false,
1417 mask_image,
1418 width: config.width,
1419 height: config.height,
1420 sample_params,
1421 strength: config.strength,
1422 seed: config.seed,
1423 batch_count: config.batch_count,
1424 control_image,
1425 control_strength: config.control_strength,
1426 pm_params,
1427 vae_tiling_params,
1428 auto_resize_ref_image: config.disable_auto_resize_ref_image,
1429 cache: config.cache,
1430 loras: loras.as_ptr(),
1431 lora_count: loras.len() as u32,
1432 };
1433
1434 let params_str = CString::from_raw(sd_img_gen_params_to_str(&sd_img_gen_params))
1435 .into_string()
1436 .unwrap();
1437
1438 let slice = generate_image(sd_ctx, &sd_img_gen_params);
1439 let ret = {
1440 if slice.is_null() {
1441 return Err(DiffusionError::Forward);
1442 }
1443 for (img, path) in slice::from_raw_parts(slice, config.batch_count as usize)
1444 .iter()
1445 .zip(files)
1446 {
1447 match upscale(model_config.upscale_repeats, upscaler_ctx, *img) {
1448 Ok(img) => save_img(img, &path, Some(¶ms_str))?,
1449 Err(err) => {
1450 return Err(err);
1451 }
1452 }
1453 }
1454 Ok(())
1455 };
1456 free(slice as *mut c_void);
1457 ret
1458 }
1459}
1460
1461fn save_img(img: sd_image_t, path: &Path, params: Option<&str>) -> Result<(), DiffusionError> {
1462 let len = (img.width * img.height * img.channel) as usize;
1464 let buffer = unsafe { slice::from_raw_parts(img.data, len).to_vec() };
1465 let save_state = ImageBuffer::from_raw(img.width, img.height, buffer).map(|img| {
1466 RgbImage::from(img)
1467 .save(path)
1468 .map_err(DiffusionError::StoreImages)
1469 });
1470 if let Some(Err(err)) = save_state {
1471 return Err(err);
1472 }
1473 if let Some(params) = params {
1474 let mut metadata = Metadata::new();
1475 metadata.set_tag(ExifTag::ImageDescription(params.to_string()));
1476 metadata.write_to_file(path)?;
1477 }
1478 Ok(())
1479}
1480
1481#[cfg(test)]
1482mod tests {
1483 use image::{DynamicImage, ImageBuffer, Rgba};
1484 use std::path::PathBuf;
1485
1486 use crate::{
1487 api::{ConfigBuilderError, ModelConfigBuilder},
1488 util::download_file_hf_hub,
1489 };
1490
1491 use super::{ConfigBuilder, gen_img};
1492
1493 #[test]
1494 fn test_required_args_txt2img() {
1495 assert!(ConfigBuilder::default().build().is_err());
1496 assert!(ModelConfigBuilder::default().build().is_err());
1497 ModelConfigBuilder::default()
1498 .model(PathBuf::from("./test.ckpt"))
1499 .build()
1500 .unwrap();
1501
1502 ConfigBuilder::default()
1503 .prompt("a lovely cat driving a sport car")
1504 .build()
1505 .unwrap();
1506
1507 assert!(matches!(
1508 ConfigBuilder::default()
1509 .prompt("a lovely cat driving a sport car")
1510 .batch_count(10)
1511 .build(),
1512 Err(ConfigBuilderError::ValidationError(_))
1513 ));
1514
1515 ConfigBuilder::default()
1516 .prompt("a lovely cat driving a sport car")
1517 .build()
1518 .unwrap();
1519
1520 ConfigBuilder::default()
1521 .prompt("a lovely duck drinking water from a bottle")
1522 .batch_count(2)
1523 .output(PathBuf::from("./"))
1524 .build()
1525 .unwrap();
1526 }
1527
1528 #[ignore]
1529 #[test]
1530 fn test_img2img_gen() {
1531 let model_path =
1532 download_file_hf_hub("CompVis/stable-diffusion-v-1-4-original", "sd-v1-4.ckpt")
1533 .unwrap();
1534 let gen_img_output = "./output_img.png";
1535 let config = ConfigBuilder::default()
1536 .prompt("A high quality 3d texture")
1537 .output(PathBuf::from(gen_img_output))
1538 .batch_count(1)
1539 .build()
1540 .unwrap();
1541
1542 let mut model_config = ModelConfigBuilder::default()
1543 .model(model_path)
1544 .build()
1545 .unwrap();
1546
1547 gen_img(&config, &mut model_config).unwrap();
1548
1549 let mut cond = ImageBuffer::new(512, 512);
1551 for (x, y, pixel) in cond.enumerate_pixels_mut() {
1552 let r = (x as f32 / 512.0 * 255.0) as u8;
1553 let g = (y as f32 / 512.0 * 255.0) as u8;
1554 let b = 127;
1555 *pixel = Rgba([r, g, b, 255]);
1556 }
1557 let cond_path = "test_cond_image.png";
1558 DynamicImage::ImageRgba8(cond)
1559 .save(cond_path)
1560 .expect("Failed to save reference image");
1561
1562 let refine_prompt = "PBR texture map, matching the lighting and micro-detail density of the reference image.";
1564 let img2img_config = ConfigBuilder::default()
1565 .prompt(refine_prompt)
1566 .output(PathBuf::from("./output_img_ref.png"))
1567 .ref_images(vec![PathBuf::from(cond_path)])
1568 .init_img(PathBuf::from(gen_img_output))
1569 .batch_count(1)
1570 .build()
1571 .unwrap();
1572 gen_img(&img2img_config, &mut model_config).unwrap();
1573
1574 gen_img(&config, &mut model_config).unwrap();
1576 }
1577
1578 #[ignore]
1579 #[test]
1580 fn test_img_gen() {
1581 let model_path =
1582 download_file_hf_hub("CompVis/stable-diffusion-v-1-4-original", "sd-v1-4.ckpt")
1583 .unwrap();
1584
1585 let upscaler_path = download_file_hf_hub(
1586 "ximso/RealESRGAN_x4plus_anime_6B",
1587 "RealESRGAN_x4plus_anime_6B.pth",
1588 )
1589 .unwrap();
1590 let config = ConfigBuilder::default()
1591 .prompt("a lovely duck drinking water from a bottle")
1592 .output(PathBuf::from("./output_1.png"))
1593 .batch_count(1)
1594 .build()
1595 .unwrap();
1596 let mut model_config = ModelConfigBuilder::default()
1597 .model(model_path)
1598 .upscale_model(upscaler_path)
1599 .upscale_repeats(1)
1600 .build()
1601 .unwrap();
1602
1603 gen_img(&config, &mut model_config).unwrap();
1604 let config2 = ConfigBuilder::from(config.clone())
1605 .prompt("a lovely duck drinking water from a straw")
1606 .output(PathBuf::from("./output_2.png"))
1607 .build()
1608 .unwrap();
1609 gen_img(&config2, &mut model_config).unwrap();
1610
1611 let config3 = ConfigBuilder::from(config)
1612 .prompt("a lovely dog drinking water from a starbucks cup")
1613 .batch_count(2)
1614 .output(PathBuf::from("./"))
1615 .build()
1616 .unwrap();
1617
1618 gen_img(&config3, &mut model_config).unwrap();
1619 }
1620}