1use std::path::Path;
13
14use serde::{Deserialize, Serialize};
15
16use super::backend::DiffusionConfig;
17use super::statistical::StatisticalDiffusionBackend;
18use super::DiffusionBackend;
19use crate::error::SynthError;
20
21#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
27pub enum ColumnType {
28 Continuous,
30 Categorical { categories: Vec<String> },
32 Integer,
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct ColumnDiffusionParams {
39 pub name: String,
41 pub mean: f64,
43 pub std: f64,
45 pub min: f64,
47 pub max: f64,
49 pub col_type: ColumnType,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct ModelMetadata {
60 pub training_timestamp: String,
62 pub n_steps: usize,
64 pub schedule_type: String,
66 pub n_columns: usize,
68 pub version: String,
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct TrainedDiffusionModel {
79 pub column_params: Vec<ColumnDiffusionParams>,
81 pub correlation_matrix: Vec<Vec<f64>>,
83 pub config: DiffusionConfig,
85 pub metadata: ModelMetadata,
87}
88
89impl TrainedDiffusionModel {
90 pub fn generate(&self, n_samples: usize, seed: u64) -> Vec<Vec<f64>> {
97 let n_features = self.column_params.len();
98 if n_samples == 0 || n_features == 0 {
99 return vec![];
100 }
101
102 let means: Vec<f64> = self.column_params.iter().map(|c| c.mean).collect();
104 let stds: Vec<f64> = self.column_params.iter().map(|c| c.std.max(1e-8)).collect();
105
106 let backend = StatisticalDiffusionBackend::new(means, stds, self.config.clone())
107 .with_correlations(self.correlation_matrix.clone());
108
109 let mut samples = backend.generate(n_samples, n_features, seed);
110
111 for row in samples.iter_mut() {
113 for (j, val) in row.iter_mut().enumerate() {
114 if j >= self.column_params.len() {
115 continue;
116 }
117 let col = &self.column_params[j];
118 match &col.col_type {
119 ColumnType::Continuous => {
120 *val = val.clamp(col.min, col.max);
121 }
122 ColumnType::Integer => {
123 *val = val.round().clamp(col.min, col.max);
124 }
125 ColumnType::Categorical { categories } => {
126 let n_cats = categories.len().max(1) as f64;
127 *val = val.round().clamp(0.0, n_cats - 1.0);
128 }
129 }
130 }
131 }
132
133 samples
134 }
135
136 pub fn save(&self, path: &Path) -> Result<(), SynthError> {
138 let json = serde_json::to_string_pretty(self)
139 .map_err(|e| SynthError::generation(format!("Failed to serialize model: {e}")))?;
140 std::fs::write(path, json).map_err(|e| {
141 SynthError::generation(format!("Failed to write model to {}: {e}", path.display()))
142 })?;
143 Ok(())
144 }
145
146 pub fn load(path: &Path) -> Result<Self, SynthError> {
148 let data = std::fs::read_to_string(path).map_err(|e| {
149 SynthError::generation(format!("Failed to read model from {}: {e}", path.display()))
150 })?;
151 let model: Self = serde_json::from_str(&data)
152 .map_err(|e| SynthError::generation(format!("Failed to deserialize model: {e}")))?;
153 Ok(model)
154 }
155}
156
157pub struct DiffusionTrainer;
167
168impl DiffusionTrainer {
169 pub fn fit(
174 column_params: Vec<ColumnDiffusionParams>,
175 correlation_matrix: Vec<Vec<f64>>,
176 config: DiffusionConfig,
177 ) -> TrainedDiffusionModel {
178 let schedule_type = match config.schedule {
179 super::backend::NoiseScheduleType::Linear => "linear".to_string(),
180 super::backend::NoiseScheduleType::Cosine => "cosine".to_string(),
181 super::backend::NoiseScheduleType::Sigmoid => "sigmoid".to_string(),
182 };
183
184 let metadata = ModelMetadata {
185 training_timestamp: chrono::Utc::now().to_rfc3339(),
186 n_steps: config.n_steps,
187 schedule_type,
188 n_columns: column_params.len(),
189 version: "1.0.0".to_string(),
190 };
191
192 TrainedDiffusionModel {
193 column_params,
194 correlation_matrix,
195 config,
196 metadata,
197 }
198 }
199
200 pub fn evaluate(model: &TrainedDiffusionModel, n_eval_samples: usize, seed: u64) -> FitReport {
206 let samples = model.generate(n_eval_samples, seed);
207 let n_cols = model.column_params.len();
208
209 if samples.is_empty() || n_cols == 0 {
210 return FitReport {
211 mean_errors: vec![],
212 std_errors: vec![],
213 correlation_error: 0.0,
214 overall_score: 0.0,
215 };
216 }
217
218 let n = samples.len() as f64;
219
220 let mut mean_errors = Vec::with_capacity(n_cols);
222 let mut std_errors = Vec::with_capacity(n_cols);
223
224 for j in 0..n_cols {
225 let col = &model.column_params[j];
226 let target_std = col.std.max(1e-8);
227
228 let sample_mean: f64 = samples.iter().map(|r| r[j]).sum::<f64>() / n;
229 let sample_var: f64 = samples
230 .iter()
231 .map(|r| (r[j] - sample_mean).powi(2))
232 .sum::<f64>()
233 / n;
234 let sample_std = sample_var.sqrt();
235
236 let me = (sample_mean - col.mean).abs() / target_std;
237 let se = (sample_std - col.std).abs() / target_std;
238
239 mean_errors.push(me);
240 std_errors.push(se);
241 }
242
243 let correlation_error = Self::compute_correlation_error(&samples, model);
245
246 let all_errors: Vec<f64> = mean_errors
248 .iter()
249 .chain(std_errors.iter())
250 .copied()
251 .collect();
252 let total_error_count = all_errors.len() as f64;
253 let avg_error = if total_error_count > 0.0 {
254 all_errors.iter().sum::<f64>() / total_error_count
255 } else {
256 0.0
257 };
258 let combined = (avg_error + correlation_error) / 2.0;
261 let overall_score = (1.0 - combined).clamp(0.0, 1.0);
262
263 FitReport {
264 mean_errors,
265 std_errors,
266 correlation_error,
267 overall_score,
268 }
269 }
270
271 fn compute_correlation_error(samples: &[Vec<f64>], model: &TrainedDiffusionModel) -> f64 {
274 let n_cols = model.column_params.len();
275 if n_cols < 2 || samples.is_empty() {
276 return 0.0;
277 }
278
279 let n = samples.len() as f64;
280
281 let mut means = vec![0.0; n_cols];
283 for row in samples {
284 for (j, &v) in row.iter().enumerate().take(n_cols) {
285 means[j] += v;
286 }
287 }
288 for m in &mut means {
289 *m /= n;
290 }
291
292 let mut stds = vec![0.0; n_cols];
294 for row in samples {
295 for (j, &v) in row.iter().enumerate().take(n_cols) {
296 stds[j] += (v - means[j]).powi(2);
297 }
298 }
299 for s in &mut stds {
300 *s = (*s / n).sqrt().max(1e-8);
301 }
302
303 let mut sample_corr = vec![vec![0.0; n_cols]; n_cols];
305 for row in samples {
306 for i in 0..n_cols {
307 for j in 0..n_cols {
308 sample_corr[i][j] += (row[i] - means[i]) * (row[j] - means[j]);
309 }
310 }
311 }
312 for (i, corr_row) in sample_corr.iter_mut().enumerate().take(n_cols) {
313 for (j, corr_val) in corr_row.iter_mut().enumerate().take(n_cols) {
314 *corr_val /= n * stds[i] * stds[j];
315 }
316 }
317
318 let target_corr = &model.correlation_matrix;
320 let mut frobenius_sq = 0.0;
321 for (i, corr_row) in sample_corr.iter().enumerate().take(n_cols) {
322 for (j, &corr_val) in corr_row.iter().enumerate().take(n_cols) {
323 let target_val = target_corr
324 .get(i)
325 .and_then(|row| row.get(j))
326 .copied()
327 .unwrap_or(if i == j { 1.0 } else { 0.0 });
328 let diff = corr_val - target_val;
329 frobenius_sq += diff * diff;
330 }
331 }
332
333 (frobenius_sq / (n_cols * n_cols) as f64).sqrt()
334 }
335}
336
337#[derive(Debug, Clone, Serialize, Deserialize)]
343pub struct FitReport {
344 pub mean_errors: Vec<f64>,
346 pub std_errors: Vec<f64>,
348 pub correlation_error: f64,
350 pub overall_score: f64,
352}
353
354#[cfg(test)]
359#[allow(clippy::unwrap_used)]
360mod tests {
361 use super::super::backend::NoiseScheduleType;
362 use super::*;
363
364 fn make_config(n_steps: usize, seed: u64) -> DiffusionConfig {
365 DiffusionConfig {
366 n_steps,
367 schedule: NoiseScheduleType::Linear,
368 seed,
369 }
370 }
371
372 fn sample_column_params() -> Vec<ColumnDiffusionParams> {
373 vec![
374 ColumnDiffusionParams {
375 name: "amount".to_string(),
376 mean: 100.0,
377 std: 15.0,
378 min: 0.0,
379 max: 500.0,
380 col_type: ColumnType::Continuous,
381 },
382 ColumnDiffusionParams {
383 name: "quantity".to_string(),
384 mean: 10.0,
385 std: 3.0,
386 min: 1.0,
387 max: 50.0,
388 col_type: ColumnType::Integer,
389 },
390 ColumnDiffusionParams {
391 name: "category".to_string(),
392 mean: 1.5,
393 std: 0.8,
394 min: 0.0,
395 max: 3.0,
396 col_type: ColumnType::Categorical {
397 categories: vec![
398 "A".to_string(),
399 "B".to_string(),
400 "C".to_string(),
401 "D".to_string(),
402 ],
403 },
404 },
405 ]
406 }
407
408 fn sample_correlation_matrix() -> Vec<Vec<f64>> {
409 vec![
410 vec![1.0, 0.6, 0.2],
411 vec![0.6, 1.0, 0.3],
412 vec![0.2, 0.3, 1.0],
413 ]
414 }
415
416 #[test]
418 fn test_fit_produces_valid_model() {
419 let params = sample_column_params();
420 let corr = sample_correlation_matrix();
421 let config = make_config(100, 42);
422
423 let model = DiffusionTrainer::fit(params.clone(), corr.clone(), config);
424
425 assert_eq!(model.column_params.len(), 3);
426 assert_eq!(model.correlation_matrix.len(), 3);
427 assert_eq!(model.metadata.n_columns, 3);
428 assert_eq!(model.metadata.n_steps, 100);
429 assert_eq!(model.metadata.schedule_type, "linear");
430 assert_eq!(model.metadata.version, "1.0.0");
431 assert!(!model.metadata.training_timestamp.is_empty());
432
433 assert_eq!(model.column_params[0].name, "amount");
435 assert!((model.column_params[0].mean - 100.0).abs() < 1e-10);
436 assert_eq!(model.column_params[1].col_type, ColumnType::Integer);
437 }
438
439 #[test]
441 fn test_save_load_roundtrip() {
442 let model = DiffusionTrainer::fit(
443 sample_column_params(),
444 sample_correlation_matrix(),
445 make_config(50, 42),
446 );
447
448 let dir = tempfile::tempdir().expect("Failed to create temp dir");
449 let path = dir.path().join("model.json");
450
451 model.save(&path).expect("Failed to save model");
452 let loaded = TrainedDiffusionModel::load(&path).expect("Failed to load model");
453
454 assert_eq!(model.column_params.len(), loaded.column_params.len());
456 for (orig, load) in model.column_params.iter().zip(loaded.column_params.iter()) {
457 assert_eq!(orig.name, load.name);
458 assert!((orig.mean - load.mean).abs() < 1e-10);
459 assert!((orig.std - load.std).abs() < 1e-10);
460 assert!((orig.min - load.min).abs() < 1e-10);
461 assert!((orig.max - load.max).abs() < 1e-10);
462 assert_eq!(orig.col_type, load.col_type);
463 }
464 assert_eq!(model.correlation_matrix, loaded.correlation_matrix);
465 assert_eq!(model.config.n_steps, loaded.config.n_steps);
466 assert_eq!(model.config.seed, loaded.config.seed);
467 assert_eq!(model.metadata.version, loaded.metadata.version);
468 assert_eq!(
469 model.metadata.training_timestamp,
470 loaded.metadata.training_timestamp
471 );
472 }
473
474 #[test]
476 fn test_generate_correct_dimensions() {
477 let model = DiffusionTrainer::fit(
478 sample_column_params(),
479 sample_correlation_matrix(),
480 make_config(50, 42),
481 );
482
483 let samples = model.generate(200, 99);
484 assert_eq!(samples.len(), 200);
485 for row in &samples {
486 assert_eq!(row.len(), 3);
487 }
488 }
489
490 #[test]
492 fn test_generated_means_within_tolerance() {
493 let params = sample_column_params();
494 let model = DiffusionTrainer::fit(
495 params.clone(),
496 sample_correlation_matrix(),
497 make_config(100, 42),
498 );
499
500 let samples = model.generate(5000, 42);
501 let n = samples.len() as f64;
502
503 for (j, col) in params.iter().enumerate() {
504 let sample_mean: f64 = samples.iter().map(|r| r[j]).sum::<f64>() / n;
505 let tolerance = 2.0 * col.std;
507 assert!(
508 (sample_mean - col.mean).abs() < tolerance,
509 "Column {} ('{}') mean {} is more than {} from target {}",
510 j,
511 col.name,
512 sample_mean,
513 tolerance,
514 col.mean,
515 );
516 }
517 }
518
519 #[test]
521 fn test_same_seed_deterministic() {
522 let model = DiffusionTrainer::fit(
523 sample_column_params(),
524 sample_correlation_matrix(),
525 make_config(50, 42),
526 );
527
528 let samples1 = model.generate(100, 123);
529 let samples2 = model.generate(100, 123);
530
531 for (row1, row2) in samples1.iter().zip(samples2.iter()) {
532 for (&v1, &v2) in row1.iter().zip(row2.iter()) {
533 assert!(
534 (v1 - v2).abs() < 1e-12,
535 "Determinism failed: {} vs {}",
536 v1,
537 v2,
538 );
539 }
540 }
541 }
542
543 #[test]
545 fn test_different_seeds_differ() {
546 let model = DiffusionTrainer::fit(
547 sample_column_params(),
548 sample_correlation_matrix(),
549 make_config(50, 42),
550 );
551
552 let samples1 = model.generate(100, 1);
553 let samples2 = model.generate(100, 2);
554
555 let mut any_diff = false;
556 for (row1, row2) in samples1.iter().zip(samples2.iter()) {
557 for (&v1, &v2) in row1.iter().zip(row2.iter()) {
558 if (v1 - v2).abs() > 1e-8 {
559 any_diff = true;
560 break;
561 }
562 }
563 if any_diff {
564 break;
565 }
566 }
567 assert!(any_diff, "Different seeds should produce different samples");
568 }
569
570 #[test]
572 fn test_evaluation_reasonable_scores() {
573 let model = DiffusionTrainer::fit(
574 sample_column_params(),
575 sample_correlation_matrix(),
576 make_config(100, 42),
577 );
578
579 let report = DiffusionTrainer::evaluate(&model, 5000, 42);
580
581 for (j, &me) in report.mean_errors.iter().enumerate() {
583 assert!(
584 me < 1.0,
585 "Column {} mean error {} is too large (should be < 1.0)",
586 j,
587 me,
588 );
589 }
590
591 for (j, &se) in report.std_errors.iter().enumerate() {
593 assert!(
594 se < 1.5,
595 "Column {} std error {} is too large (should be < 1.5)",
596 j,
597 se,
598 );
599 }
600
601 assert!(
603 report.correlation_error < 1.0,
604 "Correlation error {} is too large",
605 report.correlation_error,
606 );
607
608 assert!(
610 report.overall_score > 0.0,
611 "Overall score {} should be positive",
612 report.overall_score,
613 );
614 }
615
616 #[test]
618 fn test_correlation_structure_preserved() {
619 let params = vec![
621 ColumnDiffusionParams {
622 name: "x".to_string(),
623 mean: 0.0,
624 std: 1.0,
625 min: -10.0,
626 max: 10.0,
627 col_type: ColumnType::Continuous,
628 },
629 ColumnDiffusionParams {
630 name: "y".to_string(),
631 mean: 0.0,
632 std: 1.0,
633 min: -10.0,
634 max: 10.0,
635 col_type: ColumnType::Continuous,
636 },
637 ];
638 let corr = vec![vec![1.0, 0.8], vec![0.8, 1.0]];
639
640 let model = DiffusionTrainer::fit(params, corr, make_config(100, 42));
641 let samples = model.generate(5000, 42);
642
643 let n = samples.len() as f64;
645 let mean_x: f64 = samples.iter().map(|r| r[0]).sum::<f64>() / n;
646 let mean_y: f64 = samples.iter().map(|r| r[1]).sum::<f64>() / n;
647 let std_x: f64 = (samples.iter().map(|r| (r[0] - mean_x).powi(2)).sum::<f64>() / n).sqrt();
648 let std_y: f64 = (samples.iter().map(|r| (r[1] - mean_y).powi(2)).sum::<f64>() / n).sqrt();
649 let cov_xy: f64 = samples
650 .iter()
651 .map(|r| (r[0] - mean_x) * (r[1] - mean_y))
652 .sum::<f64>()
653 / n;
654
655 let sample_corr = if std_x > 1e-8 && std_y > 1e-8 {
656 cov_xy / (std_x * std_y)
657 } else {
658 0.0
659 };
660
661 assert!(
663 sample_corr > 0.3,
664 "Expected positive correlation near 0.8, got {}",
665 sample_corr,
666 );
667 }
668
669 #[test]
671 fn test_integer_column_produces_integers() {
672 let params = vec![ColumnDiffusionParams {
673 name: "count".to_string(),
674 mean: 10.0,
675 std: 3.0,
676 min: 1.0,
677 max: 50.0,
678 col_type: ColumnType::Integer,
679 }];
680 let corr = vec![vec![1.0]];
681
682 let model = DiffusionTrainer::fit(params, corr, make_config(50, 42));
683 let samples = model.generate(500, 42);
684
685 for row in &samples {
686 let val = row[0];
687 assert!(
688 (val - val.round()).abs() < 1e-10,
689 "Integer column produced non-integer value: {}",
690 val,
691 );
692 assert!(
693 (1.0..=50.0).contains(&val),
694 "Integer column value {} out of range [1, 50]",
695 val,
696 );
697 }
698 }
699
700 #[test]
702 fn test_categorical_column_produces_valid_indices() {
703 let params = vec![ColumnDiffusionParams {
704 name: "category".to_string(),
705 mean: 1.0,
706 std: 0.8,
707 min: 0.0,
708 max: 2.0,
709 col_type: ColumnType::Categorical {
710 categories: vec!["A".to_string(), "B".to_string(), "C".to_string()],
711 },
712 }];
713 let corr = vec![vec![1.0]];
714
715 let model = DiffusionTrainer::fit(params, corr, make_config(50, 42));
716 let samples = model.generate(500, 42);
717
718 for row in &samples {
719 let val = row[0];
720 assert!(
721 (val - val.round()).abs() < 1e-10,
722 "Categorical column produced non-integer index: {}",
723 val,
724 );
725 assert!(
726 (0.0..=2.0).contains(&val),
727 "Categorical index {} out of range [0, 2]",
728 val,
729 );
730 }
731 }
732
733 #[test]
735 fn test_empty_model_generates_empty() {
736 let model = DiffusionTrainer::fit(vec![], vec![], make_config(50, 0));
737 let samples = model.generate(100, 0);
738 assert!(samples.is_empty());
739 }
740
741 #[test]
743 fn test_load_nonexistent_returns_error() {
744 let result = TrainedDiffusionModel::load(Path::new("/tmp/nonexistent_model_12345.json"));
745 assert!(result.is_err());
746 }
747}