use super::accountant::RdpAccountant;
use super::config::DpSgdConfig;
use super::error::{DpError, Result};
use super::gradient::{add_gaussian_noise, clip_gradient};
#[derive(Debug, Clone)]
pub struct DpSgd {
config: DpSgdConfig,
accountant: RdpAccountant,
learning_rate: f64,
}
impl DpSgd {
pub fn new(learning_rate: f64, config: DpSgdConfig) -> Result<Self> {
config.validate()?;
Ok(Self { config, accountant: RdpAccountant::new(), learning_rate })
}
pub fn privacy_spent(&self) -> (f64, f64) {
self.accountant.get_privacy_spent(self.config.budget.delta)
}
pub fn current_epsilon(&self) -> f64 {
self.privacy_spent().0
}
pub fn remaining_budget(&self) -> f64 {
self.config.budget.remaining(self.current_epsilon())
}
pub fn is_budget_exhausted(&self) -> bool {
!self.config.budget.allows(self.current_epsilon())
}
pub fn n_steps(&self) -> usize {
self.accountant.n_steps()
}
pub fn config(&self) -> &DpSgdConfig {
&self.config
}
pub fn privatize_gradients(&mut self, per_sample_grads: &[Vec<f64>]) -> Result<Vec<f64>> {
if per_sample_grads.is_empty() {
return Err(DpError::GradientError("No gradients provided".to_string()));
}
if self.config.strict_budget && self.is_budget_exhausted() {
return Err(DpError::BudgetExhausted {
spent: self.current_epsilon(),
budget: self.config.budget.epsilon,
});
}
let n_samples = per_sample_grads.len();
let grad_dim = per_sample_grads[0].len();
let clipped: Vec<Vec<f64>> =
per_sample_grads.iter().map(|g| clip_gradient(g, self.config.max_grad_norm)).collect();
let mut averaged = vec![0.0; grad_dim];
for g in &clipped {
for (i, &val) in g.iter().enumerate() {
averaged[i] += val / n_samples as f64;
}
}
let mut rng = rand::rng();
let noise_std = self.config.noise_std() / n_samples as f64;
let noised = add_gaussian_noise(&averaged, noise_std, &mut rng);
self.accountant.step(self.config.noise_multiplier, self.config.sample_rate);
Ok(noised)
}
pub fn apply_update(&self, params: &mut [f64], grad: &[f64]) {
for (p, g) in params.iter_mut().zip(grad.iter()) {
*p -= self.learning_rate * g;
}
}
pub fn step(
&mut self,
params: &mut [f64],
per_sample_grads: &[Vec<f64>],
) -> Result<(f64, f64)> {
let grad = self.privatize_gradients(per_sample_grads)?;
self.apply_update(params, &grad);
Ok(self.privacy_spent())
}
pub fn reset(&mut self) {
self.accountant.reset();
}
}