blazen_image_diffusion/
options.rs1use std::path::PathBuf;
4
5use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, Serialize, Deserialize, Default)]
30pub struct DiffusionOptions {
31 pub model_id: Option<String>,
35
36 pub device: Option<String>,
43
44 pub width: Option<u32>,
48
49 pub height: Option<u32>,
53
54 pub num_inference_steps: Option<u32>,
59
60 pub guidance_scale: Option<f32>,
66
67 #[serde(default)]
71 pub scheduler: DiffusionScheduler,
72
73 pub cache_dir: Option<PathBuf>,
78}
79
80#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
85pub enum DiffusionScheduler {
86 Euler,
88 #[default]
90 EulerA,
91 #[serde(rename = "DPM")]
93 Dpm,
94 #[serde(rename = "DDIM")]
96 Ddim,
97}
98
99impl std::fmt::Display for DiffusionScheduler {
100 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101 let name = match self {
102 Self::Euler => "euler",
103 Self::EulerA => "euler_a",
104 Self::Dpm => "dpm",
105 Self::Ddim => "ddim",
106 };
107 f.write_str(name)
108 }
109}
110
111#[cfg(test)]
112mod tests {
113 use super::*;
114
115 #[test]
116 fn default_scheduler_is_euler_a() {
117 assert_eq!(DiffusionScheduler::default(), DiffusionScheduler::EulerA);
118 }
119
120 #[test]
121 fn default_options_has_euler_a_scheduler() {
122 let opts = DiffusionOptions::default();
123 assert_eq!(opts.scheduler, DiffusionScheduler::EulerA);
124 assert!(opts.model_id.is_none());
125 assert!(opts.device.is_none());
126 assert!(opts.width.is_none());
127 assert!(opts.height.is_none());
128 assert!(opts.num_inference_steps.is_none());
129 assert!(opts.guidance_scale.is_none());
130 assert!(opts.cache_dir.is_none());
131 }
132
133 #[test]
134 fn struct_update_syntax_works() {
135 let opts = DiffusionOptions {
136 width: Some(1024),
137 height: Some(1024),
138 num_inference_steps: Some(30),
139 ..DiffusionOptions::default()
140 };
141 assert_eq!(opts.width, Some(1024));
142 assert_eq!(opts.height, Some(1024));
143 assert_eq!(opts.num_inference_steps, Some(30));
144 assert!(opts.model_id.is_none());
145 assert!(opts.device.is_none());
146 }
147
148 #[test]
149 fn display_impl() {
150 assert_eq!(DiffusionScheduler::Euler.to_string(), "euler");
151 assert_eq!(DiffusionScheduler::EulerA.to_string(), "euler_a");
152 assert_eq!(DiffusionScheduler::Dpm.to_string(), "dpm");
153 assert_eq!(DiffusionScheduler::Ddim.to_string(), "ddim");
154 }
155
156 #[test]
157 fn serde_roundtrip_options() {
158 let opts = DiffusionOptions {
159 model_id: Some("stabilityai/stable-diffusion-2-1".into()),
160 device: Some("cuda:0".into()),
161 width: Some(768),
162 height: Some(768),
163 num_inference_steps: Some(25),
164 guidance_scale: Some(7.5),
165 scheduler: DiffusionScheduler::Dpm,
166 cache_dir: Some(PathBuf::from("/tmp/diffusion-cache")),
167 };
168 let json = serde_json::to_string(&opts).expect("serialize");
169 let parsed: DiffusionOptions = serde_json::from_str(&json).expect("deserialize");
170 assert_eq!(
171 parsed.model_id.as_deref(),
172 Some("stabilityai/stable-diffusion-2-1")
173 );
174 assert_eq!(parsed.device.as_deref(), Some("cuda:0"));
175 assert_eq!(parsed.width, Some(768));
176 assert_eq!(parsed.height, Some(768));
177 assert_eq!(parsed.num_inference_steps, Some(25));
178 assert_eq!(parsed.guidance_scale, Some(7.5));
179 assert_eq!(parsed.scheduler, DiffusionScheduler::Dpm);
180 assert_eq!(
181 parsed.cache_dir.as_deref(),
182 Some(std::path::Path::new("/tmp/diffusion-cache"))
183 );
184 }
185
186 #[test]
187 fn serde_roundtrip_scheduler_enum() {
188 for scheduler in [
189 DiffusionScheduler::Euler,
190 DiffusionScheduler::EulerA,
191 DiffusionScheduler::Dpm,
192 DiffusionScheduler::Ddim,
193 ] {
194 let json = serde_json::to_string(&scheduler).expect("serialize");
195 let parsed: DiffusionScheduler = serde_json::from_str(&json).expect("deserialize");
196 assert_eq!(parsed, scheduler);
197 }
198 }
199}