optirs_core/privacy/
mod.rs

1// Differential Privacy support for optimizers
2//
3// This module provides differential privacy mechanisms for machine learning
4// optimization, including DP-SGD with moment accountant for privacy budget tracking.
5
6#[allow(dead_code)]
7use crate::error::{OptimError, Result};
8use scirs2_core::ndarray::{Array, ArrayBase, Data, DataMut, Dimension, ScalarOperand};
9use scirs2_core::numeric::Float;
10use scirs2_core::random::{thread_rng, Rng};
11use scirs2_core::ScientificNumber;
12use std::collections::VecDeque;
13use std::fmt::Debug;
14
15pub mod byzantine_tolerance;
16pub mod differential_privacy; // New modular differential privacy
17pub mod dp_sgd;
18pub mod enhanced_audit;
19pub mod federated; // New modular federated privacy
20pub mod federated_privacy;
21pub mod moment_accountant;
22pub mod noise_mechanisms;
23pub mod private_hyperparameter_optimization;
24pub mod secure_multiparty;
25pub mod utility_analysis;
26
27use crate::optimizers::Optimizer;
28
29// Re-export key utility analysis types
30pub use utility_analysis::{
31    AnalysisConfig, AnalysisMetadata, BudgetRecommendations, OptimalConfiguration, ParetoPoint,
32    PrivacyConfiguration, PrivacyParameterSpace, PrivacyRiskAssessment, PrivacyUtilityAnalyzer,
33    PrivacyUtilityResults, RobustnessResults, SensitivityResults, StatisticalTestResults,
34    UtilityMetric,
35};
36
37// Re-export modular federated privacy types
38pub use federated::{
39    ByzantineRobustAggregator, ByzantineRobustConfig, ByzantineRobustMethod, ClientComposition,
40    CompositionStats, CrossDeviceConfig, CrossDevicePrivacyManager, DeviceProfile, DeviceType,
41    FederatedCompositionAnalyzer, FederatedCompositionMethod, OutlierDetectionResult,
42    ReputationSystemConfig, RoundComposition, SecureAggregationConfig, SecureAggregationPlan,
43    SecureAggregator, SeedSharingMethod, StatisticalTestConfig, StatisticalTestType, TemporalEvent,
44    TemporalEventType,
45};
46
47// Re-export modular differential privacy types
48pub use differential_privacy::{
49    AmplificationConfig, AmplificationStats, PrivacyAmplificationAnalyzer, SubsamplingEvent,
50};
51
52/// Differential privacy configuration
53#[derive(Debug, Clone)]
54pub struct DifferentialPrivacyConfig {
55    /// Target privacy parameter epsilon
56    pub target_epsilon: f64,
57
58    /// Privacy parameter delta (typically 1/n where n is dataset size)
59    pub target_delta: f64,
60
61    /// Noise multiplier for gradient perturbation
62    pub noise_multiplier: f64,
63
64    /// L2 norm clipping threshold for gradients
65    pub l2_norm_clip: f64,
66
67    /// Batch size for sampling
68    pub batch_size: usize,
69
70    /// Dataset size for privacy accounting
71    pub dataset_size: usize,
72
73    /// Maximum number of training steps
74    pub max_steps: usize,
75
76    /// Noise mechanism to use
77    pub noise_mechanism: NoiseMechanism,
78
79    /// Enable secure aggregation (for federated learning)
80    pub secure_aggregation: bool,
81
82    /// Enable adaptive clipping
83    pub adaptive_clipping: bool,
84
85    /// Initial clipping threshold for adaptive clipping
86    pub adaptive_clip_init: f64,
87
88    /// Learning rate for adaptive clipping
89    pub adaptive_clip_lr: f64,
90}
91
92impl Default for DifferentialPrivacyConfig {
93    fn default() -> Self {
94        Self {
95            target_epsilon: 1.0,
96            target_delta: 1e-5,
97            noise_multiplier: 1.1,
98            l2_norm_clip: 1.0,
99            batch_size: 256,
100            dataset_size: 50000,
101            max_steps: 1000,
102            noise_mechanism: NoiseMechanism::Gaussian,
103            secure_aggregation: false,
104            adaptive_clipping: false,
105            adaptive_clip_init: 1.0,
106            adaptive_clip_lr: 0.2,
107        }
108    }
109}
110
111/// Noise mechanisms for differential privacy
112#[derive(Debug, Clone, Copy)]
113pub enum NoiseMechanism {
114    /// Gaussian noise mechanism
115    Gaussian,
116    /// Laplace noise mechanism  
117    Laplace,
118    /// Tree aggregation with Gaussian noise
119    TreeAggregation,
120    /// Improved composition with amplification
121    ImprovedComposition,
122}
123
124/// Privacy budget tracking information
125#[derive(Debug, Clone)]
126pub struct PrivacyBudget {
127    /// Current epsilon consumed
128    pub epsilon_consumed: f64,
129
130    /// Current delta consumed
131    pub delta_consumed: f64,
132
133    /// Remaining epsilon budget
134    pub epsilon_remaining: f64,
135
136    /// Remaining delta budget
137    pub delta_remaining: f64,
138
139    /// Number of steps taken
140    pub steps_taken: usize,
141
142    /// Privacy accounting method used
143    pub accounting_method: AccountingMethod,
144
145    /// Estimated steps until budget exhaustion
146    pub estimated_steps_remaining: usize,
147}
148
149impl Default for PrivacyBudget {
150    fn default() -> Self {
151        Self {
152            epsilon_consumed: 0.0,
153            delta_consumed: 0.0,
154            epsilon_remaining: 1.0,
155            delta_remaining: 1e-5,
156            steps_taken: 0,
157            accounting_method: AccountingMethod::MomentsAccountant,
158            estimated_steps_remaining: 1000,
159        }
160    }
161}
162
163/// Privacy accounting methods
164#[derive(Debug, Clone, Copy)]
165pub enum AccountingMethod {
166    /// Moments accountant (most accurate)
167    MomentsAccountant,
168    /// Renyi differential privacy
169    RenyiDP,
170    /// Advanced composition
171    AdvancedComposition,
172    /// Zero-concentrated differential privacy
173    ZCDP,
174}
175
176/// Differentially private optimizer wrapper
177pub struct DifferentiallyPrivateOptimizer<O, A, D>
178where
179    A: Float + ScalarOperand + Debug + Send + Sync,
180    D: Dimension,
181    O: Optimizer<A, D>,
182{
183    /// Base optimizer
184    base_optimizer: O,
185
186    /// Privacy configuration
187    config: DifferentialPrivacyConfig,
188
189    /// Moment accountant for privacy tracking
190    accountant: MomentsAccountant,
191
192    /// Random number generator for noise
193    rng: scirs2_core::random::CoreRandom,
194
195    /// Adaptive clipping state
196    adaptive_clip_state: Option<AdaptiveClippingState>,
197
198    /// Gradient history for analysis
199    gradient_history: VecDeque<GradientNorms>,
200
201    /// Privacy audit trail
202    audit_trail: Vec<PrivacyEvent>,
203
204    /// Current step count
205    step_count: usize,
206
207    /// Phantom data for unused type parameters
208    _phantom: std::marker::PhantomData<(A, D)>,
209}
210
211/// Adaptive clipping state
212#[derive(Debug, Clone)]
213struct AdaptiveClippingState {
214    current_threshold: f64,
215    quantile_estimate: f64,
216    update_frequency: usize,
217    last_update_step: usize,
218}
219
220/// Gradient norm statistics
221#[derive(Debug, Clone)]
222struct GradientNorms {
223    step: usize,
224    pre_clip_norm: f64,
225    post_clip_norm: f64,
226    clipping_ratio: f64,
227}
228
229/// Privacy event for audit trail
230#[derive(Debug, Clone)]
231pub struct PrivacyEvent {
232    step: usize,
233    event_type: PrivacyEventType,
234    epsilon_spent: f64,
235    delta_spent: f64,
236    noise_scale: f64,
237}
238
239#[derive(Debug, Clone)]
240enum PrivacyEventType {
241    GradientRelease,
242    ModelUpdate,
243    ParameterQuery,
244    AdaptiveClipUpdate,
245}
246
247impl<O, A, D> DifferentiallyPrivateOptimizer<O, A, D>
248where
249    A: Float
250        + std::ops::AddAssign
251        + std::ops::SubAssign
252        + Send
253        + Sync
254        + scirs2_core::ndarray::ScalarOperand
255        + std::fmt::Debug,
256    D: Dimension,
257    O: Optimizer<A, D>,
258{
259    /// Create a new differentially private optimizer
260    pub fn new(baseoptimizer: O, config: DifferentialPrivacyConfig) -> Result<Self> {
261        let accountant = MomentsAccountant::new(
262            config.noise_multiplier,
263            config.target_delta,
264            config.batch_size,
265            config.dataset_size,
266        );
267
268        let rng = thread_rng();
269
270        let adaptive_clip_state = if config.adaptive_clipping {
271            Some(AdaptiveClippingState {
272                current_threshold: config.adaptive_clip_init,
273                quantile_estimate: config.l2_norm_clip,
274                update_frequency: 50, // Update every 50 steps
275                last_update_step: 0,
276            })
277        } else {
278            None
279        };
280
281        Ok(Self {
282            base_optimizer: baseoptimizer,
283            config,
284            accountant,
285            rng,
286            adaptive_clip_state,
287            gradient_history: VecDeque::with_capacity(1000),
288            audit_trail: Vec::new(),
289            step_count: 0,
290            _phantom: std::marker::PhantomData,
291        })
292    }
293
294    /// Perform a differentially private step
295    pub fn dp_step(
296        &mut self,
297        params: &Array<A, D>,
298        gradients: &mut Array<A, D>,
299    ) -> Result<Array<A, D>> {
300        self.step_count += 1;
301
302        // Check privacy budget
303        if !self.has_privacy_budget()? {
304            return Err(OptimError::PrivacyBudgetExhausted {
305                consumed_epsilon: self.get_privacy_budget().epsilon_consumed,
306                target_epsilon: self.config.target_epsilon,
307            });
308        }
309
310        // Compute gradient norm before clipping
311        let pre_clip_norm = self.compute_l2_norm(gradients);
312
313        // Apply gradient clipping
314        let clip_threshold = self.get_clipping_threshold();
315        let clipping_ratio = if pre_clip_norm > clip_threshold {
316            let scale = clip_threshold / pre_clip_norm;
317            gradients.mapv_inplace(|g| g * A::from(scale).unwrap());
318            scale
319        } else {
320            1.0
321        };
322
323        let post_clip_norm = self.compute_l2_norm(gradients);
324
325        // Add calibrated noise
326        self.add_calibrated_noise(gradients, clip_threshold)?;
327
328        // Update moment accountant
329        let (epsilon_spent, delta_spent) = self.accountant.get_privacy_spent(self.step_count)?;
330
331        // Record gradient statistics
332        self.gradient_history.push_back(GradientNorms {
333            step: self.step_count,
334            pre_clip_norm,
335            post_clip_norm,
336            clipping_ratio,
337        });
338
339        if self.gradient_history.len() > 1000 {
340            self.gradient_history.pop_front();
341        }
342
343        // Record privacy event
344        self.audit_trail.push(PrivacyEvent {
345            step: self.step_count,
346            event_type: PrivacyEventType::GradientRelease,
347            epsilon_spent,
348            delta_spent,
349            noise_scale: self.config.noise_multiplier * clip_threshold,
350        });
351
352        // Update adaptive clipping if enabled
353        if let Some(ref mut state) = self.adaptive_clip_state {
354            if self.step_count - state.last_update_step >= state.update_frequency {
355                state.last_update_step = self.step_count;
356                // Update threshold based on recent gradient norms
357                let target_ratio = 0.8; // Target 80% of gradients to be clipped
358                let new_threshold = pre_clip_norm * target_ratio;
359                state.current_threshold = new_threshold;
360            }
361        }
362
363        // Apply base optimizer step
364        let updated_params = self.base_optimizer.step(params, gradients)?;
365
366        Ok(updated_params)
367    }
368
369    /// Check if privacy budget is available
370    pub fn has_privacy_budget(&self) -> Result<bool> {
371        let budget = self.get_privacy_budget();
372        Ok(budget.epsilon_remaining > 0.0 && budget.delta_remaining > 0.0)
373    }
374
375    /// Get current privacy budget status
376    pub fn get_privacy_budget(&self) -> PrivacyBudget {
377        let (epsilon_consumed, delta_consumed) = self
378            .accountant
379            .get_privacy_spent(self.step_count)
380            .unwrap_or((0.0, 0.0));
381
382        let epsilon_remaining = (self.config.target_epsilon - epsilon_consumed).max(0.0);
383        let delta_remaining = (self.config.target_delta - delta_consumed).max(0.0);
384
385        // Estimate remaining steps
386        let epsilon_per_step = if self.step_count > 0 {
387            epsilon_consumed / self.step_count as f64
388        } else {
389            0.0
390        };
391
392        let estimated_steps_remaining = if epsilon_per_step > 0.0 {
393            (epsilon_remaining / epsilon_per_step) as usize
394        } else {
395            usize::MAX
396        };
397
398        PrivacyBudget {
399            epsilon_consumed,
400            delta_consumed,
401            epsilon_remaining,
402            delta_remaining,
403            steps_taken: self.step_count,
404            accounting_method: AccountingMethod::MomentsAccountant,
405            estimated_steps_remaining,
406        }
407    }
408
409    fn compute_l2_norm<S, DIM>(&self, array: &ArrayBase<S, DIM>) -> f64
410    where
411        S: Data<Elem = A>,
412        DIM: Dimension,
413    {
414        array
415            .iter()
416            .map(|&x| {
417                let val = x.to_f64().unwrap_or(0.0);
418                val * val
419            })
420            .sum::<f64>()
421            .sqrt()
422    }
423
424    fn get_clipping_threshold(&self) -> f64 {
425        if let Some(ref state) = self.adaptive_clip_state {
426            state.current_threshold
427        } else {
428            self.config.l2_norm_clip
429        }
430    }
431
432    fn add_calibrated_noise<S, DIM>(
433        &mut self,
434        gradients: &mut ArrayBase<S, DIM>,
435        clip_threshold: f64,
436    ) -> Result<()>
437    where
438        S: DataMut<Elem = A>,
439        DIM: Dimension,
440    {
441        let noise_scale = self.config.noise_multiplier * clip_threshold;
442
443        match self.config.noise_mechanism {
444            NoiseMechanism::Gaussian => {
445                let sigma_f64 = noise_scale.to_f64().unwrap_or(1.0);
446                gradients.mapv_inplace(|g| {
447                    // Use Box-Muller transformation for Gaussian noise
448                    let u1: f64 = self.rng.gen_range(0.0..1.0);
449                    let u2: f64 = self.rng.gen_range(0.0..1.0);
450                    let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
451                    let noise = A::from(z0 * sigma_f64).unwrap();
452                    g + noise
453                });
454            }
455            NoiseMechanism::Laplace => {
456                // Implement Laplace distribution using transformation method
457                let scale_f64 = noise_scale.to_f64().unwrap_or(1.0);
458                gradients.mapv_inplace(|g| {
459                    let u: f64 = self.rng.gen_range(0.0..1.0);
460                    let laplace_sample = if u < 0.5 {
461                        scale_f64 * (2.0 * u).ln()
462                    } else {
463                        -scale_f64 * (2.0 * (1.0 - u)).ln()
464                    };
465                    let noise = A::from(laplace_sample).unwrap();
466                    g + noise
467                });
468            }
469            _ => {
470                // Use Gaussian as fallback
471                let sigma_f64 = noise_scale.to_f64().unwrap_or(1.0);
472                gradients.mapv_inplace(|g| {
473                    let u1: f64 = self.rng.gen_range(0.0..1.0);
474                    let u2: f64 = self.rng.gen_range(0.0..1.0);
475                    let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
476                    let noise = A::from(z0 * sigma_f64).unwrap();
477                    g + noise
478                });
479            }
480        }
481
482        Ok(())
483    }
484
485    fn update_adaptive_clipping(&mut self, state: &mut AdaptiveClippingState, current_norm: f64) {
486        // Use exponential moving average to track gradient _norm quantiles
487        let alpha = self.config.adaptive_clip_lr;
488
489        // Target the 50th percentile of gradient norms
490        let target_quantile = 0.5;
491
492        // Update quantile estimate
493        if current_norm > state.quantile_estimate {
494            state.quantile_estimate += alpha * target_quantile;
495        } else {
496            state.quantile_estimate -= alpha * (1.0 - target_quantile);
497        }
498
499        // Update clipping threshold
500        state.current_threshold = state.quantile_estimate;
501        state.last_update_step = self.step_count;
502    }
503
504    /// Get gradient clipping statistics
505    pub fn get_clipping_stats(&self) -> ClippingStats {
506        if self.gradient_history.is_empty() {
507            return ClippingStats::default();
508        }
509
510        let total_steps = self.gradient_history.len();
511        let clipped_steps = self
512            .gradient_history
513            .iter()
514            .filter(|stats| stats.clipping_ratio < 1.0)
515            .count();
516
517        let avg_clipping_ratio: f64 = self
518            .gradient_history
519            .iter()
520            .map(|stats| stats.clipping_ratio)
521            .sum::<f64>()
522            / total_steps as f64;
523
524        let avg_pre_clip_norm: f64 = self
525            .gradient_history
526            .iter()
527            .map(|stats| stats.pre_clip_norm)
528            .sum::<f64>()
529            / total_steps as f64;
530
531        ClippingStats {
532            total_steps,
533            clipped_steps,
534            clipping_frequency: clipped_steps as f64 / total_steps as f64,
535            avg_clipping_ratio,
536            avg_pre_clip_norm,
537            current_threshold: self.get_clipping_threshold(),
538        }
539    }
540
541    /// Get privacy audit trail
542    pub fn get_audit_trail(&self) -> &[PrivacyEvent] {
543        &self.audit_trail
544    }
545
546    /// Validate privacy guarantees
547    pub fn validate_privacy(&self) -> PrivacyValidation {
548        let budget = self.get_privacy_budget();
549        let clipping_stats = self.get_clipping_stats();
550
551        let mut warnings = Vec::new();
552        let mut is_valid = true;
553
554        // Check if privacy budget is exceeded
555        if budget.epsilon_consumed > self.config.target_epsilon {
556            warnings.push("Epsilon budget exceeded".to_string());
557            is_valid = false;
558        }
559
560        if budget.delta_consumed > self.config.target_delta {
561            warnings.push("Delta budget exceeded".to_string());
562            is_valid = false;
563        }
564
565        // Check clipping frequency
566        if clipping_stats.clipping_frequency < 0.1 {
567            warnings.push(
568                "Low clipping frequency may indicate sub-optimal privacy-utility tradeoff"
569                    .to_string(),
570            );
571        }
572
573        if clipping_stats.clipping_frequency > 0.9 {
574            warnings.push("High clipping frequency may severely impact utility".to_string());
575        }
576
577        PrivacyValidation {
578            is_valid,
579            budget: budget.clone(),
580            clipping_stats: clipping_stats.clone(),
581            warnings,
582            recommendations: self.generate_recommendations(&budget, &clipping_stats),
583        }
584    }
585
586    fn generate_recommendations(
587        &self,
588        budget: &PrivacyBudget,
589        clipping: &ClippingStats,
590    ) -> Vec<String> {
591        let mut recommendations = Vec::new();
592
593        if clipping.clipping_frequency > 0.8 {
594            recommendations.push("Consider increasing the clipping threshold".to_string());
595        }
596
597        if clipping.clipping_frequency < 0.2 {
598            recommendations.push("Consider decreasing the clipping threshold".to_string());
599        }
600
601        if budget.epsilon_remaining < budget.epsilon_consumed * 0.1 {
602            recommendations.push("Privacy budget nearly exhausted - consider reducing noise multiplier for remaining steps".to_string());
603        }
604
605        recommendations
606    }
607}
608
609/// Gradient clipping statistics
610#[derive(Debug, Clone)]
611pub struct ClippingStats {
612    pub total_steps: usize,
613    pub clipped_steps: usize,
614    pub clipping_frequency: f64,
615    pub avg_clipping_ratio: f64,
616    pub avg_pre_clip_norm: f64,
617    pub current_threshold: f64,
618}
619
620impl Default for ClippingStats {
621    fn default() -> Self {
622        Self {
623            total_steps: 0,
624            clipped_steps: 0,
625            clipping_frequency: 0.0,
626            avg_clipping_ratio: 1.0,
627            avg_pre_clip_norm: 0.0,
628            current_threshold: 1.0,
629        }
630    }
631}
632
633/// Privacy validation results
634#[derive(Debug, Clone)]
635pub struct PrivacyValidation {
636    pub is_valid: bool,
637    pub budget: PrivacyBudget,
638    pub clipping_stats: ClippingStats,
639    pub warnings: Vec<String>,
640    pub recommendations: Vec<String>,
641}
642
643/// Moments accountant for privacy tracking
644pub struct MomentsAccountant {
645    noise_multiplier: f64,
646    target_delta: f64,
647    batch_size: usize,
648    dataset_size: usize,
649    sampling_probability: f64,
650}
651
652impl MomentsAccountant {
653    pub fn new(
654        noise_multiplier: f64,
655        target_delta: f64,
656        batch_size: usize,
657        dataset_size: usize,
658    ) -> Self {
659        let sampling_probability = batch_size as f64 / dataset_size as f64;
660
661        Self {
662            noise_multiplier,
663            target_delta,
664            batch_size,
665            dataset_size,
666            sampling_probability,
667        }
668    }
669
670    /// Compute privacy cost for given number of steps
671    pub fn get_privacy_spent(&self, steps: usize) -> Result<(f64, f64)> {
672        if steps == 0 {
673            return Ok((0.0, 0.0));
674        }
675
676        // Simplified moments accountant calculation
677        // In practice, this would use the full moment generating function
678
679        let sigma = self.noise_multiplier;
680        let q = self.sampling_probability;
681        let t = steps as f64;
682
683        // Gaussian mechanism with subsampling
684        let alpha_max = 32.0; // Maximum order for moment computation
685        let log_moments = self.compute_log_moments(sigma, q, t, alpha_max);
686
687        // Convert to (epsilon, delta)-DP
688        let epsilon = self.compute_epsilon_from_moments(&log_moments, self.target_delta);
689        let delta = self.target_delta;
690
691        Ok((epsilon, delta))
692    }
693
694    fn compute_log_moments(&self, sigma: f64, q: f64, t: f64, alphamax: f64) -> Vec<f64> {
695        let mut log_moments = Vec::new();
696
697        for alpha_int in 2..=(alphamax as usize) {
698            let alpha = alpha_int as f64;
699
700            // Log moment for Gaussian mechanism with subsampling
701            let log_moment = t
702                * (q * q * alpha * (alpha - 1.0) / (2.0 * sigma * sigma))
703                    .exp()
704                    .ln();
705
706            log_moments.push(log_moment);
707        }
708
709        log_moments
710    }
711
712    fn compute_epsilon_from_moments(&self, logmoments: &[f64], delta: f64) -> f64 {
713        let mut min_epsilon = f64::INFINITY;
714
715        for (i, &log_moment) in logmoments.iter().enumerate() {
716            let alpha = (i + 2) as f64;
717            let epsilon = (log_moment - delta.ln()) / (alpha - 1.0);
718
719            if epsilon < min_epsilon {
720                min_epsilon = epsilon;
721            }
722        }
723
724        min_epsilon.max(0.0)
725    }
726}
727
728#[cfg(test)]
729mod tests {
730    use super::*;
731    use crate::optimizers::SGD;
732
733    #[test]
734    fn test_dp_config_default() {
735        let config = DifferentialPrivacyConfig::default();
736        assert_eq!(config.target_epsilon, 1.0);
737        assert_eq!(config.noise_multiplier, 1.1);
738        assert!(matches!(config.noise_mechanism, NoiseMechanism::Gaussian));
739    }
740
741    #[test]
742    fn test_moments_accountant() {
743        let accountant = MomentsAccountant::new(1.1, 1e-5, 256, 50000);
744
745        let (epsilon, delta) = accountant.get_privacy_spent(100).unwrap();
746        assert!(epsilon > 0.0);
747        assert_eq!(delta, 1e-5);
748
749        let (epsilon2, _) = accountant.get_privacy_spent(200).unwrap();
750        assert!(epsilon2 > epsilon); // More steps should consume more budget
751    }
752
753    #[test]
754    fn test_dp_optimizer_creation() {
755        let sgd = SGD::new(0.01);
756        let dp_config = DifferentialPrivacyConfig::default();
757
758        let dp_optimizer = DifferentiallyPrivateOptimizer::<_, f64, scirs2_core::ndarray::Ix1>::new(
759            sgd, dp_config,
760        );
761        assert!(dp_optimizer.is_ok());
762    }
763
764    #[test]
765    fn test_privacy_budget_tracking() {
766        let sgd = SGD::new(0.01);
767        let dp_config = DifferentialPrivacyConfig {
768            target_epsilon: 1.0,
769            max_steps: 100,
770            ..Default::default()
771        };
772
773        let dp_optimizer: DifferentiallyPrivateOptimizer<SGD<f64>, f64, scirs2_core::ndarray::Ix1> =
774            DifferentiallyPrivateOptimizer::new(sgd, dp_config).unwrap();
775        let budget = dp_optimizer.get_privacy_budget();
776
777        assert_eq!(budget.epsilon_consumed, 0.0);
778        assert_eq!(budget.epsilon_remaining, 1.0);
779        assert_eq!(budget.steps_taken, 0);
780    }
781}