use crate::error::{Error, Result};
use crate::ml_testing::{utils, GenerationConfig, GenerationResult, TestCase, TestCaseType};
use ndarray::{Array2, ArrayD};
use rand::Rng;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct SyntheticConfig {
pub method: SynthesisMethod,
pub k_neighbors: usize,
pub noise_level: f32,
pub preserve_distribution: bool,
pub max_correlation: f32,
pub seed: Option<u64>,
}
impl Default for SyntheticConfig {
fn default() -> Self {
Self {
method: SynthesisMethod::SMOTE,
k_neighbors: 5,
noise_level: 0.1,
preserve_distribution: true,
max_correlation: 0.8,
seed: None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum SynthesisMethod {
SMOTE,
GaussianCopula,
VAE,
GAN,
CTGAN,
NoiseAugmentation,
}
pub struct SyntheticDataGenerator {
config: SyntheticConfig,
}
impl SyntheticDataGenerator {
pub fn new(method: SynthesisMethod) -> Self {
let config = SyntheticConfig {
method,
..Default::default()
};
Self { config }
}
pub fn smote() -> Self {
Self::new(SynthesisMethod::SMOTE)
}
pub fn gaussian_copula() -> Self {
Self::new(SynthesisMethod::GaussianCopula)
}
pub fn noise_augmentation(noise_level: f32) -> Self {
let config = SyntheticConfig {
method: SynthesisMethod::NoiseAugmentation,
noise_level,
..Default::default()
};
Self { config }
}
pub fn with_config(config: SyntheticConfig) -> Self {
Self { config }
}
pub fn generate(
&self,
training_data: &[ArrayD<f32>],
target_samples: usize,
config: &GenerationConfig,
) -> Result<GenerationResult> {
let mut result = GenerationResult::new();
let mut rng = utils::create_rng(config.seed.or(self.config.seed));
match self.config.method {
SynthesisMethod::SMOTE => {
self.generate_smote(training_data, target_samples, &mut result, &mut rng)?;
}
SynthesisMethod::GaussianCopula => {
self.generate_gaussian_copula(
training_data,
target_samples,
&mut result,
&mut rng,
)?;
}
SynthesisMethod::NoiseAugmentation => {
self.generate_noise_augmentation(
training_data,
target_samples,
&mut result,
&mut rng,
)?;
}
SynthesisMethod::VAE => {
self.generate_vae(training_data, target_samples, &mut result, &mut rng)?;
}
SynthesisMethod::GAN | SynthesisMethod::CTGAN => {
self.generate_gan(training_data, target_samples, &mut result, &mut rng)?;
}
}
result
.statistics
.insert("target_samples".to_string(), target_samples as f64);
result.statistics.insert(
"generated_samples".to_string(),
result.test_cases.len() as f64,
);
Ok(result)
}
fn generate_smote(
&self,
training_data: &[ArrayD<f32>],
target_samples: usize,
result: &mut GenerationResult,
rng: &mut impl Rng,
) -> Result<()> {
if training_data.is_empty() {
return Err(Error::parse("Training data is empty"));
}
let n_samples = training_data.len();
let n_features = training_data[0].len();
let mut data_matrix = Array2::<f32>::zeros((n_samples, n_features));
for (i, sample) in training_data.iter().enumerate() {
for (j, &val) in sample.iter().enumerate() {
data_matrix[[i, j]] = val;
}
}
for _i in 0..target_samples {
if result.test_cases.len() >= target_samples {
break;
}
let random_idx = rng.gen_range(0..n_samples);
let sample = data_matrix.row(random_idx);
let mut distances: Vec<(usize, f32)> = (0..n_samples)
.filter(|&idx| idx != random_idx)
.map(|idx| {
let other = data_matrix.row(idx);
let distance = self.euclidean_distance(sample, other);
(idx, distance)
})
.collect();
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
let k = self.config.k_neighbors.min(distances.len());
let neighbor_idx = rng.gen_range(0..k);
let neighbor = data_matrix.row(distances[neighbor_idx].0);
let mut synthetic = ArrayD::zeros(vec![n_features]);
let gap = rng.gen::<f32>();
for j in 0..n_features {
let diff = neighbor[j] - sample[j];
synthetic[j] = sample[j] + gap * diff;
}
utils::add_noise(&mut synthetic, self.config.noise_level, rng);
let metadata = HashMap::from([
("method".to_string(), "SMOTE".to_string()),
("base_sample".to_string(), random_idx.to_string()),
(
"neighbor_sample".to_string(),
distances[neighbor_idx].0.to_string(),
),
("gap".to_string(), gap.to_string()),
]);
let test_case = TestCase {
input: synthetic,
expected_output: None, case_type: TestCaseType::Synthetic,
method: "SMOTE".to_string(),
confidence: 0.8, metadata,
};
result.test_cases.push(test_case);
}
Ok(())
}
fn generate_gaussian_copula(
&self,
training_data: &[ArrayD<f32>],
target_samples: usize,
result: &mut GenerationResult,
rng: &mut impl Rng,
) -> Result<()> {
if training_data.is_empty() {
return Err(Error::parse("Training data is empty"));
}
let n_features = training_data[0].len();
let mut marginals = Vec::new();
for feature_idx in 0..n_features {
let mut values: Vec<f32> = training_data
.iter()
.map(|sample| sample[feature_idx])
.collect();
values.sort_by(|a, b| a.partial_cmp(b).unwrap());
marginals.push(values);
}
for _ in 0..target_samples {
if result.test_cases.len() >= target_samples {
break;
}
let mut synthetic = ArrayD::zeros(vec![n_features]);
for j in 0..n_features {
let u1: f32 = rng.gen();
let u2: f32 = rng.gen();
let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos();
let marginal = &marginals[j];
let n = marginal.len();
let rank = ((z + 3.0) / 6.0 * (n - 1) as f32) as usize; let rank = rank.min(n - 1);
synthetic[j] = marginal[rank];
}
let metadata = HashMap::from([
("method".to_string(), "GaussianCopula".to_string()),
("correlation_model".to_string(), "simplified".to_string()),
]);
let test_case = TestCase {
input: synthetic,
expected_output: None,
case_type: TestCaseType::Synthetic,
method: "GaussianCopula".to_string(),
confidence: 0.7, metadata,
};
result.test_cases.push(test_case);
}
Ok(())
}
fn generate_noise_augmentation(
&self,
training_data: &[ArrayD<f32>],
target_samples: usize,
result: &mut GenerationResult,
rng: &mut impl Rng,
) -> Result<()> {
for _ in 0..target_samples {
if result.test_cases.len() >= target_samples {
break;
}
let base_idx = rng.gen_range(0..training_data.len());
let mut synthetic = training_data[base_idx].clone();
utils::add_noise(&mut synthetic, self.config.noise_level, rng);
let metadata = HashMap::from([
("method".to_string(), "NoiseAugmentation".to_string()),
("base_sample".to_string(), base_idx.to_string()),
(
"noise_level".to_string(),
self.config.noise_level.to_string(),
),
]);
let test_case = TestCase {
input: synthetic,
expected_output: None, case_type: TestCaseType::Synthetic,
method: "NoiseAugmentation".to_string(),
confidence: 0.9, metadata,
};
result.test_cases.push(test_case);
}
Ok(())
}
fn generate_vae(
&self,
training_data: &[ArrayD<f32>],
target_samples: usize,
result: &mut GenerationResult,
rng: &mut impl Rng,
) -> Result<()> {
result.warnings.push(
"VAE generation not fully implemented - using noise augmentation fallback".to_string(),
);
self.generate_noise_augmentation(training_data, target_samples, result, rng)
}
fn generate_gan(
&self,
training_data: &[ArrayD<f32>],
target_samples: usize,
result: &mut GenerationResult,
rng: &mut impl Rng,
) -> Result<()> {
let method_name = match self.config.method {
SynthesisMethod::GAN => "GAN",
SynthesisMethod::CTGAN => "CTGAN",
_ => "GAN",
};
result.warnings.push(format!(
"{} generation not fully implemented - using SMOTE fallback",
method_name
));
self.generate_smote(training_data, target_samples, result, rng)
}
fn euclidean_distance(&self, a: ndarray::ArrayView1<f32>, b: ndarray::ArrayView1<f32>) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f32>()
.sqrt()
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::ArrayD;
#[test]
fn test_synthetic_config_default() {
let config = SyntheticConfig::default();
assert_eq!(config.method, SynthesisMethod::SMOTE);
assert_eq!(config.k_neighbors, 5);
}
#[test]
fn test_smote_generator() {
let generator = SyntheticDataGenerator::smote();
assert_eq!(generator.config.method, SynthesisMethod::SMOTE);
}
#[test]
fn test_noise_augmentation_generator() {
let generator = SyntheticDataGenerator::noise_augmentation(0.2);
assert_eq!(generator.config.method, SynthesisMethod::NoiseAugmentation);
assert_eq!(generator.config.noise_level, 0.2);
}
#[test]
fn test_generate_smote() {
let training_data: Vec<ArrayD<f32>> = vec![
ArrayD::from_shape_vec(ndarray::IxDyn(&[2]), vec![1.0, 2.0]).unwrap(),
ArrayD::from_shape_vec(ndarray::IxDyn(&[2]), vec![2.0, 3.0]).unwrap(),
ArrayD::from_shape_vec(ndarray::IxDyn(&[2]), vec![3.0, 4.0]).unwrap(),
];
let generator = SyntheticDataGenerator::smote();
let config = GenerationConfig {
num_cases: 5,
..Default::default()
};
let result = generator.generate(&training_data, 5, &config);
assert!(result.is_ok());
let result = result.unwrap();
assert_eq!(result.test_cases.len(), 5);
for test_case in &result.test_cases {
assert_eq!(test_case.case_type, TestCaseType::Synthetic);
assert_eq!(test_case.method, "SMOTE");
}
}
#[test]
fn test_generate_noise_augmentation() {
let training_data: Vec<ArrayD<f32>> =
vec![ArrayD::from_shape_vec(ndarray::IxDyn(&[2]), vec![1.0, 2.0]).unwrap()];
let generator = SyntheticDataGenerator::noise_augmentation(0.1);
let config = GenerationConfig {
num_cases: 3,
..Default::default()
};
let result = generator.generate(&training_data, 3, &config);
assert!(result.is_ok());
let result = result.unwrap();
assert_eq!(result.test_cases.len(), 3);
for test_case in &result.test_cases {
assert_eq!(test_case.case_type, TestCaseType::Synthetic);
assert_eq!(test_case.method, "NoiseAugmentation");
}
}
}