use crate::error::Result;
use crate::ml_testing::{
utils, GenerationConfig, GenerationResult, MLModel, TestCase, TestCaseType,
};
use ndarray::ArrayD;
use rand::Rng;
#[derive(Debug, Clone)]
pub struct AdversarialConfig {
pub method: AttackMethod,
pub epsilon: f32,
pub step_size: f32,
pub num_iterations: usize,
pub confidence: f32,
pub learning_rate: f32,
pub max_iterations: usize,
pub random_restarts: usize,
}
impl Default for AdversarialConfig {
fn default() -> Self {
Self {
method: AttackMethod::FGSM,
epsilon: 0.1,
step_size: 0.01,
num_iterations: 10,
confidence: 0.0,
learning_rate: 0.01,
max_iterations: 1000,
random_restarts: 0,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum AttackMethod {
FGSM,
PGD,
BIM,
CW,
}
pub struct AdversarialGenerator {
config: AdversarialConfig,
}
impl AdversarialGenerator {
pub fn new(method: AttackMethod, epsilon: f32) -> Self {
let config = AdversarialConfig {
method,
epsilon,
..Default::default()
};
Self { config }
}
pub fn with_config(config: AdversarialConfig) -> Self {
Self { config }
}
pub fn fgsm(epsilon: f32) -> Self {
Self::new(AttackMethod::FGSM, epsilon)
}
pub fn pgd(epsilon: f32, num_iterations: usize) -> Self {
let config = AdversarialConfig {
method: AttackMethod::PGD,
epsilon,
num_iterations,
..Default::default()
};
Self { config }
}
pub fn generate<M: MLModel>(
&self,
model: &M,
input: &ArrayD<f32>,
target: Option<&ArrayD<f32>>,
) -> Result<ArrayD<f32>> {
match self.config.method {
AttackMethod::FGSM => self.fgsm_attack(model, input, target),
AttackMethod::PGD => self.pgd_attack(model, input, target),
AttackMethod::BIM => self.bim_attack(model, input, target),
AttackMethod::CW => self.cw_attack(model, input, target),
}
}
pub fn generate_batch<M: MLModel>(
&self,
model: &M,
inputs: &[ArrayD<f32>],
targets: Option<&[ArrayD<f32>]>,
config: &GenerationConfig,
) -> Result<GenerationResult> {
let mut result = GenerationResult::new();
let mut successful_attacks = 0;
let mut rng = utils::create_rng(config.seed);
for (i, input) in inputs.iter().enumerate() {
if result.test_cases.len() >= config.num_cases {
break;
}
let target = targets.and_then(|t| t.get(i));
match self.generate(model, input, target) {
Ok(adversarial) => {
let original_pred = model.forward(input)?;
let adversarial_pred = model.forward(&adversarial)?;
let original_class = original_pred
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.map(|(i, _)| i)
.unwrap_or(0);
let adversarial_class = adversarial_pred
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.map(|(i, _)| i)
.unwrap_or(0);
let success = original_class != adversarial_class;
if success || rng.gen_bool(config.target_success_rate) {
let mut metadata = std::collections::HashMap::new();
metadata.insert("original_class".to_string(), original_class.to_string());
metadata.insert(
"adversarial_class".to_string(),
adversarial_class.to_string(),
);
metadata.insert("attack_success".to_string(), success.to_string());
metadata.insert("method".to_string(), format!("{:?}", self.config.method));
let test_case = TestCase {
input: adversarial,
expected_output: target.cloned(),
case_type: TestCaseType::Adversarial,
method: format!("{:?}", self.config.method),
confidence: if success { 1.0 } else { 0.0 },
metadata,
};
result.test_cases.push(test_case);
if success {
successful_attacks += 1;
}
}
}
Err(e) => {
result.warnings.push(format!(
"Failed to generate adversarial example {}: {}",
i, e
));
}
}
}
result.success_rate = if !result.test_cases.is_empty() {
successful_attacks as f64 / result.test_cases.len() as f64
} else {
0.0
};
result
.statistics
.insert("total_attempts".to_string(), inputs.len() as f64);
result.statistics.insert(
"successful_generations".to_string(),
result.test_cases.len() as f64,
);
result
.statistics
.insert("success_rate".to_string(), result.success_rate);
Ok(result)
}
fn fgsm_attack<M: MLModel>(
&self,
model: &M,
input: &ArrayD<f32>,
target: Option<&ArrayD<f32>>,
) -> Result<ArrayD<f32>> {
let grad = model.gradient(input, target)?;
let perturbation = utils::sign(&grad);
let perturbation = &perturbation * self.config.epsilon;
let mut adversarial = input.clone();
adversarial += &perturbation;
utils::clip(&mut adversarial, 0.0, 1.0);
Ok(adversarial)
}
fn pgd_attack<M: MLModel>(
&self,
model: &M,
input: &ArrayD<f32>,
target: Option<&ArrayD<f32>>,
) -> Result<ArrayD<f32>> {
let mut adversarial = input.clone();
let mut rng = utils::create_rng(None);
for _restart in 0..=self.config.random_restarts {
let mut current = if _restart == 0 {
input.clone()
} else {
let mut init = input.clone();
utils::add_noise(&mut init, self.config.epsilon * 0.1, &mut rng);
utils::clip(&mut init, 0.0, 1.0);
init
};
for _ in 0..self.config.num_iterations {
let grad = model.gradient(¤t, target)?;
let perturbation = utils::sign(&grad);
let perturbation = &perturbation * self.config.step_size;
current += &perturbation;
let diff = ¤t - input;
let norm = utils::l2_norm(&diff);
if norm > self.config.epsilon {
let scale = self.config.epsilon / norm;
current = input + &(&diff * scale);
}
utils::clip(&mut current, 0.0, 1.0);
}
let current_pred = model.forward(¤t)?;
let adversarial_pred = model.forward(&adversarial)?;
let current_loss = self.compute_loss(¤t_pred, target)?;
let adversarial_loss = self.compute_loss(&adversarial_pred, target)?;
if current_loss > adversarial_loss {
adversarial = current;
}
}
Ok(adversarial)
}
fn bim_attack<M: MLModel>(
&self,
model: &M,
input: &ArrayD<f32>,
target: Option<&ArrayD<f32>>,
) -> Result<ArrayD<f32>> {
let mut adversarial = input.clone();
for _ in 0..self.config.num_iterations {
let grad = model.gradient(&adversarial, target)?;
let perturbation = utils::sign(&grad);
let perturbation = &perturbation * self.config.step_size;
adversarial += &perturbation;
let diff = &adversarial - input;
let norm = utils::l2_norm(&diff);
if norm > self.config.epsilon {
let scale = self.config.epsilon / norm;
adversarial = input + &(&diff * scale);
}
utils::clip(&mut adversarial, 0.0, 1.0);
}
Ok(adversarial)
}
fn cw_attack<M: MLModel>(
&self,
model: &M,
input: &ArrayD<f32>,
target: Option<&ArrayD<f32>>,
) -> Result<ArrayD<f32>> {
let adversarial = input.clone();
let mut best_loss = f32::INFINITY;
for _ in 0..self.config.max_iterations {
let grad = model.gradient(&adversarial, target)?;
let perturbation = &grad * self.config.learning_rate;
let mut candidate = &adversarial + perturbation;
utils::clip(&mut candidate, 0.0, 1.0);
let pred = model.forward(&adversarial)?;
let loss = self.compute_loss(&pred, target)?;
if loss < best_loss && self.is_adversarial(&pred, target) {
best_loss = loss;
}
if let Some(target) = target {
let target_class = target
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.map(|(i, _)| i)
.unwrap_or(0);
if pred[target_class] > self.config.confidence {
break;
}
}
}
Ok(adversarial)
}
fn compute_loss(&self, prediction: &ArrayD<f32>, target: Option<&ArrayD<f32>>) -> Result<f32> {
match target {
Some(target) => {
let mut loss = 0.0;
for (pred, targ) in prediction.iter().zip(target.iter()) {
if *targ > 0.0 {
loss -= targ * pred.ln();
}
}
Ok(loss)
}
None => {
Ok(-prediction.iter().cloned().fold(f32::NEG_INFINITY, f32::max))
}
}
}
fn is_adversarial(&self, prediction: &ArrayD<f32>, target: Option<&ArrayD<f32>>) -> bool {
match target {
Some(target) => {
let pred_class = prediction
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.map(|(i, _)| i)
.unwrap_or(0);
let target_class = target
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.map(|(i, _)| i)
.unwrap_or(0);
pred_class != target_class
}
None => {
prediction.iter().cloned().fold(f32::NEG_INFINITY, f32::max) < 0.5
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::ArrayD;
struct MockModel {
weights: ArrayD<f32>,
}
impl MockModel {
fn new() -> Self {
Self {
weights: ArrayD::from_elem(vec![10, 10], 1.0),
}
}
}
impl MLModel for MockModel {
fn forward(&self, input: &ArrayD<f32>) -> Result<ArrayD<f32>> {
Ok(input.clone())
}
fn gradient(
&self,
input: &ArrayD<f32>,
_target: Option<&ArrayD<f32>>,
) -> Result<ArrayD<f32>> {
Ok(&self.weights - input)
}
fn input_shape(&self) -> Vec<usize> {
vec![10, 10]
}
fn output_shape(&self) -> Vec<usize> {
vec![10, 10]
}
}
#[test]
fn test_adversarial_config_default() {
let config = AdversarialConfig::default();
assert_eq!(config.epsilon, 0.1);
assert_eq!(config.method, AttackMethod::FGSM);
}
#[test]
fn test_fgsm_generator() {
let generator = AdversarialGenerator::fgsm(0.1);
assert_eq!(generator.config.method, AttackMethod::FGSM);
assert_eq!(generator.config.epsilon, 0.1);
}
#[test]
fn test_pgd_generator() {
let generator = AdversarialGenerator::pgd(0.1, 20);
assert_eq!(generator.config.method, AttackMethod::PGD);
assert_eq!(generator.config.epsilon, 0.1);
assert_eq!(generator.config.num_iterations, 20);
}
#[test]
fn test_fgsm_attack() {
let model = MockModel::new();
let generator = AdversarialGenerator::fgsm(0.1);
let input = ArrayD::from_elem(vec![10, 10], 0.5);
let result = generator.generate(&model, &input, None);
assert!(result.is_ok());
}
}