Skip to main content

blazen_image_diffusion/
options.rs

1//! Configuration options for the diffusion-rs local image generation backend.
2
3use std::path::PathBuf;
4
5use serde::{Deserialize, Serialize};
6
7/// Options for constructing a [`DiffusionProvider`](crate::DiffusionProvider).
8///
9/// All fields are optional and have sensible defaults. The scheduler defaults
10/// to [`DiffusionScheduler::EulerA`] and image dimensions default to 512x512.
11///
12/// # Examples
13///
14/// ```
15/// use blazen_image_diffusion::{DiffusionOptions, DiffusionScheduler};
16///
17/// // Use defaults (512x512, EulerA scheduler, 20 steps)
18/// let opts = DiffusionOptions::default();
19/// assert_eq!(opts.scheduler, DiffusionScheduler::EulerA);
20///
21/// // Override specific fields
22/// let opts = DiffusionOptions {
23///     width: Some(1024),
24///     height: Some(1024),
25///     num_inference_steps: Some(30),
26///     ..DiffusionOptions::default()
27/// };
28/// ```
29#[derive(Debug, Clone, Serialize, Deserialize, Default)]
30pub struct DiffusionOptions {
31    /// `HuggingFace` model repository ID (e.g. `"stabilityai/stable-diffusion-2-1"`).
32    ///
33    /// When `None`, a sensible default model will be selected in Phase 5.3.
34    pub model_id: Option<String>,
35
36    /// Hardware device specifier string (e.g. `"cpu"`, `"cuda:0"`, `"metal"`).
37    ///
38    /// Accepts the same format strings as `blazen_llm::Device::parse`:
39    /// `"cpu"`, `"cuda"`, `"cuda:N"`, `"metal"`.
40    ///
41    /// When `None`, defaults to `"cpu"`.
42    pub device: Option<String>,
43
44    /// Output image width in pixels.
45    ///
46    /// When `None`, defaults to 512.
47    pub width: Option<u32>,
48
49    /// Output image height in pixels.
50    ///
51    /// When `None`, defaults to 512.
52    pub height: Option<u32>,
53
54    /// Number of denoising steps to run.
55    ///
56    /// More steps generally produce higher quality images at the cost of
57    /// longer generation time. When `None`, defaults to 20.
58    pub num_inference_steps: Option<u32>,
59
60    /// Classifier-free guidance scale.
61    ///
62    /// Higher values make the output more closely follow the prompt but may
63    /// reduce diversity. Typical values range from 5.0 to 15.0.
64    /// When `None`, defaults to 7.5.
65    pub guidance_scale: Option<f32>,
66
67    /// The noise scheduler to use during the diffusion process.
68    ///
69    /// Defaults to [`DiffusionScheduler::EulerA`].
70    #[serde(default)]
71    pub scheduler: DiffusionScheduler,
72
73    /// Path to cache downloaded models.
74    ///
75    /// When `None`, falls back to `blazen-model-cache`'s default cache
76    /// directory (`$BLAZEN_CACHE_DIR` or `~/.cache/blazen/models`).
77    pub cache_dir: Option<PathBuf>,
78}
79
80/// Noise schedulers available for the diffusion process.
81///
82/// Different schedulers trade off between generation speed and output quality.
83/// [`EulerA`](Self::EulerA) is a good default for most use cases.
84#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
85pub enum DiffusionScheduler {
86    /// Euler discrete scheduler.
87    Euler,
88    /// Euler ancestral discrete scheduler (stochastic, good default).
89    #[default]
90    EulerA,
91    /// DPM-Solver++ multistep scheduler (fast, high quality).
92    #[serde(rename = "DPM")]
93    Dpm,
94    /// Denoising Diffusion Implicit Models scheduler.
95    #[serde(rename = "DDIM")]
96    Ddim,
97}
98
99impl std::fmt::Display for DiffusionScheduler {
100    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101        let name = match self {
102            Self::Euler => "euler",
103            Self::EulerA => "euler_a",
104            Self::Dpm => "dpm",
105            Self::Ddim => "ddim",
106        };
107        f.write_str(name)
108    }
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114
115    #[test]
116    fn default_scheduler_is_euler_a() {
117        assert_eq!(DiffusionScheduler::default(), DiffusionScheduler::EulerA);
118    }
119
120    #[test]
121    fn default_options_has_euler_a_scheduler() {
122        let opts = DiffusionOptions::default();
123        assert_eq!(opts.scheduler, DiffusionScheduler::EulerA);
124        assert!(opts.model_id.is_none());
125        assert!(opts.device.is_none());
126        assert!(opts.width.is_none());
127        assert!(opts.height.is_none());
128        assert!(opts.num_inference_steps.is_none());
129        assert!(opts.guidance_scale.is_none());
130        assert!(opts.cache_dir.is_none());
131    }
132
133    #[test]
134    fn struct_update_syntax_works() {
135        let opts = DiffusionOptions {
136            width: Some(1024),
137            height: Some(1024),
138            num_inference_steps: Some(30),
139            ..DiffusionOptions::default()
140        };
141        assert_eq!(opts.width, Some(1024));
142        assert_eq!(opts.height, Some(1024));
143        assert_eq!(opts.num_inference_steps, Some(30));
144        assert!(opts.model_id.is_none());
145        assert!(opts.device.is_none());
146    }
147
148    #[test]
149    fn display_impl() {
150        assert_eq!(DiffusionScheduler::Euler.to_string(), "euler");
151        assert_eq!(DiffusionScheduler::EulerA.to_string(), "euler_a");
152        assert_eq!(DiffusionScheduler::Dpm.to_string(), "dpm");
153        assert_eq!(DiffusionScheduler::Ddim.to_string(), "ddim");
154    }
155
156    #[test]
157    fn serde_roundtrip_options() {
158        let opts = DiffusionOptions {
159            model_id: Some("stabilityai/stable-diffusion-2-1".into()),
160            device: Some("cuda:0".into()),
161            width: Some(768),
162            height: Some(768),
163            num_inference_steps: Some(25),
164            guidance_scale: Some(7.5),
165            scheduler: DiffusionScheduler::Dpm,
166            cache_dir: Some(PathBuf::from("/tmp/diffusion-cache")),
167        };
168        let json = serde_json::to_string(&opts).expect("serialize");
169        let parsed: DiffusionOptions = serde_json::from_str(&json).expect("deserialize");
170        assert_eq!(
171            parsed.model_id.as_deref(),
172            Some("stabilityai/stable-diffusion-2-1")
173        );
174        assert_eq!(parsed.device.as_deref(), Some("cuda:0"));
175        assert_eq!(parsed.width, Some(768));
176        assert_eq!(parsed.height, Some(768));
177        assert_eq!(parsed.num_inference_steps, Some(25));
178        assert_eq!(parsed.guidance_scale, Some(7.5));
179        assert_eq!(parsed.scheduler, DiffusionScheduler::Dpm);
180        assert_eq!(
181            parsed.cache_dir.as_deref(),
182            Some(std::path::Path::new("/tmp/diffusion-cache"))
183        );
184    }
185
186    #[test]
187    fn serde_roundtrip_scheduler_enum() {
188        for scheduler in [
189            DiffusionScheduler::Euler,
190            DiffusionScheduler::EulerA,
191            DiffusionScheduler::Dpm,
192            DiffusionScheduler::Ddim,
193        ] {
194            let json = serde_json::to_string(&scheduler).expect("serialize");
195            let parsed: DiffusionScheduler = serde_json::from_str(&json).expect("deserialize");
196            assert_eq!(parsed, scheduler);
197        }
198    }
199}