1use ghostflow_core::Tensor;
11use rand::Rng;
12
13#[derive(Debug, Clone)]
15pub struct AttackConfig {
16 pub attack_type: AttackType,
18 pub epsilon: f32,
20 pub num_iterations: usize,
22 pub step_size: f32,
24 pub random_init: bool,
26}
27
28impl Default for AttackConfig {
29 fn default() -> Self {
30 AttackConfig {
31 attack_type: AttackType::PGD,
32 epsilon: 0.3,
33 num_iterations: 40,
34 step_size: 0.01,
35 random_init: true,
36 }
37 }
38}
39
40#[derive(Debug, Clone, Copy, PartialEq)]
41pub enum AttackType {
42 FGSM,
44 PGD,
46 CW,
48 DeepFool,
50}
51
52pub struct AdversarialAttack {
54 config: AttackConfig,
55}
56
57impl AdversarialAttack {
58 pub fn new(config: AttackConfig) -> Self {
60 AdversarialAttack { config }
61 }
62
63 pub fn fgsm(&self, input: &Tensor, gradient: &Tensor) -> Result<Tensor, String> {
65 let input_data = input.data_f32();
66 let grad_data = gradient.data_f32();
67
68 if input_data.len() != grad_data.len() {
69 return Err("Input and gradient dimensions must match".to_string());
70 }
71
72 let perturbed: Vec<f32> = input_data.iter().zip(grad_data.iter()).map(|(&x, &g)| {
74 let sign = if g > 0.0 { 1.0 } else if g < 0.0 { -1.0 } else { 0.0 };
75 let perturbed_x = x + self.config.epsilon * sign;
76 perturbed_x.max(0.0).min(1.0)
78 }).collect();
79
80 Tensor::from_slice(&perturbed, input.dims())
81 .map_err(|e| format!("Failed to create perturbed tensor: {:?}", e))
82 }
83
84 pub fn pgd(&self, input: &Tensor, compute_gradient: impl Fn(&Tensor) -> Result<Tensor, String>)
86 -> Result<Tensor, String> {
87 let input_data = input.data_f32();
88 let mut perturbed_data = input_data.clone();
89
90 if self.config.random_init {
92 let mut rng = rand::thread_rng();
93 for (p, &x) in perturbed_data.iter_mut().zip(input_data.iter()) {
94 let noise: f32 = rng.gen_range(-self.config.epsilon..self.config.epsilon);
95 *p = (x + noise).max(0.0).min(1.0);
96 }
97 }
98
99 for _ in 0..self.config.num_iterations {
101 let perturbed = Tensor::from_slice(&perturbed_data, input.dims())
103 .map_err(|e| format!("Failed to create tensor: {:?}", e))?;
104
105 let gradient = compute_gradient(&perturbed)?;
107 let grad_data = gradient.data_f32();
108
109 for (p, &g) in perturbed_data.iter_mut().zip(grad_data.iter()) {
111 let sign = if g > 0.0 { 1.0 } else if g < 0.0 { -1.0 } else { 0.0 };
112 *p += self.config.step_size * sign;
113 }
114
115 for (p, &x) in perturbed_data.iter_mut().zip(input_data.iter()) {
117 let delta = (*p - x).max(-self.config.epsilon).min(self.config.epsilon);
119 *p = (x + delta).max(0.0).min(1.0);
120 }
121 }
122
123 Tensor::from_slice(&perturbed_data, input.dims())
124 .map_err(|e| format!("Failed to create final tensor: {:?}", e))
125 }
126
127 pub fn cw(&self, input: &Tensor, target_class: usize,
129 compute_logits: impl Fn(&Tensor) -> Result<Tensor, String>)
130 -> Result<Tensor, String> {
131 let input_data = input.data_f32();
132 let mut perturbed_data = input_data.clone();
133 let c = 1.0; for _ in 0..self.config.num_iterations {
136 let perturbed = Tensor::from_slice(&perturbed_data, input.dims())
138 .map_err(|e| format!("Failed to create tensor: {:?}", e))?;
139
140 let logits = compute_logits(&perturbed)?;
142 let logits_data = logits.data_f32();
143
144 if logits_data.len() <= target_class {
145 return Err("Invalid target class".to_string());
146 }
147
148 let target_logit = logits_data[target_class];
150 let max_other = logits_data.iter()
151 .enumerate()
152 .filter(|(i, _)| *i != target_class)
153 .map(|(_, &l)| l)
154 .fold(f32::NEG_INFINITY, f32::max);
155
156 let loss = (max_other - target_logit + c).max(0.0);
157
158 for (p, &x) in perturbed_data.iter_mut().zip(input_data.iter()) {
160 let grad_sign = if loss > 0.0 { 1.0 } else { -1.0 };
161 *p -= self.config.step_size * grad_sign;
162
163 let delta = (*p - x).max(-self.config.epsilon).min(self.config.epsilon);
165 *p = (x + delta).max(0.0).min(1.0);
166 }
167 }
168
169 Tensor::from_slice(&perturbed_data, input.dims())
170 .map_err(|e| format!("Failed to create final tensor: {:?}", e))
171 }
172
173 pub fn deepfool(&self, input: &Tensor,
175 compute_gradient: impl Fn(&Tensor, usize) -> Result<Tensor, String>,
176 num_classes: usize)
177 -> Result<Tensor, String> {
178 let input_data = input.data_f32();
179 let mut perturbed_data = input_data.clone();
180
181 for _ in 0..self.config.num_iterations {
182 let mut min_distance = f32::INFINITY;
183 let mut best_perturbation = vec![0.0; input_data.len()];
184
185 let perturbed = Tensor::from_slice(&perturbed_data, input.dims())
187 .map_err(|e| format!("Failed to create tensor: {:?}", e))?;
188
189 for class in 0..num_classes {
191 let gradient = compute_gradient(&perturbed, class)?;
192 let grad_data = gradient.data_f32();
193
194 let grad_norm: f32 = grad_data.iter().map(|g| g * g).sum::<f32>().sqrt();
196 if grad_norm > 1e-8 {
197 let distance = 1.0 / grad_norm;
198
199 if distance < min_distance {
200 min_distance = distance;
201 best_perturbation = grad_data.iter()
202 .map(|&g| (distance * g / grad_norm).min(self.config.epsilon))
203 .collect();
204 }
205 }
206 }
207
208 for (p, (&x, &delta)) in perturbed_data.iter_mut()
210 .zip(input_data.iter().zip(best_perturbation.iter())) {
211 *p = (x + delta).max(0.0).min(1.0);
212 }
213 }
214
215 Tensor::from_slice(&perturbed_data, input.dims())
216 .map_err(|e| format!("Failed to create final tensor: {:?}", e))
217 }
218}
219
220#[derive(Debug, Clone)]
222pub struct AdversarialTrainingConfig {
223 pub adversarial_ratio: f32,
225 pub attack_config: AttackConfig,
227 pub label_smoothing: f32,
229}
230
231impl Default for AdversarialTrainingConfig {
232 fn default() -> Self {
233 AdversarialTrainingConfig {
234 adversarial_ratio: 0.5,
235 attack_config: AttackConfig::default(),
236 label_smoothing: 0.1,
237 }
238 }
239}
240
241pub struct AdversarialTrainer {
243 config: AdversarialTrainingConfig,
244 attack: AdversarialAttack,
245}
246
247impl AdversarialTrainer {
248 pub fn new(config: AdversarialTrainingConfig) -> Self {
250 let attack = AdversarialAttack::new(config.attack_config.clone());
251 AdversarialTrainer { config, attack }
252 }
253
254 pub fn generate_training_batch(&self,
256 clean_inputs: &[Tensor],
257 compute_gradient: impl Fn(&Tensor) -> Result<Tensor, String>)
258 -> Result<Vec<Tensor>, String> {
259 let num_adversarial = (clean_inputs.len() as f32 * self.config.adversarial_ratio) as usize;
260 let mut batch = Vec::with_capacity(clean_inputs.len());
261
262 for input in clean_inputs.iter().skip(num_adversarial) {
264 batch.push(input.clone());
265 }
266
267 for input in clean_inputs.iter().take(num_adversarial) {
269 let adv_example = self.attack.pgd(input, &compute_gradient)?;
270 batch.push(adv_example);
271 }
272
273 Ok(batch)
274 }
275
276 pub fn smooth_labels(&self, labels: &Tensor, num_classes: usize) -> Result<Tensor, String> {
278 let label_data = labels.data_f32();
279 let smoothing = self.config.label_smoothing;
280 let mut smoothed = Vec::with_capacity(label_data.len() * num_classes);
281
282 for &label in label_data.iter() {
283 let label_idx = label as usize;
284 if label_idx >= num_classes {
285 return Err("Label index out of bounds".to_string());
286 }
287
288 for i in 0..num_classes {
289 if i == label_idx {
290 smoothed.push(1.0 - smoothing);
291 } else {
292 smoothed.push(smoothing / (num_classes - 1) as f32);
293 }
294 }
295 }
296
297 Tensor::from_slice(&smoothed, &[label_data.len(), num_classes])
298 .map_err(|e| format!("Failed to create smoothed labels: {:?}", e))
299 }
300}
301
302pub struct RandomizedSmoothing {
304 pub sigma: f32,
306 pub num_samples: usize,
308 pub alpha: f32,
310}
311
312impl RandomizedSmoothing {
313 pub fn new(sigma: f32, num_samples: usize, alpha: f32) -> Self {
315 RandomizedSmoothing {
316 sigma,
317 num_samples,
318 alpha,
319 }
320 }
321
322 pub fn predict(&self, input: &Tensor,
324 model_predict: impl Fn(&Tensor) -> Result<usize, String>)
325 -> Result<usize, String> {
326 let mut rng = rand::thread_rng();
327 let input_data = input.data_f32();
328 let mut class_counts = std::collections::HashMap::new();
329
330 for _ in 0..self.num_samples {
332 let noisy_input: Vec<f32> = input_data.iter().map(|&x| {
333 let noise: f32 = rng.gen::<f32>() * self.sigma;
334 (x + noise).max(0.0).min(1.0)
335 }).collect();
336
337 let noisy_tensor = Tensor::from_slice(&noisy_input, input.dims())
338 .map_err(|e| format!("Failed to create noisy tensor: {:?}", e))?;
339 let pred = model_predict(&noisy_tensor)?;
340 *class_counts.entry(pred).or_insert(0) += 1;
341 }
342
343 class_counts.into_iter()
345 .max_by_key(|(_, count)| *count)
346 .map(|(class, _)| class)
347 .ok_or_else(|| "No predictions generated".to_string())
348 }
349
350 pub fn certify(&self, input: &Tensor,
352 model_predict: impl Fn(&Tensor) -> Result<usize, String>)
353 -> Result<(usize, f32), String> {
354 let predicted_class = self.predict(input, &model_predict)?;
355
356 let mut rng = rand::thread_rng();
358 let input_data = input.data_f32();
359 let mut correct_count = 0;
360
361 for _ in 0..self.num_samples {
362 let noisy_input: Vec<f32> = input_data.iter().map(|&x| {
363 let noise: f32 = rng.gen::<f32>() * self.sigma;
364 (x + noise).max(0.0).min(1.0)
365 }).collect();
366
367 let noisy_tensor = Tensor::from_slice(&noisy_input, input.dims())
368 .map_err(|e| format!("Failed to create noisy tensor: {:?}", e))?;
369 let pred = model_predict(&noisy_tensor)?;
370 if pred == predicted_class {
371 correct_count += 1;
372 }
373 }
374
375 let p_lower = (correct_count as f32 / self.num_samples as f32) - self.alpha;
376
377 let radius = if p_lower > 0.5 {
379 self.sigma * (2.0 * p_lower - 1.0).sqrt()
380 } else {
381 0.0
382 };
383
384 Ok((predicted_class, radius))
385 }
386}
387
388#[cfg(test)]
389mod tests {
390 use super::*;
391
392 #[test]
393 fn test_fgsm_attack() {
394 let config = AttackConfig {
395 attack_type: AttackType::FGSM,
396 epsilon: 0.1,
397 ..Default::default()
398 };
399 let attack = AdversarialAttack::new(config);
400
401 let input = Tensor::from_slice(&[0.5, 0.5, 0.5, 0.5], &[4]).unwrap();
402 let gradient = Tensor::from_slice(&[1.0, -1.0, 0.5, -0.5], &[4]).unwrap();
403
404 let adv = attack.fgsm(&input, &gradient).unwrap();
405 let adv_data = adv.data_f32();
406
407 assert!((adv_data[0] - 0.6).abs() < 1e-5); assert!((adv_data[1] - 0.4).abs() < 1e-5); }
411
412 #[test]
413 fn test_label_smoothing() {
414 let config = AdversarialTrainingConfig {
415 label_smoothing: 0.1,
416 ..Default::default()
417 };
418 let trainer = AdversarialTrainer::new(config);
419
420 let labels = Tensor::from_slice(&[0.0, 1.0, 2.0], &[3]).unwrap();
421 let smoothed = trainer.smooth_labels(&labels, 3).unwrap();
422
423 assert_eq!(smoothed.dims(), &[3, 3]);
424 let data = smoothed.data_f32();
425
426 assert!((data[0] - 0.9).abs() < 1e-5); assert!((data[1] - 0.05).abs() < 1e-5); assert!((data[2] - 0.05).abs() < 1e-5); }
431
432 #[test]
433 fn test_randomized_smoothing() {
434 let smoothing = RandomizedSmoothing::new(0.1, 100, 0.05);
435
436 let input = Tensor::from_slice(&[0.5; 10], &[10]).unwrap();
437
438 let model = |_: &Tensor| Ok(0);
440
441 let pred = smoothing.predict(&input, model).unwrap();
442 assert_eq!(pred, 0);
443 }
444}