blazen_image_diffusion/
provider.rs1use std::fmt;
10
11use crate::DiffusionOptions;
12
13#[derive(Debug)]
15pub enum DiffusionError {
16 InvalidOptions(String),
18 ModelLoad(String),
20 Generation(String),
22 EngineNotAvailable,
26}
27
28impl fmt::Display for DiffusionError {
29 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30 match self {
31 Self::InvalidOptions(msg) => write!(f, "diffusion-rs invalid options: {msg}"),
32 Self::ModelLoad(msg) => write!(f, "diffusion-rs model load failed: {msg}"),
33 Self::Generation(msg) => write!(f, "diffusion-rs generation failed: {msg}"),
34 Self::EngineNotAvailable => f.write_str(
35 "diffusion-rs runtime is not linked -- rebuild blazen-image-diffusion \
36 with the `engine` feature (or a forwarding feature such as `cuda` / \
37 `metal`) to enable image generation",
38 ),
39 }
40 }
41}
42
43impl std::error::Error for DiffusionError {}
44
45pub struct DiffusionProvider {
52 options: DiffusionOptions,
54 #[cfg(feature = "engine")]
55 engine: tokio::sync::OnceCell<std::sync::Arc<crate::engine::Engine>>,
56}
57
58impl DiffusionProvider {
59 pub fn from_options(opts: DiffusionOptions) -> Result<Self, DiffusionError> {
69 if let Some(ref device) = opts.device
70 && device.is_empty()
71 {
72 return Err(DiffusionError::InvalidOptions(
73 "device must not be empty when specified".into(),
74 ));
75 }
76
77 if let Some(ref model_id) = opts.model_id
78 && model_id.is_empty()
79 {
80 return Err(DiffusionError::InvalidOptions(
81 "model_id must not be empty when specified".into(),
82 ));
83 }
84
85 if let Some(width) = opts.width
86 && width == 0
87 {
88 return Err(DiffusionError::InvalidOptions(
89 "width must be greater than zero".into(),
90 ));
91 }
92
93 if let Some(height) = opts.height
94 && height == 0
95 {
96 return Err(DiffusionError::InvalidOptions(
97 "height must be greater than zero".into(),
98 ));
99 }
100
101 if let Some(steps) = opts.num_inference_steps
102 && steps == 0
103 {
104 return Err(DiffusionError::InvalidOptions(
105 "num_inference_steps must be greater than zero".into(),
106 ));
107 }
108
109 if let Some(scale) = opts.guidance_scale
110 && scale <= 0.0
111 {
112 return Err(DiffusionError::InvalidOptions(
113 "guidance_scale must be positive".into(),
114 ));
115 }
116
117 Ok(Self {
118 options: opts,
119 #[cfg(feature = "engine")]
120 engine: tokio::sync::OnceCell::new(),
121 })
122 }
123
124 #[must_use]
126 pub fn device_str(&self) -> &str {
127 self.options.device.as_deref().unwrap_or("cpu")
128 }
129
130 #[must_use]
132 pub fn model_id(&self) -> &str {
133 self.options.model_id.as_deref().unwrap_or("sd-1.5")
134 }
135
136 #[must_use]
138 pub fn width(&self) -> u32 {
139 self.options.width.unwrap_or(512)
140 }
141
142 #[must_use]
144 pub fn height(&self) -> u32 {
145 self.options.height.unwrap_or(512)
146 }
147
148 #[must_use]
150 pub fn num_inference_steps(&self) -> u32 {
151 self.options.num_inference_steps.unwrap_or(20)
152 }
153
154 #[must_use]
156 pub fn guidance_scale(&self) -> f32 {
157 self.options.guidance_scale.unwrap_or(7.5)
158 }
159
160 #[must_use]
162 pub const fn scheduler(&self) -> crate::DiffusionScheduler {
163 self.options.scheduler
164 }
165
166 #[allow(clippy::unused_async)] pub async fn load(&self) -> Result<(), DiffusionError> {
178 #[cfg(feature = "engine")]
179 {
180 let opts = self.options.clone();
181 self.engine
182 .get_or_try_init(|| async move {
183 tokio::task::spawn_blocking(move || crate::engine::Engine::new(&opts))
184 .await
185 .map_err(|e| DiffusionError::ModelLoad(format!("join: {e}")))?
186 .map(std::sync::Arc::new)
187 })
188 .await?;
189 Ok(())
190 }
191 #[cfg(not(feature = "engine"))]
192 {
193 Err(DiffusionError::EngineNotAvailable)
194 }
195 }
196
197 #[allow(clippy::unused_async)]
211 pub async fn unload(&self) -> Result<(), DiffusionError> {
212 Ok(())
213 }
214
215 #[allow(clippy::unused_async)]
218 pub async fn is_loaded(&self) -> bool {
219 #[cfg(feature = "engine")]
220 {
221 self.engine.initialized()
222 }
223 #[cfg(not(feature = "engine"))]
224 {
225 false
226 }
227 }
228}
229
230#[cfg(feature = "engine")]
231impl DiffusionProvider {
232 pub async fn generate_image_inherent(
242 &self,
243 prompt: String,
244 negative_prompt: Option<String>,
245 width: Option<u32>,
246 height: Option<u32>,
247 ) -> Result<crate::engine::GeneratedImage, DiffusionError> {
248 let opts = self.options.clone();
250 let engine = self
251 .engine
252 .get_or_try_init(|| async move {
253 tokio::task::spawn_blocking(move || crate::engine::Engine::new(&opts))
254 .await
255 .map_err(|e| DiffusionError::ModelLoad(format!("join: {e}")))?
256 .map(std::sync::Arc::new)
257 })
258 .await?
259 .clone();
260
261 let w = width.unwrap_or_else(|| self.width());
262 let h = height.unwrap_or_else(|| self.height());
263 let steps = self.num_inference_steps();
264 let scale = self.guidance_scale();
265
266 tokio::task::spawn_blocking(move || {
267 engine.txt2img(&prompt, negative_prompt.as_deref(), w, h, steps, scale)
268 })
269 .await
270 .map_err(|e| DiffusionError::Generation(format!("join: {e}")))?
271 }
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277 use crate::{DiffusionOptions, DiffusionScheduler};
278
279 #[test]
280 fn from_options_with_defaults() {
281 let opts = DiffusionOptions::default();
282 let provider = DiffusionProvider::from_options(opts).expect("should succeed");
283 assert_eq!(provider.width(), 512);
284 assert_eq!(provider.height(), 512);
285 assert_eq!(provider.num_inference_steps(), 20);
286 assert!((provider.guidance_scale() - 7.5).abs() < f32::EPSILON);
287 assert_eq!(provider.scheduler(), DiffusionScheduler::EulerA);
288 }
289
290 #[cfg(feature = "engine")]
297 #[tokio::test]
298 #[ignore = "downloads an SD-Turbo diffusion model + generates an image"]
299 async fn smoke_generate_image() {
300 let opts = DiffusionOptions {
301 model_id: Some("sd-turbo".into()),
302 num_inference_steps: Some(1),
303 ..DiffusionOptions::default()
304 };
305 let provider = DiffusionProvider::from_options(opts).expect("options valid");
306 let image = provider
307 .generate_image_inherent("a red square".into(), None, Some(512), Some(512))
308 .await
309 .expect("image generation should succeed");
310 assert!(!image.bytes.is_empty(), "should produce non-empty image bytes");
311 assert!(
312 image.width > 0 && image.height > 0,
313 "image should have positive dimensions, got {}x{}",
314 image.width,
315 image.height
316 );
317 }
318
319 #[test]
320 fn from_options_with_custom_values() {
321 let opts = DiffusionOptions {
322 model_id: Some("stabilityai/stable-diffusion-2-1".into()),
323 width: Some(1024),
324 height: Some(768),
325 num_inference_steps: Some(30),
326 guidance_scale: Some(10.0),
327 scheduler: DiffusionScheduler::Dpm,
328 ..DiffusionOptions::default()
329 };
330 let provider = DiffusionProvider::from_options(opts).expect("should succeed");
331 assert_eq!(provider.width(), 1024);
332 assert_eq!(provider.height(), 768);
333 assert_eq!(provider.num_inference_steps(), 30);
334 assert!((provider.guidance_scale() - 10.0).abs() < f32::EPSILON);
335 assert_eq!(provider.scheduler(), DiffusionScheduler::Dpm);
336 }
337
338 #[test]
339 fn from_options_rejects_empty_device() {
340 let opts = DiffusionOptions {
341 device: Some(String::new()),
342 ..DiffusionOptions::default()
343 };
344 assert!(DiffusionProvider::from_options(opts).is_err());
345 }
346
347 #[test]
348 fn from_options_rejects_empty_model_id() {
349 let opts = DiffusionOptions {
350 model_id: Some(String::new()),
351 ..DiffusionOptions::default()
352 };
353 assert!(DiffusionProvider::from_options(opts).is_err());
354 }
355
356 #[test]
357 fn from_options_rejects_zero_width() {
358 let opts = DiffusionOptions {
359 width: Some(0),
360 ..DiffusionOptions::default()
361 };
362 assert!(DiffusionProvider::from_options(opts).is_err());
363 }
364
365 #[test]
366 fn from_options_rejects_zero_height() {
367 let opts = DiffusionOptions {
368 height: Some(0),
369 ..DiffusionOptions::default()
370 };
371 assert!(DiffusionProvider::from_options(opts).is_err());
372 }
373
374 #[test]
375 fn from_options_rejects_zero_steps() {
376 let opts = DiffusionOptions {
377 num_inference_steps: Some(0),
378 ..DiffusionOptions::default()
379 };
380 assert!(DiffusionProvider::from_options(opts).is_err());
381 }
382
383 #[test]
384 fn from_options_rejects_non_positive_guidance() {
385 let opts = DiffusionOptions {
386 guidance_scale: Some(0.0),
387 ..DiffusionOptions::default()
388 };
389 assert!(DiffusionProvider::from_options(opts).is_err());
390
391 let opts = DiffusionOptions {
392 guidance_scale: Some(-1.0),
393 ..DiffusionOptions::default()
394 };
395 assert!(DiffusionProvider::from_options(opts).is_err());
396 }
397
398 #[test]
399 fn from_options_accepts_valid_device() {
400 let opts = DiffusionOptions {
401 device: Some("cuda:0".into()),
402 ..DiffusionOptions::default()
403 };
404 let provider = DiffusionProvider::from_options(opts).expect("should succeed");
405 assert_eq!(provider.width(), 512);
406 }
407}