use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
pub const RDP_ALPHA_ORDERS: [f64; 7] = [2.0, 4.0, 8.0, 16.0, 32.0, 64.0, 128.0];
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum CompositionMethod {
#[default]
Naive,
Advanced,
#[serde(rename = "renyi_dp")]
RenyiDP,
#[serde(rename = "zcdp")]
ZeroCDP,
}
impl std::fmt::Display for CompositionMethod {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Naive => write!(f, "naive"),
Self::Advanced => write!(f, "advanced"),
Self::RenyiDP => write!(f, "renyi_dp"),
Self::ZeroCDP => write!(f, "zcdp"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MechanismRecord {
pub epsilon: f64,
#[serde(default)]
pub delta: f64,
pub timestamp: DateTime<Utc>,
pub description: String,
}
impl MechanismRecord {
pub fn new(epsilon: f64, description: impl Into<String>) -> Self {
Self {
epsilon,
delta: 0.0,
timestamp: Utc::now(),
description: description.into(),
}
}
pub fn with_delta(mut self, delta: f64) -> Self {
self.delta = delta;
self
}
}
pub trait PrivacyAccountant {
fn record_mechanism(&mut self, record: MechanismRecord);
fn effective_epsilon(&self) -> f64;
fn remaining_budget(&self) -> f64;
fn is_exhausted(&self) -> bool;
fn method(&self) -> CompositionMethod;
fn mechanisms(&self) -> &[MechanismRecord];
fn target_delta(&self) -> Option<f64> {
None
}
fn optimal_alpha(&self) -> Option<f64> {
None
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NaiveAccountant {
pub total_budget: f64,
pub mechanisms: Vec<MechanismRecord>,
pub epsilon_spent: f64,
}
impl NaiveAccountant {
pub fn new(total_budget: f64) -> Self {
Self {
total_budget,
mechanisms: Vec::new(),
epsilon_spent: 0.0,
}
}
}
impl PrivacyAccountant for NaiveAccountant {
fn record_mechanism(&mut self, record: MechanismRecord) {
self.epsilon_spent += record.epsilon;
self.mechanisms.push(record);
}
fn effective_epsilon(&self) -> f64 {
self.epsilon_spent
}
fn remaining_budget(&self) -> f64 {
(self.total_budget - self.epsilon_spent).max(0.0)
}
fn is_exhausted(&self) -> bool {
self.epsilon_spent >= self.total_budget
}
fn method(&self) -> CompositionMethod {
CompositionMethod::Naive
}
fn mechanisms(&self) -> &[MechanismRecord] {
&self.mechanisms
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RenyiDPAccountant {
pub total_budget: f64,
pub target_delta: f64,
pub rdp_curve: Vec<f64>,
pub mechanisms: Vec<MechanismRecord>,
}
impl RenyiDPAccountant {
pub fn new(total_budget: f64, target_delta: f64) -> Self {
Self {
total_budget,
target_delta,
rdp_curve: vec![0.0; RDP_ALPHA_ORDERS.len()],
mechanisms: Vec::new(),
}
}
fn epsilon_to_rdp(epsilon: f64, alpha: f64) -> f64 {
if epsilon <= 0.0 {
return 0.0;
}
if alpha > 1.0 {
let term1 = ((alpha - 1.0) / (2.0 * alpha - 1.0)) * ((alpha - 1.0) * epsilon).exp();
let term2 = (alpha / (2.0 * alpha - 1.0)) * (-(alpha) * epsilon).exp();
let sum = term1 + term2;
if sum > 0.0 {
let rdp_val = sum.ln() / (alpha - 1.0);
return rdp_val.clamp(0.0, epsilon);
}
}
epsilon
}
pub fn rdp_to_dp(&self) -> (f64, f64) {
if self.target_delta <= 0.0 {
let max_rdp = self.rdp_curve.iter().copied().fold(f64::INFINITY, f64::min);
return (max_rdp, 0.0);
}
let ln_inv_delta = (1.0 / self.target_delta).ln();
let best_epsilon = RDP_ALPHA_ORDERS
.iter()
.zip(self.rdp_curve.iter())
.map(|(&alpha, &rdp_val)| rdp_val + ln_inv_delta / (alpha - 1.0))
.fold(f64::INFINITY, f64::min);
(best_epsilon, self.target_delta)
}
pub fn optimal_alpha(&self) -> f64 {
if self.target_delta <= 0.0 {
return RDP_ALPHA_ORDERS[0];
}
let ln_inv_delta = (1.0 / self.target_delta).ln();
let (best_idx, _) = RDP_ALPHA_ORDERS
.iter()
.zip(self.rdp_curve.iter())
.enumerate()
.map(|(i, (&alpha, &rdp_val))| (i, rdp_val + ln_inv_delta / (alpha - 1.0)))
.min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
.unwrap_or((0, f64::INFINITY));
RDP_ALPHA_ORDERS[best_idx]
}
}
impl PrivacyAccountant for RenyiDPAccountant {
fn record_mechanism(&mut self, record: MechanismRecord) {
for (i, &alpha) in RDP_ALPHA_ORDERS.iter().enumerate() {
let rdp_cost = Self::epsilon_to_rdp(record.epsilon, alpha);
self.rdp_curve[i] += rdp_cost;
}
self.mechanisms.push(record);
}
fn effective_epsilon(&self) -> f64 {
let (eps, _) = self.rdp_to_dp();
eps
}
fn remaining_budget(&self) -> f64 {
(self.total_budget - self.effective_epsilon()).max(0.0)
}
fn is_exhausted(&self) -> bool {
self.effective_epsilon() >= self.total_budget
}
fn method(&self) -> CompositionMethod {
CompositionMethod::RenyiDP
}
fn mechanisms(&self) -> &[MechanismRecord] {
&self.mechanisms
}
fn target_delta(&self) -> Option<f64> {
Some(self.target_delta)
}
fn optimal_alpha(&self) -> Option<f64> {
Some(self.optimal_alpha())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ZeroCDPAccountant {
pub total_budget: f64,
pub target_delta: f64,
pub total_rho: f64,
pub mechanisms: Vec<MechanismRecord>,
}
impl ZeroCDPAccountant {
pub fn new(total_budget: f64, target_delta: f64) -> Self {
Self {
total_budget,
target_delta,
total_rho: 0.0,
mechanisms: Vec::new(),
}
}
pub fn epsilon_to_rho(epsilon: f64) -> f64 {
epsilon * epsilon / 2.0
}
pub fn rho_to_dp(&self) -> (f64, f64) {
if self.total_rho <= 0.0 {
return (0.0, 0.0);
}
if self.target_delta <= 0.0 {
return (f64::INFINITY, 0.0);
}
let ln_inv_delta = (1.0 / self.target_delta).ln();
let epsilon = self.total_rho + 2.0 * (self.total_rho * ln_inv_delta).sqrt();
(epsilon, self.target_delta)
}
pub fn current_rho(&self) -> f64 {
self.total_rho
}
}
impl PrivacyAccountant for ZeroCDPAccountant {
fn record_mechanism(&mut self, record: MechanismRecord) {
let rho = Self::epsilon_to_rho(record.epsilon);
self.total_rho += rho;
self.mechanisms.push(record);
}
fn effective_epsilon(&self) -> f64 {
let (eps, _) = self.rho_to_dp();
eps
}
fn remaining_budget(&self) -> f64 {
(self.total_budget - self.effective_epsilon()).max(0.0)
}
fn is_exhausted(&self) -> bool {
self.effective_epsilon() >= self.total_budget
}
fn method(&self) -> CompositionMethod {
CompositionMethod::ZeroCDP
}
fn mechanisms(&self) -> &[MechanismRecord] {
&self.mechanisms
}
fn target_delta(&self) -> Option<f64> {
Some(self.target_delta)
}
}
pub fn create_accountant(
method: CompositionMethod,
total_budget: f64,
) -> Box<dyn PrivacyAccountant> {
match method {
CompositionMethod::Naive | CompositionMethod::Advanced => {
Box::new(NaiveAccountant::new(total_budget))
}
CompositionMethod::RenyiDP => Box::new(RenyiDPAccountant::new(total_budget, 1e-5)),
CompositionMethod::ZeroCDP => Box::new(ZeroCDPAccountant::new(total_budget, 1e-5)),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_naive_accountant_simple_sum() {
let mut acc = NaiveAccountant::new(1.0);
acc.record_mechanism(MechanismRecord::new(0.1, "query 1"));
acc.record_mechanism(MechanismRecord::new(0.2, "query 2"));
acc.record_mechanism(MechanismRecord::new(0.3, "query 3"));
assert!((acc.effective_epsilon() - 0.6).abs() < 1e-10);
assert!((acc.remaining_budget() - 0.4).abs() < 1e-10);
assert!(!acc.is_exhausted());
assert_eq!(acc.mechanisms().len(), 3);
}
#[test]
fn test_naive_accountant_exhaustion() {
let mut acc = NaiveAccountant::new(0.5);
acc.record_mechanism(MechanismRecord::new(0.3, "query 1"));
assert!(!acc.is_exhausted());
acc.record_mechanism(MechanismRecord::new(0.3, "query 2"));
assert!(acc.is_exhausted());
assert_eq!(acc.remaining_budget(), 0.0);
}
#[test]
fn test_renyi_accountant_composition() {
let mut acc = RenyiDPAccountant::new(100.0, 1e-5);
let n_queries = 1000;
let eps_per_query = 0.1;
for i in 0..n_queries {
acc.record_mechanism(MechanismRecord::new(eps_per_query, format!("query {}", i)));
}
let effective = acc.effective_epsilon();
let naive_total = n_queries as f64 * eps_per_query;
assert!(effective > 0.0, "Effective epsilon should be positive");
assert!(
effective < naive_total,
"RDP ({:.4}) should be tighter than naive ({:.4}) for many queries",
effective,
naive_total
);
}
#[test]
fn test_renyi_accountant_rdp_to_dp() {
let mut acc = RenyiDPAccountant::new(10.0, 1e-5);
acc.record_mechanism(MechanismRecord::new(1.0, "large query"));
let (eps, delta) = acc.rdp_to_dp();
assert!(eps > 0.0);
assert!((delta - 1e-5).abs() < 1e-15);
}
#[test]
fn test_renyi_optimal_alpha() {
let mut acc = RenyiDPAccountant::new(10.0, 1e-5);
acc.record_mechanism(MechanismRecord::new(0.5, "query"));
let alpha = acc.optimal_alpha();
assert!(RDP_ALPHA_ORDERS.contains(&alpha));
}
#[test]
fn test_zcdp_accountant_composition() {
let mut acc = ZeroCDPAccountant::new(10.0, 1e-5);
acc.record_mechanism(MechanismRecord::new(0.1, "query 1"));
acc.record_mechanism(MechanismRecord::new(0.2, "query 2"));
let expected_rho = 0.1_f64.powi(2) / 2.0 + 0.2_f64.powi(2) / 2.0;
assert!(
(acc.current_rho() - expected_rho).abs() < 1e-10,
"Expected rho={}, got rho={}",
expected_rho,
acc.current_rho()
);
}
#[test]
fn test_zcdp_rho_to_dp() {
let mut acc = ZeroCDPAccountant::new(10.0, 1e-5);
acc.record_mechanism(MechanismRecord::new(1.0, "query"));
let (eps, delta) = acc.rho_to_dp();
assert!(eps > 0.0);
assert!((delta - 1e-5).abs() < 1e-15);
let expected_rho = 0.5;
let expected_eps = expected_rho + 2.0 * (expected_rho * (1.0 / 1e-5_f64).ln()).sqrt();
assert!(
(eps - expected_eps).abs() < 1e-10,
"Expected eps={}, got eps={}",
expected_eps,
eps
);
}
#[test]
fn test_zcdp_epsilon_to_rho() {
assert!((ZeroCDPAccountant::epsilon_to_rho(1.0) - 0.5).abs() < 1e-10);
assert!((ZeroCDPAccountant::epsilon_to_rho(0.0) - 0.0).abs() < 1e-10);
assert!((ZeroCDPAccountant::epsilon_to_rho(2.0) - 2.0).abs() < 1e-10);
}
#[test]
fn test_composition_method_display() {
assert_eq!(CompositionMethod::Naive.to_string(), "naive");
assert_eq!(CompositionMethod::Advanced.to_string(), "advanced");
assert_eq!(CompositionMethod::RenyiDP.to_string(), "renyi_dp");
assert_eq!(CompositionMethod::ZeroCDP.to_string(), "zcdp");
}
#[test]
fn test_composition_method_serde() {
let json = serde_json::to_string(&CompositionMethod::RenyiDP).unwrap();
assert_eq!(json, "\"renyi_dp\"");
let parsed: CompositionMethod = serde_json::from_str("\"zcdp\"").unwrap();
assert_eq!(parsed, CompositionMethod::ZeroCDP);
let default: CompositionMethod = Default::default();
assert_eq!(default, CompositionMethod::Naive);
}
#[test]
fn test_mechanism_record_serde() {
let record = MechanismRecord::new(0.5, "test mechanism").with_delta(1e-5);
let json = serde_json::to_string(&record).unwrap();
let parsed: MechanismRecord = serde_json::from_str(&json).unwrap();
assert!((parsed.epsilon - 0.5).abs() < 1e-10);
assert!((parsed.delta - 1e-5).abs() < 1e-15);
assert_eq!(parsed.description, "test mechanism");
}
#[test]
fn test_create_accountant_factory() {
let acc = create_accountant(CompositionMethod::Naive, 1.0);
assert_eq!(acc.method(), CompositionMethod::Naive);
let acc = create_accountant(CompositionMethod::RenyiDP, 1.0);
assert_eq!(acc.method(), CompositionMethod::RenyiDP);
let acc = create_accountant(CompositionMethod::ZeroCDP, 1.0);
assert_eq!(acc.method(), CompositionMethod::ZeroCDP);
}
#[test]
fn test_rdp_tighter_than_naive_many_queries() {
let n_queries = 100;
let eps_per_query = 0.01;
let mut naive = NaiveAccountant::new(100.0);
let mut rdp = RenyiDPAccountant::new(100.0, 1e-5);
for i in 0..n_queries {
let record = MechanismRecord::new(eps_per_query, format!("q{}", i));
naive.record_mechanism(record.clone());
rdp.record_mechanism(record);
}
let naive_eps = naive.effective_epsilon();
let rdp_eps = rdp.effective_epsilon();
assert!(
rdp_eps <= naive_eps + 1e-10,
"RDP ({}) should not be worse than naive ({})",
rdp_eps,
naive_eps
);
}
#[test]
fn test_zcdp_exhaustion() {
let mut acc = ZeroCDPAccountant::new(1.0, 1e-5);
for i in 0..100 {
if acc.is_exhausted() {
assert!(i > 0, "Should not be exhausted immediately");
return;
}
acc.record_mechanism(MechanismRecord::new(0.1, format!("q{}", i)));
}
assert!(acc.is_exhausted());
}
}