Skip to main content

datasynth_core/diffusion/
training.rs

1//! Diffusion model training pipeline: fit from column statistics, persist, and evaluate.
2//!
3//! The [`DiffusionTrainer`] fits a [`TrainedDiffusionModel`] from per-column statistics
4//! (mean, std, min, max, type) and an optional correlation matrix. The trained model
5//! can be serialized to JSON for persistence and later reloaded for generation.
6//!
7//! Generation uses the same statistical diffusion approach as
8//! [`StatisticalDiffusionBackend`](super::StatisticalDiffusionBackend): start from
9//! Gaussian noise, iteratively denoise toward the target distribution, then apply
10//! correlation structure via Cholesky decomposition.
11
12use std::path::Path;
13
14use serde::{Deserialize, Serialize};
15
16use super::backend::DiffusionConfig;
17use super::statistical::StatisticalDiffusionBackend;
18use super::DiffusionBackend;
19use crate::error::SynthError;
20
21// ---------------------------------------------------------------------------
22// Column types and parameters
23// ---------------------------------------------------------------------------
24
25/// The type of a column in the dataset.
26#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
27pub enum ColumnType {
28    /// A continuous (floating-point) column.
29    Continuous,
30    /// A categorical column with a fixed set of string categories.
31    Categorical { categories: Vec<String> },
32    /// An integer-valued column.
33    Integer,
34}
35
36/// Statistical parameters for a single column in the trained model.
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct ColumnDiffusionParams {
39    /// Column name.
40    pub name: String,
41    /// Target mean.
42    pub mean: f64,
43    /// Target standard deviation.
44    pub std: f64,
45    /// Minimum observed value.
46    pub min: f64,
47    /// Maximum observed value.
48    pub max: f64,
49    /// Column type (continuous, categorical, integer).
50    pub col_type: ColumnType,
51}
52
53// ---------------------------------------------------------------------------
54// Metadata
55// ---------------------------------------------------------------------------
56
57/// Metadata about a trained diffusion model.
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct ModelMetadata {
60    /// ISO-8601 timestamp of when the model was trained.
61    pub training_timestamp: String,
62    /// Number of diffusion steps used.
63    pub n_steps: usize,
64    /// Noise schedule type (e.g. "linear", "cosine", "sigmoid").
65    pub schedule_type: String,
66    /// Number of columns in the model.
67    pub n_columns: usize,
68    /// Model format version.
69    pub version: String,
70}
71
72// ---------------------------------------------------------------------------
73// Trained model
74// ---------------------------------------------------------------------------
75
76/// A trained diffusion model that can generate samples and be persisted to disk.
77#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct TrainedDiffusionModel {
79    /// Per-column statistical parameters.
80    pub column_params: Vec<ColumnDiffusionParams>,
81    /// Correlation matrix (n_columns x n_columns).
82    pub correlation_matrix: Vec<Vec<f64>>,
83    /// Diffusion configuration used during training.
84    pub config: DiffusionConfig,
85    /// Model metadata.
86    pub metadata: ModelMetadata,
87}
88
89impl TrainedDiffusionModel {
90    /// Generate `n_samples` rows of synthetic data using the trained model parameters.
91    ///
92    /// Each row contains one value per column. Column types are respected:
93    /// - **Continuous**: clipped to `[min, max]`
94    /// - **Integer**: rounded and clipped to `[min, max]`
95    /// - **Categorical**: mapped to category indices, rounded and clipped to `[0, n_categories - 1]`
96    pub fn generate(&self, n_samples: usize, seed: u64) -> Vec<Vec<f64>> {
97        let n_features = self.column_params.len();
98        if n_samples == 0 || n_features == 0 {
99            return vec![];
100        }
101
102        // Use the StatisticalDiffusionBackend for the core generation
103        let means: Vec<f64> = self.column_params.iter().map(|c| c.mean).collect();
104        let stds: Vec<f64> = self.column_params.iter().map(|c| c.std.max(1e-8)).collect();
105
106        let backend = StatisticalDiffusionBackend::new(means, stds, self.config.clone())
107            .with_correlations(self.correlation_matrix.clone());
108
109        let mut samples = backend.generate(n_samples, n_features, seed);
110
111        // Post-process according to column types
112        for row in samples.iter_mut() {
113            for (j, val) in row.iter_mut().enumerate() {
114                if j >= self.column_params.len() {
115                    continue;
116                }
117                let col = &self.column_params[j];
118                match &col.col_type {
119                    ColumnType::Continuous => {
120                        *val = val.clamp(col.min, col.max);
121                    }
122                    ColumnType::Integer => {
123                        *val = val.round().clamp(col.min, col.max);
124                    }
125                    ColumnType::Categorical { categories } => {
126                        let n_cats = categories.len().max(1) as f64;
127                        *val = val.round().clamp(0.0, n_cats - 1.0);
128                    }
129                }
130            }
131        }
132
133        samples
134    }
135
136    /// Serialize and save the model to a JSON file at `path`.
137    pub fn save(&self, path: &Path) -> Result<(), SynthError> {
138        let json = serde_json::to_string_pretty(self)
139            .map_err(|e| SynthError::generation(format!("Failed to serialize model: {e}")))?;
140        std::fs::write(path, json).map_err(|e| {
141            SynthError::generation(format!("Failed to write model to {}: {e}", path.display()))
142        })?;
143        Ok(())
144    }
145
146    /// Load a model from a JSON file at `path`.
147    pub fn load(path: &Path) -> Result<Self, SynthError> {
148        let data = std::fs::read_to_string(path).map_err(|e| {
149            SynthError::generation(format!("Failed to read model from {}: {e}", path.display()))
150        })?;
151        let model: Self = serde_json::from_str(&data)
152            .map_err(|e| SynthError::generation(format!("Failed to deserialize model: {e}")))?;
153        Ok(model)
154    }
155}
156
157// ---------------------------------------------------------------------------
158// Trainer
159// ---------------------------------------------------------------------------
160
161/// Trainer that fits diffusion model parameters from column statistics.
162///
163/// This is a stateless builder: call [`DiffusionTrainer::fit`] with the desired
164/// parameters, then use the returned [`TrainedDiffusionModel`] for generation or
165/// persistence.
166pub struct DiffusionTrainer;
167
168impl DiffusionTrainer {
169    /// Fit a diffusion model from per-column statistics and a correlation matrix.
170    ///
171    /// The resulting [`TrainedDiffusionModel`] captures the target distribution
172    /// and can generate new samples via its `generate` method.
173    pub fn fit(
174        column_params: Vec<ColumnDiffusionParams>,
175        correlation_matrix: Vec<Vec<f64>>,
176        config: DiffusionConfig,
177    ) -> TrainedDiffusionModel {
178        let schedule_type = match config.schedule {
179            super::backend::NoiseScheduleType::Linear => "linear".to_string(),
180            super::backend::NoiseScheduleType::Cosine => "cosine".to_string(),
181            super::backend::NoiseScheduleType::Sigmoid => "sigmoid".to_string(),
182        };
183
184        let metadata = ModelMetadata {
185            training_timestamp: chrono::Utc::now().to_rfc3339(),
186            n_steps: config.n_steps,
187            schedule_type,
188            n_columns: column_params.len(),
189            version: "1.0.0".to_string(),
190        };
191
192        TrainedDiffusionModel {
193            column_params,
194            correlation_matrix,
195            config,
196            metadata,
197        }
198    }
199
200    /// Evaluate a trained model by comparing generated samples against the
201    /// target statistics captured in the model.
202    ///
203    /// Returns a [`FitReport`] with per-column errors, correlation error, and
204    /// an overall quality score.
205    pub fn evaluate(model: &TrainedDiffusionModel, n_eval_samples: usize, seed: u64) -> FitReport {
206        let samples = model.generate(n_eval_samples, seed);
207        let n_cols = model.column_params.len();
208
209        if samples.is_empty() || n_cols == 0 {
210            return FitReport {
211                mean_errors: vec![],
212                std_errors: vec![],
213                correlation_error: 0.0,
214                overall_score: 0.0,
215            };
216        }
217
218        let n = samples.len() as f64;
219
220        // Per-column mean and std errors (normalized by target std)
221        let mut mean_errors = Vec::with_capacity(n_cols);
222        let mut std_errors = Vec::with_capacity(n_cols);
223
224        for j in 0..n_cols {
225            let col = &model.column_params[j];
226            let target_std = col.std.max(1e-8);
227
228            let sample_mean: f64 = samples.iter().map(|r| r[j]).sum::<f64>() / n;
229            let sample_var: f64 = samples
230                .iter()
231                .map(|r| (r[j] - sample_mean).powi(2))
232                .sum::<f64>()
233                / n;
234            let sample_std = sample_var.sqrt();
235
236            let me = (sample_mean - col.mean).abs() / target_std;
237            let se = (sample_std - col.std).abs() / target_std;
238
239            mean_errors.push(me);
240            std_errors.push(se);
241        }
242
243        // Correlation matrix error: Frobenius norm of difference
244        let correlation_error = Self::compute_correlation_error(&samples, model);
245
246        // Overall score: 1.0 - mean(all individual errors), clamped to [0, 1]
247        let all_errors: Vec<f64> = mean_errors
248            .iter()
249            .chain(std_errors.iter())
250            .copied()
251            .collect();
252        let total_error_count = all_errors.len() as f64;
253        let avg_error = if total_error_count > 0.0 {
254            all_errors.iter().sum::<f64>() / total_error_count
255        } else {
256            0.0
257        };
258        // Include correlation error in the overall score (weighted equally with
259        // the average of the per-column errors)
260        let combined = (avg_error + correlation_error) / 2.0;
261        let overall_score = (1.0 - combined).clamp(0.0, 1.0);
262
263        FitReport {
264            mean_errors,
265            std_errors,
266            correlation_error,
267            overall_score,
268        }
269    }
270
271    /// Compute the Frobenius norm of (sample_corr - target_corr) normalized by
272    /// the number of elements.
273    fn compute_correlation_error(samples: &[Vec<f64>], model: &TrainedDiffusionModel) -> f64 {
274        let n_cols = model.column_params.len();
275        if n_cols < 2 || samples.is_empty() {
276            return 0.0;
277        }
278
279        let n = samples.len() as f64;
280
281        // Compute sample means
282        let mut means = vec![0.0; n_cols];
283        for row in samples {
284            for (j, &v) in row.iter().enumerate().take(n_cols) {
285                means[j] += v;
286            }
287        }
288        for m in &mut means {
289            *m /= n;
290        }
291
292        // Compute sample stds
293        let mut stds = vec![0.0; n_cols];
294        for row in samples {
295            for (j, &v) in row.iter().enumerate().take(n_cols) {
296                stds[j] += (v - means[j]).powi(2);
297            }
298        }
299        for s in &mut stds {
300            *s = (*s / n).sqrt().max(1e-8);
301        }
302
303        // Compute sample correlation matrix
304        let mut sample_corr = vec![vec![0.0; n_cols]; n_cols];
305        for row in samples {
306            for i in 0..n_cols {
307                for j in 0..n_cols {
308                    sample_corr[i][j] += (row[i] - means[i]) * (row[j] - means[j]);
309                }
310            }
311        }
312        for (i, corr_row) in sample_corr.iter_mut().enumerate().take(n_cols) {
313            for (j, corr_val) in corr_row.iter_mut().enumerate().take(n_cols) {
314                *corr_val /= n * stds[i] * stds[j];
315            }
316        }
317
318        // Frobenius norm of difference, normalized by matrix size
319        let target_corr = &model.correlation_matrix;
320        let mut frobenius_sq = 0.0;
321        for (i, corr_row) in sample_corr.iter().enumerate().take(n_cols) {
322            for (j, &corr_val) in corr_row.iter().enumerate().take(n_cols) {
323                let target_val = target_corr
324                    .get(i)
325                    .and_then(|row| row.get(j))
326                    .copied()
327                    .unwrap_or(if i == j { 1.0 } else { 0.0 });
328                let diff = corr_val - target_val;
329                frobenius_sq += diff * diff;
330            }
331        }
332
333        (frobenius_sq / (n_cols * n_cols) as f64).sqrt()
334    }
335}
336
337// ---------------------------------------------------------------------------
338// Fit report
339// ---------------------------------------------------------------------------
340
341/// Report from evaluating a trained diffusion model against its target statistics.
342#[derive(Debug, Clone, Serialize, Deserialize)]
343pub struct FitReport {
344    /// Per-column normalized mean error: `|sample_mean - target_mean| / target_std`.
345    pub mean_errors: Vec<f64>,
346    /// Per-column normalized std error: `|sample_std - target_std| / target_std`.
347    pub std_errors: Vec<f64>,
348    /// Root-mean-square element difference between sample and target correlation matrices.
349    pub correlation_error: f64,
350    /// Overall quality score in `[0.0, 1.0]` (higher is better).
351    pub overall_score: f64,
352}
353
354// ---------------------------------------------------------------------------
355// Tests
356// ---------------------------------------------------------------------------
357
358#[cfg(test)]
359#[allow(clippy::unwrap_used)]
360mod tests {
361    use super::super::backend::NoiseScheduleType;
362    use super::*;
363
364    fn make_config(n_steps: usize, seed: u64) -> DiffusionConfig {
365        DiffusionConfig {
366            n_steps,
367            schedule: NoiseScheduleType::Linear,
368            seed,
369        }
370    }
371
372    fn sample_column_params() -> Vec<ColumnDiffusionParams> {
373        vec![
374            ColumnDiffusionParams {
375                name: "amount".to_string(),
376                mean: 100.0,
377                std: 15.0,
378                min: 0.0,
379                max: 500.0,
380                col_type: ColumnType::Continuous,
381            },
382            ColumnDiffusionParams {
383                name: "quantity".to_string(),
384                mean: 10.0,
385                std: 3.0,
386                min: 1.0,
387                max: 50.0,
388                col_type: ColumnType::Integer,
389            },
390            ColumnDiffusionParams {
391                name: "category".to_string(),
392                mean: 1.5,
393                std: 0.8,
394                min: 0.0,
395                max: 3.0,
396                col_type: ColumnType::Categorical {
397                    categories: vec![
398                        "A".to_string(),
399                        "B".to_string(),
400                        "C".to_string(),
401                        "D".to_string(),
402                    ],
403                },
404            },
405        ]
406    }
407
408    fn sample_correlation_matrix() -> Vec<Vec<f64>> {
409        vec![
410            vec![1.0, 0.6, 0.2],
411            vec![0.6, 1.0, 0.3],
412            vec![0.2, 0.3, 1.0],
413        ]
414    }
415
416    // 1. Fit from column params produces valid model
417    #[test]
418    fn test_fit_produces_valid_model() {
419        let params = sample_column_params();
420        let corr = sample_correlation_matrix();
421        let config = make_config(100, 42);
422
423        let model = DiffusionTrainer::fit(params.clone(), corr.clone(), config);
424
425        assert_eq!(model.column_params.len(), 3);
426        assert_eq!(model.correlation_matrix.len(), 3);
427        assert_eq!(model.metadata.n_columns, 3);
428        assert_eq!(model.metadata.n_steps, 100);
429        assert_eq!(model.metadata.schedule_type, "linear");
430        assert_eq!(model.metadata.version, "1.0.0");
431        assert!(!model.metadata.training_timestamp.is_empty());
432
433        // Column params are preserved
434        assert_eq!(model.column_params[0].name, "amount");
435        assert!((model.column_params[0].mean - 100.0).abs() < 1e-10);
436        assert_eq!(model.column_params[1].col_type, ColumnType::Integer);
437    }
438
439    // 2. Save/load roundtrip produces identical model
440    #[test]
441    fn test_save_load_roundtrip() {
442        let model = DiffusionTrainer::fit(
443            sample_column_params(),
444            sample_correlation_matrix(),
445            make_config(50, 42),
446        );
447
448        let dir = tempfile::tempdir().expect("Failed to create temp dir");
449        let path = dir.path().join("model.json");
450
451        model.save(&path).expect("Failed to save model");
452        let loaded = TrainedDiffusionModel::load(&path).expect("Failed to load model");
453
454        // Verify all fields match
455        assert_eq!(model.column_params.len(), loaded.column_params.len());
456        for (orig, load) in model.column_params.iter().zip(loaded.column_params.iter()) {
457            assert_eq!(orig.name, load.name);
458            assert!((orig.mean - load.mean).abs() < 1e-10);
459            assert!((orig.std - load.std).abs() < 1e-10);
460            assert!((orig.min - load.min).abs() < 1e-10);
461            assert!((orig.max - load.max).abs() < 1e-10);
462            assert_eq!(orig.col_type, load.col_type);
463        }
464        assert_eq!(model.correlation_matrix, loaded.correlation_matrix);
465        assert_eq!(model.config.n_steps, loaded.config.n_steps);
466        assert_eq!(model.config.seed, loaded.config.seed);
467        assert_eq!(model.metadata.version, loaded.metadata.version);
468        assert_eq!(
469            model.metadata.training_timestamp,
470            loaded.metadata.training_timestamp
471        );
472    }
473
474    // 3. Generated samples have correct dimensions
475    #[test]
476    fn test_generate_correct_dimensions() {
477        let model = DiffusionTrainer::fit(
478            sample_column_params(),
479            sample_correlation_matrix(),
480            make_config(50, 42),
481        );
482
483        let samples = model.generate(200, 99);
484        assert_eq!(samples.len(), 200);
485        for row in &samples {
486            assert_eq!(row.len(), 3);
487        }
488    }
489
490    // 4. Generated means within tolerance of target
491    #[test]
492    fn test_generated_means_within_tolerance() {
493        let params = sample_column_params();
494        let model = DiffusionTrainer::fit(
495            params.clone(),
496            sample_correlation_matrix(),
497            make_config(100, 42),
498        );
499
500        let samples = model.generate(5000, 42);
501        let n = samples.len() as f64;
502
503        for (j, col) in params.iter().enumerate() {
504            let sample_mean: f64 = samples.iter().map(|r| r[j]).sum::<f64>() / n;
505            // Allow tolerance of 2 * target_std (generous but verifies convergence)
506            let tolerance = 2.0 * col.std;
507            assert!(
508                (sample_mean - col.mean).abs() < tolerance,
509                "Column {} ('{}') mean {} is more than {} from target {}",
510                j,
511                col.name,
512                sample_mean,
513                tolerance,
514                col.mean,
515            );
516        }
517    }
518
519    // 5. Same seed produces same output
520    #[test]
521    fn test_same_seed_deterministic() {
522        let model = DiffusionTrainer::fit(
523            sample_column_params(),
524            sample_correlation_matrix(),
525            make_config(50, 42),
526        );
527
528        let samples1 = model.generate(100, 123);
529        let samples2 = model.generate(100, 123);
530
531        for (row1, row2) in samples1.iter().zip(samples2.iter()) {
532            for (&v1, &v2) in row1.iter().zip(row2.iter()) {
533                assert!(
534                    (v1 - v2).abs() < 1e-12,
535                    "Determinism failed: {} vs {}",
536                    v1,
537                    v2,
538                );
539            }
540        }
541    }
542
543    // 6. Different seeds produce different output
544    #[test]
545    fn test_different_seeds_differ() {
546        let model = DiffusionTrainer::fit(
547            sample_column_params(),
548            sample_correlation_matrix(),
549            make_config(50, 42),
550        );
551
552        let samples1 = model.generate(100, 1);
553        let samples2 = model.generate(100, 2);
554
555        let mut any_diff = false;
556        for (row1, row2) in samples1.iter().zip(samples2.iter()) {
557            for (&v1, &v2) in row1.iter().zip(row2.iter()) {
558                if (v1 - v2).abs() > 1e-8 {
559                    any_diff = true;
560                    break;
561                }
562            }
563            if any_diff {
564                break;
565            }
566        }
567        assert!(any_diff, "Different seeds should produce different samples");
568    }
569
570    // 7. Evaluation report has reasonable scores for well-fitted model
571    #[test]
572    fn test_evaluation_reasonable_scores() {
573        let model = DiffusionTrainer::fit(
574            sample_column_params(),
575            sample_correlation_matrix(),
576            make_config(100, 42),
577        );
578
579        let report = DiffusionTrainer::evaluate(&model, 5000, 42);
580
581        // Mean errors should be small (less than 1.0 normalized by std)
582        for (j, &me) in report.mean_errors.iter().enumerate() {
583            assert!(
584                me < 1.0,
585                "Column {} mean error {} is too large (should be < 1.0)",
586                j,
587                me,
588            );
589        }
590
591        // Std errors should be reasonable
592        for (j, &se) in report.std_errors.iter().enumerate() {
593            assert!(
594                se < 1.5,
595                "Column {} std error {} is too large (should be < 1.5)",
596                j,
597                se,
598            );
599        }
600
601        // Correlation error should be bounded
602        assert!(
603            report.correlation_error < 1.0,
604            "Correlation error {} is too large",
605            report.correlation_error,
606        );
607
608        // Overall score should be positive for a reasonable model
609        assert!(
610            report.overall_score > 0.0,
611            "Overall score {} should be positive",
612            report.overall_score,
613        );
614    }
615
616    // 8. Correlation structure is preserved
617    #[test]
618    fn test_correlation_structure_preserved() {
619        // Use only 2 continuous columns with strong correlation for cleaner test
620        let params = vec![
621            ColumnDiffusionParams {
622                name: "x".to_string(),
623                mean: 0.0,
624                std: 1.0,
625                min: -10.0,
626                max: 10.0,
627                col_type: ColumnType::Continuous,
628            },
629            ColumnDiffusionParams {
630                name: "y".to_string(),
631                mean: 0.0,
632                std: 1.0,
633                min: -10.0,
634                max: 10.0,
635                col_type: ColumnType::Continuous,
636            },
637        ];
638        let corr = vec![vec![1.0, 0.8], vec![0.8, 1.0]];
639
640        let model = DiffusionTrainer::fit(params, corr, make_config(100, 42));
641        let samples = model.generate(5000, 42);
642
643        // Compute sample correlation
644        let n = samples.len() as f64;
645        let mean_x: f64 = samples.iter().map(|r| r[0]).sum::<f64>() / n;
646        let mean_y: f64 = samples.iter().map(|r| r[1]).sum::<f64>() / n;
647        let std_x: f64 = (samples.iter().map(|r| (r[0] - mean_x).powi(2)).sum::<f64>() / n).sqrt();
648        let std_y: f64 = (samples.iter().map(|r| (r[1] - mean_y).powi(2)).sum::<f64>() / n).sqrt();
649        let cov_xy: f64 = samples
650            .iter()
651            .map(|r| (r[0] - mean_x) * (r[1] - mean_y))
652            .sum::<f64>()
653            / n;
654
655        let sample_corr = if std_x > 1e-8 && std_y > 1e-8 {
656            cov_xy / (std_x * std_y)
657        } else {
658            0.0
659        };
660
661        // Correlation should be positive and reasonably close to 0.8
662        assert!(
663            sample_corr > 0.3,
664            "Expected positive correlation near 0.8, got {}",
665            sample_corr,
666        );
667    }
668
669    // 9. Integer column type produces integer values
670    #[test]
671    fn test_integer_column_produces_integers() {
672        let params = vec![ColumnDiffusionParams {
673            name: "count".to_string(),
674            mean: 10.0,
675            std: 3.0,
676            min: 1.0,
677            max: 50.0,
678            col_type: ColumnType::Integer,
679        }];
680        let corr = vec![vec![1.0]];
681
682        let model = DiffusionTrainer::fit(params, corr, make_config(50, 42));
683        let samples = model.generate(500, 42);
684
685        for row in &samples {
686            let val = row[0];
687            assert!(
688                (val - val.round()).abs() < 1e-10,
689                "Integer column produced non-integer value: {}",
690                val,
691            );
692            assert!(
693                (1.0..=50.0).contains(&val),
694                "Integer column value {} out of range [1, 50]",
695                val,
696            );
697        }
698    }
699
700    // 10. Categorical column produces valid indices
701    #[test]
702    fn test_categorical_column_produces_valid_indices() {
703        let params = vec![ColumnDiffusionParams {
704            name: "category".to_string(),
705            mean: 1.0,
706            std: 0.8,
707            min: 0.0,
708            max: 2.0,
709            col_type: ColumnType::Categorical {
710                categories: vec!["A".to_string(), "B".to_string(), "C".to_string()],
711            },
712        }];
713        let corr = vec![vec![1.0]];
714
715        let model = DiffusionTrainer::fit(params, corr, make_config(50, 42));
716        let samples = model.generate(500, 42);
717
718        for row in &samples {
719            let val = row[0];
720            assert!(
721                (val - val.round()).abs() < 1e-10,
722                "Categorical column produced non-integer index: {}",
723                val,
724            );
725            assert!(
726                (0.0..=2.0).contains(&val),
727                "Categorical index {} out of range [0, 2]",
728                val,
729            );
730        }
731    }
732
733    // 11. Empty model generates empty samples
734    #[test]
735    fn test_empty_model_generates_empty() {
736        let model = DiffusionTrainer::fit(vec![], vec![], make_config(50, 0));
737        let samples = model.generate(100, 0);
738        assert!(samples.is_empty());
739    }
740
741    // 12. Load from non-existent file returns error
742    #[test]
743    fn test_load_nonexistent_returns_error() {
744        let result = TrainedDiffusionModel::load(Path::new("/tmp/nonexistent_model_12345.json"));
745        assert!(result.is_err());
746    }
747}