use clap::Parser;
use stable_diffusion::*;
use candle::{Device, Result};
pub fn device(cpu: bool) -> Result<Device> {
if cpu {
Ok(Device::Cpu)
} else if candle::utils::cuda_is_available() {
Ok(Device::new_cuda(0)?)
} else if candle::utils::metal_is_available() {
Ok(Device::new_metal(0)?)
} else {
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
{
println!(
"Running on CPU, to run on GPU(metal), build this example with `--features metal`"
);
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
{
println!("Running on CPU, to run on GPU, build this example with `--features cuda`");
}
Ok(Device::Cpu)
}
}
#[derive(Debug, Parser, Clone)]
#[command(author, version, about, long_about = None)]
pub struct Arguments {
#[arg(
long,
default_value = "A very realistic photo of a rusty robot walking on a sandy beach"
)]
prompt: String,
#[arg(long)]
output: Option<String>,
#[arg(long)]
repository: Option<String>,
#[arg(long, default_value = "")]
uncond_prompt: String,
#[arg(long)]
style_prompt: Option<String>,
#[arg(long)]
uncond_style_prompt: Option<String>,
#[arg(long)]
cpu: bool,
#[arg(long)]
height: Option<usize>,
#[arg(long)]
width: Option<usize>,
#[arg(long)]
n_steps: Option<usize>,
#[arg(long, default_value_t = 1)]
num_samples: i64,
#[arg(long, value_name = "FILE", default_value = "sd_final.png")]
final_image: String,
#[arg(long, value_enum, default_value = "v2-1")]
sd_version: StableDiffusionVersion,
#[arg(long)]
guidance_scale: Option<f64>,
#[arg(long, value_name = "FILE")]
img2img: Option<String>,
#[arg(long, default_value_t = 0.8)]
img2img_strength: f64,
}
fn image_preprocess<T: AsRef<std::path::Path>>(path: T) -> anyhow::Result<image::ImageBuffer<image::Rgb<u8>, Vec<u8>>> {
let img = image::io::Reader::open(path)?.decode()?;
let (height, width) = (img.height() as usize, img.width() as usize);
let height = height - height % 32;
let width = width - width % 32;
let img = img.resize_to_fill(
width as u32,
height as u32,
image::imageops::FilterType::CatmullRom,
);
Ok(img.to_rgb8())
}
fn view(output: &str) -> anyhow::Result<()> {
#[cfg(target_os = "windows")]
std::process::Command::new("explorer").arg(output).output()?;
#[cfg(target_os = "macos")]
std::process::Command::new("open").arg(output).output()?;
#[cfg(target_os = "linux")]
std::process::Command::new("xdg-open").arg(output).output()?;
Ok(())
}
#[derive(Debug, Clone, clap::ValueEnum, Copy, PartialEq, Eq)]
pub enum StableDiffusionVersion {
V1_5,
V2_1,
XL,
Turbo,
}
impl From<StableDiffusionVersion> for stable_diffusion::StableDiffusionVersion {
fn from(version: StableDiffusionVersion) -> Self {
match version {
StableDiffusionVersion::V1_5 => stable_diffusion::StableDiffusionVersion::V1_5,
StableDiffusionVersion::V2_1 => stable_diffusion::StableDiffusionVersion::V2_1,
StableDiffusionVersion::XL => stable_diffusion::StableDiffusionVersion::XL,
StableDiffusionVersion::Turbo => stable_diffusion::StableDiffusionVersion::Turbo,
}
}
}
impl Arguments {
pub fn execute(self) -> anyhow::Result<()> {
let args = self;
let device = device(args.cpu)?;
let output = args.output.as_ref().map(String::from).unwrap_or(String::from("output.png"));
let weights = StableDiffusionWeights::from_repository(args.sd_version.into(), args.repository, DType::F32);
let parameters = StableDiffusionParameters::new(weights, device, DType::F16)?;
let stable_diffusion = StableDiffusion::new(parameters)?;
let args = GenerationParameters::new(args.prompt.clone())
.with_width(args.width)
.with_height(args.height)
.with_uncond_prompt(args.uncond_prompt)
.with_style_prompt(Some(args.prompt))
.with_uncond_style_prompt(args.uncond_style_prompt)
.with_n_steps(args.n_steps)
.with_guidance_scale(args.guidance_scale)
.with_img2img(args.img2img.as_ref().and_then(|path| image_preprocess(path).ok()))
.with_img2img_strength(args.img2img_strength);
let image = stable_diffusion.generate(args)?;
image.save(&output)?;
view(&output)?;
Ok(())
}
}