datasynth_core/diffusion/
backend.rs1use serde::{Deserialize, Serialize};
2
3use super::schedule::NoiseSchedule;
4
5pub trait DiffusionBackend: Send + Sync {
7 fn name(&self) -> &str;
9 fn forward(&self, x: &[Vec<f64>], t: usize) -> Vec<Vec<f64>>;
11 fn reverse(&self, x_t: &[Vec<f64>], t: usize) -> Vec<Vec<f64>>;
13 fn generate(&self, n_samples: usize, n_features: usize, seed: u64) -> Vec<Vec<f64>>;
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct DiffusionConfig {
20 #[serde(default = "default_n_steps")]
22 pub n_steps: usize,
23 #[serde(default)]
25 pub schedule: NoiseScheduleType,
26 #[serde(default)]
28 pub seed: u64,
29}
30
31fn default_n_steps() -> usize {
32 1000
33}
34
35impl Default for DiffusionConfig {
36 fn default() -> Self {
37 Self {
38 n_steps: default_n_steps(),
39 schedule: NoiseScheduleType::default(),
40 seed: 0,
41 }
42 }
43}
44
45impl DiffusionConfig {
46 pub fn build_schedule(&self) -> NoiseSchedule {
48 NoiseSchedule::new(&self.schedule, self.n_steps)
49 }
50}
51
52#[derive(Debug, Clone, Default, Serialize, Deserialize)]
54#[serde(rename_all = "snake_case")]
55pub enum NoiseScheduleType {
56 #[default]
57 Linear,
58 Cosine,
59 Sigmoid,
60}
61
62#[cfg(test)]
63mod tests {
64 use super::*;
65
66 #[test]
67 fn test_diffusion_config_default() {
68 let config = DiffusionConfig::default();
69 assert_eq!(config.n_steps, 1000);
70 assert!(matches!(config.schedule, NoiseScheduleType::Linear));
71 }
72
73 #[test]
74 fn test_diffusion_config_serde() {
75 let config = DiffusionConfig {
76 n_steps: 500,
77 schedule: NoiseScheduleType::Cosine,
78 seed: 42,
79 };
80 let json = serde_json::to_string(&config).unwrap();
81 let parsed: DiffusionConfig = serde_json::from_str(&json).unwrap();
82 assert_eq!(parsed.n_steps, 500);
83 assert_eq!(parsed.seed, 42);
84 }
85
86 #[test]
87 fn test_build_schedule() {
88 let config = DiffusionConfig {
89 n_steps: 100,
90 schedule: NoiseScheduleType::Cosine,
91 seed: 0,
92 };
93 let schedule = config.build_schedule();
94 assert_eq!(schedule.n_steps(), 100);
95 }
96}