Skip to main content

entrenar/optim/dp/
config.rs

1//! Configuration for DP-SGD.
2
3use serde::{Deserialize, Serialize};
4
5use super::budget::PrivacyBudget;
6use super::error::{DpError, Result};
7
8/// Configuration for DP-SGD
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct DpSgdConfig {
11    /// Maximum gradient norm for clipping
12    pub max_grad_norm: f64,
13    /// Noise multiplier (sigma = noise_multiplier * max_grad_norm)
14    pub noise_multiplier: f64,
15    /// Privacy budget
16    pub budget: PrivacyBudget,
17    /// Sampling rate (batch_size / dataset_size)
18    pub sample_rate: f64,
19    /// Whether to stop training when budget is exhausted
20    pub strict_budget: bool,
21}
22
23impl DpSgdConfig {
24    /// Create a new DP-SGD configuration
25    pub fn new() -> Self {
26        Self::default()
27    }
28
29    /// Set maximum gradient norm
30    pub fn with_max_grad_norm(mut self, norm: f64) -> Self {
31        self.max_grad_norm = norm;
32        self
33    }
34
35    /// Set noise multiplier
36    pub fn with_noise_multiplier(mut self, multiplier: f64) -> Self {
37        self.noise_multiplier = multiplier.max(0.0);
38        self
39    }
40
41    /// Set privacy budget
42    pub fn with_budget(mut self, budget: PrivacyBudget) -> Self {
43        self.budget = budget;
44        self
45    }
46
47    /// Set sampling rate
48    pub fn with_sample_rate(mut self, rate: f64) -> Self {
49        self.sample_rate = rate.clamp(0.0, 1.0);
50        self
51    }
52
53    /// Set strict budget enforcement
54    pub fn with_strict_budget(mut self, strict: bool) -> Self {
55        self.strict_budget = strict;
56        self
57    }
58
59    /// Compute noise standard deviation
60    pub fn noise_std(&self) -> f64 {
61        self.noise_multiplier * self.max_grad_norm
62    }
63
64    /// Validate configuration
65    pub fn validate(&self) -> Result<()> {
66        if self.max_grad_norm <= 0.0 {
67            return Err(DpError::InvalidConfig("max_grad_norm must be positive".to_string()));
68        }
69        if self.noise_multiplier < 0.0 {
70            return Err(DpError::InvalidConfig(
71                "noise_multiplier must be non-negative".to_string(),
72            ));
73        }
74        if self.budget.epsilon <= 0.0 {
75            return Err(DpError::InvalidConfig("epsilon must be positive".to_string()));
76        }
77        if self.budget.delta <= 0.0 || self.budget.delta >= 1.0 {
78            return Err(DpError::InvalidConfig("delta must be in (0, 1)".to_string()));
79        }
80        Ok(())
81    }
82}
83
84impl Default for DpSgdConfig {
85    fn default() -> Self {
86        Self {
87            max_grad_norm: 1.0,
88            noise_multiplier: 1.1,
89            budget: PrivacyBudget::default(),
90            sample_rate: 0.01, // 1% batch size
91            strict_budget: true,
92        }
93    }
94}