use ndarray::ArrayD;
use reasonkit::ml_testing::{
AdversarialGenerator, AttackMethod, EdgeCaseGenerator, EdgeCaseType, FeatureConstraint,
FeatureType, GenerationConfig, InputSchema, SynthesisMethod, SyntheticDataGenerator,
};
use std::collections::HashMap;
struct MockModel;
impl reasonkit::ml_testing::MLModel for MockModel {
fn forward(&self, input: &ArrayD<f32>) -> reasonkit::error::Result<ArrayD<f32>> {
Ok(input.clone())
}
fn gradient(
&self,
input: &ArrayD<f32>,
target: Option<&ArrayD<f32>>,
) -> reasonkit::error::Result<ArrayD<f32>> {
match target {
Some(target) => Ok(target - input),
None => Ok(-input), }
}
fn input_shape(&self) -> Vec<usize> {
vec![784] }
fn output_shape(&self) -> Vec<usize> {
vec![10] }
}
fn main() -> reasonkit::error::Result<()> {
println!("🚀 ML Test Case Generators Demo");
println!("\n1️⃣ Adversarial Examples");
let model = MockModel;
let input = ArrayD::from_elem(vec![784], 0.5);
let attacker = AdversarialGenerator::fgsm(0.1);
match attacker.generate(&model, &input, None) {
Ok(adversarial) => {
println!("✓ Generated FGSM adversarial example");
println!(
" Original norm: {:.3}",
input.iter().map(|x| x * x).sum::<f32>().sqrt()
);
println!(
" Adversarial norm: {:.3}",
adversarial.iter().map(|x| x * x).sum::<f32>().sqrt()
);
}
Err(e) => println!("✗ Failed to generate adversarial example: {}", e),
}
println!("\n2️⃣ Edge Cases");
let mut schema = InputSchema {
features: HashMap::new(),
constraints: HashMap::new(),
};
schema
.features
.insert("age".to_string(), FeatureType::Numeric);
schema.constraints.insert(
"age".to_string(),
FeatureConstraint::Range {
min: 0.0,
max: 100.0,
},
);
schema.features.insert(
"category".to_string(),
FeatureType::Categorical(vec!["A".to_string(), "B".to_string(), "C".to_string()]),
);
let edge_generator = EdgeCaseGenerator::boundary_values();
let config = GenerationConfig {
num_cases: 5,
..Default::default()
};
match edge_generator.generate(&schema, &config) {
Ok(result) => {
println!("✓ Generated {} edge cases", result.test_cases.len());
for (i, test_case) in result.test_cases.iter().enumerate() {
println!(
" Case {}: {} (confidence: {:.2})",
i + 1,
test_case
.metadata
.get("type")
.unwrap_or(&"unknown".to_string()),
test_case.confidence
);
}
}
Err(e) => println!("✗ Failed to generate edge cases: {}", e),
}
println!("\n3. Synthetic Data");
let training_data: Vec<ArrayD<f32>> = vec![
ArrayD::from_shape_vec(ndarray::IxDyn(&[3]), vec![1.0, 2.0, 3.0]).unwrap(),
ArrayD::from_shape_vec(ndarray::IxDyn(&[3]), vec![2.0, 3.0, 4.0]).unwrap(),
ArrayD::from_shape_vec(ndarray::IxDyn(&[3]), vec![3.0, 4.0, 5.0]).unwrap(),
];
let synth_generator = SyntheticDataGenerator::smote();
match synth_generator.generate(&training_data, 5, &config) {
Ok(result) => {
println!("✓ Generated {} synthetic samples", result.test_cases.len());
println!(" Statistics:");
for (key, value) in &result.statistics {
println!(" {}: {:.1}", key, value);
}
}
Err(e) => println!("✗ Failed to generate synthetic data: {}", e),
}
println!("\n✅ Demo completed!");
Ok(())
}