1#[allow(dead_code)]
8use scirs2_core::ndarray::{Array, ArrayBase, Data, DataMut, Dimension};
9use scirs2_core::numeric::Float;
10use scirs2_core::random::RandNormal;
11use scirs2_core::random::{thread_rng, Rng};
12use std::collections::{HashMap, VecDeque};
13
14use super::moment_accountant::MomentsAccountant;
15use super::{DifferentialPrivacyConfig, NoiseMechanism, PrivacyBudget};
16use crate::error::{OptimError, Result};
17use crate::optimizers::Optimizer;
18
19pub struct DPSGDOptimizer<O, A, D>
21where
22 A: Float
23 + Send
24 + Sync
25 + scirs2_core::ndarray::ScalarOperand
26 + std::fmt::Debug
27 + Default
28 + Clone
29 + std::iter::Sum,
30 D: scirs2_core::ndarray::Dimension,
31 O: Optimizer<A, D>,
32{
33 baseoptimizer: O,
35
36 config: DifferentialPrivacyConfig,
38
39 accountant: MomentsAccountant,
41
42 rng: scirs2_core::random::CoreRandom,
44
45 adaptive_clipping: Option<AdaptiveClippingState>,
47
48 privacy_budget: PrivacyBudgetTracker,
50
51 gradient_stats: GradientStatistics<A>,
53
54 noise_calibrator: NoiseCalibrator<A>,
56
57 step_count: usize,
59
60 current_batch_size: usize,
62
63 _phantom: std::marker::PhantomData<D>,
65}
66
67#[derive(Debug, Clone)]
69struct AdaptiveClippingState {
70 current_threshold: f64,
72
73 target_quantile: f64,
75
76 adaptationlr: f64,
78
79 norm_history: VecDeque<f64>,
81
82 update_frequency: usize,
84
85 last_update_step: usize,
87
88 quantile_estimator: QuantileEstimator,
90}
91
92#[derive(Debug, Clone)]
94struct QuantileEstimator {
95 p2_state: P2AlgorithmState,
97
98 moving_avg: f64,
100
101 ema: f64,
103
104 ema_decay: f64,
106}
107
108#[derive(Debug, Clone)]
110struct P2AlgorithmState {
111 markers: [f64; 5],
113
114 values: [f64; 5],
116
117 desired_positions: [f64; 5],
119
120 increments: [f64; 5],
122
123 count: usize,
125}
126
127#[derive(Debug, Clone)]
129struct PrivacyBudgetTracker {
130 epsilon_consumed: f64,
132
133 delta_consumed: f64,
135
136 target_epsilon: f64,
138
139 target_delta: f64,
141
142 epsilon_per_step: f64,
144
145 delta_per_step: f64,
147
148 consumption_history: Vec<PrivacyConsumption>,
150}
151
152#[derive(Debug, Clone)]
154pub struct PrivacyConsumption {
155 step: usize,
156 epsilon_spent: f64,
157 delta_spent: f64,
158 batchsize: usize,
159 noise_multiplier: f64,
160}
161
162#[derive(Debug, Clone)]
164struct GradientStatistics<A: Float + Default + Clone + std::iter::Sum> {
165 norm_history: VecDeque<A>,
167
168 clipping_frequency: f64,
170
171 avg_norm: A,
173
174 std_norm: A,
176
177 percentiles: HashMap<String, A>,
179
180 max_history_size: usize,
182}
183
184#[derive(Debug, Clone)]
186struct NoiseCalibrator<A: Float> {
187 noise_multiplier: A,
189
190 base_noise_scale: A,
192
193 adaptive_scaling: bool,
195
196 mechanism: NoiseMechanism,
198
199 calibration_history: Vec<NoiseCalibration<A>>,
201}
202
203#[derive(Debug, Clone)]
205pub struct NoiseCalibration<A: Float> {
206 step: usize,
207 noise_scale: A,
208 gradientnorm: A,
209 clipping_threshold: A,
210 privacy_cost: A,
211}
212
213impl<O, A, D> DPSGDOptimizer<O, A, D>
214where
215 A: Float
216 + Default
217 + Clone
218 + Send
219 + Sync
220 + scirs2_core::ndarray::ScalarOperand
221 + std::fmt::Debug
222 + std::iter::Sum,
223 D: scirs2_core::ndarray::Dimension,
224 O: Optimizer<A, D> + Send + Sync,
225{
226 pub fn new(baseoptimizer: O, config: DifferentialPrivacyConfig) -> Result<Self> {
228 let accountant = MomentsAccountant::new(
229 config.noise_multiplier,
230 config.target_delta,
231 config.batch_size,
232 config.dataset_size,
233 );
234
235 let rng = thread_rng();
236
237 let adaptive_clipping = if config.adaptive_clipping {
238 Some(AdaptiveClippingState::new(
239 config.adaptive_clip_init,
240 config.adaptive_clip_lr,
241 )?)
242 } else {
243 None
244 };
245
246 let privacy_budget = PrivacyBudgetTracker::new(&config);
247 let gradient_stats = GradientStatistics::new();
248 let noise_calibrator = NoiseCalibrator::new(&config);
249
250 let batchsize = config.batch_size;
251 Ok(Self {
252 baseoptimizer,
253 config,
254 accountant,
255 rng,
256 adaptive_clipping,
257 privacy_budget,
258 gradient_stats,
259 noise_calibrator,
260 step_count: 0,
261 current_batch_size: batchsize,
262 _phantom: std::marker::PhantomData,
263 })
264 }
265
266 pub fn dp_step(
268 &mut self,
269 params: &Array<A, D>,
270 gradients: &mut Array<A, D>,
271 batchsize: usize,
272 ) -> Result<Array<A, D>> {
273 self.step_count += 1;
274 self.current_batch_size = batchsize;
275
276 if !self.has_privacy_budget()? {
278 return Err(OptimError::PrivacyBudgetExhausted {
279 consumed_epsilon: self.privacy_budget.epsilon_consumed,
280 target_epsilon: self.privacy_budget.target_epsilon,
281 });
282 }
283
284 let pre_clip_norm = self.compute_gradient_norm(gradients);
286
287 self.gradient_stats.update_norm(pre_clip_norm);
289
290 let clipping_threshold = self.get_clipping_threshold();
292
293 let was_clipped = self.clip_gradients(gradients, clipping_threshold)?;
295
296 if was_clipped {
298 self.gradient_stats.update_clipping();
299 }
300
301 let _post_clip_norm = self.compute_gradient_norm(gradients);
303
304 self.add_noise(gradients, clipping_threshold)?;
306
307 let (epsilon_spent, delta_spent) = self.accountant.get_privacy_spent(self.step_count)?;
309 self.privacy_budget.update_consumption(
310 self.step_count,
311 epsilon_spent,
312 delta_spent,
313 batchsize,
314 self.config.noise_multiplier,
315 );
316
317 let should_update = self.should_update_clipping_threshold();
319 if let Some(ref mut adaptive_state) = self.adaptive_clipping {
320 if should_update {
321 adaptive_state.update_threshold(pre_clip_norm.to_f64().unwrap_or(0.0));
322 }
323 }
324
325 self.noise_calibrator.update_calibration(
327 self.step_count,
328 pre_clip_norm,
329 A::from(clipping_threshold).unwrap(),
330 A::from(epsilon_spent).unwrap(),
331 );
332
333 let updated_params = self.baseoptimizer.step(params, gradients)?;
335
336 Ok(updated_params)
337 }
338
339 pub fn has_privacy_budget(&self) -> Result<bool> {
341 Ok(
342 self.privacy_budget.epsilon_consumed < self.privacy_budget.target_epsilon
343 && self.privacy_budget.delta_consumed < self.privacy_budget.target_delta,
344 )
345 }
346
347 pub fn get_privacy_budget(&self) -> PrivacyBudget {
349 PrivacyBudget {
350 epsilon_consumed: self.privacy_budget.epsilon_consumed,
351 delta_consumed: self.privacy_budget.delta_consumed,
352 epsilon_remaining: (self.privacy_budget.target_epsilon
353 - self.privacy_budget.epsilon_consumed)
354 .max(0.0),
355 delta_remaining: (self.privacy_budget.target_delta
356 - self.privacy_budget.delta_consumed)
357 .max(0.0),
358 steps_taken: self.step_count,
359 accounting_method: super::AccountingMethod::MomentsAccountant,
360 estimated_steps_remaining: self.estimate_remaining_steps(),
361 }
362 }
363
364 pub fn get_clipping_stats(&self) -> AdaptiveClippingStats {
366 AdaptiveClippingStats {
367 current_threshold: self.get_clipping_threshold(),
368 target_quantile: self
369 .adaptive_clipping
370 .as_ref()
371 .map(|ac| ac.target_quantile)
372 .unwrap_or(0.5),
373 clipping_frequency: self.gradient_stats.clipping_frequency,
374 avg_gradient_norm: self.gradient_stats.avg_norm.to_f64().unwrap_or(0.0),
375 std_gradient_norm: self.gradient_stats.std_norm.to_f64().unwrap_or(0.0),
376 adaptation_rate: self
377 .adaptive_clipping
378 .as_ref()
379 .map(|ac| ac.adaptationlr)
380 .unwrap_or(0.0),
381 }
382 }
383
384 pub fn set_batch_size(&mut self, batchsize: usize) {
386 self.current_batch_size = batchsize;
387 self.accountant = MomentsAccountant::new(
389 self.config.noise_multiplier,
390 self.config.target_delta,
391 batchsize,
392 self.config.dataset_size,
393 );
394 }
395
396 pub fn update_privacy_config(&mut self, newconfig: DifferentialPrivacyConfig) -> Result<()> {
398 if newconfig.target_epsilon < self.config.target_epsilon
400 || newconfig.target_delta < self.config.target_delta
401 {
402 return Err(OptimError::InvalidConfig(
403 "Cannot decrease privacy budget mid-training".to_string(),
404 ));
405 }
406
407 self.config = newconfig;
408 self.privacy_budget.target_epsilon = self.config.target_epsilon;
409 self.privacy_budget.target_delta = self.config.target_delta;
410
411 self.accountant = MomentsAccountant::new(
413 self.config.noise_multiplier,
414 self.config.target_delta,
415 self.current_batch_size,
416 self.config.dataset_size,
417 );
418
419 Ok(())
420 }
421
422 fn compute_gradient_norm<S, DIM>(&self, gradients: &ArrayBase<S, DIM>) -> A
424 where
425 S: Data<Elem = A>,
426 DIM: Dimension,
427 {
428 gradients.iter().map(|&g| g * g).sum::<A>().sqrt()
429 }
430
431 fn get_clipping_threshold(&self) -> f64 {
433 if let Some(ref adaptive_state) = self.adaptive_clipping {
434 adaptive_state.current_threshold
435 } else {
436 self.config.l2_norm_clip
437 }
438 }
439
440 fn clip_gradients<S, DIM>(
442 &self,
443 gradients: &mut ArrayBase<S, DIM>,
444 threshold: f64,
445 ) -> Result<bool>
446 where
447 S: DataMut<Elem = A>,
448 DIM: Dimension,
449 {
450 let norm = self.compute_gradient_norm(gradients);
451 let threshold_a = A::from(threshold).unwrap();
452
453 if norm > threshold_a {
454 let scale = threshold_a / norm;
455 gradients.mapv_inplace(|g| g * scale);
456 Ok(true)
457 } else {
458 Ok(false)
459 }
460 }
461
462 fn add_noise<S, DIM>(
464 &mut self,
465 gradients: &mut ArrayBase<S, DIM>,
466 clipping_threshold: f64,
467 ) -> Result<()>
468 where
469 S: DataMut<Elem = A>,
470 DIM: Dimension,
471 {
472 let noise_scale = self.config.noise_multiplier * clipping_threshold;
473
474 match self.config.noise_mechanism {
475 NoiseMechanism::Gaussian => {
476 let normal = RandNormal::new(0.0, noise_scale)
477 .map_err(|_| OptimError::InvalidConfig("Invalid noise scale".to_string()))?;
478 gradients.mapv_inplace(|g| {
479 let noise_f64 = self.rng.sample(normal);
481 let noise = A::from(noise_f64).unwrap();
482 g + noise
483 });
484 }
485 NoiseMechanism::Laplace => {
486 let normal = RandNormal::new(0.0, noise_scale * 1.414)
488 .map_err(|_| OptimError::InvalidConfig("Invalid noise scale".to_string()))?;
489 gradients.mapv_inplace(|g| {
490 let noise_f64 = self.rng.sample(normal);
492 let noise = A::from(noise_f64).unwrap();
493 g + noise
494 });
495 }
496 _ => {
497 let normal = RandNormal::new(0.0, noise_scale)
499 .map_err(|_| OptimError::InvalidConfig("Invalid noise scale".to_string()))?;
500 gradients.mapv_inplace(|g| {
501 let noise_f64 = self.rng.sample(normal);
503 let noise = A::from(noise_f64).unwrap();
504 g + noise
505 });
506 }
507 }
508
509 Ok(())
510 }
511
512 fn should_update_clipping_threshold(&self) -> bool {
514 if let Some(ref adaptive_state) = self.adaptive_clipping {
515 self.step_count - adaptive_state.last_update_step >= adaptive_state.update_frequency
516 } else {
517 false
518 }
519 }
520
521 fn estimate_remaining_steps(&self) -> usize {
523 if self.step_count == 0 {
524 return usize::MAX;
525 }
526
527 let epsilon_per_step = self.privacy_budget.epsilon_consumed / self.step_count as f64;
528 let remaining_epsilon =
529 self.privacy_budget.target_epsilon - self.privacy_budget.epsilon_consumed;
530
531 if epsilon_per_step > 0.0 {
532 (remaining_epsilon / epsilon_per_step) as usize
533 } else {
534 usize::MAX
535 }
536 }
537
538 pub fn get_privacy_accounting_details(&self) -> PrivacyAccountingDetails {
540 PrivacyAccountingDetails {
541 moment_accountant_orders: self.accountant.get_computed_orders(),
542 privacy_consumption_history: self.privacy_budget.consumption_history.clone(),
543 gradient_statistics: GradientStatsSnapshot {
544 avg_norm: self.gradient_stats.avg_norm.to_f64().unwrap_or(0.0),
545 std_norm: self.gradient_stats.std_norm.to_f64().unwrap_or(0.0),
546 clipping_frequency: self.gradient_stats.clipping_frequency,
547 percentiles: self
548 .gradient_stats
549 .percentiles
550 .iter()
551 .map(|(k, v)| (k.clone(), v.to_f64().unwrap_or(0.0)))
552 .collect(),
553 },
554 noise_calibration_history: self
555 .noise_calibrator
556 .calibration_history
557 .iter()
558 .map(|entry| NoiseCalibration {
559 step: entry.step,
560 noise_scale: entry.noise_scale.to_f64().unwrap_or(0.0),
561 gradientnorm: entry.gradientnorm.to_f64().unwrap_or(0.0),
562 clipping_threshold: entry.clipping_threshold.to_f64().unwrap_or(0.0),
563 privacy_cost: entry.privacy_cost.to_f64().unwrap_or(0.0),
564 })
565 .collect(),
566 }
567 }
568
569 pub fn validate_configuration(&self) -> Result<ConfigurationValidation> {
571 let mut warnings = Vec::new();
572 let mut errors = Vec::new();
573
574 if self.config.noise_multiplier < 0.1 {
576 warnings
577 .push("Very low noise multiplier may not provide sufficient privacy".to_string());
578 }
579 if self.config.noise_multiplier > 10.0 {
580 warnings.push("Very high noise multiplier may severely impact utility".to_string());
581 }
582
583 if self.config.l2_norm_clip < 0.01 {
585 warnings
586 .push("Very low clipping threshold may impact gradient information".to_string());
587 }
588 if self.config.l2_norm_clip > 100.0 {
589 warnings.push(
590 "Very high clipping threshold may not provide effective clipping".to_string(),
591 );
592 }
593
594 if self.config.batch_size < 16 {
596 warnings.push("Small batch size may reduce privacy amplification benefits".to_string());
597 }
598
599 if self.config.dataset_size < 1000 {
601 warnings
602 .push("Small dataset may limit achievable privacy-utility tradeoff".to_string());
603 }
604
605 if self.config.target_epsilon > 10.0 {
607 warnings.push("Large epsilon value provides limited privacy guarantee".to_string());
608 }
609 if self.config.target_delta > 1.0 / self.config.dataset_size as f64 {
610 errors.push("Delta should typically be much smaller than 1/n".to_string());
611 }
612
613 Ok(ConfigurationValidation {
614 is_valid: errors.is_empty(),
615 warnings,
616 errors,
617 recommended_adjustments: self.generate_recommendations(),
618 })
619 }
620
621 fn generate_recommendations(&self) -> Vec<String> {
622 let mut recommendations = Vec::new();
623
624 if self.gradient_stats.clipping_frequency > 0.8 {
625 recommendations.push(
626 "Consider increasing clipping threshold - high clipping frequency detected"
627 .to_string(),
628 );
629 }
630
631 if self.gradient_stats.clipping_frequency < 0.1 {
632 recommendations.push(
633 "Consider decreasing clipping threshold - low clipping frequency detected"
634 .to_string(),
635 );
636 }
637
638 if self.privacy_budget.epsilon_consumed / self.privacy_budget.target_epsilon > 0.9 {
639 recommendations.push(
640 "Privacy budget nearly exhausted - consider reducing noise multiplier".to_string(),
641 );
642 }
643
644 recommendations
645 }
646}
647
648impl AdaptiveClippingState {
651 fn new(initial_threshold: f64, adaptationlr: f64) -> Result<Self> {
652 Ok(Self {
653 current_threshold: initial_threshold,
654 target_quantile: 0.5,
655 adaptationlr,
656 norm_history: VecDeque::with_capacity(1000),
657 update_frequency: 50,
658 last_update_step: 0,
659 quantile_estimator: QuantileEstimator::new(),
660 })
661 }
662
663 fn update_threshold(&mut self, gradientnorm: f64) {
664 self.norm_history.push_back(gradientnorm);
665 if self.norm_history.len() > 1000 {
666 self.norm_history.pop_front();
667 }
668
669 self.quantile_estimator.update(gradientnorm);
671
672 let quantile_estimate = self.quantile_estimator.get_quantile(self.target_quantile);
674 let error = quantile_estimate - self.current_threshold;
675 self.current_threshold += self.adaptationlr * error;
676
677 self.current_threshold = self.current_threshold.max(1e-6);
679 }
680}
681
682impl QuantileEstimator {
683 fn new() -> Self {
684 Self {
685 p2_state: P2AlgorithmState::new(0.5),
686 moving_avg: 0.0,
687 ema: 0.0,
688 ema_decay: 0.99,
689 }
690 }
691
692 fn update(&mut self, value: f64) {
693 self.p2_state.update(value);
694
695 if self.p2_state.count == 1 {
697 self.ema = value;
698 } else {
699 self.ema = self.ema_decay * self.ema + (1.0 - self.ema_decay) * value;
700 }
701 }
702
703 fn get_quantile(&self, quantile: f64) -> f64 {
704 if self.p2_state.count >= 5 {
705 self.p2_state.get_quantile()
706 } else {
707 self.ema
708 }
709 }
710}
711
712impl P2AlgorithmState {
713 fn new(quantile: f64) -> Self {
714 Self {
715 markers: [0.0; 5],
716 values: [0.0; 5],
717 desired_positions: [0.0, quantile / 2.0, quantile, (1.0 + quantile) / 2.0, 1.0],
718 increments: [0.0, quantile / 2.0, quantile, (1.0 + quantile) / 2.0, 1.0],
719 count: 0,
720 }
721 }
722
723 fn update(&mut self, value: f64) {
724 if self.count < 5 {
725 self.values[self.count] = value;
726 if self.count == 4 {
727 self.values.sort_by(|a, b| a.partial_cmp(b).unwrap());
728 for i in 0..5 {
729 self.markers[i] = i as f64;
730 }
731 }
732 self.count += 1;
733 } else {
734 self.count += 1;
737 }
738 }
739
740 fn get_quantile(&self) -> f64 {
741 if self.count >= 5 {
742 self.values[2] } else {
744 0.0
745 }
746 }
747}
748
749impl PrivacyBudgetTracker {
750 fn new(config: &DifferentialPrivacyConfig) -> Self {
751 Self {
752 epsilon_consumed: 0.0,
753 delta_consumed: 0.0,
754 target_epsilon: config.target_epsilon,
755 target_delta: config.target_delta,
756 epsilon_per_step: 0.0,
757 delta_per_step: 0.0,
758 consumption_history: Vec::new(),
759 }
760 }
761
762 fn update_consumption(
763 &mut self,
764 step: usize,
765 epsilon_spent: f64,
766 delta_spent: f64,
767 batchsize: usize,
768 noise_multiplier: f64,
769 ) {
770 self.epsilon_consumed = epsilon_spent;
771 self.delta_consumed = delta_spent;
772
773 self.consumption_history.push(PrivacyConsumption {
774 step,
775 epsilon_spent,
776 delta_spent,
777 batchsize,
778 noise_multiplier,
779 });
780 }
781}
782
783impl<A: Float + Default + Clone + std::iter::Sum + Send + Sync> GradientStatistics<A> {
784 fn new() -> Self {
785 Self {
786 norm_history: VecDeque::with_capacity(1000),
787 clipping_frequency: 0.0,
788 avg_norm: A::zero(),
789 std_norm: A::zero(),
790 percentiles: HashMap::new(),
791 max_history_size: 1000,
792 }
793 }
794
795 fn update_norm(&mut self, norm: A) {
796 self.norm_history.push_back(norm);
797 if self.norm_history.len() > self.max_history_size {
798 self.norm_history.pop_front();
799 }
800
801 let n = A::from(self.norm_history.len()).unwrap();
803 self.avg_norm = self.norm_history.iter().cloned().sum::<A>() / n;
804
805 let variance = self
806 .norm_history
807 .iter()
808 .map(|&x| (x - self.avg_norm) * (x - self.avg_norm))
809 .sum::<A>()
810 / n;
811 self.std_norm = variance.sqrt();
812 }
813
814 fn update_clipping(&mut self) {
815 let alpha = 0.01; self.clipping_frequency = (1.0 - alpha) * self.clipping_frequency + alpha;
818 }
819}
820
821impl<A: Float + Default + Clone + Send + Sync> NoiseCalibrator<A> {
822 fn new(config: &DifferentialPrivacyConfig) -> Self {
823 Self {
824 noise_multiplier: A::from(config.noise_multiplier).unwrap(),
825 base_noise_scale: A::from(config.noise_multiplier * config.l2_norm_clip).unwrap(),
826 adaptive_scaling: false,
827 mechanism: config.noise_mechanism,
828 calibration_history: Vec::new(),
829 }
830 }
831
832 fn update_calibration(
833 &mut self,
834 step: usize,
835 gradientnorm: A,
836 clipping_threshold: A,
837 privacy_cost: A,
838 ) {
839 let noise_scale = self.noise_multiplier * clipping_threshold;
840
841 self.calibration_history.push(NoiseCalibration {
842 step,
843 noise_scale,
844 gradientnorm,
845 clipping_threshold,
846 privacy_cost,
847 });
848
849 if self.calibration_history.len() > 1000 {
851 self.calibration_history.remove(0);
852 }
853 }
854}
855
856#[derive(Debug, Clone)]
858pub struct AdaptiveClippingStats {
859 pub current_threshold: f64,
860 pub target_quantile: f64,
861 pub clipping_frequency: f64,
862 pub avg_gradient_norm: f64,
863 pub std_gradient_norm: f64,
864 pub adaptation_rate: f64,
865}
866
867#[derive(Debug, Clone)]
869pub struct PrivacyAccountingDetails {
870 pub moment_accountant_orders: Vec<f64>,
871 pub privacy_consumption_history: Vec<PrivacyConsumption>,
872 pub gradient_statistics: GradientStatsSnapshot,
873 pub noise_calibration_history: Vec<NoiseCalibration<f64>>,
874}
875
876#[derive(Debug, Clone)]
878pub struct GradientStatsSnapshot {
879 pub avg_norm: f64,
880 pub std_norm: f64,
881 pub clipping_frequency: f64,
882 pub percentiles: HashMap<String, f64>,
883}
884
885#[derive(Debug, Clone)]
887pub struct ConfigurationValidation {
888 pub is_valid: bool,
889 pub warnings: Vec<String>,
890 pub errors: Vec<String>,
891 pub recommended_adjustments: Vec<String>,
892}
893
894#[cfg(test)]
895mod tests {
896 use super::*;
897 use crate::optimizers::SGD;
898
899 #[test]
900 fn test_dp_sgd_creation() {
901 let sgd = SGD::new(0.01);
902 let config = DifferentialPrivacyConfig::default();
903 let dp_sgd = DPSGDOptimizer::<_, f64, scirs2_core::ndarray::Ix1>::new(sgd, config);
904 assert!(dp_sgd.is_ok());
905 }
906
907 #[test]
908 fn test_adaptive_clipping_state() {
909 let state = AdaptiveClippingState::new(1.0, 0.1);
910 assert!(state.is_ok());
911
912 let state = state.unwrap();
913 assert_eq!(state.current_threshold, 1.0);
914 assert_eq!(state.adaptationlr, 0.1);
915 }
916
917 #[test]
918 fn test_privacy_budget_tracker() {
919 let config = DifferentialPrivacyConfig::default();
920 let tracker = PrivacyBudgetTracker::new(&config);
921
922 assert_eq!(tracker.target_epsilon, config.target_epsilon);
923 assert_eq!(tracker.epsilon_consumed, 0.0);
924 }
925
926 #[test]
927 fn test_quantile_estimator() {
928 let mut estimator = QuantileEstimator::new();
929
930 for i in 1..=10 {
931 estimator.update(i as f64);
932 }
933
934 let quantile = estimator.get_quantile(0.5);
935 assert!(quantile > 0.0);
936 }
937
938 #[test]
939 fn test_gradient_statistics() {
940 let mut stats = GradientStatistics::<f64>::new();
941
942 stats.update_norm(1.0);
943 stats.update_norm(2.0);
944 stats.update_norm(3.0);
945
946 assert_eq!(stats.avg_norm, 2.0);
947 assert!(stats.std_norm > 0.0);
948 }
949}