entrenar/optim/dp/
config.rs1use serde::{Deserialize, Serialize};
4
5use super::budget::PrivacyBudget;
6use super::error::{DpError, Result};
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct DpSgdConfig {
11 pub max_grad_norm: f64,
13 pub noise_multiplier: f64,
15 pub budget: PrivacyBudget,
17 pub sample_rate: f64,
19 pub strict_budget: bool,
21}
22
23impl DpSgdConfig {
24 pub fn new() -> Self {
26 Self::default()
27 }
28
29 pub fn with_max_grad_norm(mut self, norm: f64) -> Self {
31 self.max_grad_norm = norm;
32 self
33 }
34
35 pub fn with_noise_multiplier(mut self, multiplier: f64) -> Self {
37 self.noise_multiplier = multiplier.max(0.0);
38 self
39 }
40
41 pub fn with_budget(mut self, budget: PrivacyBudget) -> Self {
43 self.budget = budget;
44 self
45 }
46
47 pub fn with_sample_rate(mut self, rate: f64) -> Self {
49 self.sample_rate = rate.clamp(0.0, 1.0);
50 self
51 }
52
53 pub fn with_strict_budget(mut self, strict: bool) -> Self {
55 self.strict_budget = strict;
56 self
57 }
58
59 pub fn noise_std(&self) -> f64 {
61 self.noise_multiplier * self.max_grad_norm
62 }
63
64 pub fn validate(&self) -> Result<()> {
66 if self.max_grad_norm <= 0.0 {
67 return Err(DpError::InvalidConfig("max_grad_norm must be positive".to_string()));
68 }
69 if self.noise_multiplier < 0.0 {
70 return Err(DpError::InvalidConfig(
71 "noise_multiplier must be non-negative".to_string(),
72 ));
73 }
74 if self.budget.epsilon <= 0.0 {
75 return Err(DpError::InvalidConfig("epsilon must be positive".to_string()));
76 }
77 if self.budget.delta <= 0.0 || self.budget.delta >= 1.0 {
78 return Err(DpError::InvalidConfig("delta must be in (0, 1)".to_string()));
79 }
80 Ok(())
81 }
82}
83
84impl Default for DpSgdConfig {
85 fn default() -> Self {
86 Self {
87 max_grad_norm: 1.0,
88 noise_multiplier: 1.1,
89 budget: PrivacyBudget::default(),
90 sample_rate: 0.01, strict_budget: true,
92 }
93 }
94}