reasonkit-core 0.1.8

The Reasoning Engine — Auditable Reasoning for Production AI | Rust-Native | Turn Prompts into Protocols
//! # ML Test Case Generators - Example Usage
//!
//! This example demonstrates how to use the ML test case generators
//! to create adversarial examples, edge cases, and synthetic data.

use ndarray::ArrayD;
use reasonkit::ml_testing::{
    AdversarialGenerator, AttackMethod, EdgeCaseGenerator, EdgeCaseType, FeatureConstraint,
    FeatureType, GenerationConfig, InputSchema, SynthesisMethod, SyntheticDataGenerator,
};
use std::collections::HashMap;

/// Mock ML model for demonstration
struct MockModel;

impl reasonkit::ml_testing::MLModel for MockModel {
    fn forward(&self, input: &ArrayD<f32>) -> reasonkit::error::Result<ArrayD<f32>> {
        // Simple mock: return input as-is for classification
        Ok(input.clone())
    }

    fn gradient(
        &self,
        input: &ArrayD<f32>,
        target: Option<&ArrayD<f32>>,
    ) -> reasonkit::error::Result<ArrayD<f32>> {
        // Mock gradient: return difference from target or random direction
        match target {
            Some(target) => Ok(target - input),
            None => Ok(-input), // Gradient descent towards zero
        }
    }

    fn input_shape(&self) -> Vec<usize> {
        vec![784] // Example: flattened 28x28 image
    }

    fn output_shape(&self) -> Vec<usize> {
        vec![10] // 10 classes
    }
}

fn main() -> reasonkit::error::Result<()> {
    println!("🚀 ML Test Case Generators Demo");

    // 1. Adversarial Example Generation
    println!("\n1️⃣ Adversarial Examples");
    let model = MockModel;
    let input = ArrayD::from_elem(vec![784], 0.5); // Mock input

    let attacker = AdversarialGenerator::fgsm(0.1);
    match attacker.generate(&model, &input, None) {
        Ok(adversarial) => {
            println!("✓ Generated FGSM adversarial example");
            println!(
                "  Original norm: {:.3}",
                input.iter().map(|x| x * x).sum::<f32>().sqrt()
            );
            println!(
                "  Adversarial norm: {:.3}",
                adversarial.iter().map(|x| x * x).sum::<f32>().sqrt()
            );
        }
        Err(e) => println!("✗ Failed to generate adversarial example: {}", e),
    }

    // 2. Edge Case Generation
    println!("\n2️⃣ Edge Cases");
    let mut schema = InputSchema {
        features: HashMap::new(),
        constraints: HashMap::new(),
    };

    // Add a numeric feature
    schema
        .features
        .insert("age".to_string(), FeatureType::Numeric);
    schema.constraints.insert(
        "age".to_string(),
        FeatureConstraint::Range {
            min: 0.0,
            max: 100.0,
        },
    );

    // Add a categorical feature
    schema.features.insert(
        "category".to_string(),
        FeatureType::Categorical(vec!["A".to_string(), "B".to_string(), "C".to_string()]),
    );

    let edge_generator = EdgeCaseGenerator::boundary_values();
    let config = GenerationConfig {
        num_cases: 5,
        ..Default::default()
    };

    match edge_generator.generate(&schema, &config) {
        Ok(result) => {
            println!("✓ Generated {} edge cases", result.test_cases.len());
            for (i, test_case) in result.test_cases.iter().enumerate() {
                println!(
                    "  Case {}: {} (confidence: {:.2})",
                    i + 1,
                    test_case
                        .metadata
                        .get("type")
                        .unwrap_or(&"unknown".to_string()),
                    test_case.confidence
                );
            }
        }
        Err(e) => println!("✗ Failed to generate edge cases: {}", e),
    }

    // 3. Synthetic Data Generation
    println!("\n3. Synthetic Data");
    let training_data: Vec<ArrayD<f32>> = vec![
        ArrayD::from_shape_vec(ndarray::IxDyn(&[3]), vec![1.0, 2.0, 3.0]).unwrap(),
        ArrayD::from_shape_vec(ndarray::IxDyn(&[3]), vec![2.0, 3.0, 4.0]).unwrap(),
        ArrayD::from_shape_vec(ndarray::IxDyn(&[3]), vec![3.0, 4.0, 5.0]).unwrap(),
    ];

    let synth_generator = SyntheticDataGenerator::smote();
    match synth_generator.generate(&training_data, 5, &config) {
        Ok(result) => {
            println!("✓ Generated {} synthetic samples", result.test_cases.len());
            println!("  Statistics:");
            for (key, value) in &result.statistics {
                println!("    {}: {:.1}", key, value);
            }
        }
        Err(e) => println!("✗ Failed to generate synthetic data: {}", e),
    }

    println!("\n✅ Demo completed!");
    Ok(())
}