1use std::collections::HashMap;
2use std::ffi::CString;
3use std::ffi::c_char;
4use std::ffi::c_void;
5use std::path::Path;
6use std::path::PathBuf;
7use std::ptr::null_mut;
8use std::slice;
9
10use derive_builder::Builder;
11use diffusion_rs_sys::free_upscaler_ctx;
12use diffusion_rs_sys::new_upscaler_ctx;
13use diffusion_rs_sys::sd_ctx_params_t;
14use diffusion_rs_sys::sd_easycache_params_t;
15use diffusion_rs_sys::sd_embedding_t;
16use diffusion_rs_sys::sd_get_default_sample_method;
17use diffusion_rs_sys::sd_get_default_scheduler;
18use diffusion_rs_sys::sd_guidance_params_t;
19use diffusion_rs_sys::sd_image_t;
20use diffusion_rs_sys::sd_img_gen_params_t;
21use diffusion_rs_sys::sd_lora_t;
22use diffusion_rs_sys::sd_pm_params_t;
23use diffusion_rs_sys::sd_sample_params_t;
24use diffusion_rs_sys::sd_set_preview_callback;
25use diffusion_rs_sys::sd_slg_params_t;
26use diffusion_rs_sys::sd_tiling_params_t;
27use diffusion_rs_sys::upscaler_ctx_t;
28use image::ImageBuffer;
29use image::ImageError;
30use image::RgbImage;
31use libc::free;
32use thiserror::Error;
33use walkdir::DirEntry;
34use walkdir::WalkDir;
35
36use diffusion_rs_sys::free_sd_ctx;
37use diffusion_rs_sys::new_sd_ctx;
38use diffusion_rs_sys::sd_ctx_t;
39
40pub use diffusion_rs_sys::rng_type_t as RngFunction;
42
43pub use diffusion_rs_sys::sample_method_t as SampleMethod;
45
46pub use diffusion_rs_sys::scheduler_t as Scheduler;
48
49pub use diffusion_rs_sys::prediction_t as Prediction;
51
52pub use diffusion_rs_sys::sd_type_t as WeightType;
54
55pub use diffusion_rs_sys::preview_t as PreviewType;
57
58pub use diffusion_rs_sys::lora_apply_mode_t as LoraModeType;
60
61static VALID_EXT: [&str; 3] = ["pt", "safetensors", "gguf"];
62
63#[non_exhaustive]
64#[derive(Error, Debug)]
65pub enum DiffusionError {
67 #[error("The underling stablediffusion.cpp function returned NULL")]
68 Forward,
69 #[error(transparent)]
70 StoreImages(#[from] ImageError),
71 #[error("The underling upscaler model returned a NULL image")]
72 Upscaler,
73}
74
75#[repr(i32)]
76#[non_exhaustive]
77#[derive(Debug, Default, Copy, Clone, Hash, PartialEq, Eq)]
78pub enum ClipSkip {
80 #[default]
82 Unspecified = 0,
83 None = 1,
84 OneLayer = 2,
85}
86
87type EmbeddingsStorage = (PathBuf, Vec<(CLibString, CLibPath)>, Vec<sd_embedding_t>);
88
89#[derive(Default, Debug, Clone)]
90struct LoraStorage {
91 lora_model_dir: CLibPath,
92 data: Vec<(CLibPath, String, f32)>,
93 loras_t: Vec<sd_lora_t>,
94}
95
96#[derive(Default, Debug, Clone)]
98pub struct LoraSpec {
99 pub file_name: String,
100 pub is_high_noise: bool,
101 pub multiplier: f32,
102}
103
104#[derive(Builder, Debug, Clone)]
105#[builder(
106 setter(into, strip_option),
107 build_fn(error = "ConfigBuilderError", validate = "Self::validate")
108)]
109pub struct ModelConfig {
110 #[builder(default = "num_cpus::get_physical() as i32", setter(custom))]
113 n_threads: i32,
114
115 #[builder(default = "false")]
117 offload_params_to_cpu: bool,
118
119 #[builder(default = "Default::default()")]
121 upscale_model: Option<CLibPath>,
122
123 #[builder(default = "1")]
125 upscale_repeats: i32,
126
127 #[builder(default = "128")]
129 upscale_tile_size: i32,
130
131 #[builder(default = "Default::default()")]
133 model: CLibPath,
134
135 #[builder(default = "Default::default()")]
137 diffusion_model: CLibPath,
138
139 #[builder(default = "Default::default()")]
141 llm: CLibPath,
142
143 #[builder(default = "Default::default()")]
145 llm_vision: CLibPath,
146
147 #[builder(default = "Default::default()")]
149 clip_l: CLibPath,
150
151 #[builder(default = "Default::default()")]
153 clip_g: CLibPath,
154
155 #[builder(default = "Default::default()")]
157 clip_vision: CLibPath,
158
159 #[builder(default = "Default::default()")]
161 t5xxl: CLibPath,
162
163 #[builder(default = "Default::default()")]
165 vae: CLibPath,
166
167 #[builder(default = "Default::default()")]
169 taesd: CLibPath,
170
171 #[builder(default = "Default::default()")]
173 control_net: CLibPath,
174
175 #[builder(default = "Default::default()", setter(custom))]
177 embeddings: EmbeddingsStorage,
178
179 #[builder(default = "Default::default()")]
181 photo_maker: CLibPath,
182
183 #[builder(default = "Default::default()")]
185 pm_id_embed_path: CLibPath,
186
187 #[builder(default = "WeightType::SD_TYPE_COUNT")]
189 weight_type: WeightType,
190
191 #[builder(default = "Default::default()", setter(custom))]
193 lora_models: LoraStorage,
194
195 #[builder(default = "Default::default()")]
197 high_noise_diffusion_model: CLibPath,
198
199 #[builder(default = "false")]
201 vae_tiling: bool,
202
203 #[builder(default = "(32,32)")]
205 vae_tile_size: (i32, i32),
206
207 #[builder(default = "(0.,0.)")]
209 vae_relative_tile_size: (f32, f32),
210
211 #[builder(default = "0.5")]
213 vae_tile_overlap: f32,
214
215 #[builder(default = "RngFunction::CUDA_RNG")]
217 rng: RngFunction,
218
219 #[builder(default = "RngFunction::RNG_TYPE_COUNT")]
221 sampler_rng_type: RngFunction,
222
223 #[builder(default = "Scheduler::SCHEDULER_COUNT")]
227 scheduler: Scheduler,
228
229 #[builder(default = "Prediction::PREDICTION_COUNT")]
231 prediction: Prediction,
232
233 #[builder(default = "false")]
235 vae_on_cpu: bool,
236
237 #[builder(default = "false")]
239 clip_on_cpu: bool,
240
241 #[builder(default = "false")]
243 control_net_cpu: bool,
244
245 #[builder(default = "false")]
248 flash_attention: bool,
249
250 #[builder(default = "false")]
252 chroma_disable_dit_mask: bool,
253
254 #[builder(default = "false")]
256 chroma_enable_t5_mask: bool,
257
258 #[builder(default = "1")]
260 chroma_t5_mask_pad: i32,
261
262 #[builder(default = "false")]
265 diffusion_conv_direct: bool,
266
267 #[builder(default = "false")]
270 vae_conv_direct: bool,
271
272 #[builder(default = "false")]
274 force_sdxl_vae_conv_scale: bool,
275
276 #[builder(default = "f32::INFINITY")]
278 flow_shift: f32,
279
280 #[builder(default = "0")]
282 timestep_shift: i32,
283
284 #[builder(default = "false")]
286 taesd_preview_only: bool,
287
288 #[builder(default = "LoraModeType::LORA_APPLY_AUTO")]
290 lora_apply_mode: LoraModeType,
291
292 #[builder(default = "false")]
294 easy_cache: bool,
295
296 #[builder(default = "0.2")]
298 easy_cache_reuse_threshold: f32,
299
300 #[builder(default = "0.15")]
302 easy_cache_start_percent: f32,
303
304 #[builder(default = "0.95")]
306 easy_cache_end_percent: f32,
307
308 #[builder(default = "None", private)]
309 upscaler_ctx: Option<*mut upscaler_ctx_t>,
310
311 #[builder(default = "None", private)]
312 diffusion_ctx: Option<(*mut sd_ctx_t, sd_ctx_params_t)>,
313}
314
315impl ModelConfigBuilder {
316 fn validate(&self) -> Result<(), ConfigBuilderError> {
317 self.validate_model()
318 }
319
320 fn validate_model(&self) -> Result<(), ConfigBuilderError> {
321 self.model
322 .as_ref()
323 .or(self.diffusion_model.as_ref())
324 .map(|_| ())
325 .ok_or(ConfigBuilderError::UninitializedField(
326 "Model OR DiffusionModel must be valorized",
327 ))
328 }
329
330 fn filter_valid_extensions(&self, path: &Path) -> impl Iterator<Item = DirEntry> {
331 WalkDir::new(path)
332 .into_iter()
333 .filter_map(|entry| entry.ok())
334 .filter(|entry| {
335 entry
336 .path()
337 .extension()
338 .and_then(|ext| ext.to_str())
339 .map(|ext_str| VALID_EXT.contains(&ext_str))
340 .unwrap_or(false)
341 })
342 }
343
344 fn build_single_lora_storage(
345 spec: &LoraSpec,
346 is_high_noise: bool,
347 valid_loras: &HashMap<String, PathBuf>,
348 ) -> ((CLibPath, String, f32), sd_lora_t) {
349 let path = valid_loras.get(&spec.file_name).unwrap().as_path();
350 let c_path = CLibPath::from(path);
351 let lora = sd_lora_t {
352 is_high_noise,
353 multiplier: spec.multiplier,
354 path: c_path.as_ptr(),
355 };
356 let data = (c_path, spec.file_name.clone(), spec.multiplier);
357 (data, lora)
358 }
359
360 pub fn embeddings(&mut self, embeddings_dir: &Path) -> &mut Self {
361 let data: Vec<(CLibString, CLibPath)> = self
362 .filter_valid_extensions(embeddings_dir)
363 .map(|entry| {
364 let file_stem = entry
365 .path()
366 .file_stem()
367 .and_then(|stem| stem.to_str())
368 .unwrap_or_default()
369 .to_owned();
370 (CLibString::from(file_stem), CLibPath::from(entry.path()))
371 })
372 .collect();
373 let data_pointer = data
374 .iter()
375 .map(|(name, path)| sd_embedding_t {
376 name: name.as_ptr(),
377 path: path.as_ptr(),
378 })
379 .collect();
380 self.embeddings = Some((embeddings_dir.to_path_buf(), data, data_pointer));
381 self
382 }
383
384 pub fn lora_models(&mut self, lora_model_dir: &Path, specs: Vec<LoraSpec>) -> &mut Self {
385 let valid_loras: HashMap<String, PathBuf> = self
386 .filter_valid_extensions(lora_model_dir)
387 .map(|entry| {
388 let path = entry.path();
389 (
390 path.file_stem()
391 .and_then(|stem| stem.to_str())
392 .unwrap_or_default()
393 .to_owned(),
394 path.to_path_buf(),
395 )
396 })
397 .collect();
398 let valid_lora_names: Vec<&String> = valid_loras.keys().collect();
399 let standard = specs
400 .iter()
401 .filter(|s| valid_lora_names.contains(&&s.file_name) && !s.is_high_noise)
402 .map(|s| Self::build_single_lora_storage(s, false, &valid_loras));
403 let high_noise = specs
404 .iter()
405 .filter(|s| valid_lora_names.contains(&&s.file_name) && s.is_high_noise)
406 .map(|s| Self::build_single_lora_storage(s, true, &valid_loras));
407
408 let mut data = Vec::new();
409 let mut loras_t = Vec::new();
410 for lora in standard.chain(high_noise) {
411 data.push(lora.0);
412 loras_t.push(lora.1);
413 }
414
415 self.lora_models = Some(LoraStorage {
416 lora_model_dir: lora_model_dir.into(),
417 data,
418 loras_t,
419 });
420 self
421 }
422
423 pub fn n_threads(&mut self, value: i32) -> &mut Self {
424 self.n_threads = if value > 0 {
425 Some(value)
426 } else {
427 Some(num_cpus::get_physical() as i32)
428 };
429 self
430 }
431}
432
433impl ModelConfig {
434 unsafe fn upscaler_ctx(&mut self) -> Option<*mut upscaler_ctx_t> {
435 unsafe {
436 if self.upscale_model.is_none() || self.upscale_repeats == 0 {
437 None
438 } else {
439 if self.upscaler_ctx.is_none() {
440 let upscaler = new_upscaler_ctx(
441 self.upscale_model.as_ref().unwrap().as_ptr(),
442 self.offload_params_to_cpu,
443 self.diffusion_conv_direct,
444 self.n_threads,
445 self.upscale_tile_size,
446 );
447 self.upscaler_ctx = Some(upscaler);
448 }
449 self.upscaler_ctx
450 }
451 }
452 }
453
454 unsafe fn diffusion_ctx(&mut self, vae_decode_only: bool) -> *mut sd_ctx_t {
455 unsafe {
456 if self.diffusion_ctx.is_none() {
457 let sd_ctx_params = sd_ctx_params_t {
458 model_path: self.model.as_ptr(),
459 llm_path: self.llm.as_ptr(),
460 llm_vision_path: self.llm_vision.as_ptr(),
461 clip_l_path: self.clip_l.as_ptr(),
462 clip_g_path: self.clip_g.as_ptr(),
463 clip_vision_path: self.clip_vision.as_ptr(),
464 high_noise_diffusion_model_path: self.high_noise_diffusion_model.as_ptr(),
465 t5xxl_path: self.t5xxl.as_ptr(),
466 diffusion_model_path: self.diffusion_model.as_ptr(),
467 vae_path: self.vae.as_ptr(),
468 taesd_path: self.taesd.as_ptr(),
469 control_net_path: self.control_net.as_ptr(),
470 lora_model_dir: self.lora_models.lora_model_dir.as_ptr(),
471 embeddings: self.embeddings.2.as_ptr(),
472 embedding_count: self.embeddings.1.len() as u32,
473 photo_maker_path: self.photo_maker.as_ptr(),
474 vae_decode_only,
475 free_params_immediately: false,
476 n_threads: self.n_threads,
477 wtype: self.weight_type,
478 rng_type: self.rng,
479 keep_clip_on_cpu: self.clip_on_cpu,
480 keep_control_net_on_cpu: self.control_net_cpu,
481 keep_vae_on_cpu: self.vae_on_cpu,
482 diffusion_flash_attn: self.flash_attention,
483 diffusion_conv_direct: self.diffusion_conv_direct,
484 chroma_use_dit_mask: !self.chroma_disable_dit_mask,
485 chroma_use_t5_mask: self.chroma_enable_t5_mask,
486 chroma_t5_mask_pad: self.chroma_t5_mask_pad,
487 vae_conv_direct: self.vae_conv_direct,
488 offload_params_to_cpu: self.offload_params_to_cpu,
489 flow_shift: self.flow_shift,
490 prediction: self.prediction,
491 force_sdxl_vae_conv_scale: self.force_sdxl_vae_conv_scale,
492 tae_preview_only: self.taesd_preview_only,
493 lora_apply_mode: self.lora_apply_mode,
494 tensor_type_rules: null_mut(),
495 sampler_rng_type: self.sampler_rng_type,
496 };
497 let ctx = new_sd_ctx(&sd_ctx_params);
498 self.diffusion_ctx = Some((ctx, sd_ctx_params))
499 }
500 self.diffusion_ctx.unwrap().0
501 }
502 }
503}
504
505impl Drop for ModelConfig {
506 fn drop(&mut self) {
507 unsafe {
509 if let Some((sd_ctx, _)) = self.diffusion_ctx {
510 free_sd_ctx(sd_ctx);
511 }
512
513 if let Some(upscaler_ctx) = self.upscaler_ctx {
514 free_upscaler_ctx(upscaler_ctx);
515 }
516 }
517 }
518}
519
520impl From<ModelConfig> for ModelConfigBuilder {
521 fn from(value: ModelConfig) -> Self {
522 let mut builder = ModelConfigBuilder::default();
523 builder
524 .n_threads(value.n_threads)
525 .offload_params_to_cpu(value.offload_params_to_cpu)
526 .upscale_repeats(value.upscale_repeats)
527 .model(value.model.clone())
528 .diffusion_model(value.diffusion_model.clone())
529 .llm(value.llm.clone())
530 .llm_vision(value.llm_vision.clone())
531 .clip_l(value.clip_l.clone())
532 .clip_g(value.clip_g.clone())
533 .clip_vision(value.clip_vision.clone())
534 .t5xxl(value.t5xxl.clone())
535 .vae(value.vae.clone())
536 .taesd(value.taesd.clone())
537 .control_net(value.control_net.clone())
538 .embeddings(&value.embeddings.0)
539 .photo_maker(value.photo_maker.clone())
540 .pm_id_embed_path(value.pm_id_embed_path.clone())
541 .weight_type(value.weight_type)
542 .high_noise_diffusion_model(value.high_noise_diffusion_model.clone())
543 .vae_tiling(value.vae_tiling)
544 .vae_tile_size(value.vae_tile_size)
545 .vae_relative_tile_size(value.vae_relative_tile_size)
546 .vae_tile_overlap(value.vae_tile_overlap)
547 .rng(value.rng)
548 .sampler_rng_type(value.rng)
549 .scheduler(value.scheduler)
550 .prediction(value.prediction)
551 .vae_on_cpu(value.vae_on_cpu)
552 .clip_on_cpu(value.clip_on_cpu)
553 .control_net(value.control_net.clone())
554 .control_net_cpu(value.control_net_cpu)
555 .flash_attention(value.flash_attention)
556 .chroma_disable_dit_mask(value.chroma_disable_dit_mask)
557 .chroma_enable_t5_mask(value.chroma_enable_t5_mask)
558 .chroma_t5_mask_pad(value.chroma_t5_mask_pad)
559 .diffusion_conv_direct(value.diffusion_conv_direct)
560 .vae_conv_direct(value.vae_conv_direct)
561 .force_sdxl_vae_conv_scale(value.force_sdxl_vae_conv_scale)
562 .flow_shift(value.flow_shift)
563 .timestep_shift(value.timestep_shift)
564 .taesd_preview_only(value.taesd_preview_only)
565 .lora_apply_mode(value.lora_apply_mode);
566
567 let lora_model_dir = Into::<PathBuf>::into(&value.lora_models.lora_model_dir);
568 let lora_specs = value
569 .lora_models
570 .data
571 .iter()
572 .map(|(_, name, multiplier)| LoraSpec {
573 file_name: name.clone(),
574 is_high_noise: false,
575 multiplier: *multiplier,
576 })
577 .collect();
578
579 builder.lora_models(&lora_model_dir, lora_specs);
580
581 if let Some(model) = &value.upscale_model {
582 builder.upscale_model(model.clone());
583 }
584 builder
585 }
586}
587
588#[derive(Builder, Debug, Clone)]
589#[builder(setter(into, strip_option), build_fn(validate = "Self::validate"))]
590pub struct Config {
592 #[builder(default = "Default::default()")]
594 pm_id_images_dir: CLibPath,
595
596 #[builder(default = "Default::default()")]
598 init_img: CLibPath,
599
600 #[builder(default = "Default::default()")]
602 control_image: CLibPath,
603
604 #[builder(default = "PathBuf::from(\"./output.png\")")]
606 output: PathBuf,
607
608 #[builder(default = "PathBuf::from(\"./preview_output.png\")")]
610 preview_output: PathBuf,
611
612 #[builder(default = "PreviewType::PREVIEW_NONE")]
614 preview_mode: PreviewType,
615
616 #[builder(default = "false")]
618 preview_noisy: bool,
619
620 #[builder(default = "1")]
622 preview_interval: i32,
623
624 prompt: String,
626
627 #[builder(default = "\"\".into()")]
629 negative_prompt: CLibString,
630
631 #[builder(default = "7.0")]
633 cfg_scale: f32,
634
635 #[builder(default = "3.5")]
637 guidance: f32,
638
639 #[builder(default = "0.75")]
641 strength: f32,
642
643 #[builder(default = "20.0")]
645 pm_style_strength: f32,
646
647 #[builder(default = "0.9")]
650 control_strength: f32,
651
652 #[builder(default = "512")]
654 height: i32,
655
656 #[builder(default = "512")]
658 width: i32,
659
660 #[builder(default = "SampleMethod::SAMPLE_METHOD_COUNT")]
664 sampling_method: SampleMethod,
665
666 #[builder(default = "0.")]
668 eta: f32,
669
670 #[builder(default = "20")]
672 steps: i32,
673
674 #[builder(default = "42")]
676 seed: i64,
677
678 #[builder(default = "1")]
680 batch_count: i32,
681
682 #[builder(default = "ClipSkip::Unspecified")]
685 clip_skip: ClipSkip,
686
687 #[builder(default = "false")]
689 canny: bool,
690
691 #[builder(default = "0.")]
694 slg_scale: f32,
695
696 #[builder(default = "vec![7, 8, 9]")]
698 skip_layer: Vec<i32>,
699
700 #[builder(default = "0.01")]
702 skip_layer_start: f32,
703
704 #[builder(default = "0.2")]
706 skip_layer_end: f32,
707
708 #[builder(default = "false")]
710 disable_auto_resize_ref_image: bool,
711}
712
713impl ConfigBuilder {
714 fn validate(&self) -> Result<(), ConfigBuilderError> {
715 self.validate_output_dir()
716 }
717
718 fn validate_output_dir(&self) -> Result<(), ConfigBuilderError> {
719 let is_dir = self.output.as_ref().is_some_and(|val| val.is_dir());
720 let multiple_items = self.batch_count.as_ref().is_some_and(|val| *val > 1);
721 if is_dir == multiple_items {
722 Ok(())
723 } else {
724 Err(ConfigBuilderError::ValidationError(
725 "When batch_count > 1, output should point to folder and vice versa".to_owned(),
726 ))
727 }
728 }
729}
730
731impl From<Config> for ConfigBuilder {
732 fn from(value: Config) -> Self {
733 let mut builder = ConfigBuilder::default();
734 builder
735 .pm_id_images_dir(value.pm_id_images_dir)
736 .init_img(value.init_img)
737 .control_image(value.control_image)
738 .output(value.output)
739 .prompt(value.prompt)
740 .negative_prompt(value.negative_prompt)
741 .cfg_scale(value.cfg_scale)
742 .strength(value.strength)
743 .pm_style_strength(value.pm_style_strength)
744 .control_strength(value.control_strength)
745 .height(value.height)
746 .width(value.width)
747 .sampling_method(value.sampling_method)
748 .steps(value.steps)
749 .seed(value.seed)
750 .batch_count(value.batch_count)
751 .clip_skip(value.clip_skip)
752 .slg_scale(value.slg_scale)
753 .skip_layer(value.skip_layer)
754 .skip_layer_start(value.skip_layer_start)
755 .skip_layer_end(value.skip_layer_end)
756 .canny(value.canny)
757 .disable_auto_resize_ref_image(value.disable_auto_resize_ref_image)
758 .preview_output(value.preview_output)
759 .preview_mode(value.preview_mode)
760 .preview_noisy(value.preview_noisy)
761 .preview_interval(value.preview_interval);
762 builder
763 }
764}
765
766#[derive(Debug, Clone, Default)]
767struct CLibString(CString);
768
769impl CLibString {
770 fn as_ptr(&self) -> *const c_char {
771 self.0.as_ptr()
772 }
773}
774
775impl From<&str> for CLibString {
776 fn from(value: &str) -> Self {
777 Self(CString::new(value).unwrap())
778 }
779}
780
781impl From<String> for CLibString {
782 fn from(value: String) -> Self {
783 Self(CString::new(value).unwrap())
784 }
785}
786
787#[derive(Debug, Clone, Default)]
788struct CLibPath(CString);
789
790impl CLibPath {
791 fn as_ptr(&self) -> *const c_char {
792 self.0.as_ptr()
793 }
794}
795
796impl From<PathBuf> for CLibPath {
797 fn from(value: PathBuf) -> Self {
798 Self(CString::new(value.to_str().unwrap_or_default()).unwrap())
799 }
800}
801
802impl From<&Path> for CLibPath {
803 fn from(value: &Path) -> Self {
804 Self(CString::new(value.to_str().unwrap_or_default()).unwrap())
805 }
806}
807
808impl From<&CLibPath> for PathBuf {
809 fn from(value: &CLibPath) -> Self {
810 PathBuf::from(value.0.to_str().unwrap())
811 }
812}
813
814fn output_files(path: &Path, prompt: &str, batch_size: i32) -> Vec<PathBuf> {
815 if batch_size == 1 {
816 vec![path.into()]
817 } else {
818 (1..=batch_size)
819 .map(|id| path.join(format!("{prompt}_{id}.png")))
820 .collect()
821 }
822}
823
824unsafe fn upscale(
825 upscale_repeats: i32,
826 upscaler_ctx: Option<*mut upscaler_ctx_t>,
827 data: sd_image_t,
828) -> Result<sd_image_t, DiffusionError> {
829 unsafe {
830 match upscaler_ctx {
831 Some(upscaler_ctx) => {
832 let upscale_factor = 4; let mut current_image = data;
834 for _ in 0..upscale_repeats {
835 let upscaled_image =
836 diffusion_rs_sys::upscale(upscaler_ctx, current_image, upscale_factor);
837
838 if upscaled_image.data.is_null() {
839 return Err(DiffusionError::Upscaler);
840 }
841
842 free(current_image.data as *mut c_void);
843 current_image = upscaled_image;
844 }
845 Ok(current_image)
846 }
847 None => Ok(data),
848 }
849 }
850}
851
852pub fn gen_img(config: &Config, model_config: &mut ModelConfig) -> Result<(), DiffusionError> {
854 let prompt: CLibString = CLibString::from(config.prompt.as_str());
855 let files = output_files(&config.output, &config.prompt, config.batch_count);
856 unsafe {
857 let sd_ctx = model_config.diffusion_ctx(true);
858 let upscaler_ctx = model_config.upscaler_ctx();
859 let init_image = sd_image_t {
860 width: 0,
861 height: 0,
862 channel: 3,
863 data: null_mut(),
864 };
865 let mask_image = sd_image_t {
866 width: config.width as u32,
867 height: config.height as u32,
868 channel: 1,
869 data: null_mut(),
870 };
871 let mut layers = config.skip_layer.clone();
872 let guidance = sd_guidance_params_t {
873 txt_cfg: config.cfg_scale,
874 img_cfg: config.cfg_scale,
875 distilled_guidance: config.guidance,
876 slg: sd_slg_params_t {
877 layers: layers.as_mut_ptr(),
878 layer_count: config.skip_layer.len(),
879 layer_start: config.skip_layer_start,
880 layer_end: config.skip_layer_end,
881 scale: config.slg_scale,
882 },
883 };
884 let scheduler = if model_config.scheduler == Scheduler::SCHEDULER_COUNT {
885 sd_get_default_scheduler(sd_ctx)
886 } else {
887 model_config.scheduler
888 };
889 let sample_method = if config.sampling_method == SampleMethod::SAMPLE_METHOD_COUNT {
890 sd_get_default_sample_method(sd_ctx)
891 } else {
892 config.sampling_method
893 };
894 let sample_params = sd_sample_params_t {
895 guidance,
896 sample_method,
897 sample_steps: config.steps,
898 eta: config.eta,
899 scheduler,
900 shifted_timestep: model_config.timestep_shift,
901 };
902 let control_image = sd_image_t {
903 width: 0,
904 height: 0,
905 channel: 3,
906 data: null_mut(),
907 };
908 let vae_tiling_params = sd_tiling_params_t {
909 enabled: model_config.vae_tiling,
910 tile_size_x: model_config.vae_tile_size.0,
911 tile_size_y: model_config.vae_tile_size.1,
912 target_overlap: model_config.vae_tile_overlap,
913 rel_size_x: model_config.vae_relative_tile_size.0,
914 rel_size_y: model_config.vae_relative_tile_size.1,
915 };
916 let pm_params = sd_pm_params_t {
917 id_images: null_mut(),
918 id_images_count: 0,
919 id_embed_path: model_config.pm_id_embed_path.as_ptr(),
920 style_strength: config.pm_style_strength,
921 };
922
923 unsafe extern "C" fn save_preview_local(
924 _step: ::std::os::raw::c_int,
925 _frame_count: ::std::os::raw::c_int,
926 frames: *mut sd_image_t,
927 _is_noisy: bool,
928 data: *mut ::std::os::raw::c_void,
929 ) {
930 unsafe {
931 let path = &*data.cast::<PathBuf>();
932 let _ = save_img(*frames, path);
933 }
934 }
935
936 if config.preview_mode != PreviewType::PREVIEW_NONE {
937 let data = &config.preview_output as *const PathBuf;
938
939 sd_set_preview_callback(
940 Some(save_preview_local),
941 config.preview_mode,
942 config.preview_interval,
943 !config.preview_noisy,
944 config.preview_noisy,
945 data as *mut c_void,
946 );
947 }
948
949 let easy_cache = sd_easycache_params_t {
950 enabled: model_config.easy_cache,
951 reuse_threshold: model_config.easy_cache_reuse_threshold,
952 start_percent: model_config.easy_cache_start_percent,
953 end_percent: model_config.easy_cache_end_percent,
954 };
955
956 let sd_img_gen_params = sd_img_gen_params_t {
957 prompt: prompt.as_ptr(),
958 negative_prompt: config.negative_prompt.as_ptr(),
959 clip_skip: config.clip_skip as i32,
960 init_image,
961 ref_images: null_mut(),
962 ref_images_count: 0,
963 increase_ref_index: false,
964 mask_image,
965 width: config.width,
966 height: config.height,
967 sample_params,
968 strength: config.strength,
969 seed: config.seed,
970 batch_count: config.batch_count,
971 control_image,
972 control_strength: config.control_strength,
973 pm_params,
974 vae_tiling_params,
975 auto_resize_ref_image: config.disable_auto_resize_ref_image,
976 easycache: easy_cache,
977 loras: model_config.lora_models.loras_t.as_ptr(),
978 lora_count: model_config.lora_models.loras_t.len() as u32,
979 };
980 let slice = diffusion_rs_sys::generate_image(sd_ctx, &sd_img_gen_params);
981 let ret = {
982 if slice.is_null() {
983 return Err(DiffusionError::Forward);
984 }
985 for (img, path) in slice::from_raw_parts(slice, config.batch_count as usize)
986 .iter()
987 .zip(files)
988 {
989 match upscale(model_config.upscale_repeats, upscaler_ctx, *img) {
990 Ok(img) => save_img(img, &path)?,
991 Err(err) => {
992 return Err(err);
993 }
994 }
995 }
996 Ok(())
997 };
998 free(slice as *mut c_void);
999 ret
1000 }
1001}
1002
1003fn save_img(img: sd_image_t, path: &Path) -> Result<(), DiffusionError> {
1004 let len = (img.width * img.height * img.channel) as usize;
1006 let buffer = unsafe { slice::from_raw_parts(img.data, len).to_vec() };
1007 let save_state = ImageBuffer::from_raw(img.width, img.height, buffer)
1008 .map(|img| RgbImage::from(img).save(path));
1009 if let Some(Err(err)) = save_state {
1010 return Err(DiffusionError::StoreImages(err));
1011 }
1012 Ok(())
1013}
1014
1015#[cfg(test)]
1016mod tests {
1017 use std::path::PathBuf;
1018
1019 use crate::{
1020 api::{ConfigBuilderError, ModelConfigBuilder},
1021 util::download_file_hf_hub,
1022 };
1023
1024 use super::{ConfigBuilder, gen_img};
1025
1026 #[test]
1027 fn test_required_args_txt2img() {
1028 assert!(ConfigBuilder::default().build().is_err());
1029 assert!(ModelConfigBuilder::default().build().is_err());
1030 ModelConfigBuilder::default()
1031 .model(PathBuf::from("./test.ckpt"))
1032 .build()
1033 .unwrap();
1034
1035 ConfigBuilder::default()
1036 .prompt("a lovely cat driving a sport car")
1037 .build()
1038 .unwrap();
1039
1040 assert!(matches!(
1041 ConfigBuilder::default()
1042 .prompt("a lovely cat driving a sport car")
1043 .batch_count(10)
1044 .build(),
1045 Err(ConfigBuilderError::ValidationError(_))
1046 ));
1047
1048 ConfigBuilder::default()
1049 .prompt("a lovely cat driving a sport car")
1050 .build()
1051 .unwrap();
1052
1053 ConfigBuilder::default()
1054 .prompt("a lovely duck drinking water from a bottle")
1055 .batch_count(2)
1056 .output(PathBuf::from("./"))
1057 .build()
1058 .unwrap();
1059 }
1060
1061 #[ignore]
1062 #[test]
1063 fn test_img_gen() {
1064 let model_path =
1065 download_file_hf_hub("CompVis/stable-diffusion-v-1-4-original", "sd-v1-4.ckpt")
1066 .unwrap();
1067
1068 let upscaler_path = download_file_hf_hub(
1069 "ximso/RealESRGAN_x4plus_anime_6B",
1070 "RealESRGAN_x4plus_anime_6B.pth",
1071 )
1072 .unwrap();
1073 let config = ConfigBuilder::default()
1074 .prompt("a lovely duck drinking water from a bottle")
1075 .output(PathBuf::from("./output_1.png"))
1076 .batch_count(1)
1077 .build()
1078 .unwrap();
1079 let mut model_config = ModelConfigBuilder::default()
1080 .model(model_path)
1081 .upscale_model(upscaler_path)
1082 .upscale_repeats(1)
1083 .build()
1084 .unwrap();
1085
1086 gen_img(&config, &mut model_config).unwrap();
1087 let config2 = ConfigBuilder::from(config.clone())
1088 .prompt("a lovely duck drinking water from a straw")
1089 .output(PathBuf::from("./output_2.png"))
1090 .build()
1091 .unwrap();
1092 gen_img(&config2, &mut model_config).unwrap();
1093
1094 let config3 = ConfigBuilder::from(config)
1095 .prompt("a lovely dog drinking water from a starbucks cup")
1096 .batch_count(2)
1097 .output(PathBuf::from("./"))
1098 .build()
1099 .unwrap();
1100
1101 gen_img(&config3, &mut model_config).unwrap();
1102 }
1103}