Skip to main content

datasynth_core/diffusion/
backend.rs

1use serde::{Deserialize, Serialize};
2
3use super::schedule::NoiseSchedule;
4
5/// Diffusion model backend trait.
6pub trait DiffusionBackend: Send + Sync {
7    /// Backend name.
8    fn name(&self) -> &str;
9    /// Forward process: add noise at timestep t.
10    fn forward(&self, x: &[Vec<f64>], t: usize) -> Vec<Vec<f64>>;
11    /// Reverse process: denoise at timestep t.
12    fn reverse(&self, x_t: &[Vec<f64>], t: usize) -> Vec<Vec<f64>>;
13    /// Generate n_samples with n_features from noise.
14    fn generate(&self, n_samples: usize, n_features: usize, seed: u64) -> Vec<Vec<f64>>;
15}
16
17/// Diffusion configuration.
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct DiffusionConfig {
20    /// Number of diffusion steps.
21    #[serde(default = "default_n_steps")]
22    pub n_steps: usize,
23    /// Noise schedule type.
24    #[serde(default)]
25    pub schedule: NoiseScheduleType,
26    /// Random seed.
27    #[serde(default)]
28    pub seed: u64,
29}
30
31fn default_n_steps() -> usize {
32    1000
33}
34
35impl Default for DiffusionConfig {
36    fn default() -> Self {
37        Self {
38            n_steps: default_n_steps(),
39            schedule: NoiseScheduleType::default(),
40            seed: 0,
41        }
42    }
43}
44
45impl DiffusionConfig {
46    /// Build a noise schedule from this configuration.
47    pub fn build_schedule(&self) -> NoiseSchedule {
48        NoiseSchedule::new(&self.schedule, self.n_steps)
49    }
50}
51
52/// Noise schedule type.
53#[derive(Debug, Clone, Default, Serialize, Deserialize)]
54#[serde(rename_all = "snake_case")]
55pub enum NoiseScheduleType {
56    #[default]
57    Linear,
58    Cosine,
59    Sigmoid,
60}
61
62#[cfg(test)]
63#[allow(clippy::unwrap_used)]
64mod tests {
65    use super::*;
66
67    #[test]
68    fn test_diffusion_config_default() {
69        let config = DiffusionConfig::default();
70        assert_eq!(config.n_steps, 1000);
71        assert!(matches!(config.schedule, NoiseScheduleType::Linear));
72    }
73
74    #[test]
75    fn test_diffusion_config_serde() {
76        let config = DiffusionConfig {
77            n_steps: 500,
78            schedule: NoiseScheduleType::Cosine,
79            seed: 42,
80        };
81        let json = serde_json::to_string(&config).unwrap();
82        let parsed: DiffusionConfig = serde_json::from_str(&json).unwrap();
83        assert_eq!(parsed.n_steps, 500);
84        assert_eq!(parsed.seed, 42);
85    }
86
87    #[test]
88    fn test_build_schedule() {
89        let config = DiffusionConfig {
90            n_steps: 100,
91            schedule: NoiseScheduleType::Cosine,
92            seed: 0,
93        };
94        let schedule = config.build_schedule();
95        assert_eq!(schedule.n_steps(), 100);
96    }
97}