use std::path::PathBuf;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct DiffusionOptions {
pub model_id: Option<String>,
pub device: Option<String>,
pub width: Option<u32>,
pub height: Option<u32>,
pub num_inference_steps: Option<u32>,
pub guidance_scale: Option<f32>,
#[serde(default)]
pub scheduler: DiffusionScheduler,
pub cache_dir: Option<PathBuf>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
pub enum DiffusionScheduler {
Euler,
#[default]
EulerA,
#[serde(rename = "DPM")]
Dpm,
#[serde(rename = "DDIM")]
Ddim,
}
impl std::fmt::Display for DiffusionScheduler {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let name = match self {
Self::Euler => "euler",
Self::EulerA => "euler_a",
Self::Dpm => "dpm",
Self::Ddim => "ddim",
};
f.write_str(name)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_scheduler_is_euler_a() {
assert_eq!(DiffusionScheduler::default(), DiffusionScheduler::EulerA);
}
#[test]
fn default_options_has_euler_a_scheduler() {
let opts = DiffusionOptions::default();
assert_eq!(opts.scheduler, DiffusionScheduler::EulerA);
assert!(opts.model_id.is_none());
assert!(opts.device.is_none());
assert!(opts.width.is_none());
assert!(opts.height.is_none());
assert!(opts.num_inference_steps.is_none());
assert!(opts.guidance_scale.is_none());
assert!(opts.cache_dir.is_none());
}
#[test]
fn struct_update_syntax_works() {
let opts = DiffusionOptions {
width: Some(1024),
height: Some(1024),
num_inference_steps: Some(30),
..DiffusionOptions::default()
};
assert_eq!(opts.width, Some(1024));
assert_eq!(opts.height, Some(1024));
assert_eq!(opts.num_inference_steps, Some(30));
assert!(opts.model_id.is_none());
assert!(opts.device.is_none());
}
#[test]
fn display_impl() {
assert_eq!(DiffusionScheduler::Euler.to_string(), "euler");
assert_eq!(DiffusionScheduler::EulerA.to_string(), "euler_a");
assert_eq!(DiffusionScheduler::Dpm.to_string(), "dpm");
assert_eq!(DiffusionScheduler::Ddim.to_string(), "ddim");
}
#[test]
fn serde_roundtrip_options() {
let opts = DiffusionOptions {
model_id: Some("stabilityai/stable-diffusion-2-1".into()),
device: Some("cuda:0".into()),
width: Some(768),
height: Some(768),
num_inference_steps: Some(25),
guidance_scale: Some(7.5),
scheduler: DiffusionScheduler::Dpm,
cache_dir: Some(PathBuf::from("/tmp/diffusion-cache")),
};
let json = serde_json::to_string(&opts).expect("serialize");
let parsed: DiffusionOptions = serde_json::from_str(&json).expect("deserialize");
assert_eq!(
parsed.model_id.as_deref(),
Some("stabilityai/stable-diffusion-2-1")
);
assert_eq!(parsed.device.as_deref(), Some("cuda:0"));
assert_eq!(parsed.width, Some(768));
assert_eq!(parsed.height, Some(768));
assert_eq!(parsed.num_inference_steps, Some(25));
assert_eq!(parsed.guidance_scale, Some(7.5));
assert_eq!(parsed.scheduler, DiffusionScheduler::Dpm);
assert_eq!(
parsed.cache_dir.as_deref(),
Some(std::path::Path::new("/tmp/diffusion-cache"))
);
}
#[test]
fn serde_roundtrip_scheduler_enum() {
for scheduler in [
DiffusionScheduler::Euler,
DiffusionScheduler::EulerA,
DiffusionScheduler::Dpm,
DiffusionScheduler::Ddim,
] {
let json = serde_json::to_string(&scheduler).expect("serialize");
let parsed: DiffusionScheduler = serde_json::from_str(&json).expect("deserialize");
assert_eq!(parsed, scheduler);
}
}
}