#[cfg(feature = "nn")]
use axonml_autograd::Variable;
#[cfg(feature = "nn")]
use axonml_nn::Module;
#[cfg(feature = "nn")]
use axonml_optim::Optimizer;
#[cfg(feature = "nn")]
use axonml_tensor::Tensor;
#[cfg(feature = "nn")]
pub struct AdversarialTrainer {
gen_optimizer: Box<dyn Optimizer>,
disc_optimizer: Box<dyn Optimizer>,
disc_steps_per_gen: usize,
current_step: usize,
}
#[cfg(feature = "nn")]
impl AdversarialTrainer {
pub fn new(gen_optimizer: Box<dyn Optimizer>, disc_optimizer: Box<dyn Optimizer>) -> Self {
Self {
gen_optimizer,
disc_optimizer,
disc_steps_per_gen: 1,
current_step: 0,
}
}
pub fn set_disc_steps_per_gen(&mut self, steps: usize) {
self.disc_steps_per_gen = steps;
}
pub fn step<G, D, F>(
&mut self,
generator: &G,
discriminator: &D,
real_data: &Variable,
noise: &Variable,
loss_fn: F,
) -> (f32, f32)
where
G: Module,
D: Module,
F: Fn(&Variable, &Variable) -> Variable,
{
self.disc_optimizer.zero_grad();
let real_pred = discriminator.forward(real_data);
let real_labels = Variable::new(
Tensor::from_vec(vec![1.0; real_pred.numel()], &real_pred.shape()).unwrap(),
false,
);
let d_real_loss = loss_fn(&real_pred, &real_labels);
let fake_data = generator.forward(noise);
let fake_pred = discriminator.forward(&fake_data);
let fake_labels = Variable::new(
Tensor::from_vec(vec![0.0; fake_pred.numel()], &fake_pred.shape()).unwrap(),
false,
);
let d_fake_loss = loss_fn(&fake_pred, &fake_labels);
let d_loss = d_real_loss.add_var(&d_fake_loss);
let d_loss_val = d_loss.data().to_vec()[0] * 0.5;
d_loss.backward();
self.disc_optimizer.step();
let mut g_loss_val = 0.0;
self.current_step += 1;
if self.current_step % self.disc_steps_per_gen == 0 {
self.gen_optimizer.zero_grad();
let fake_data = generator.forward(noise);
let fake_pred = discriminator.forward(&fake_data);
let gen_labels = Variable::new(
Tensor::from_vec(vec![1.0; fake_pred.numel()], &fake_pred.shape()).unwrap(),
false,
);
let g_loss = loss_fn(&fake_pred, &gen_labels);
g_loss_val = g_loss.data().to_vec()[0];
g_loss.backward();
self.gen_optimizer.step();
}
(g_loss_val, d_loss_val)
}
pub fn optimizers(&self) -> (&dyn Optimizer, &dyn Optimizer) {
(self.gen_optimizer.as_ref(), self.disc_optimizer.as_ref())
}
}
#[cfg(feature = "nn")]
pub fn fgsm_attack(input: &Variable, grad: &Tensor<f32>, epsilon: f32) -> Variable {
let input_data = input.data().to_vec();
let grad_data = grad.to_vec();
assert_eq!(
input_data.len(),
grad_data.len(),
"Input and gradient must have same size"
);
let perturbed: Vec<f32> = input_data
.iter()
.zip(grad_data.iter())
.map(|(x, g)| {
let sign = if *g > 0.0 {
1.0
} else if *g < 0.0 {
-1.0
} else {
0.0
};
x + epsilon * sign
})
.collect();
Variable::new(Tensor::from_vec(perturbed, &input.shape()).unwrap(), false)
}
#[cfg(feature = "nn")]
pub fn pgd_attack<F>(
model: &dyn Module,
input: &Variable,
target: &Variable,
epsilon: f32,
alpha: f32,
num_steps: usize,
loss_fn: F,
) -> Variable
where
F: Fn(&Variable, &Variable) -> Variable,
{
let original_data = input.data().to_vec();
let shape = input.shape();
let n = original_data.len();
let mut perturbed = original_data.clone();
for _ in 0..num_steps {
let adv_input = Variable::new(Tensor::from_vec(perturbed.clone(), &shape).unwrap(), true);
let output = model.forward(&adv_input);
let loss = loss_fn(&output, target);
loss.backward();
if let Some(grad) = adv_input.grad() {
let grad_data = grad.to_vec();
for i in 0..n {
let sign = if grad_data[i] > 0.0 {
1.0
} else if grad_data[i] < 0.0 {
-1.0
} else {
0.0
};
perturbed[i] += alpha * sign;
let delta = perturbed[i] - original_data[i];
perturbed[i] = original_data[i] + delta.clamp(-epsilon, epsilon);
}
}
}
Variable::new(Tensor::from_vec(perturbed, &shape).unwrap(), false)
}
#[cfg(feature = "nn")]
pub fn adversarial_training_step<F>(
model: &dyn Module,
optimizer: &mut dyn Optimizer,
clean_input: &Variable,
target: &Variable,
epsilon: f32,
pgd_steps: usize,
loss_fn: F,
) -> (f32, f32)
where
F: Fn(&Variable, &Variable) -> Variable,
{
let alpha = epsilon / pgd_steps.max(1) as f32 * 2.0;
let adv_input = pgd_attack(
model,
clean_input,
target,
epsilon,
alpha,
pgd_steps,
&loss_fn,
);
optimizer.zero_grad();
let clean_output = model.forward(clean_input);
let clean_loss = loss_fn(&clean_output, target);
let clean_loss_val = clean_loss.data().to_vec()[0];
let adv_output = model.forward(&adv_input);
let adv_loss = loss_fn(&adv_output, target);
let adv_loss_val = adv_loss.data().to_vec()[0];
let total_loss = clean_loss.add_var(&adv_loss);
total_loss.backward();
optimizer.step();
(clean_loss_val, adv_loss_val)
}
#[cfg(test)]
#[cfg(feature = "nn")]
mod tests {
use super::*;
use axonml_nn::{Linear, ReLU, Sequential};
use std::ops::Neg;
#[test]
fn test_fgsm_attack_perturbation_bound() {
let input_data = vec![1.0, 2.0, 3.0, 4.0];
let input = Variable::new(Tensor::from_vec(input_data.clone(), &[1, 4]).unwrap(), true);
let grad = Tensor::from_vec(vec![0.5, -0.3, 0.0, 1.0], &[1, 4]).unwrap();
let epsilon = 0.1;
let adv = fgsm_attack(&input, &grad, epsilon);
let adv_data = adv.data().to_vec();
for (orig, perturbed) in input_data.iter().zip(adv_data.iter()) {
let delta = (perturbed - orig).abs();
assert!(
delta <= epsilon + 1e-6,
"Perturbation {} exceeds epsilon {}",
delta,
epsilon
);
}
}
#[test]
fn test_fgsm_attack_sign_direction() {
let input = Variable::new(
Tensor::from_vec(vec![0.0, 0.0, 0.0], &[1, 3]).unwrap(),
true,
);
let grad = Tensor::from_vec(vec![1.0, -1.0, 0.0], &[1, 3]).unwrap();
let epsilon = 0.5;
let adv = fgsm_attack(&input, &grad, epsilon);
let adv_data = adv.data().to_vec();
assert!((adv_data[0] - 0.5).abs() < 1e-6); assert!((adv_data[1] - (-0.5)).abs() < 1e-6); assert!((adv_data[2] - 0.0).abs() < 1e-6); }
#[test]
fn test_fgsm_attack_shape_preserved() {
let input = Variable::new(Tensor::from_vec(vec![1.0; 12], &[3, 4]).unwrap(), true);
let grad = Tensor::from_vec(vec![0.1; 12], &[3, 4]).unwrap();
let adv = fgsm_attack(&input, &grad, 0.01);
assert_eq!(adv.shape(), vec![3, 4]);
}
#[test]
fn test_pgd_attack_shape() {
let model = Sequential::new().add(Linear::new(4, 2));
let input = Variable::new(Tensor::from_vec(vec![1.0; 4], &[1, 4]).unwrap(), true);
let target = Variable::new(Tensor::from_vec(vec![1.0, 0.0], &[1, 2]).unwrap(), false);
let adv = pgd_attack(&model, &input, &target, 0.1, 0.03, 3, |pred, tgt| {
let diff = pred.add_var(&tgt.neg());
let sq = diff.mul_var(&diff);
sq.mean()
});
assert_eq!(adv.shape(), vec![1, 4]);
}
#[test]
fn test_pgd_attack_within_epsilon() {
let model = Sequential::new().add(Linear::new(4, 2));
let input_data = vec![1.0, 2.0, 3.0, 4.0];
let input = Variable::new(Tensor::from_vec(input_data.clone(), &[1, 4]).unwrap(), true);
let target = Variable::new(Tensor::from_vec(vec![1.0, 0.0], &[1, 2]).unwrap(), false);
let epsilon = 0.1;
let adv = pgd_attack(&model, &input, &target, epsilon, 0.03, 5, |pred, tgt| {
let diff = pred.add_var(&tgt.neg());
let sq = diff.mul_var(&diff);
sq.mean()
});
let adv_data = adv.data().to_vec();
for (orig, perturbed) in input_data.iter().zip(adv_data.iter()) {
let delta = (perturbed - orig).abs();
assert!(
delta <= epsilon + 1e-5,
"PGD perturbation {} exceeds epsilon {}",
delta,
epsilon
);
}
}
#[test]
fn test_adversarial_trainer_creation() {
use axonml_optim::Adam;
let gen_model = Sequential::new().add(Linear::new(10, 4));
let disc_model = Sequential::new().add(Linear::new(4, 1));
let gen_opt = Box::new(Adam::new(gen_model.parameters(), 0.001));
let disc_opt = Box::new(Adam::new(disc_model.parameters(), 0.001));
let mut trainer = AdversarialTrainer::new(gen_opt, disc_opt);
trainer.set_disc_steps_per_gen(3);
let (gen_opt_ref, disc_opt_ref) = trainer.optimizers();
assert_eq!(gen_opt_ref.parameters().len(), 2); assert_eq!(disc_opt_ref.parameters().len(), 2);
}
}