use std::fmt;
use crate::DiffusionOptions;
#[derive(Debug)]
pub enum DiffusionError {
InvalidOptions(String),
ModelLoad(String),
Generation(String),
EngineNotAvailable,
}
impl fmt::Display for DiffusionError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InvalidOptions(msg) => write!(f, "diffusion-rs invalid options: {msg}"),
Self::ModelLoad(msg) => write!(f, "diffusion-rs model load failed: {msg}"),
Self::Generation(msg) => write!(f, "diffusion-rs generation failed: {msg}"),
Self::EngineNotAvailable => f.write_str(
"diffusion-rs runtime is not linked -- rebuild blazen-image-diffusion \
with the `engine` feature (or a forwarding feature such as `cuda` / \
`metal`) to enable image generation",
),
}
}
}
impl std::error::Error for DiffusionError {}
pub struct DiffusionProvider {
options: DiffusionOptions,
#[cfg(feature = "engine")]
engine: tokio::sync::OnceCell<std::sync::Arc<crate::engine::Engine>>,
}
impl DiffusionProvider {
pub fn from_options(opts: DiffusionOptions) -> Result<Self, DiffusionError> {
if let Some(ref device) = opts.device
&& device.is_empty()
{
return Err(DiffusionError::InvalidOptions(
"device must not be empty when specified".into(),
));
}
if let Some(ref model_id) = opts.model_id
&& model_id.is_empty()
{
return Err(DiffusionError::InvalidOptions(
"model_id must not be empty when specified".into(),
));
}
if let Some(width) = opts.width
&& width == 0
{
return Err(DiffusionError::InvalidOptions(
"width must be greater than zero".into(),
));
}
if let Some(height) = opts.height
&& height == 0
{
return Err(DiffusionError::InvalidOptions(
"height must be greater than zero".into(),
));
}
if let Some(steps) = opts.num_inference_steps
&& steps == 0
{
return Err(DiffusionError::InvalidOptions(
"num_inference_steps must be greater than zero".into(),
));
}
if let Some(scale) = opts.guidance_scale
&& scale <= 0.0
{
return Err(DiffusionError::InvalidOptions(
"guidance_scale must be positive".into(),
));
}
Ok(Self {
options: opts,
#[cfg(feature = "engine")]
engine: tokio::sync::OnceCell::new(),
})
}
#[must_use]
pub fn device_str(&self) -> &str {
self.options.device.as_deref().unwrap_or("cpu")
}
#[must_use]
pub fn model_id(&self) -> &str {
self.options.model_id.as_deref().unwrap_or("sd-1.5")
}
#[must_use]
pub fn width(&self) -> u32 {
self.options.width.unwrap_or(512)
}
#[must_use]
pub fn height(&self) -> u32 {
self.options.height.unwrap_or(512)
}
#[must_use]
pub fn num_inference_steps(&self) -> u32 {
self.options.num_inference_steps.unwrap_or(20)
}
#[must_use]
pub fn guidance_scale(&self) -> f32 {
self.options.guidance_scale.unwrap_or(7.5)
}
#[must_use]
pub const fn scheduler(&self) -> crate::DiffusionScheduler {
self.options.scheduler
}
#[allow(clippy::unused_async)] pub async fn load(&self) -> Result<(), DiffusionError> {
#[cfg(feature = "engine")]
{
let opts = self.options.clone();
self.engine
.get_or_try_init(|| async move {
tokio::task::spawn_blocking(move || crate::engine::Engine::new(&opts))
.await
.map_err(|e| DiffusionError::ModelLoad(format!("join: {e}")))?
.map(std::sync::Arc::new)
})
.await?;
Ok(())
}
#[cfg(not(feature = "engine"))]
{
Err(DiffusionError::EngineNotAvailable)
}
}
#[allow(clippy::unused_async)]
pub async fn unload(&self) -> Result<(), DiffusionError> {
Ok(())
}
#[allow(clippy::unused_async)]
pub async fn is_loaded(&self) -> bool {
#[cfg(feature = "engine")]
{
self.engine.initialized()
}
#[cfg(not(feature = "engine"))]
{
false
}
}
}
#[cfg(feature = "engine")]
impl DiffusionProvider {
pub async fn generate_image_inherent(
&self,
prompt: String,
negative_prompt: Option<String>,
width: Option<u32>,
height: Option<u32>,
) -> Result<crate::engine::GeneratedImage, DiffusionError> {
let opts = self.options.clone();
let engine = self
.engine
.get_or_try_init(|| async move {
tokio::task::spawn_blocking(move || crate::engine::Engine::new(&opts))
.await
.map_err(|e| DiffusionError::ModelLoad(format!("join: {e}")))?
.map(std::sync::Arc::new)
})
.await?
.clone();
let w = width.unwrap_or_else(|| self.width());
let h = height.unwrap_or_else(|| self.height());
let steps = self.num_inference_steps();
let scale = self.guidance_scale();
tokio::task::spawn_blocking(move || {
engine.txt2img(&prompt, negative_prompt.as_deref(), w, h, steps, scale)
})
.await
.map_err(|e| DiffusionError::Generation(format!("join: {e}")))?
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{DiffusionOptions, DiffusionScheduler};
#[test]
fn from_options_with_defaults() {
let opts = DiffusionOptions::default();
let provider = DiffusionProvider::from_options(opts).expect("should succeed");
assert_eq!(provider.width(), 512);
assert_eq!(provider.height(), 512);
assert_eq!(provider.num_inference_steps(), 20);
assert!((provider.guidance_scale() - 7.5).abs() < f32::EPSILON);
assert_eq!(provider.scheduler(), DiffusionScheduler::EulerA);
}
#[cfg(feature = "engine")]
#[tokio::test]
#[ignore = "downloads an SD-Turbo diffusion model + generates an image"]
async fn smoke_generate_image() {
let opts = DiffusionOptions {
model_id: Some("sd-turbo".into()),
num_inference_steps: Some(1),
..DiffusionOptions::default()
};
let provider = DiffusionProvider::from_options(opts).expect("options valid");
let image = provider
.generate_image_inherent("a red square".into(), None, Some(512), Some(512))
.await
.expect("image generation should succeed");
assert!(!image.bytes.is_empty(), "should produce non-empty image bytes");
assert!(
image.width > 0 && image.height > 0,
"image should have positive dimensions, got {}x{}",
image.width,
image.height
);
}
#[test]
fn from_options_with_custom_values() {
let opts = DiffusionOptions {
model_id: Some("stabilityai/stable-diffusion-2-1".into()),
width: Some(1024),
height: Some(768),
num_inference_steps: Some(30),
guidance_scale: Some(10.0),
scheduler: DiffusionScheduler::Dpm,
..DiffusionOptions::default()
};
let provider = DiffusionProvider::from_options(opts).expect("should succeed");
assert_eq!(provider.width(), 1024);
assert_eq!(provider.height(), 768);
assert_eq!(provider.num_inference_steps(), 30);
assert!((provider.guidance_scale() - 10.0).abs() < f32::EPSILON);
assert_eq!(provider.scheduler(), DiffusionScheduler::Dpm);
}
#[test]
fn from_options_rejects_empty_device() {
let opts = DiffusionOptions {
device: Some(String::new()),
..DiffusionOptions::default()
};
assert!(DiffusionProvider::from_options(opts).is_err());
}
#[test]
fn from_options_rejects_empty_model_id() {
let opts = DiffusionOptions {
model_id: Some(String::new()),
..DiffusionOptions::default()
};
assert!(DiffusionProvider::from_options(opts).is_err());
}
#[test]
fn from_options_rejects_zero_width() {
let opts = DiffusionOptions {
width: Some(0),
..DiffusionOptions::default()
};
assert!(DiffusionProvider::from_options(opts).is_err());
}
#[test]
fn from_options_rejects_zero_height() {
let opts = DiffusionOptions {
height: Some(0),
..DiffusionOptions::default()
};
assert!(DiffusionProvider::from_options(opts).is_err());
}
#[test]
fn from_options_rejects_zero_steps() {
let opts = DiffusionOptions {
num_inference_steps: Some(0),
..DiffusionOptions::default()
};
assert!(DiffusionProvider::from_options(opts).is_err());
}
#[test]
fn from_options_rejects_non_positive_guidance() {
let opts = DiffusionOptions {
guidance_scale: Some(0.0),
..DiffusionOptions::default()
};
assert!(DiffusionProvider::from_options(opts).is_err());
let opts = DiffusionOptions {
guidance_scale: Some(-1.0),
..DiffusionOptions::default()
};
assert!(DiffusionProvider::from_options(opts).is_err());
}
#[test]
fn from_options_accepts_valid_device() {
let opts = DiffusionOptions {
device: Some("cuda:0".into()),
..DiffusionOptions::default()
};
let provider = DiffusionProvider::from_options(opts).expect("should succeed");
assert_eq!(provider.width(), 512);
}
}