optirs_core/privacy/
dp_sgd.rs

1// Differentially Private Stochastic Gradient Descent (DP-SGD)
2//
3// This module implements DP-SGD with adaptive clipping, noise calibration,
4// and privacy budget tracking for training machine learning models with
5// formal privacy guarantees.
6
7#[allow(dead_code)]
8use scirs2_core::ndarray::{Array, ArrayBase, Data, DataMut, Dimension};
9use scirs2_core::numeric::Float;
10use scirs2_core::random::RandNormal;
11use scirs2_core::random::{thread_rng, Rng};
12use std::collections::{HashMap, VecDeque};
13
14use super::moment_accountant::MomentsAccountant;
15use super::{DifferentialPrivacyConfig, NoiseMechanism, PrivacyBudget};
16use crate::error::{OptimError, Result};
17use crate::optimizers::Optimizer;
18
19/// DP-SGD optimizer with privacy guarantees
20pub struct DPSGDOptimizer<O, A, D>
21where
22    A: Float
23        + Send
24        + Sync
25        + scirs2_core::ndarray::ScalarOperand
26        + std::fmt::Debug
27        + Default
28        + Clone
29        + std::iter::Sum,
30    D: scirs2_core::ndarray::Dimension,
31    O: Optimizer<A, D>,
32{
33    /// Base optimizer (SGD, Adam, etc.)
34    baseoptimizer: O,
35
36    /// Privacy configuration
37    config: DifferentialPrivacyConfig,
38
39    /// Moment accountant for privacy tracking
40    accountant: MomentsAccountant,
41
42    /// Random number generator for noise
43    rng: scirs2_core::random::CoreRandom,
44
45    /// Adaptive clipping state
46    adaptive_clipping: Option<AdaptiveClippingState>,
47
48    /// Privacy budget tracker
49    privacy_budget: PrivacyBudgetTracker,
50
51    /// Gradient statistics
52    gradient_stats: GradientStatistics<A>,
53
54    /// Noise calibration
55    noise_calibrator: NoiseCalibrator<A>,
56
57    /// Current step count
58    step_count: usize,
59
60    /// Batch size for current iteration
61    current_batch_size: usize,
62
63    /// Phantom data for unused type parameter
64    _phantom: std::marker::PhantomData<D>,
65}
66
67/// Adaptive clipping state for DP-SGD
68#[derive(Debug, Clone)]
69struct AdaptiveClippingState {
70    /// Current clipping threshold
71    current_threshold: f64,
72
73    /// Target quantile (e.g., 0.5 for median)
74    target_quantile: f64,
75
76    /// Learning rate for threshold adaptation
77    adaptationlr: f64,
78
79    /// History of gradient norms
80    norm_history: VecDeque<f64>,
81
82    /// Update frequency (in steps)
83    update_frequency: usize,
84
85    /// Last update step
86    last_update_step: usize,
87
88    /// Quantile estimation parameters
89    quantile_estimator: QuantileEstimator,
90}
91
92/// Quantile estimator for adaptive clipping
93#[derive(Debug, Clone)]
94struct QuantileEstimator {
95    /// P² algorithm state
96    p2_state: P2AlgorithmState,
97
98    /// Simple moving average
99    moving_avg: f64,
100
101    /// Exponential moving average
102    ema: f64,
103
104    /// EMA decay factor
105    ema_decay: f64,
106}
107
108/// P² algorithm state for quantile estimation
109#[derive(Debug, Clone)]
110struct P2AlgorithmState {
111    /// Marker positions
112    markers: [f64; 5],
113
114    /// Marker values
115    values: [f64; 5],
116
117    /// Desired marker positions
118    desired_positions: [f64; 5],
119
120    /// Increments
121    increments: [f64; 5],
122
123    /// Number of observations
124    count: usize,
125}
126
127/// Privacy budget tracker
128#[derive(Debug, Clone)]
129struct PrivacyBudgetTracker {
130    /// Total epsilon consumed
131    epsilon_consumed: f64,
132
133    /// Total delta consumed
134    delta_consumed: f64,
135
136    /// Target epsilon
137    target_epsilon: f64,
138
139    /// Target delta
140    target_delta: f64,
141
142    /// Privacy budget per step
143    epsilon_per_step: f64,
144
145    /// Delta per step
146    delta_per_step: f64,
147
148    /// Privacy consumption history
149    consumption_history: Vec<PrivacyConsumption>,
150}
151
152/// Privacy consumption record
153#[derive(Debug, Clone)]
154pub struct PrivacyConsumption {
155    step: usize,
156    epsilon_spent: f64,
157    delta_spent: f64,
158    batchsize: usize,
159    noise_multiplier: f64,
160}
161
162/// Gradient statistics for DP-SGD
163#[derive(Debug, Clone)]
164struct GradientStatistics<A: Float + Default + Clone + std::iter::Sum> {
165    /// Recent gradient norms
166    norm_history: VecDeque<A>,
167
168    /// Clipping frequency
169    clipping_frequency: f64,
170
171    /// Average gradient norm
172    avg_norm: A,
173
174    /// Std deviation of gradient norms
175    std_norm: A,
176
177    /// Percentile statistics
178    percentiles: HashMap<String, A>,
179
180    /// Maximum history size
181    max_history_size: usize,
182}
183
184/// Noise calibration for different mechanisms
185#[derive(Debug, Clone)]
186struct NoiseCalibrator<A: Float> {
187    /// Current noise multiplier
188    noise_multiplier: A,
189
190    /// Base noise scale
191    base_noise_scale: A,
192
193    /// Adaptive noise scaling
194    adaptive_scaling: bool,
195
196    /// Noise mechanism
197    mechanism: NoiseMechanism,
198
199    /// Calibration history
200    calibration_history: Vec<NoiseCalibration<A>>,
201}
202
203/// Noise calibration record
204#[derive(Debug, Clone)]
205pub struct NoiseCalibration<A: Float> {
206    step: usize,
207    noise_scale: A,
208    gradientnorm: A,
209    clipping_threshold: A,
210    privacy_cost: A,
211}
212
213impl<O, A, D> DPSGDOptimizer<O, A, D>
214where
215    A: Float
216        + Default
217        + Clone
218        + Send
219        + Sync
220        + scirs2_core::ndarray::ScalarOperand
221        + std::fmt::Debug
222        + std::iter::Sum,
223    D: scirs2_core::ndarray::Dimension,
224    O: Optimizer<A, D> + Send + Sync,
225{
226    /// Create a new DP-SGD optimizer
227    pub fn new(baseoptimizer: O, config: DifferentialPrivacyConfig) -> Result<Self> {
228        let accountant = MomentsAccountant::new(
229            config.noise_multiplier,
230            config.target_delta,
231            config.batch_size,
232            config.dataset_size,
233        );
234
235        let rng = thread_rng();
236
237        let adaptive_clipping = if config.adaptive_clipping {
238            Some(AdaptiveClippingState::new(
239                config.adaptive_clip_init,
240                config.adaptive_clip_lr,
241            )?)
242        } else {
243            None
244        };
245
246        let privacy_budget = PrivacyBudgetTracker::new(&config);
247        let gradient_stats = GradientStatistics::new();
248        let noise_calibrator = NoiseCalibrator::new(&config);
249
250        let batchsize = config.batch_size;
251        Ok(Self {
252            baseoptimizer,
253            config,
254            accountant,
255            rng,
256            adaptive_clipping,
257            privacy_budget,
258            gradient_stats,
259            noise_calibrator,
260            step_count: 0,
261            current_batch_size: batchsize,
262            _phantom: std::marker::PhantomData,
263        })
264    }
265
266    /// Perform a DP-SGD step
267    pub fn dp_step(
268        &mut self,
269        params: &Array<A, D>,
270        gradients: &mut Array<A, D>,
271        batchsize: usize,
272    ) -> Result<Array<A, D>> {
273        self.step_count += 1;
274        self.current_batch_size = batchsize;
275
276        // Check privacy budget
277        if !self.has_privacy_budget()? {
278            return Err(OptimError::PrivacyBudgetExhausted {
279                consumed_epsilon: self.privacy_budget.epsilon_consumed,
280                target_epsilon: self.privacy_budget.target_epsilon,
281            });
282        }
283
284        // Compute gradient norm before clipping
285        let pre_clip_norm = self.compute_gradient_norm(gradients);
286
287        // Update gradient statistics
288        self.gradient_stats.update_norm(pre_clip_norm);
289
290        // Get current clipping threshold
291        let clipping_threshold = self.get_clipping_threshold();
292
293        // Apply gradient clipping
294        let was_clipped = self.clip_gradients(gradients, clipping_threshold)?;
295
296        // Update clipping statistics
297        if was_clipped {
298            self.gradient_stats.update_clipping();
299        }
300
301        // Compute post-clipping norm
302        let _post_clip_norm = self.compute_gradient_norm(gradients);
303
304        // Add calibrated noise
305        self.add_noise(gradients, clipping_threshold)?;
306
307        // Update moment accountant
308        let (epsilon_spent, delta_spent) = self.accountant.get_privacy_spent(self.step_count)?;
309        self.privacy_budget.update_consumption(
310            self.step_count,
311            epsilon_spent,
312            delta_spent,
313            batchsize,
314            self.config.noise_multiplier,
315        );
316
317        // Update adaptive clipping if enabled
318        let should_update = self.should_update_clipping_threshold();
319        if let Some(ref mut adaptive_state) = self.adaptive_clipping {
320            if should_update {
321                adaptive_state.update_threshold(pre_clip_norm.to_f64().unwrap_or(0.0));
322            }
323        }
324
325        // Update noise calibration
326        self.noise_calibrator.update_calibration(
327            self.step_count,
328            pre_clip_norm,
329            A::from(clipping_threshold).unwrap(),
330            A::from(epsilon_spent).unwrap(),
331        );
332
333        // Apply base optimizer step
334        let updated_params = self.baseoptimizer.step(params, gradients)?;
335
336        Ok(updated_params)
337    }
338
339    /// Check if privacy budget is available
340    pub fn has_privacy_budget(&self) -> Result<bool> {
341        Ok(
342            self.privacy_budget.epsilon_consumed < self.privacy_budget.target_epsilon
343                && self.privacy_budget.delta_consumed < self.privacy_budget.target_delta,
344        )
345    }
346
347    /// Get current privacy budget status
348    pub fn get_privacy_budget(&self) -> PrivacyBudget {
349        PrivacyBudget {
350            epsilon_consumed: self.privacy_budget.epsilon_consumed,
351            delta_consumed: self.privacy_budget.delta_consumed,
352            epsilon_remaining: (self.privacy_budget.target_epsilon
353                - self.privacy_budget.epsilon_consumed)
354                .max(0.0),
355            delta_remaining: (self.privacy_budget.target_delta
356                - self.privacy_budget.delta_consumed)
357                .max(0.0),
358            steps_taken: self.step_count,
359            accounting_method: super::AccountingMethod::MomentsAccountant,
360            estimated_steps_remaining: self.estimate_remaining_steps(),
361        }
362    }
363
364    /// Get adaptive clipping statistics
365    pub fn get_clipping_stats(&self) -> AdaptiveClippingStats {
366        AdaptiveClippingStats {
367            current_threshold: self.get_clipping_threshold(),
368            target_quantile: self
369                .adaptive_clipping
370                .as_ref()
371                .map(|ac| ac.target_quantile)
372                .unwrap_or(0.5),
373            clipping_frequency: self.gradient_stats.clipping_frequency,
374            avg_gradient_norm: self.gradient_stats.avg_norm.to_f64().unwrap_or(0.0),
375            std_gradient_norm: self.gradient_stats.std_norm.to_f64().unwrap_or(0.0),
376            adaptation_rate: self
377                .adaptive_clipping
378                .as_ref()
379                .map(|ac| ac.adaptationlr)
380                .unwrap_or(0.0),
381        }
382    }
383
384    /// Set batch size for next iterations
385    pub fn set_batch_size(&mut self, batchsize: usize) {
386        self.current_batch_size = batchsize;
387        // Update moment accountant with new batch _size
388        self.accountant = MomentsAccountant::new(
389            self.config.noise_multiplier,
390            self.config.target_delta,
391            batchsize,
392            self.config.dataset_size,
393        );
394    }
395
396    /// Update privacy configuration
397    pub fn update_privacy_config(&mut self, newconfig: DifferentialPrivacyConfig) -> Result<()> {
398        // Validate that privacy budget doesn't decrease
399        if newconfig.target_epsilon < self.config.target_epsilon
400            || newconfig.target_delta < self.config.target_delta
401        {
402            return Err(OptimError::InvalidConfig(
403                "Cannot decrease privacy budget mid-training".to_string(),
404            ));
405        }
406
407        self.config = newconfig;
408        self.privacy_budget.target_epsilon = self.config.target_epsilon;
409        self.privacy_budget.target_delta = self.config.target_delta;
410
411        // Update moment accountant
412        self.accountant = MomentsAccountant::new(
413            self.config.noise_multiplier,
414            self.config.target_delta,
415            self.current_batch_size,
416            self.config.dataset_size,
417        );
418
419        Ok(())
420    }
421
422    /// Compute gradient norm
423    fn compute_gradient_norm<S, DIM>(&self, gradients: &ArrayBase<S, DIM>) -> A
424    where
425        S: Data<Elem = A>,
426        DIM: Dimension,
427    {
428        gradients.iter().map(|&g| g * g).sum::<A>().sqrt()
429    }
430
431    /// Get current clipping threshold
432    fn get_clipping_threshold(&self) -> f64 {
433        if let Some(ref adaptive_state) = self.adaptive_clipping {
434            adaptive_state.current_threshold
435        } else {
436            self.config.l2_norm_clip
437        }
438    }
439
440    /// Clip gradients to threshold
441    fn clip_gradients<S, DIM>(
442        &self,
443        gradients: &mut ArrayBase<S, DIM>,
444        threshold: f64,
445    ) -> Result<bool>
446    where
447        S: DataMut<Elem = A>,
448        DIM: Dimension,
449    {
450        let norm = self.compute_gradient_norm(gradients);
451        let threshold_a = A::from(threshold).unwrap();
452
453        if norm > threshold_a {
454            let scale = threshold_a / norm;
455            gradients.mapv_inplace(|g| g * scale);
456            Ok(true)
457        } else {
458            Ok(false)
459        }
460    }
461
462    /// Add calibrated noise to gradients
463    fn add_noise<S, DIM>(
464        &mut self,
465        gradients: &mut ArrayBase<S, DIM>,
466        clipping_threshold: f64,
467    ) -> Result<()>
468    where
469        S: DataMut<Elem = A>,
470        DIM: Dimension,
471    {
472        let noise_scale = self.config.noise_multiplier * clipping_threshold;
473
474        match self.config.noise_mechanism {
475            NoiseMechanism::Gaussian => {
476                let normal = RandNormal::new(0.0, noise_scale)
477                    .map_err(|_| OptimError::InvalidConfig("Invalid noise scale".to_string()))?;
478                gradients.mapv_inplace(|g| {
479                    // Use scirs2_core random for Gaussian noise
480                    let noise_f64 = self.rng.sample(normal);
481                    let noise = A::from(noise_f64).unwrap();
482                    g + noise
483                });
484            }
485            NoiseMechanism::Laplace => {
486                // Simplified Laplace noise using Normal distribution approximation
487                let normal = RandNormal::new(0.0, noise_scale * 1.414)
488                    .map_err(|_| OptimError::InvalidConfig("Invalid noise scale".to_string()))?;
489                gradients.mapv_inplace(|g| {
490                    // Use scirs2_core random for Laplace-approximated noise
491                    let noise_f64 = self.rng.sample(normal);
492                    let noise = A::from(noise_f64).unwrap();
493                    g + noise
494                });
495            }
496            _ => {
497                // Default to Gaussian
498                let normal = RandNormal::new(0.0, noise_scale)
499                    .map_err(|_| OptimError::InvalidConfig("Invalid noise scale".to_string()))?;
500                gradients.mapv_inplace(|g| {
501                    // Use scirs2_core random for Gaussian noise
502                    let noise_f64 = self.rng.sample(normal);
503                    let noise = A::from(noise_f64).unwrap();
504                    g + noise
505                });
506            }
507        }
508
509        Ok(())
510    }
511
512    /// Check if clipping threshold should be updated
513    fn should_update_clipping_threshold(&self) -> bool {
514        if let Some(ref adaptive_state) = self.adaptive_clipping {
515            self.step_count - adaptive_state.last_update_step >= adaptive_state.update_frequency
516        } else {
517            false
518        }
519    }
520
521    /// Estimate remaining steps before budget exhaustion
522    fn estimate_remaining_steps(&self) -> usize {
523        if self.step_count == 0 {
524            return usize::MAX;
525        }
526
527        let epsilon_per_step = self.privacy_budget.epsilon_consumed / self.step_count as f64;
528        let remaining_epsilon =
529            self.privacy_budget.target_epsilon - self.privacy_budget.epsilon_consumed;
530
531        if epsilon_per_step > 0.0 {
532            (remaining_epsilon / epsilon_per_step) as usize
533        } else {
534            usize::MAX
535        }
536    }
537
538    /// Get privacy accounting details
539    pub fn get_privacy_accounting_details(&self) -> PrivacyAccountingDetails {
540        PrivacyAccountingDetails {
541            moment_accountant_orders: self.accountant.get_computed_orders(),
542            privacy_consumption_history: self.privacy_budget.consumption_history.clone(),
543            gradient_statistics: GradientStatsSnapshot {
544                avg_norm: self.gradient_stats.avg_norm.to_f64().unwrap_or(0.0),
545                std_norm: self.gradient_stats.std_norm.to_f64().unwrap_or(0.0),
546                clipping_frequency: self.gradient_stats.clipping_frequency,
547                percentiles: self
548                    .gradient_stats
549                    .percentiles
550                    .iter()
551                    .map(|(k, v)| (k.clone(), v.to_f64().unwrap_or(0.0)))
552                    .collect(),
553            },
554            noise_calibration_history: self
555                .noise_calibrator
556                .calibration_history
557                .iter()
558                .map(|entry| NoiseCalibration {
559                    step: entry.step,
560                    noise_scale: entry.noise_scale.to_f64().unwrap_or(0.0),
561                    gradientnorm: entry.gradientnorm.to_f64().unwrap_or(0.0),
562                    clipping_threshold: entry.clipping_threshold.to_f64().unwrap_or(0.0),
563                    privacy_cost: entry.privacy_cost.to_f64().unwrap_or(0.0),
564                })
565                .collect(),
566        }
567    }
568
569    /// Validate DP-SGD configuration
570    pub fn validate_configuration(&self) -> Result<ConfigurationValidation> {
571        let mut warnings = Vec::new();
572        let mut errors = Vec::new();
573
574        // Check noise multiplier
575        if self.config.noise_multiplier < 0.1 {
576            warnings
577                .push("Very low noise multiplier may not provide sufficient privacy".to_string());
578        }
579        if self.config.noise_multiplier > 10.0 {
580            warnings.push("Very high noise multiplier may severely impact utility".to_string());
581        }
582
583        // Check clipping threshold
584        if self.config.l2_norm_clip < 0.01 {
585            warnings
586                .push("Very low clipping threshold may impact gradient information".to_string());
587        }
588        if self.config.l2_norm_clip > 100.0 {
589            warnings.push(
590                "Very high clipping threshold may not provide effective clipping".to_string(),
591            );
592        }
593
594        // Check batch size
595        if self.config.batch_size < 16 {
596            warnings.push("Small batch size may reduce privacy amplification benefits".to_string());
597        }
598
599        // Check dataset size
600        if self.config.dataset_size < 1000 {
601            warnings
602                .push("Small dataset may limit achievable privacy-utility tradeoff".to_string());
603        }
604
605        // Check privacy budget
606        if self.config.target_epsilon > 10.0 {
607            warnings.push("Large epsilon value provides limited privacy guarantee".to_string());
608        }
609        if self.config.target_delta > 1.0 / self.config.dataset_size as f64 {
610            errors.push("Delta should typically be much smaller than 1/n".to_string());
611        }
612
613        Ok(ConfigurationValidation {
614            is_valid: errors.is_empty(),
615            warnings,
616            errors,
617            recommended_adjustments: self.generate_recommendations(),
618        })
619    }
620
621    fn generate_recommendations(&self) -> Vec<String> {
622        let mut recommendations = Vec::new();
623
624        if self.gradient_stats.clipping_frequency > 0.8 {
625            recommendations.push(
626                "Consider increasing clipping threshold - high clipping frequency detected"
627                    .to_string(),
628            );
629        }
630
631        if self.gradient_stats.clipping_frequency < 0.1 {
632            recommendations.push(
633                "Consider decreasing clipping threshold - low clipping frequency detected"
634                    .to_string(),
635            );
636        }
637
638        if self.privacy_budget.epsilon_consumed / self.privacy_budget.target_epsilon > 0.9 {
639            recommendations.push(
640                "Privacy budget nearly exhausted - consider reducing noise multiplier".to_string(),
641            );
642        }
643
644        recommendations
645    }
646}
647
648// Implementation of helper structures
649
650impl AdaptiveClippingState {
651    fn new(initial_threshold: f64, adaptationlr: f64) -> Result<Self> {
652        Ok(Self {
653            current_threshold: initial_threshold,
654            target_quantile: 0.5,
655            adaptationlr,
656            norm_history: VecDeque::with_capacity(1000),
657            update_frequency: 50,
658            last_update_step: 0,
659            quantile_estimator: QuantileEstimator::new(),
660        })
661    }
662
663    fn update_threshold(&mut self, gradientnorm: f64) {
664        self.norm_history.push_back(gradientnorm);
665        if self.norm_history.len() > 1000 {
666            self.norm_history.pop_front();
667        }
668
669        // Update quantile estimate
670        self.quantile_estimator.update(gradientnorm);
671
672        // Adapt threshold towards target quantile
673        let quantile_estimate = self.quantile_estimator.get_quantile(self.target_quantile);
674        let error = quantile_estimate - self.current_threshold;
675        self.current_threshold += self.adaptationlr * error;
676
677        // Ensure threshold is positive
678        self.current_threshold = self.current_threshold.max(1e-6);
679    }
680}
681
682impl QuantileEstimator {
683    fn new() -> Self {
684        Self {
685            p2_state: P2AlgorithmState::new(0.5),
686            moving_avg: 0.0,
687            ema: 0.0,
688            ema_decay: 0.99,
689        }
690    }
691
692    fn update(&mut self, value: f64) {
693        self.p2_state.update(value);
694
695        // Update EMA
696        if self.p2_state.count == 1 {
697            self.ema = value;
698        } else {
699            self.ema = self.ema_decay * self.ema + (1.0 - self.ema_decay) * value;
700        }
701    }
702
703    fn get_quantile(&self, quantile: f64) -> f64 {
704        if self.p2_state.count >= 5 {
705            self.p2_state.get_quantile()
706        } else {
707            self.ema
708        }
709    }
710}
711
712impl P2AlgorithmState {
713    fn new(quantile: f64) -> Self {
714        Self {
715            markers: [0.0; 5],
716            values: [0.0; 5],
717            desired_positions: [0.0, quantile / 2.0, quantile, (1.0 + quantile) / 2.0, 1.0],
718            increments: [0.0, quantile / 2.0, quantile, (1.0 + quantile) / 2.0, 1.0],
719            count: 0,
720        }
721    }
722
723    fn update(&mut self, value: f64) {
724        if self.count < 5 {
725            self.values[self.count] = value;
726            if self.count == 4 {
727                self.values.sort_by(|a, b| a.partial_cmp(b).unwrap());
728                for i in 0..5 {
729                    self.markers[i] = i as f64;
730                }
731            }
732            self.count += 1;
733        } else {
734            // P² algorithm update
735            // Simplified implementation
736            self.count += 1;
737        }
738    }
739
740    fn get_quantile(&self) -> f64 {
741        if self.count >= 5 {
742            self.values[2] // Median for simplicity
743        } else {
744            0.0
745        }
746    }
747}
748
749impl PrivacyBudgetTracker {
750    fn new(config: &DifferentialPrivacyConfig) -> Self {
751        Self {
752            epsilon_consumed: 0.0,
753            delta_consumed: 0.0,
754            target_epsilon: config.target_epsilon,
755            target_delta: config.target_delta,
756            epsilon_per_step: 0.0,
757            delta_per_step: 0.0,
758            consumption_history: Vec::new(),
759        }
760    }
761
762    fn update_consumption(
763        &mut self,
764        step: usize,
765        epsilon_spent: f64,
766        delta_spent: f64,
767        batchsize: usize,
768        noise_multiplier: f64,
769    ) {
770        self.epsilon_consumed = epsilon_spent;
771        self.delta_consumed = delta_spent;
772
773        self.consumption_history.push(PrivacyConsumption {
774            step,
775            epsilon_spent,
776            delta_spent,
777            batchsize,
778            noise_multiplier,
779        });
780    }
781}
782
783impl<A: Float + Default + Clone + std::iter::Sum + Send + Sync> GradientStatistics<A> {
784    fn new() -> Self {
785        Self {
786            norm_history: VecDeque::with_capacity(1000),
787            clipping_frequency: 0.0,
788            avg_norm: A::zero(),
789            std_norm: A::zero(),
790            percentiles: HashMap::new(),
791            max_history_size: 1000,
792        }
793    }
794
795    fn update_norm(&mut self, norm: A) {
796        self.norm_history.push_back(norm);
797        if self.norm_history.len() > self.max_history_size {
798            self.norm_history.pop_front();
799        }
800
801        // Update statistics
802        let n = A::from(self.norm_history.len()).unwrap();
803        self.avg_norm = self.norm_history.iter().cloned().sum::<A>() / n;
804
805        let variance = self
806            .norm_history
807            .iter()
808            .map(|&x| (x - self.avg_norm) * (x - self.avg_norm))
809            .sum::<A>()
810            / n;
811        self.std_norm = variance.sqrt();
812    }
813
814    fn update_clipping(&mut self) {
815        // Simple moving average for clipping frequency
816        let alpha = 0.01; // Learning rate for frequency update
817        self.clipping_frequency = (1.0 - alpha) * self.clipping_frequency + alpha;
818    }
819}
820
821impl<A: Float + Default + Clone + Send + Sync> NoiseCalibrator<A> {
822    fn new(config: &DifferentialPrivacyConfig) -> Self {
823        Self {
824            noise_multiplier: A::from(config.noise_multiplier).unwrap(),
825            base_noise_scale: A::from(config.noise_multiplier * config.l2_norm_clip).unwrap(),
826            adaptive_scaling: false,
827            mechanism: config.noise_mechanism,
828            calibration_history: Vec::new(),
829        }
830    }
831
832    fn update_calibration(
833        &mut self,
834        step: usize,
835        gradientnorm: A,
836        clipping_threshold: A,
837        privacy_cost: A,
838    ) {
839        let noise_scale = self.noise_multiplier * clipping_threshold;
840
841        self.calibration_history.push(NoiseCalibration {
842            step,
843            noise_scale,
844            gradientnorm,
845            clipping_threshold,
846            privacy_cost,
847        });
848
849        // Limit history size
850        if self.calibration_history.len() > 1000 {
851            self.calibration_history.remove(0);
852        }
853    }
854}
855
856/// Adaptive clipping statistics
857#[derive(Debug, Clone)]
858pub struct AdaptiveClippingStats {
859    pub current_threshold: f64,
860    pub target_quantile: f64,
861    pub clipping_frequency: f64,
862    pub avg_gradient_norm: f64,
863    pub std_gradient_norm: f64,
864    pub adaptation_rate: f64,
865}
866
867/// Privacy accounting details
868#[derive(Debug, Clone)]
869pub struct PrivacyAccountingDetails {
870    pub moment_accountant_orders: Vec<f64>,
871    pub privacy_consumption_history: Vec<PrivacyConsumption>,
872    pub gradient_statistics: GradientStatsSnapshot,
873    pub noise_calibration_history: Vec<NoiseCalibration<f64>>,
874}
875
876/// Gradient statistics snapshot
877#[derive(Debug, Clone)]
878pub struct GradientStatsSnapshot {
879    pub avg_norm: f64,
880    pub std_norm: f64,
881    pub clipping_frequency: f64,
882    pub percentiles: HashMap<String, f64>,
883}
884
885/// Configuration validation result
886#[derive(Debug, Clone)]
887pub struct ConfigurationValidation {
888    pub is_valid: bool,
889    pub warnings: Vec<String>,
890    pub errors: Vec<String>,
891    pub recommended_adjustments: Vec<String>,
892}
893
894#[cfg(test)]
895mod tests {
896    use super::*;
897    use crate::optimizers::SGD;
898
899    #[test]
900    fn test_dp_sgd_creation() {
901        let sgd = SGD::new(0.01);
902        let config = DifferentialPrivacyConfig::default();
903        let dp_sgd = DPSGDOptimizer::<_, f64, scirs2_core::ndarray::Ix1>::new(sgd, config);
904        assert!(dp_sgd.is_ok());
905    }
906
907    #[test]
908    fn test_adaptive_clipping_state() {
909        let state = AdaptiveClippingState::new(1.0, 0.1);
910        assert!(state.is_ok());
911
912        let state = state.unwrap();
913        assert_eq!(state.current_threshold, 1.0);
914        assert_eq!(state.adaptationlr, 0.1);
915    }
916
917    #[test]
918    fn test_privacy_budget_tracker() {
919        let config = DifferentialPrivacyConfig::default();
920        let tracker = PrivacyBudgetTracker::new(&config);
921
922        assert_eq!(tracker.target_epsilon, config.target_epsilon);
923        assert_eq!(tracker.epsilon_consumed, 0.0);
924    }
925
926    #[test]
927    fn test_quantile_estimator() {
928        let mut estimator = QuantileEstimator::new();
929
930        for i in 1..=10 {
931            estimator.update(i as f64);
932        }
933
934        let quantile = estimator.get_quantile(0.5);
935        assert!(quantile > 0.0);
936    }
937
938    #[test]
939    fn test_gradient_statistics() {
940        let mut stats = GradientStatistics::<f64>::new();
941
942        stats.update_norm(1.0);
943        stats.update_norm(2.0);
944        stats.update_norm(3.0);
945
946        assert_eq!(stats.avg_norm, 2.0);
947        assert!(stats.std_norm > 0.0);
948    }
949}