Skip to main content

entrenar/optim/dp/
dp_sgd.rs

1//! Differentially private SGD wrapper.
2//!
3//! Wraps any optimizer with DP guarantees by:
4//! 1. Per-sample gradient clipping
5//! 2. Adding calibrated Gaussian noise
6//! 3. Privacy accounting
7
8use super::accountant::RdpAccountant;
9use super::config::DpSgdConfig;
10use super::error::{DpError, Result};
11use super::gradient::{add_gaussian_noise, clip_gradient};
12
13/// Differentially private SGD wrapper
14///
15/// Wraps any optimizer with DP guarantees by:
16/// 1. Per-sample gradient clipping
17/// 2. Adding calibrated Gaussian noise
18/// 3. Privacy accounting
19#[derive(Debug, Clone)]
20pub struct DpSgd {
21    /// DP configuration
22    config: DpSgdConfig,
23    /// Privacy accountant
24    accountant: RdpAccountant,
25    /// Learning rate
26    learning_rate: f64,
27}
28
29impl DpSgd {
30    /// Create a new DP-SGD optimizer
31    pub fn new(learning_rate: f64, config: DpSgdConfig) -> Result<Self> {
32        config.validate()?;
33        Ok(Self { config, accountant: RdpAccountant::new(), learning_rate })
34    }
35
36    /// Get current privacy spent
37    pub fn privacy_spent(&self) -> (f64, f64) {
38        self.accountant.get_privacy_spent(self.config.budget.delta)
39    }
40
41    /// Get current epsilon
42    pub fn current_epsilon(&self) -> f64 {
43        self.privacy_spent().0
44    }
45
46    /// Get remaining budget
47    pub fn remaining_budget(&self) -> f64 {
48        self.config.budget.remaining(self.current_epsilon())
49    }
50
51    /// Check if budget is exhausted
52    pub fn is_budget_exhausted(&self) -> bool {
53        !self.config.budget.allows(self.current_epsilon())
54    }
55
56    /// Get number of training steps
57    pub fn n_steps(&self) -> usize {
58        self.accountant.n_steps()
59    }
60
61    /// Get configuration
62    pub fn config(&self) -> &DpSgdConfig {
63        &self.config
64    }
65
66    /// Process per-sample gradients with DP mechanism
67    ///
68    /// Returns the privatized aggregated gradient
69    pub fn privatize_gradients(&mut self, per_sample_grads: &[Vec<f64>]) -> Result<Vec<f64>> {
70        if per_sample_grads.is_empty() {
71            return Err(DpError::GradientError("No gradients provided".to_string()));
72        }
73
74        // Check budget
75        if self.config.strict_budget && self.is_budget_exhausted() {
76            return Err(DpError::BudgetExhausted {
77                spent: self.current_epsilon(),
78                budget: self.config.budget.epsilon,
79            });
80        }
81
82        let n_samples = per_sample_grads.len();
83        let grad_dim = per_sample_grads[0].len();
84
85        // Step 1: Clip each per-sample gradient
86        let clipped: Vec<Vec<f64>> =
87            per_sample_grads.iter().map(|g| clip_gradient(g, self.config.max_grad_norm)).collect();
88
89        // Step 2: Average clipped gradients
90        let mut averaged = vec![0.0; grad_dim];
91        for g in &clipped {
92            for (i, &val) in g.iter().enumerate() {
93                averaged[i] += val / n_samples as f64;
94            }
95        }
96
97        // Step 3: Add Gaussian noise
98        let mut rng = rand::rng();
99        let noise_std = self.config.noise_std() / n_samples as f64;
100        let noised = add_gaussian_noise(&averaged, noise_std, &mut rng);
101
102        // Step 4: Update privacy accounting
103        self.accountant.step(self.config.noise_multiplier, self.config.sample_rate);
104
105        Ok(noised)
106    }
107
108    /// Apply gradient update to parameters
109    pub fn apply_update(&self, params: &mut [f64], grad: &[f64]) {
110        for (p, g) in params.iter_mut().zip(grad.iter()) {
111            *p -= self.learning_rate * g;
112        }
113    }
114
115    /// Full DP-SGD step: privatize gradients and update parameters
116    pub fn step(
117        &mut self,
118        params: &mut [f64],
119        per_sample_grads: &[Vec<f64>],
120    ) -> Result<(f64, f64)> {
121        let grad = self.privatize_gradients(per_sample_grads)?;
122        self.apply_update(params, &grad);
123        Ok(self.privacy_spent())
124    }
125
126    /// Reset privacy accountant
127    pub fn reset(&mut self) {
128        self.accountant.reset();
129    }
130}