use std::ffi::c_char;
use std::ffi::c_void;
use std::ffi::CString;
use std::path::Path;
use std::path::PathBuf;
use std::ptr::null;
use std::slice;
use derive_builder::Builder;
use diffusion_rs_sys::free_upscaler_ctx;
use diffusion_rs_sys::new_upscaler_ctx;
use diffusion_rs_sys::sd_image_t;
use diffusion_rs_sys::upscaler_ctx_t;
use libc::free;
use thiserror::Error;
use diffusion_rs_sys::free_sd_ctx;
use diffusion_rs_sys::new_sd_ctx;
use diffusion_rs_sys::sd_ctx_t;
use diffusion_rs_sys::stbi_write_png_custom;
pub use diffusion_rs_sys::rng_type_t as RngFunction;
pub use diffusion_rs_sys::sample_method_t as SampleMethod;
pub use diffusion_rs_sys::schedule_t as Schedule;
pub use diffusion_rs_sys::sd_type_t as WeightType;
#[non_exhaustive]
#[derive(Error, Debug)]
pub enum DiffusionError {
#[error("The underling stablediffusion.cpp function returned NULL")]
Forward,
#[error("The underling stbi_write_image function returned 0 while saving image {0}/{1})")]
StoreImages(usize, i32),
#[error("The underling upsclaer model returned a NULL image")]
Upscaler,
}
#[repr(i32)]
#[non_exhaustive]
#[derive(Debug, Default, Copy, Clone, Hash, PartialEq, Eq)]
pub enum ClipSkip {
#[default]
Unspecified = 0,
None = 1,
OneLayer = 2,
}
#[derive(Builder, Debug, Clone)]
#[builder(setter(into, strip_option), build_fn(validate = "Self::validate"))]
pub struct Config {
#[builder(default = "num_cpus::get_physical() as i32", setter(custom))]
n_threads: i32,
#[builder(default = "Default::default()")]
model: CLibPath,
#[builder(default = "Default::default()")]
diffusion_model: CLibPath,
#[builder(default = "Default::default()")]
clip_l: CLibPath,
#[builder(default = "Default::default()")]
clip_g: CLibPath,
#[builder(default = "Default::default()")]
t5xxl: CLibPath,
#[builder(default = "Default::default()")]
vae: CLibPath,
#[builder(default = "Default::default()")]
taesd: CLibPath,
#[builder(default = "Default::default()")]
control_net: CLibPath,
#[builder(default = "Default::default()")]
embeddings: CLibPath,
#[builder(default = "Default::default()")]
stacked_id_embd: CLibPath,
#[builder(default = "Default::default()")]
input_id_images: CLibPath,
#[builder(default = "false")]
normalize_input: bool,
#[builder(default = "Default::default()")]
upscale_model: Option<CLibPath>,
#[builder(default = "0")]
upscale_repeats: i32,
#[builder(default = "WeightType::SD_TYPE_COUNT")]
weight_type: WeightType,
#[builder(default = "Default::default()")]
lora_model: CLibPath,
#[builder(default = "Default::default()")]
init_img: CLibPath,
#[builder(default = "Default::default()")]
control_image: CLibPath,
#[builder(default = "PathBuf::from(\"./output.png\")")]
output: PathBuf,
prompt: CLibString,
#[builder(default = "\"\".into()")]
negative_prompt: CLibString,
#[builder(default = "7.0")]
cfg_scale: f32,
#[builder(default = "3.5")]
guidance: f32,
#[builder(default = "0.75")]
strength: f32,
#[builder(default = "20.0")]
style_ratio: f32,
#[builder(default = "0.9")]
control_strength: f32,
#[builder(default = "512")]
height: i32,
#[builder(default = "512")]
width: i32,
#[builder(default = "SampleMethod::EULER_A")]
sampling_method: SampleMethod,
#[builder(default = "20")]
steps: i32,
#[builder(default = "RngFunction::CUDA_RNG")]
rng: RngFunction,
#[builder(default = "42")]
seed: i64,
#[builder(default = "1")]
batch_count: i32,
#[builder(default = "Schedule::DEFAULT")]
schedule: Schedule,
#[builder(default = "ClipSkip::Unspecified")]
clip_skip: ClipSkip,
#[builder(default = "false")]
vae_tiling: bool,
#[builder(default = "false")]
vae_on_cpu: bool,
#[builder(default = "false")]
clip_on_cpu: bool,
#[builder(default = "false")]
control_net_cpu: bool,
#[builder(default = "false")]
canny: bool,
}
impl ConfigBuilder {
pub fn n_threads(&mut self, value: i32) -> &mut Self {
self.n_threads = if value > 0 {
Some(value)
} else {
Some(num_cpus::get_physical() as i32)
};
self
}
fn validate(&self) -> Result<(), ConfigBuilderError> {
self.validate_model()?;
self.validate_output_dir()
}
fn validate_model(&self) -> Result<(), ConfigBuilderError> {
self.model
.as_ref()
.or(self.diffusion_model.as_ref())
.map(|_| ())
.ok_or(ConfigBuilderError::UninitializedField(
"Model OR DiffusionModel must be valorized",
))
}
fn validate_output_dir(&self) -> Result<(), ConfigBuilderError> {
let is_dir = self.output.as_ref().is_some_and(|val| val.is_dir());
let multiple_items = self.batch_count.as_ref().is_some_and(|val| *val > 1);
if is_dir == multiple_items {
Ok(())
} else {
Err(ConfigBuilderError::ValidationError(
"When batch_count > 0, ouput should point to folder and viceversa".to_owned(),
))
}
}
}
impl Config {
unsafe fn build_sd_ctx(&self, vae_decode_only: bool) -> *mut sd_ctx_t {
new_sd_ctx(
self.model.as_ptr(),
self.clip_l.as_ptr(),
self.clip_g.as_ptr(),
self.t5xxl.as_ptr(),
self.diffusion_model.as_ptr(),
self.vae.as_ptr(),
self.taesd.as_ptr(),
self.control_net.as_ptr(),
self.lora_model.as_ptr(),
self.embeddings.as_ptr(),
self.stacked_id_embd.as_ptr(),
vae_decode_only,
self.vae_tiling,
true,
self.n_threads,
self.weight_type,
self.rng,
self.schedule,
self.clip_on_cpu,
self.control_net_cpu,
self.vae_on_cpu,
)
}
unsafe fn upscaler_ctx(&self) -> Option<*mut upscaler_ctx_t> {
if self.upscale_model.is_none() || self.upscale_repeats == 0 {
None
} else {
let upscaler = new_upscaler_ctx(
self.upscale_model.as_ref().unwrap().as_ptr(),
self.n_threads,
self.weight_type,
);
Some(upscaler)
}
}
}
#[derive(Debug, Clone, Default)]
struct CLibString(CString);
impl CLibString {
fn as_ptr(&self) -> *const c_char {
self.0.as_ptr()
}
}
impl From<&str> for CLibString {
fn from(value: &str) -> Self {
Self(CString::new(value).unwrap())
}
}
impl From<String> for CLibString {
fn from(value: String) -> Self {
Self(CString::new(value).unwrap())
}
}
#[derive(Debug, Clone, Default)]
struct CLibPath(CString);
impl CLibPath {
fn as_ptr(&self) -> *const c_char {
self.0.as_ptr()
}
}
impl From<PathBuf> for CLibPath {
fn from(value: PathBuf) -> Self {
Self(CString::new(value.to_str().unwrap_or_default()).unwrap())
}
}
impl From<&Path> for CLibPath {
fn from(value: &Path) -> Self {
Self(CString::new(value.to_str().unwrap_or_default()).unwrap())
}
}
fn output_files(path: PathBuf, batch_size: i32) -> Vec<CLibPath> {
if batch_size == 1 {
vec![path.into()]
} else {
(1..=batch_size)
.map(|id| path.join(format!("output_{id}.png")).into())
.collect()
}
}
unsafe fn upscale(
upscale_repeats: i32,
upscaler_ctx: Option<*mut upscaler_ctx_t>,
data: sd_image_t,
) -> Result<sd_image_t, DiffusionError> {
match upscaler_ctx {
Some(upscaler_ctx) => {
let upscale_factor = 4; let mut current_image = data;
for _ in 0..upscale_repeats {
let upscaled_image =
diffusion_rs_sys::upscale(upscaler_ctx, current_image, upscale_factor);
if upscaled_image.data.is_null() {
return Err(DiffusionError::Upscaler);
}
free(current_image.data as *mut c_void);
current_image = upscaled_image;
}
Ok(current_image)
}
None => Ok(data),
}
}
pub fn txt2img(config: Config) -> Result<(), DiffusionError> {
unsafe {
let sd_ctx = config.build_sd_ctx(true);
let upscaler_ctx = config.upscaler_ctx();
let res = {
let slice = diffusion_rs_sys::txt2img(
sd_ctx,
config.prompt.as_ptr(),
config.negative_prompt.as_ptr(),
config.clip_skip as i32,
config.cfg_scale,
config.guidance,
config.width,
config.height,
config.sampling_method,
config.steps,
config.seed,
config.batch_count,
null(),
config.control_strength,
config.style_ratio,
config.normalize_input,
config.input_id_images.as_ptr(),
);
if slice.is_null() {
return Err(DiffusionError::Forward);
}
let files = output_files(config.output, config.batch_count);
for (id, (img, path)) in slice::from_raw_parts(slice, config.batch_count as usize)
.iter()
.zip(files)
.enumerate()
{
match upscale(config.upscale_repeats, upscaler_ctx, *img) {
Ok(img) => {
let status = stbi_write_png_custom(
path.as_ptr(),
img.width as i32,
img.height as i32,
img.channel as i32,
img.data as *const c_void,
0,
);
if status == 0 {
return Err(DiffusionError::StoreImages(id, config.batch_count));
}
}
Err(err) => {
return Err(err);
}
}
}
free(slice as *mut c_void);
Ok(())
};
free_sd_ctx(sd_ctx);
if let Some(upscaler_ctx) = upscaler_ctx {
free_upscaler_ctx(upscaler_ctx);
}
res
}
}
#[cfg(test)]
mod tests {
use std::path::PathBuf;
use crate::{api::ConfigBuilderError, util::download_file_hf_hub};
use super::{txt2img, ConfigBuilder};
#[test]
fn test_required_args_txt2img() {
assert!(ConfigBuilder::default().build().is_err());
assert!(ConfigBuilder::default()
.model(PathBuf::from("./test.ckpt"))
.build()
.is_err());
assert!(ConfigBuilder::default()
.prompt("a lovely cat driving a sport car")
.build()
.is_err());
assert!(matches!(
ConfigBuilder::default()
.model(PathBuf::from("./test.ckpt"))
.prompt("a lovely cat driving a sport car")
.batch_count(10)
.build(),
Err(ConfigBuilderError::ValidationError(_))
));
ConfigBuilder::default()
.model(PathBuf::from("./test.ckpt"))
.prompt("a lovely cat driving a sport car")
.build()
.unwrap();
ConfigBuilder::default()
.model(PathBuf::from("./test.ckpt"))
.prompt("a lovely duck drinking water from a bottle")
.batch_count(2)
.output(PathBuf::from("./"))
.build()
.unwrap();
}
#[ignore]
#[test]
fn test_txt2img() {
let model_path =
download_file_hf_hub("CompVis/stable-diffusion-v-1-4-original", "sd-v1-4.ckpt")
.unwrap();
let upscaler_path = download_file_hf_hub(
"ximso/RealESRGAN_x4plus_anime_6B",
"RealESRGAN_x4plus_anime_6B.pth",
)
.unwrap();
let config = ConfigBuilder::default()
.model(model_path)
.prompt("a lovely duck drinking water from a bottle")
.output(PathBuf::from("./output_1.png"))
.upscale_model(upscaler_path)
.upscale_repeats(1)
.batch_count(1)
.build()
.unwrap();
txt2img(config).unwrap();
}
}