Skip to main content

blazen_image_diffusion/
provider.rs

1//! The [`DiffusionProvider`] type.
2//!
3//! Without the `engine` cargo feature this is a pure-stub provider: it
4//! validates options and exposes accessors but cannot actually run image
5//! generation. With `engine`, the inherent [`DiffusionProvider::generate_image`]
6//! method lazily initialises a [`crate::engine::Engine`] and runs the
7//! stable-diffusion.cpp pipeline through `diffusion-rs`.
8
9use std::fmt;
10
11use crate::DiffusionOptions;
12
13/// Error type for diffusion-rs operations.
14#[derive(Debug)]
15pub enum DiffusionError {
16    /// A required option was missing or invalid.
17    InvalidOptions(String),
18    /// The model file could not be downloaded or found.
19    ModelLoad(String),
20    /// An image generation operation failed.
21    Generation(String),
22    /// The crate was built without the `engine` feature so the underlying
23    /// diffusion-rs runtime is not linked. Surface this distinctly from
24    /// generic generation failures so bindings can map it to a clear error.
25    EngineNotAvailable,
26}
27
28impl fmt::Display for DiffusionError {
29    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30        match self {
31            Self::InvalidOptions(msg) => write!(f, "diffusion-rs invalid options: {msg}"),
32            Self::ModelLoad(msg) => write!(f, "diffusion-rs model load failed: {msg}"),
33            Self::Generation(msg) => write!(f, "diffusion-rs generation failed: {msg}"),
34            Self::EngineNotAvailable => f.write_str(
35                "diffusion-rs runtime is not linked -- rebuild blazen-image-diffusion \
36                 with the `engine` feature (or a forwarding feature such as `cuda` / \
37                 `metal`) to enable image generation",
38            ),
39        }
40    }
41}
42
43impl std::error::Error for DiffusionError {}
44
45/// A local image generation provider backed by [`diffusion-rs`](https://github.com/newfla/diffusion-rs).
46///
47/// Constructed via [`DiffusionProvider::from_options`]. With the `engine`
48/// feature on, [`DiffusionProvider::generate_image`] lazily initialises the
49/// underlying pipeline on first call and runs the synchronous
50/// stable-diffusion.cpp generation inside [`tokio::task::spawn_blocking`].
51pub struct DiffusionProvider {
52    /// Full options preserved for deferred engine initialisation.
53    options: DiffusionOptions,
54    #[cfg(feature = "engine")]
55    engine: tokio::sync::OnceCell<std::sync::Arc<crate::engine::Engine>>,
56}
57
58impl DiffusionProvider {
59    /// Create a new provider from the given options.
60    ///
61    /// This currently validates the options and stores them. The actual
62    /// diffusion-rs pipeline will be initialised in Phase 5.3.
63    ///
64    /// # Errors
65    ///
66    /// Returns [`DiffusionError::InvalidOptions`] if any option is present but
67    /// invalid (e.g. an empty device string, zero dimensions, or zero steps).
68    pub fn from_options(opts: DiffusionOptions) -> Result<Self, DiffusionError> {
69        if let Some(ref device) = opts.device
70            && device.is_empty()
71        {
72            return Err(DiffusionError::InvalidOptions(
73                "device must not be empty when specified".into(),
74            ));
75        }
76
77        if let Some(ref model_id) = opts.model_id
78            && model_id.is_empty()
79        {
80            return Err(DiffusionError::InvalidOptions(
81                "model_id must not be empty when specified".into(),
82            ));
83        }
84
85        if let Some(width) = opts.width
86            && width == 0
87        {
88            return Err(DiffusionError::InvalidOptions(
89                "width must be greater than zero".into(),
90            ));
91        }
92
93        if let Some(height) = opts.height
94            && height == 0
95        {
96            return Err(DiffusionError::InvalidOptions(
97                "height must be greater than zero".into(),
98            ));
99        }
100
101        if let Some(steps) = opts.num_inference_steps
102            && steps == 0
103        {
104            return Err(DiffusionError::InvalidOptions(
105                "num_inference_steps must be greater than zero".into(),
106            ));
107        }
108
109        if let Some(scale) = opts.guidance_scale
110            && scale <= 0.0
111        {
112            return Err(DiffusionError::InvalidOptions(
113                "guidance_scale must be positive".into(),
114            ));
115        }
116
117        Ok(Self {
118            options: opts,
119            #[cfg(feature = "engine")]
120            engine: tokio::sync::OnceCell::new(),
121        })
122    }
123
124    /// The resolved device string (`"cpu"` when unset).
125    #[must_use]
126    pub fn device_str(&self) -> &str {
127        self.options.device.as_deref().unwrap_or("cpu")
128    }
129
130    /// The configured model identifier (or `"sd-1.5"` when unset).
131    #[must_use]
132    pub fn model_id(&self) -> &str {
133        self.options.model_id.as_deref().unwrap_or("sd-1.5")
134    }
135
136    /// The resolved width (user-specified or default 512).
137    #[must_use]
138    pub fn width(&self) -> u32 {
139        self.options.width.unwrap_or(512)
140    }
141
142    /// The resolved height (user-specified or default 512).
143    #[must_use]
144    pub fn height(&self) -> u32 {
145        self.options.height.unwrap_or(512)
146    }
147
148    /// The resolved number of inference steps (user-specified or default 20).
149    #[must_use]
150    pub fn num_inference_steps(&self) -> u32 {
151        self.options.num_inference_steps.unwrap_or(20)
152    }
153
154    /// The resolved guidance scale (user-specified or default 7.5).
155    #[must_use]
156    pub fn guidance_scale(&self) -> f32 {
157        self.options.guidance_scale.unwrap_or(7.5)
158    }
159
160    /// The scheduler configured for this provider.
161    #[must_use]
162    pub const fn scheduler(&self) -> crate::DiffusionScheduler {
163        self.options.scheduler
164    }
165
166    /// Eagerly warm the underlying diffusion-rs pipeline.
167    ///
168    /// Without the `engine` feature this returns
169    /// [`DiffusionError::EngineNotAvailable`]. With it, this is idempotent
170    /// and safe to call from multiple tasks concurrently.
171    ///
172    /// # Errors
173    ///
174    /// Returns [`DiffusionError::ModelLoad`] if pipeline construction or the
175    /// output-directory bootstrap fails.
176    #[allow(clippy::unused_async)] // async to mirror LocalModel and the engine path
177    pub async fn load(&self) -> Result<(), DiffusionError> {
178        #[cfg(feature = "engine")]
179        {
180            let opts = self.options.clone();
181            self.engine
182                .get_or_try_init(|| async move {
183                    tokio::task::spawn_blocking(move || crate::engine::Engine::new(&opts))
184                        .await
185                        .map_err(|e| DiffusionError::ModelLoad(format!("join: {e}")))?
186                        .map(std::sync::Arc::new)
187                })
188                .await?;
189            Ok(())
190        }
191        #[cfg(not(feature = "engine"))]
192        {
193            Err(DiffusionError::EngineNotAvailable)
194        }
195    }
196
197    /// Best-effort unload. Always succeeds.
198    ///
199    /// `diffusion-rs` does not expose a "drop weights" entry point and the
200    /// cached pipeline lives behind a [`tokio::sync::OnceCell`] shared via
201    /// `&self`, so we cannot evict it from interior mutability alone.
202    /// Callers that require strict resource release should `drop` the
203    /// entire [`DiffusionProvider`] and construct a fresh one.
204    ///
205    /// # Errors
206    ///
207    /// Never errors today; the `Result` is kept to match the
208    /// [`blazen_llm::LocalModel::unload`] trait signature so the
209    /// bridge can forward without contortions.
210    #[allow(clippy::unused_async)]
211    pub async fn unload(&self) -> Result<(), DiffusionError> {
212        Ok(())
213    }
214
215    /// `true` if a pipeline has been warmed via [`Self::load`] or the first
216    /// generate call.
217    #[allow(clippy::unused_async)]
218    pub async fn is_loaded(&self) -> bool {
219        #[cfg(feature = "engine")]
220        {
221            self.engine.initialized()
222        }
223        #[cfg(not(feature = "engine"))]
224        {
225            false
226        }
227    }
228}
229
230#[cfg(feature = "engine")]
231impl DiffusionProvider {
232    /// Inherent text-to-image entry point used by the
233    /// [`blazen_llm::ImageGeneration`] trait impl in
234    /// `blazen-llm::backends::diffusion`. Kept as an inherent method so the
235    /// engine type does not leak across the `blazen-llm` boundary.
236    ///
237    /// # Errors
238    ///
239    /// Returns a [`DiffusionError`] if the engine cannot be initialised or
240    /// generation fails.
241    pub async fn generate_image_inherent(
242        &self,
243        prompt: String,
244        negative_prompt: Option<String>,
245        width: Option<u32>,
246        height: Option<u32>,
247    ) -> Result<crate::engine::GeneratedImage, DiffusionError> {
248        // Lazy init via the shared OnceCell.
249        let opts = self.options.clone();
250        let engine = self
251            .engine
252            .get_or_try_init(|| async move {
253                tokio::task::spawn_blocking(move || crate::engine::Engine::new(&opts))
254                    .await
255                    .map_err(|e| DiffusionError::ModelLoad(format!("join: {e}")))?
256                    .map(std::sync::Arc::new)
257            })
258            .await?
259            .clone();
260
261        let w = width.unwrap_or_else(|| self.width());
262        let h = height.unwrap_or_else(|| self.height());
263        let steps = self.num_inference_steps();
264        let scale = self.guidance_scale();
265
266        tokio::task::spawn_blocking(move || {
267            engine.txt2img(&prompt, negative_prompt.as_deref(), w, h, steps, scale)
268        })
269        .await
270        .map_err(|e| DiffusionError::Generation(format!("join: {e}")))?
271    }
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277    use crate::{DiffusionOptions, DiffusionScheduler};
278
279    #[test]
280    fn from_options_with_defaults() {
281        let opts = DiffusionOptions::default();
282        let provider = DiffusionProvider::from_options(opts).expect("should succeed");
283        assert_eq!(provider.width(), 512);
284        assert_eq!(provider.height(), 512);
285        assert_eq!(provider.num_inference_steps(), 20);
286        assert!((provider.guidance_scale() - 7.5).abs() < f32::EPSILON);
287        assert_eq!(provider.scheduler(), DiffusionScheduler::EulerA);
288    }
289
290    /// GPU/inference smoke: download a fast diffusion model and generate one
291    /// image. Gated on `engine` (the real generate path) and `#[ignore]`'d so
292    /// the beastpc-e2e `--run-ignored only` step runs it. Uses SD-Turbo at a
293    /// single step for speed. Runs on CPU unless the binary enables
294    /// `diffusion-rs/cuda` (the crate-level `cuda` feature is a marker — see
295    /// Cargo.toml).
296    #[cfg(feature = "engine")]
297    #[tokio::test]
298    #[ignore = "downloads an SD-Turbo diffusion model + generates an image"]
299    async fn smoke_generate_image() {
300        let opts = DiffusionOptions {
301            model_id: Some("sd-turbo".into()),
302            num_inference_steps: Some(1),
303            ..DiffusionOptions::default()
304        };
305        let provider = DiffusionProvider::from_options(opts).expect("options valid");
306        let image = provider
307            .generate_image_inherent("a red square".into(), None, Some(512), Some(512))
308            .await
309            .expect("image generation should succeed");
310        assert!(!image.bytes.is_empty(), "should produce non-empty image bytes");
311        assert!(
312            image.width > 0 && image.height > 0,
313            "image should have positive dimensions, got {}x{}",
314            image.width,
315            image.height
316        );
317    }
318
319    #[test]
320    fn from_options_with_custom_values() {
321        let opts = DiffusionOptions {
322            model_id: Some("stabilityai/stable-diffusion-2-1".into()),
323            width: Some(1024),
324            height: Some(768),
325            num_inference_steps: Some(30),
326            guidance_scale: Some(10.0),
327            scheduler: DiffusionScheduler::Dpm,
328            ..DiffusionOptions::default()
329        };
330        let provider = DiffusionProvider::from_options(opts).expect("should succeed");
331        assert_eq!(provider.width(), 1024);
332        assert_eq!(provider.height(), 768);
333        assert_eq!(provider.num_inference_steps(), 30);
334        assert!((provider.guidance_scale() - 10.0).abs() < f32::EPSILON);
335        assert_eq!(provider.scheduler(), DiffusionScheduler::Dpm);
336    }
337
338    #[test]
339    fn from_options_rejects_empty_device() {
340        let opts = DiffusionOptions {
341            device: Some(String::new()),
342            ..DiffusionOptions::default()
343        };
344        assert!(DiffusionProvider::from_options(opts).is_err());
345    }
346
347    #[test]
348    fn from_options_rejects_empty_model_id() {
349        let opts = DiffusionOptions {
350            model_id: Some(String::new()),
351            ..DiffusionOptions::default()
352        };
353        assert!(DiffusionProvider::from_options(opts).is_err());
354    }
355
356    #[test]
357    fn from_options_rejects_zero_width() {
358        let opts = DiffusionOptions {
359            width: Some(0),
360            ..DiffusionOptions::default()
361        };
362        assert!(DiffusionProvider::from_options(opts).is_err());
363    }
364
365    #[test]
366    fn from_options_rejects_zero_height() {
367        let opts = DiffusionOptions {
368            height: Some(0),
369            ..DiffusionOptions::default()
370        };
371        assert!(DiffusionProvider::from_options(opts).is_err());
372    }
373
374    #[test]
375    fn from_options_rejects_zero_steps() {
376        let opts = DiffusionOptions {
377            num_inference_steps: Some(0),
378            ..DiffusionOptions::default()
379        };
380        assert!(DiffusionProvider::from_options(opts).is_err());
381    }
382
383    #[test]
384    fn from_options_rejects_non_positive_guidance() {
385        let opts = DiffusionOptions {
386            guidance_scale: Some(0.0),
387            ..DiffusionOptions::default()
388        };
389        assert!(DiffusionProvider::from_options(opts).is_err());
390
391        let opts = DiffusionOptions {
392            guidance_scale: Some(-1.0),
393            ..DiffusionOptions::default()
394        };
395        assert!(DiffusionProvider::from_options(opts).is_err());
396    }
397
398    #[test]
399    fn from_options_accepts_valid_device() {
400        let opts = DiffusionOptions {
401            device: Some("cuda:0".into()),
402            ..DiffusionOptions::default()
403        };
404        let provider = DiffusionProvider::from_options(opts).expect("should succeed");
405        assert_eq!(provider.width(), 512);
406    }
407}