Skip to main content

blazen_image_diffusion/
provider.rs

1//! The [`DiffusionProvider`] type -- stub for Phase 5.1-5.2.
2//!
3//! The actual `ImageGeneration` trait implementation will be added in Phase 5.3
4//! once the diffusion-rs engine API is wired up.
5
6use std::fmt;
7
8use crate::DiffusionOptions;
9
10/// Error type for diffusion-rs operations.
11#[derive(Debug)]
12pub enum DiffusionError {
13    /// A required option was missing or invalid.
14    InvalidOptions(String),
15    /// The model file could not be downloaded or found.
16    ModelLoad(String),
17    /// An image generation operation failed.
18    Generation(String),
19}
20
21impl fmt::Display for DiffusionError {
22    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23        match self {
24            Self::InvalidOptions(msg) => write!(f, "diffusion-rs invalid options: {msg}"),
25            Self::ModelLoad(msg) => write!(f, "diffusion-rs model load failed: {msg}"),
26            Self::Generation(msg) => write!(f, "diffusion-rs generation failed: {msg}"),
27        }
28    }
29}
30
31impl std::error::Error for DiffusionError {}
32
33/// A local image generation provider backed by [`diffusion-rs`](https://github.com/huggingface/diffusion-rs).
34///
35/// Constructed via [`DiffusionProvider::from_options`]. The `ImageGeneration`
36/// trait implementation will be added in Phase 5.3.
37pub struct DiffusionProvider {
38    /// Full options preserved for deferred engine initialisation.
39    #[allow(dead_code)]
40    options: DiffusionOptions,
41    // pipeline: ... -- will hold the diffusion-rs pipeline once wired (Phase 5.3)
42}
43
44impl DiffusionProvider {
45    /// Create a new provider from the given options.
46    ///
47    /// This currently validates the options and stores them. The actual
48    /// diffusion-rs pipeline will be initialised in Phase 5.3.
49    ///
50    /// # Errors
51    ///
52    /// Returns [`DiffusionError::InvalidOptions`] if any option is present but
53    /// invalid (e.g. an empty device string, zero dimensions, or zero steps).
54    pub fn from_options(opts: DiffusionOptions) -> Result<Self, DiffusionError> {
55        if let Some(ref device) = opts.device
56            && device.is_empty()
57        {
58            return Err(DiffusionError::InvalidOptions(
59                "device must not be empty when specified".into(),
60            ));
61        }
62
63        if let Some(ref model_id) = opts.model_id
64            && model_id.is_empty()
65        {
66            return Err(DiffusionError::InvalidOptions(
67                "model_id must not be empty when specified".into(),
68            ));
69        }
70
71        if let Some(width) = opts.width
72            && width == 0
73        {
74            return Err(DiffusionError::InvalidOptions(
75                "width must be greater than zero".into(),
76            ));
77        }
78
79        if let Some(height) = opts.height
80            && height == 0
81        {
82            return Err(DiffusionError::InvalidOptions(
83                "height must be greater than zero".into(),
84            ));
85        }
86
87        if let Some(steps) = opts.num_inference_steps
88            && steps == 0
89        {
90            return Err(DiffusionError::InvalidOptions(
91                "num_inference_steps must be greater than zero".into(),
92            ));
93        }
94
95        if let Some(scale) = opts.guidance_scale
96            && scale <= 0.0
97        {
98            return Err(DiffusionError::InvalidOptions(
99                "guidance_scale must be positive".into(),
100            ));
101        }
102
103        Ok(Self { options: opts })
104    }
105
106    /// The resolved width (user-specified or default 512).
107    #[must_use]
108    pub fn width(&self) -> u32 {
109        self.options.width.unwrap_or(512)
110    }
111
112    /// The resolved height (user-specified or default 512).
113    #[must_use]
114    pub fn height(&self) -> u32 {
115        self.options.height.unwrap_or(512)
116    }
117
118    /// The resolved number of inference steps (user-specified or default 20).
119    #[must_use]
120    pub fn num_inference_steps(&self) -> u32 {
121        self.options.num_inference_steps.unwrap_or(20)
122    }
123
124    /// The resolved guidance scale (user-specified or default 7.5).
125    #[must_use]
126    pub fn guidance_scale(&self) -> f32 {
127        self.options.guidance_scale.unwrap_or(7.5)
128    }
129
130    /// The scheduler configured for this provider.
131    #[must_use]
132    pub const fn scheduler(&self) -> crate::DiffusionScheduler {
133        self.options.scheduler
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140    use crate::{DiffusionOptions, DiffusionScheduler};
141
142    #[test]
143    fn from_options_with_defaults() {
144        let opts = DiffusionOptions::default();
145        let provider = DiffusionProvider::from_options(opts).expect("should succeed");
146        assert_eq!(provider.width(), 512);
147        assert_eq!(provider.height(), 512);
148        assert_eq!(provider.num_inference_steps(), 20);
149        assert!((provider.guidance_scale() - 7.5).abs() < f32::EPSILON);
150        assert_eq!(provider.scheduler(), DiffusionScheduler::EulerA);
151    }
152
153    #[test]
154    fn from_options_with_custom_values() {
155        let opts = DiffusionOptions {
156            model_id: Some("stabilityai/stable-diffusion-2-1".into()),
157            width: Some(1024),
158            height: Some(768),
159            num_inference_steps: Some(30),
160            guidance_scale: Some(10.0),
161            scheduler: DiffusionScheduler::Dpm,
162            ..DiffusionOptions::default()
163        };
164        let provider = DiffusionProvider::from_options(opts).expect("should succeed");
165        assert_eq!(provider.width(), 1024);
166        assert_eq!(provider.height(), 768);
167        assert_eq!(provider.num_inference_steps(), 30);
168        assert!((provider.guidance_scale() - 10.0).abs() < f32::EPSILON);
169        assert_eq!(provider.scheduler(), DiffusionScheduler::Dpm);
170    }
171
172    #[test]
173    fn from_options_rejects_empty_device() {
174        let opts = DiffusionOptions {
175            device: Some(String::new()),
176            ..DiffusionOptions::default()
177        };
178        assert!(DiffusionProvider::from_options(opts).is_err());
179    }
180
181    #[test]
182    fn from_options_rejects_empty_model_id() {
183        let opts = DiffusionOptions {
184            model_id: Some(String::new()),
185            ..DiffusionOptions::default()
186        };
187        assert!(DiffusionProvider::from_options(opts).is_err());
188    }
189
190    #[test]
191    fn from_options_rejects_zero_width() {
192        let opts = DiffusionOptions {
193            width: Some(0),
194            ..DiffusionOptions::default()
195        };
196        assert!(DiffusionProvider::from_options(opts).is_err());
197    }
198
199    #[test]
200    fn from_options_rejects_zero_height() {
201        let opts = DiffusionOptions {
202            height: Some(0),
203            ..DiffusionOptions::default()
204        };
205        assert!(DiffusionProvider::from_options(opts).is_err());
206    }
207
208    #[test]
209    fn from_options_rejects_zero_steps() {
210        let opts = DiffusionOptions {
211            num_inference_steps: Some(0),
212            ..DiffusionOptions::default()
213        };
214        assert!(DiffusionProvider::from_options(opts).is_err());
215    }
216
217    #[test]
218    fn from_options_rejects_non_positive_guidance() {
219        let opts = DiffusionOptions {
220            guidance_scale: Some(0.0),
221            ..DiffusionOptions::default()
222        };
223        assert!(DiffusionProvider::from_options(opts).is_err());
224
225        let opts = DiffusionOptions {
226            guidance_scale: Some(-1.0),
227            ..DiffusionOptions::default()
228        };
229        assert!(DiffusionProvider::from_options(opts).is_err());
230    }
231
232    #[test]
233    fn from_options_accepts_valid_device() {
234        let opts = DiffusionOptions {
235            device: Some("cuda:0".into()),
236            ..DiffusionOptions::default()
237        };
238        let provider = DiffusionProvider::from_options(opts).expect("should succeed");
239        assert_eq!(provider.width(), 512);
240    }
241}