datasynth_core/diffusion/
backend.rs1use serde::{Deserialize, Serialize};
2
3use super::schedule::NoiseSchedule;
4
5pub trait DiffusionBackend: Send + Sync {
7 fn name(&self) -> &str;
9 fn forward(&self, x: &[Vec<f64>], t: usize) -> Vec<Vec<f64>>;
11 fn reverse(&self, x_t: &[Vec<f64>], t: usize) -> Vec<Vec<f64>>;
13 fn generate(&self, n_samples: usize, n_features: usize, seed: u64) -> Vec<Vec<f64>>;
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct DiffusionConfig {
20 #[serde(default = "default_n_steps")]
22 pub n_steps: usize,
23 #[serde(default)]
25 pub schedule: NoiseScheduleType,
26 #[serde(default)]
28 pub seed: u64,
29}
30
31fn default_n_steps() -> usize {
32 1000
33}
34
35impl Default for DiffusionConfig {
36 fn default() -> Self {
37 Self {
38 n_steps: default_n_steps(),
39 schedule: NoiseScheduleType::default(),
40 seed: 0,
41 }
42 }
43}
44
45impl DiffusionConfig {
46 pub fn build_schedule(&self) -> NoiseSchedule {
48 NoiseSchedule::new(&self.schedule, self.n_steps)
49 }
50}
51
52#[derive(Debug, Clone, Default, Serialize, Deserialize)]
54#[serde(rename_all = "snake_case")]
55pub enum NoiseScheduleType {
56 #[default]
57 Linear,
58 Cosine,
59 Sigmoid,
60}
61
62#[cfg(test)]
63#[allow(clippy::unwrap_used)]
64mod tests {
65 use super::*;
66
67 #[test]
68 fn test_diffusion_config_default() {
69 let config = DiffusionConfig::default();
70 assert_eq!(config.n_steps, 1000);
71 assert!(matches!(config.schedule, NoiseScheduleType::Linear));
72 }
73
74 #[test]
75 fn test_diffusion_config_serde() {
76 let config = DiffusionConfig {
77 n_steps: 500,
78 schedule: NoiseScheduleType::Cosine,
79 seed: 42,
80 };
81 let json = serde_json::to_string(&config).unwrap();
82 let parsed: DiffusionConfig = serde_json::from_str(&json).unwrap();
83 assert_eq!(parsed.n_steps, 500);
84 assert_eq!(parsed.seed, 42);
85 }
86
87 #[test]
88 fn test_build_schedule() {
89 let config = DiffusionConfig {
90 n_steps: 100,
91 schedule: NoiseScheduleType::Cosine,
92 seed: 0,
93 };
94 let schedule = config.build_schedule();
95 assert_eq!(schedule.n_steps(), 100);
96 }
97}