blazen_image_diffusion/
provider.rs1use std::fmt;
7
8use crate::DiffusionOptions;
9
10#[derive(Debug)]
12pub enum DiffusionError {
13 InvalidOptions(String),
15 ModelLoad(String),
17 Generation(String),
19}
20
21impl fmt::Display for DiffusionError {
22 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23 match self {
24 Self::InvalidOptions(msg) => write!(f, "diffusion-rs invalid options: {msg}"),
25 Self::ModelLoad(msg) => write!(f, "diffusion-rs model load failed: {msg}"),
26 Self::Generation(msg) => write!(f, "diffusion-rs generation failed: {msg}"),
27 }
28 }
29}
30
31impl std::error::Error for DiffusionError {}
32
33pub struct DiffusionProvider {
38 #[allow(dead_code)]
40 options: DiffusionOptions,
41 }
43
44impl DiffusionProvider {
45 pub fn from_options(opts: DiffusionOptions) -> Result<Self, DiffusionError> {
55 if let Some(ref device) = opts.device
56 && device.is_empty()
57 {
58 return Err(DiffusionError::InvalidOptions(
59 "device must not be empty when specified".into(),
60 ));
61 }
62
63 if let Some(ref model_id) = opts.model_id
64 && model_id.is_empty()
65 {
66 return Err(DiffusionError::InvalidOptions(
67 "model_id must not be empty when specified".into(),
68 ));
69 }
70
71 if let Some(width) = opts.width
72 && width == 0
73 {
74 return Err(DiffusionError::InvalidOptions(
75 "width must be greater than zero".into(),
76 ));
77 }
78
79 if let Some(height) = opts.height
80 && height == 0
81 {
82 return Err(DiffusionError::InvalidOptions(
83 "height must be greater than zero".into(),
84 ));
85 }
86
87 if let Some(steps) = opts.num_inference_steps
88 && steps == 0
89 {
90 return Err(DiffusionError::InvalidOptions(
91 "num_inference_steps must be greater than zero".into(),
92 ));
93 }
94
95 if let Some(scale) = opts.guidance_scale
96 && scale <= 0.0
97 {
98 return Err(DiffusionError::InvalidOptions(
99 "guidance_scale must be positive".into(),
100 ));
101 }
102
103 Ok(Self { options: opts })
104 }
105
106 #[must_use]
108 pub fn width(&self) -> u32 {
109 self.options.width.unwrap_or(512)
110 }
111
112 #[must_use]
114 pub fn height(&self) -> u32 {
115 self.options.height.unwrap_or(512)
116 }
117
118 #[must_use]
120 pub fn num_inference_steps(&self) -> u32 {
121 self.options.num_inference_steps.unwrap_or(20)
122 }
123
124 #[must_use]
126 pub fn guidance_scale(&self) -> f32 {
127 self.options.guidance_scale.unwrap_or(7.5)
128 }
129
130 #[must_use]
132 pub const fn scheduler(&self) -> crate::DiffusionScheduler {
133 self.options.scheduler
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140 use crate::{DiffusionOptions, DiffusionScheduler};
141
142 #[test]
143 fn from_options_with_defaults() {
144 let opts = DiffusionOptions::default();
145 let provider = DiffusionProvider::from_options(opts).expect("should succeed");
146 assert_eq!(provider.width(), 512);
147 assert_eq!(provider.height(), 512);
148 assert_eq!(provider.num_inference_steps(), 20);
149 assert!((provider.guidance_scale() - 7.5).abs() < f32::EPSILON);
150 assert_eq!(provider.scheduler(), DiffusionScheduler::EulerA);
151 }
152
153 #[test]
154 fn from_options_with_custom_values() {
155 let opts = DiffusionOptions {
156 model_id: Some("stabilityai/stable-diffusion-2-1".into()),
157 width: Some(1024),
158 height: Some(768),
159 num_inference_steps: Some(30),
160 guidance_scale: Some(10.0),
161 scheduler: DiffusionScheduler::Dpm,
162 ..DiffusionOptions::default()
163 };
164 let provider = DiffusionProvider::from_options(opts).expect("should succeed");
165 assert_eq!(provider.width(), 1024);
166 assert_eq!(provider.height(), 768);
167 assert_eq!(provider.num_inference_steps(), 30);
168 assert!((provider.guidance_scale() - 10.0).abs() < f32::EPSILON);
169 assert_eq!(provider.scheduler(), DiffusionScheduler::Dpm);
170 }
171
172 #[test]
173 fn from_options_rejects_empty_device() {
174 let opts = DiffusionOptions {
175 device: Some(String::new()),
176 ..DiffusionOptions::default()
177 };
178 assert!(DiffusionProvider::from_options(opts).is_err());
179 }
180
181 #[test]
182 fn from_options_rejects_empty_model_id() {
183 let opts = DiffusionOptions {
184 model_id: Some(String::new()),
185 ..DiffusionOptions::default()
186 };
187 assert!(DiffusionProvider::from_options(opts).is_err());
188 }
189
190 #[test]
191 fn from_options_rejects_zero_width() {
192 let opts = DiffusionOptions {
193 width: Some(0),
194 ..DiffusionOptions::default()
195 };
196 assert!(DiffusionProvider::from_options(opts).is_err());
197 }
198
199 #[test]
200 fn from_options_rejects_zero_height() {
201 let opts = DiffusionOptions {
202 height: Some(0),
203 ..DiffusionOptions::default()
204 };
205 assert!(DiffusionProvider::from_options(opts).is_err());
206 }
207
208 #[test]
209 fn from_options_rejects_zero_steps() {
210 let opts = DiffusionOptions {
211 num_inference_steps: Some(0),
212 ..DiffusionOptions::default()
213 };
214 assert!(DiffusionProvider::from_options(opts).is_err());
215 }
216
217 #[test]
218 fn from_options_rejects_non_positive_guidance() {
219 let opts = DiffusionOptions {
220 guidance_scale: Some(0.0),
221 ..DiffusionOptions::default()
222 };
223 assert!(DiffusionProvider::from_options(opts).is_err());
224
225 let opts = DiffusionOptions {
226 guidance_scale: Some(-1.0),
227 ..DiffusionOptions::default()
228 };
229 assert!(DiffusionProvider::from_options(opts).is_err());
230 }
231
232 #[test]
233 fn from_options_accepts_valid_device() {
234 let opts = DiffusionOptions {
235 device: Some("cuda:0".into()),
236 ..DiffusionOptions::default()
237 };
238 let provider = DiffusionProvider::from_options(opts).expect("should succeed");
239 assert_eq!(provider.width(), 512);
240 }
241}