use std::fmt;
use crate::DiffusionOptions;
#[derive(Debug)]
pub enum DiffusionError {
InvalidOptions(String),
ModelLoad(String),
Generation(String),
}
impl fmt::Display for DiffusionError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InvalidOptions(msg) => write!(f, "diffusion-rs invalid options: {msg}"),
Self::ModelLoad(msg) => write!(f, "diffusion-rs model load failed: {msg}"),
Self::Generation(msg) => write!(f, "diffusion-rs generation failed: {msg}"),
}
}
}
impl std::error::Error for DiffusionError {}
pub struct DiffusionProvider {
#[allow(dead_code)]
options: DiffusionOptions,
}
impl DiffusionProvider {
pub fn from_options(opts: DiffusionOptions) -> Result<Self, DiffusionError> {
if let Some(ref device) = opts.device
&& device.is_empty()
{
return Err(DiffusionError::InvalidOptions(
"device must not be empty when specified".into(),
));
}
if let Some(ref model_id) = opts.model_id
&& model_id.is_empty()
{
return Err(DiffusionError::InvalidOptions(
"model_id must not be empty when specified".into(),
));
}
if let Some(width) = opts.width
&& width == 0
{
return Err(DiffusionError::InvalidOptions(
"width must be greater than zero".into(),
));
}
if let Some(height) = opts.height
&& height == 0
{
return Err(DiffusionError::InvalidOptions(
"height must be greater than zero".into(),
));
}
if let Some(steps) = opts.num_inference_steps
&& steps == 0
{
return Err(DiffusionError::InvalidOptions(
"num_inference_steps must be greater than zero".into(),
));
}
if let Some(scale) = opts.guidance_scale
&& scale <= 0.0
{
return Err(DiffusionError::InvalidOptions(
"guidance_scale must be positive".into(),
));
}
Ok(Self { options: opts })
}
#[must_use]
pub fn width(&self) -> u32 {
self.options.width.unwrap_or(512)
}
#[must_use]
pub fn height(&self) -> u32 {
self.options.height.unwrap_or(512)
}
#[must_use]
pub fn num_inference_steps(&self) -> u32 {
self.options.num_inference_steps.unwrap_or(20)
}
#[must_use]
pub fn guidance_scale(&self) -> f32 {
self.options.guidance_scale.unwrap_or(7.5)
}
#[must_use]
pub const fn scheduler(&self) -> crate::DiffusionScheduler {
self.options.scheduler
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{DiffusionOptions, DiffusionScheduler};
#[test]
fn from_options_with_defaults() {
let opts = DiffusionOptions::default();
let provider = DiffusionProvider::from_options(opts).expect("should succeed");
assert_eq!(provider.width(), 512);
assert_eq!(provider.height(), 512);
assert_eq!(provider.num_inference_steps(), 20);
assert!((provider.guidance_scale() - 7.5).abs() < f32::EPSILON);
assert_eq!(provider.scheduler(), DiffusionScheduler::EulerA);
}
#[test]
fn from_options_with_custom_values() {
let opts = DiffusionOptions {
model_id: Some("stabilityai/stable-diffusion-2-1".into()),
width: Some(1024),
height: Some(768),
num_inference_steps: Some(30),
guidance_scale: Some(10.0),
scheduler: DiffusionScheduler::Dpm,
..DiffusionOptions::default()
};
let provider = DiffusionProvider::from_options(opts).expect("should succeed");
assert_eq!(provider.width(), 1024);
assert_eq!(provider.height(), 768);
assert_eq!(provider.num_inference_steps(), 30);
assert!((provider.guidance_scale() - 10.0).abs() < f32::EPSILON);
assert_eq!(provider.scheduler(), DiffusionScheduler::Dpm);
}
#[test]
fn from_options_rejects_empty_device() {
let opts = DiffusionOptions {
device: Some(String::new()),
..DiffusionOptions::default()
};
assert!(DiffusionProvider::from_options(opts).is_err());
}
#[test]
fn from_options_rejects_empty_model_id() {
let opts = DiffusionOptions {
model_id: Some(String::new()),
..DiffusionOptions::default()
};
assert!(DiffusionProvider::from_options(opts).is_err());
}
#[test]
fn from_options_rejects_zero_width() {
let opts = DiffusionOptions {
width: Some(0),
..DiffusionOptions::default()
};
assert!(DiffusionProvider::from_options(opts).is_err());
}
#[test]
fn from_options_rejects_zero_height() {
let opts = DiffusionOptions {
height: Some(0),
..DiffusionOptions::default()
};
assert!(DiffusionProvider::from_options(opts).is_err());
}
#[test]
fn from_options_rejects_zero_steps() {
let opts = DiffusionOptions {
num_inference_steps: Some(0),
..DiffusionOptions::default()
};
assert!(DiffusionProvider::from_options(opts).is_err());
}
#[test]
fn from_options_rejects_non_positive_guidance() {
let opts = DiffusionOptions {
guidance_scale: Some(0.0),
..DiffusionOptions::default()
};
assert!(DiffusionProvider::from_options(opts).is_err());
let opts = DiffusionOptions {
guidance_scale: Some(-1.0),
..DiffusionOptions::default()
};
assert!(DiffusionProvider::from_options(opts).is_err());
}
#[test]
fn from_options_accepts_valid_device() {
let opts = DiffusionOptions {
device: Some("cuda:0".into()),
..DiffusionOptions::default()
};
let provider = DiffusionProvider::from_options(opts).expect("should succeed");
assert_eq!(provider.width(), 512);
}
}