1use std::path::Path;
13
14use serde::{Deserialize, Serialize};
15
16use super::backend::DiffusionConfig;
17use super::statistical::StatisticalDiffusionBackend;
18use super::DiffusionBackend;
19use crate::error::SynthError;
20
21#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
27pub enum ColumnType {
28 Continuous,
30 Categorical { categories: Vec<String> },
32 Integer,
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct ColumnDiffusionParams {
39 pub name: String,
41 pub mean: f64,
43 pub std: f64,
45 pub min: f64,
47 pub max: f64,
49 pub col_type: ColumnType,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct ModelMetadata {
60 pub training_timestamp: String,
62 pub n_steps: usize,
64 pub schedule_type: String,
66 pub n_columns: usize,
68 pub version: String,
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct TrainedDiffusionModel {
79 pub column_params: Vec<ColumnDiffusionParams>,
81 pub correlation_matrix: Vec<Vec<f64>>,
83 pub config: DiffusionConfig,
85 pub metadata: ModelMetadata,
87}
88
89impl TrainedDiffusionModel {
90 pub fn generate(&self, n_samples: usize, seed: u64) -> Vec<Vec<f64>> {
97 let n_features = self.column_params.len();
98 if n_samples == 0 || n_features == 0 {
99 return vec![];
100 }
101
102 let means: Vec<f64> = self.column_params.iter().map(|c| c.mean).collect();
104 let stds: Vec<f64> = self.column_params.iter().map(|c| c.std.max(1e-8)).collect();
105
106 let backend = StatisticalDiffusionBackend::new(means, stds, self.config.clone())
107 .with_correlations(self.correlation_matrix.clone());
108
109 let mut samples = backend.generate(n_samples, n_features, seed);
110
111 for row in samples.iter_mut() {
113 for (j, val) in row.iter_mut().enumerate() {
114 if j >= self.column_params.len() {
115 continue;
116 }
117 let col = &self.column_params[j];
118 match &col.col_type {
119 ColumnType::Continuous => {
120 *val = val.clamp(col.min, col.max);
121 }
122 ColumnType::Integer => {
123 *val = val.round().clamp(col.min, col.max);
124 }
125 ColumnType::Categorical { categories } => {
126 let n_cats = categories.len().max(1) as f64;
127 *val = val.round().clamp(0.0, n_cats - 1.0);
128 }
129 }
130 }
131 }
132
133 samples
134 }
135
136 pub fn save(&self, path: &Path) -> Result<(), SynthError> {
138 let json = serde_json::to_string_pretty(self)
139 .map_err(|e| SynthError::generation(format!("Failed to serialize model: {e}")))?;
140 std::fs::write(path, json).map_err(|e| {
141 SynthError::generation(format!("Failed to write model to {}: {e}", path.display()))
142 })?;
143 Ok(())
144 }
145
146 pub fn load(path: &Path) -> Result<Self, SynthError> {
148 let data = std::fs::read_to_string(path).map_err(|e| {
149 SynthError::generation(format!("Failed to read model from {}: {e}", path.display()))
150 })?;
151 let model: Self = serde_json::from_str(&data)
152 .map_err(|e| SynthError::generation(format!("Failed to deserialize model: {e}")))?;
153 Ok(model)
154 }
155}
156
157pub struct DiffusionTrainer;
167
168impl DiffusionTrainer {
169 pub fn fit(
174 column_params: Vec<ColumnDiffusionParams>,
175 correlation_matrix: Vec<Vec<f64>>,
176 config: DiffusionConfig,
177 ) -> TrainedDiffusionModel {
178 let schedule_type = match config.schedule {
179 super::backend::NoiseScheduleType::Linear => "linear".to_string(),
180 super::backend::NoiseScheduleType::Cosine => "cosine".to_string(),
181 super::backend::NoiseScheduleType::Sigmoid => "sigmoid".to_string(),
182 };
183
184 let metadata = ModelMetadata {
185 training_timestamp: chrono::Utc::now().to_rfc3339(),
186 n_steps: config.n_steps,
187 schedule_type,
188 n_columns: column_params.len(),
189 version: "1.0.0".to_string(),
190 };
191
192 TrainedDiffusionModel {
193 column_params,
194 correlation_matrix,
195 config,
196 metadata,
197 }
198 }
199
200 pub fn evaluate(model: &TrainedDiffusionModel, n_eval_samples: usize, seed: u64) -> FitReport {
206 let samples = model.generate(n_eval_samples, seed);
207 let n_cols = model.column_params.len();
208
209 if samples.is_empty() || n_cols == 0 {
210 return FitReport {
211 mean_errors: vec![],
212 std_errors: vec![],
213 correlation_error: 0.0,
214 overall_score: 0.0,
215 };
216 }
217
218 let n = samples.len() as f64;
219
220 let mut mean_errors = Vec::with_capacity(n_cols);
222 let mut std_errors = Vec::with_capacity(n_cols);
223
224 for j in 0..n_cols {
225 let col = &model.column_params[j];
226 let target_std = col.std.max(1e-8);
227
228 let sample_mean: f64 = samples.iter().map(|r| r[j]).sum::<f64>() / n;
229 let sample_var: f64 = samples
230 .iter()
231 .map(|r| (r[j] - sample_mean).powi(2))
232 .sum::<f64>()
233 / n;
234 let sample_std = sample_var.sqrt();
235
236 let me = (sample_mean - col.mean).abs() / target_std;
237 let se = (sample_std - col.std).abs() / target_std;
238
239 mean_errors.push(me);
240 std_errors.push(se);
241 }
242
243 let correlation_error = Self::compute_correlation_error(&samples, model);
245
246 let all_errors: Vec<f64> = mean_errors
248 .iter()
249 .chain(std_errors.iter())
250 .copied()
251 .collect();
252 let total_error_count = all_errors.len() as f64;
253 let avg_error = if total_error_count > 0.0 {
254 all_errors.iter().sum::<f64>() / total_error_count
255 } else {
256 0.0
257 };
258 let combined = (avg_error + correlation_error) / 2.0;
261 let overall_score = (1.0 - combined).clamp(0.0, 1.0);
262
263 FitReport {
264 mean_errors,
265 std_errors,
266 correlation_error,
267 overall_score,
268 }
269 }
270
271 fn compute_correlation_error(samples: &[Vec<f64>], model: &TrainedDiffusionModel) -> f64 {
274 let n_cols = model.column_params.len();
275 if n_cols < 2 || samples.is_empty() {
276 return 0.0;
277 }
278
279 let n = samples.len() as f64;
280
281 let mut means = vec![0.0; n_cols];
283 for row in samples {
284 for (j, &v) in row.iter().enumerate().take(n_cols) {
285 means[j] += v;
286 }
287 }
288 for m in &mut means {
289 *m /= n;
290 }
291
292 let mut stds = vec![0.0; n_cols];
294 for row in samples {
295 for (j, &v) in row.iter().enumerate().take(n_cols) {
296 stds[j] += (v - means[j]).powi(2);
297 }
298 }
299 for s in &mut stds {
300 *s = (*s / n).sqrt().max(1e-8);
301 }
302
303 let mut sample_corr = vec![vec![0.0; n_cols]; n_cols];
305 for row in samples {
306 for i in 0..n_cols {
307 for j in 0..n_cols {
308 sample_corr[i][j] += (row[i] - means[i]) * (row[j] - means[j]);
309 }
310 }
311 }
312 for (i, corr_row) in sample_corr.iter_mut().enumerate().take(n_cols) {
313 for (j, corr_val) in corr_row.iter_mut().enumerate().take(n_cols) {
314 *corr_val /= n * stds[i] * stds[j];
315 }
316 }
317
318 let target_corr = &model.correlation_matrix;
320 let mut frobenius_sq = 0.0;
321 for (i, corr_row) in sample_corr.iter().enumerate().take(n_cols) {
322 for (j, &corr_val) in corr_row.iter().enumerate().take(n_cols) {
323 let target_val = target_corr
324 .get(i)
325 .and_then(|row| row.get(j))
326 .copied()
327 .unwrap_or(if i == j { 1.0 } else { 0.0 });
328 let diff = corr_val - target_val;
329 frobenius_sq += diff * diff;
330 }
331 }
332
333 (frobenius_sq / (n_cols * n_cols) as f64).sqrt()
334 }
335}
336
337#[derive(Debug, Clone, Serialize, Deserialize)]
343pub struct FitReport {
344 pub mean_errors: Vec<f64>,
346 pub std_errors: Vec<f64>,
348 pub correlation_error: f64,
350 pub overall_score: f64,
352}
353
354#[cfg(test)]
359mod tests {
360 use super::super::backend::NoiseScheduleType;
361 use super::*;
362
363 fn make_config(n_steps: usize, seed: u64) -> DiffusionConfig {
364 DiffusionConfig {
365 n_steps,
366 schedule: NoiseScheduleType::Linear,
367 seed,
368 }
369 }
370
371 fn sample_column_params() -> Vec<ColumnDiffusionParams> {
372 vec![
373 ColumnDiffusionParams {
374 name: "amount".to_string(),
375 mean: 100.0,
376 std: 15.0,
377 min: 0.0,
378 max: 500.0,
379 col_type: ColumnType::Continuous,
380 },
381 ColumnDiffusionParams {
382 name: "quantity".to_string(),
383 mean: 10.0,
384 std: 3.0,
385 min: 1.0,
386 max: 50.0,
387 col_type: ColumnType::Integer,
388 },
389 ColumnDiffusionParams {
390 name: "category".to_string(),
391 mean: 1.5,
392 std: 0.8,
393 min: 0.0,
394 max: 3.0,
395 col_type: ColumnType::Categorical {
396 categories: vec![
397 "A".to_string(),
398 "B".to_string(),
399 "C".to_string(),
400 "D".to_string(),
401 ],
402 },
403 },
404 ]
405 }
406
407 fn sample_correlation_matrix() -> Vec<Vec<f64>> {
408 vec![
409 vec![1.0, 0.6, 0.2],
410 vec![0.6, 1.0, 0.3],
411 vec![0.2, 0.3, 1.0],
412 ]
413 }
414
415 #[test]
417 fn test_fit_produces_valid_model() {
418 let params = sample_column_params();
419 let corr = sample_correlation_matrix();
420 let config = make_config(100, 42);
421
422 let model = DiffusionTrainer::fit(params.clone(), corr.clone(), config);
423
424 assert_eq!(model.column_params.len(), 3);
425 assert_eq!(model.correlation_matrix.len(), 3);
426 assert_eq!(model.metadata.n_columns, 3);
427 assert_eq!(model.metadata.n_steps, 100);
428 assert_eq!(model.metadata.schedule_type, "linear");
429 assert_eq!(model.metadata.version, "1.0.0");
430 assert!(!model.metadata.training_timestamp.is_empty());
431
432 assert_eq!(model.column_params[0].name, "amount");
434 assert!((model.column_params[0].mean - 100.0).abs() < 1e-10);
435 assert_eq!(model.column_params[1].col_type, ColumnType::Integer);
436 }
437
438 #[test]
440 fn test_save_load_roundtrip() {
441 let model = DiffusionTrainer::fit(
442 sample_column_params(),
443 sample_correlation_matrix(),
444 make_config(50, 42),
445 );
446
447 let dir = tempfile::tempdir().expect("Failed to create temp dir");
448 let path = dir.path().join("model.json");
449
450 model.save(&path).expect("Failed to save model");
451 let loaded = TrainedDiffusionModel::load(&path).expect("Failed to load model");
452
453 assert_eq!(model.column_params.len(), loaded.column_params.len());
455 for (orig, load) in model.column_params.iter().zip(loaded.column_params.iter()) {
456 assert_eq!(orig.name, load.name);
457 assert!((orig.mean - load.mean).abs() < 1e-10);
458 assert!((orig.std - load.std).abs() < 1e-10);
459 assert!((orig.min - load.min).abs() < 1e-10);
460 assert!((orig.max - load.max).abs() < 1e-10);
461 assert_eq!(orig.col_type, load.col_type);
462 }
463 assert_eq!(model.correlation_matrix, loaded.correlation_matrix);
464 assert_eq!(model.config.n_steps, loaded.config.n_steps);
465 assert_eq!(model.config.seed, loaded.config.seed);
466 assert_eq!(model.metadata.version, loaded.metadata.version);
467 assert_eq!(
468 model.metadata.training_timestamp,
469 loaded.metadata.training_timestamp
470 );
471 }
472
473 #[test]
475 fn test_generate_correct_dimensions() {
476 let model = DiffusionTrainer::fit(
477 sample_column_params(),
478 sample_correlation_matrix(),
479 make_config(50, 42),
480 );
481
482 let samples = model.generate(200, 99);
483 assert_eq!(samples.len(), 200);
484 for row in &samples {
485 assert_eq!(row.len(), 3);
486 }
487 }
488
489 #[test]
491 fn test_generated_means_within_tolerance() {
492 let params = sample_column_params();
493 let model = DiffusionTrainer::fit(
494 params.clone(),
495 sample_correlation_matrix(),
496 make_config(100, 42),
497 );
498
499 let samples = model.generate(5000, 42);
500 let n = samples.len() as f64;
501
502 for (j, col) in params.iter().enumerate() {
503 let sample_mean: f64 = samples.iter().map(|r| r[j]).sum::<f64>() / n;
504 let tolerance = 2.0 * col.std;
506 assert!(
507 (sample_mean - col.mean).abs() < tolerance,
508 "Column {} ('{}') mean {} is more than {} from target {}",
509 j,
510 col.name,
511 sample_mean,
512 tolerance,
513 col.mean,
514 );
515 }
516 }
517
518 #[test]
520 fn test_same_seed_deterministic() {
521 let model = DiffusionTrainer::fit(
522 sample_column_params(),
523 sample_correlation_matrix(),
524 make_config(50, 42),
525 );
526
527 let samples1 = model.generate(100, 123);
528 let samples2 = model.generate(100, 123);
529
530 for (row1, row2) in samples1.iter().zip(samples2.iter()) {
531 for (&v1, &v2) in row1.iter().zip(row2.iter()) {
532 assert!(
533 (v1 - v2).abs() < 1e-12,
534 "Determinism failed: {} vs {}",
535 v1,
536 v2,
537 );
538 }
539 }
540 }
541
542 #[test]
544 fn test_different_seeds_differ() {
545 let model = DiffusionTrainer::fit(
546 sample_column_params(),
547 sample_correlation_matrix(),
548 make_config(50, 42),
549 );
550
551 let samples1 = model.generate(100, 1);
552 let samples2 = model.generate(100, 2);
553
554 let mut any_diff = false;
555 for (row1, row2) in samples1.iter().zip(samples2.iter()) {
556 for (&v1, &v2) in row1.iter().zip(row2.iter()) {
557 if (v1 - v2).abs() > 1e-8 {
558 any_diff = true;
559 break;
560 }
561 }
562 if any_diff {
563 break;
564 }
565 }
566 assert!(any_diff, "Different seeds should produce different samples");
567 }
568
569 #[test]
571 fn test_evaluation_reasonable_scores() {
572 let model = DiffusionTrainer::fit(
573 sample_column_params(),
574 sample_correlation_matrix(),
575 make_config(100, 42),
576 );
577
578 let report = DiffusionTrainer::evaluate(&model, 5000, 42);
579
580 for (j, &me) in report.mean_errors.iter().enumerate() {
582 assert!(
583 me < 1.0,
584 "Column {} mean error {} is too large (should be < 1.0)",
585 j,
586 me,
587 );
588 }
589
590 for (j, &se) in report.std_errors.iter().enumerate() {
592 assert!(
593 se < 1.5,
594 "Column {} std error {} is too large (should be < 1.5)",
595 j,
596 se,
597 );
598 }
599
600 assert!(
602 report.correlation_error < 1.0,
603 "Correlation error {} is too large",
604 report.correlation_error,
605 );
606
607 assert!(
609 report.overall_score > 0.0,
610 "Overall score {} should be positive",
611 report.overall_score,
612 );
613 }
614
615 #[test]
617 fn test_correlation_structure_preserved() {
618 let params = vec![
620 ColumnDiffusionParams {
621 name: "x".to_string(),
622 mean: 0.0,
623 std: 1.0,
624 min: -10.0,
625 max: 10.0,
626 col_type: ColumnType::Continuous,
627 },
628 ColumnDiffusionParams {
629 name: "y".to_string(),
630 mean: 0.0,
631 std: 1.0,
632 min: -10.0,
633 max: 10.0,
634 col_type: ColumnType::Continuous,
635 },
636 ];
637 let corr = vec![vec![1.0, 0.8], vec![0.8, 1.0]];
638
639 let model = DiffusionTrainer::fit(params, corr, make_config(100, 42));
640 let samples = model.generate(5000, 42);
641
642 let n = samples.len() as f64;
644 let mean_x: f64 = samples.iter().map(|r| r[0]).sum::<f64>() / n;
645 let mean_y: f64 = samples.iter().map(|r| r[1]).sum::<f64>() / n;
646 let std_x: f64 = (samples.iter().map(|r| (r[0] - mean_x).powi(2)).sum::<f64>() / n).sqrt();
647 let std_y: f64 = (samples.iter().map(|r| (r[1] - mean_y).powi(2)).sum::<f64>() / n).sqrt();
648 let cov_xy: f64 = samples
649 .iter()
650 .map(|r| (r[0] - mean_x) * (r[1] - mean_y))
651 .sum::<f64>()
652 / n;
653
654 let sample_corr = if std_x > 1e-8 && std_y > 1e-8 {
655 cov_xy / (std_x * std_y)
656 } else {
657 0.0
658 };
659
660 assert!(
662 sample_corr > 0.3,
663 "Expected positive correlation near 0.8, got {}",
664 sample_corr,
665 );
666 }
667
668 #[test]
670 fn test_integer_column_produces_integers() {
671 let params = vec![ColumnDiffusionParams {
672 name: "count".to_string(),
673 mean: 10.0,
674 std: 3.0,
675 min: 1.0,
676 max: 50.0,
677 col_type: ColumnType::Integer,
678 }];
679 let corr = vec![vec![1.0]];
680
681 let model = DiffusionTrainer::fit(params, corr, make_config(50, 42));
682 let samples = model.generate(500, 42);
683
684 for row in &samples {
685 let val = row[0];
686 assert!(
687 (val - val.round()).abs() < 1e-10,
688 "Integer column produced non-integer value: {}",
689 val,
690 );
691 assert!(
692 (1.0..=50.0).contains(&val),
693 "Integer column value {} out of range [1, 50]",
694 val,
695 );
696 }
697 }
698
699 #[test]
701 fn test_categorical_column_produces_valid_indices() {
702 let params = vec![ColumnDiffusionParams {
703 name: "category".to_string(),
704 mean: 1.0,
705 std: 0.8,
706 min: 0.0,
707 max: 2.0,
708 col_type: ColumnType::Categorical {
709 categories: vec!["A".to_string(), "B".to_string(), "C".to_string()],
710 },
711 }];
712 let corr = vec![vec![1.0]];
713
714 let model = DiffusionTrainer::fit(params, corr, make_config(50, 42));
715 let samples = model.generate(500, 42);
716
717 for row in &samples {
718 let val = row[0];
719 assert!(
720 (val - val.round()).abs() < 1e-10,
721 "Categorical column produced non-integer index: {}",
722 val,
723 );
724 assert!(
725 (0.0..=2.0).contains(&val),
726 "Categorical index {} out of range [0, 2]",
727 val,
728 );
729 }
730 }
731
732 #[test]
734 fn test_empty_model_generates_empty() {
735 let model = DiffusionTrainer::fit(vec![], vec![], make_config(50, 0));
736 let samples = model.generate(100, 0);
737 assert!(samples.is_empty());
738 }
739
740 #[test]
742 fn test_load_nonexistent_returns_error() {
743 let result = TrainedDiffusionModel::load(Path::new("/tmp/nonexistent_model_12345.json"));
744 assert!(result.is_err());
745 }
746}