use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PartialFingerprint {
pub source_id: String,
pub local_epsilon: f64,
pub record_count: u64,
pub column_names: Vec<String>,
pub means: Vec<f64>,
pub stds: Vec<f64>,
pub mins: Vec<f64>,
pub maxs: Vec<f64>,
pub correlations: Vec<f64>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum AggregationMethod {
#[default]
WeightedAverage,
Median,
TrimmedMean,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FederatedConfig {
pub min_sources: usize,
pub max_epsilon_per_source: f64,
pub aggregation_method: AggregationMethod,
}
impl Default for FederatedConfig {
fn default() -> Self {
Self {
min_sources: 2,
max_epsilon_per_source: 5.0,
aggregation_method: AggregationMethod::WeightedAverage,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AggregatedFingerprint {
pub column_names: Vec<String>,
pub means: Vec<f64>,
pub stds: Vec<f64>,
pub mins: Vec<f64>,
pub maxs: Vec<f64>,
pub correlations: Vec<f64>,
pub total_record_count: u64,
pub total_epsilon: f64,
pub source_count: usize,
}
#[derive(Debug, Clone)]
pub struct FederatedFingerprintProtocol {
config: FederatedConfig,
}
impl FederatedFingerprintProtocol {
pub fn new(config: FederatedConfig) -> Self {
Self { config }
}
#[allow(clippy::too_many_arguments)]
pub fn create_partial(
source_id: &str,
columns: Vec<String>,
record_count: u64,
means: Vec<f64>,
stds: Vec<f64>,
mins: Vec<f64>,
maxs: Vec<f64>,
correlations: Vec<f64>,
epsilon: f64,
) -> PartialFingerprint {
PartialFingerprint {
source_id: source_id.to_string(),
local_epsilon: epsilon,
record_count,
column_names: columns,
means,
stds,
mins,
maxs,
correlations,
}
}
pub fn aggregate(
&self,
partials: &[PartialFingerprint],
) -> Result<AggregatedFingerprint, String> {
if partials.len() < self.config.min_sources {
return Err(format!(
"Need at least {} sources, got {}",
self.config.min_sources,
partials.len()
));
}
for p in partials {
if p.record_count == 0 {
return Err(format!("Source '{}' has zero records", p.source_id));
}
if p.local_epsilon > self.config.max_epsilon_per_source {
return Err(format!(
"Source '{}' epsilon {} exceeds max {}",
p.source_id, p.local_epsilon, self.config.max_epsilon_per_source
));
}
}
let first = &partials[0];
let n_cols = first.column_names.len();
for p in &partials[1..] {
if p.column_names.len() != n_cols {
return Err(format!(
"Column count mismatch: source '{}' has {} columns, expected {}",
p.source_id,
p.column_names.len(),
n_cols
));
}
for (i, name) in p.column_names.iter().enumerate() {
if name != &first.column_names[i] {
return Err(format!(
"Column name mismatch at index {}: source '{}' has '{}', expected '{}'",
i, p.source_id, name, first.column_names[i]
));
}
}
}
let total_record_count: u64 = partials.iter().map(|p| p.record_count).sum();
let total_epsilon: f64 = partials.iter().map(|p| p.local_epsilon).sum();
match self.config.aggregation_method {
AggregationMethod::WeightedAverage => {
self.aggregate_weighted(partials, n_cols, total_record_count, total_epsilon)
}
AggregationMethod::Median => {
self.aggregate_median(partials, n_cols, total_record_count, total_epsilon)
}
AggregationMethod::TrimmedMean => {
self.aggregate_trimmed_mean(partials, n_cols, total_record_count, total_epsilon)
}
}
}
fn aggregate_weighted(
&self,
partials: &[PartialFingerprint],
n_cols: usize,
total_record_count: u64,
total_epsilon: f64,
) -> Result<AggregatedFingerprint, String> {
if total_record_count == 0 {
return Err("Cannot aggregate fingerprints: total record count is zero".to_string());
}
let total_f = total_record_count as f64;
let mut agg_means = vec![0.0_f64; n_cols];
let mut agg_stds = vec![0.0_f64; n_cols];
let mut agg_mins = vec![f64::INFINITY; n_cols];
let mut agg_maxs = vec![f64::NEG_INFINITY; n_cols];
for p in partials {
let w = p.record_count as f64 / total_f;
for i in 0..n_cols {
if i < p.means.len() {
agg_means[i] += w * p.means[i];
}
if i < p.stds.len() {
agg_stds[i] += w * p.stds[i];
}
if i < p.mins.len() && p.mins[i].is_finite() && p.mins[i] < agg_mins[i] {
agg_mins[i] = p.mins[i];
}
if i < p.maxs.len() && p.maxs[i].is_finite() && p.maxs[i] > agg_maxs[i] {
agg_maxs[i] = p.maxs[i];
}
}
}
let corr_len = n_cols * n_cols;
let all_have_corr = partials.iter().all(|p| p.correlations.len() == corr_len);
let agg_corr = if all_have_corr && corr_len > 0 {
let mut corr = vec![0.0_f64; corr_len];
for p in partials {
let w = p.record_count as f64 / total_f;
for (j, val) in p.correlations.iter().enumerate() {
corr[j] += w * val;
}
}
corr
} else {
Vec::new()
};
Ok(AggregatedFingerprint {
column_names: partials[0].column_names.clone(),
means: agg_means,
stds: agg_stds,
mins: agg_mins,
maxs: agg_maxs,
correlations: agg_corr,
total_record_count,
total_epsilon,
source_count: partials.len(),
})
}
fn aggregate_median(
&self,
partials: &[PartialFingerprint],
n_cols: usize,
total_record_count: u64,
total_epsilon: f64,
) -> Result<AggregatedFingerprint, String> {
let mut agg_means = vec![0.0_f64; n_cols];
let mut agg_stds = vec![0.0_f64; n_cols];
let mut agg_mins = vec![f64::INFINITY; n_cols];
let mut agg_maxs = vec![f64::NEG_INFINITY; n_cols];
for i in 0..n_cols {
let mut col_means: Vec<f64> = partials
.iter()
.filter(|p| i < p.means.len())
.map(|p| p.means[i])
.collect();
let mut col_stds: Vec<f64> = partials
.iter()
.filter(|p| i < p.stds.len())
.map(|p| p.stds[i])
.collect();
col_means.sort_by(f64::total_cmp);
col_stds.sort_by(f64::total_cmp);
agg_means[i] = compute_median(&col_means);
agg_stds[i] = compute_median(&col_stds);
for p in partials {
if i < p.mins.len() && p.mins[i].is_finite() && p.mins[i] < agg_mins[i] {
agg_mins[i] = p.mins[i];
}
if i < p.maxs.len() && p.maxs[i].is_finite() && p.maxs[i] > agg_maxs[i] {
agg_maxs[i] = p.maxs[i];
}
}
}
Ok(AggregatedFingerprint {
column_names: partials[0].column_names.clone(),
means: agg_means,
stds: agg_stds,
mins: agg_mins,
maxs: agg_maxs,
correlations: Vec::new(),
total_record_count,
total_epsilon,
source_count: partials.len(),
})
}
fn aggregate_trimmed_mean(
&self,
partials: &[PartialFingerprint],
n_cols: usize,
total_record_count: u64,
total_epsilon: f64,
) -> Result<AggregatedFingerprint, String> {
let n = partials.len();
let trim_count = (n as f64 * 0.1).floor() as usize;
let mut agg_means = vec![0.0_f64; n_cols];
let mut agg_stds = vec![0.0_f64; n_cols];
let mut agg_mins = vec![f64::INFINITY; n_cols];
let mut agg_maxs = vec![f64::NEG_INFINITY; n_cols];
for i in 0..n_cols {
let mut col_means: Vec<f64> = partials
.iter()
.filter(|p| i < p.means.len())
.map(|p| p.means[i])
.collect();
let mut col_stds: Vec<f64> = partials
.iter()
.filter(|p| i < p.stds.len())
.map(|p| p.stds[i])
.collect();
col_means.sort_by(f64::total_cmp);
col_stds.sort_by(f64::total_cmp);
let trimmed_means = trim_slice(&col_means, trim_count);
let trimmed_stds = trim_slice(&col_stds, trim_count);
agg_means[i] = if trimmed_means.is_empty() {
0.0
} else {
trimmed_means.iter().sum::<f64>() / trimmed_means.len() as f64
};
agg_stds[i] = if trimmed_stds.is_empty() {
0.0
} else {
trimmed_stds.iter().sum::<f64>() / trimmed_stds.len() as f64
};
for p in partials {
if i < p.mins.len() && p.mins[i].is_finite() && p.mins[i] < agg_mins[i] {
agg_mins[i] = p.mins[i];
}
if i < p.maxs.len() && p.maxs[i].is_finite() && p.maxs[i] > agg_maxs[i] {
agg_maxs[i] = p.maxs[i];
}
}
}
Ok(AggregatedFingerprint {
column_names: partials[0].column_names.clone(),
means: agg_means,
stds: agg_stds,
mins: agg_mins,
maxs: agg_maxs,
correlations: Vec::new(),
total_record_count,
total_epsilon,
source_count: partials.len(),
})
}
}
fn compute_median(sorted: &[f64]) -> f64 {
if sorted.is_empty() {
return 0.0;
}
let mid = sorted.len() / 2;
if sorted.len().is_multiple_of(2) {
(sorted[mid - 1] + sorted[mid]) / 2.0
} else {
sorted[mid]
}
}
fn trim_slice(sorted: &[f64], count: usize) -> &[f64] {
if count * 2 >= sorted.len() {
return sorted; }
&sorted[count..sorted.len() - count]
}
#[cfg(test)]
mod tests {
use super::*;
fn make_partial(
source_id: &str,
record_count: u64,
means: Vec<f64>,
stds: Vec<f64>,
epsilon: f64,
) -> PartialFingerprint {
let columns = vec!["amount".to_string(), "qty".to_string()];
FederatedFingerprintProtocol::create_partial(
source_id,
columns,
record_count,
means,
stds,
Vec::new(),
Vec::new(),
Vec::new(),
epsilon,
)
}
#[test]
fn test_three_sources_aggregate_correctly() {
let config = FederatedConfig {
min_sources: 2,
max_epsilon_per_source: 5.0,
aggregation_method: AggregationMethod::WeightedAverage,
};
let protocol = FederatedFingerprintProtocol::new(config);
let p1 = make_partial("site-a", 1000, vec![100.0, 5.0], vec![10.0, 1.0], 1.0);
let p2 = make_partial("site-b", 2000, vec![200.0, 10.0], vec![20.0, 2.0], 0.5);
let p3 = make_partial("site-c", 1000, vec![300.0, 15.0], vec![30.0, 3.0], 0.8);
let result = protocol.aggregate(&[p1, p2, p3]).expect("should aggregate");
assert_eq!(result.source_count, 3);
assert_eq!(result.total_record_count, 4000);
assert!((result.means[0] - 200.0).abs() < 1e-10);
assert!((result.means[1] - 10.0).abs() < 1e-10);
}
#[test]
fn test_weights_proportional_to_record_count() {
let config = FederatedConfig::default();
let protocol = FederatedFingerprintProtocol::new(config);
let p1 = make_partial("site-a", 1000, vec![100.0, 1.0], vec![10.0, 1.0], 1.0);
let p2 = make_partial("site-b", 3000, vec![200.0, 3.0], vec![20.0, 3.0], 1.0);
let result = protocol.aggregate(&[p1, p2]).expect("should aggregate");
assert!((result.means[0] - 175.0).abs() < 1e-10);
assert!((result.means[1] - 2.5).abs() < 1e-10);
}
#[test]
fn test_total_epsilon_sums_correctly() {
let config = FederatedConfig::default();
let protocol = FederatedFingerprintProtocol::new(config);
let p1 = make_partial("a", 100, vec![1.0, 2.0], vec![0.1, 0.2], 0.5);
let p2 = make_partial("b", 200, vec![3.0, 4.0], vec![0.3, 0.4], 1.5);
let p3 = make_partial("c", 300, vec![5.0, 6.0], vec![0.5, 0.6], 2.0);
let result = protocol.aggregate(&[p1, p2, p3]).expect("should aggregate");
assert!((result.total_epsilon - 4.0).abs() < 1e-10);
}
#[test]
fn test_empty_sources_rejected() {
let config = FederatedConfig {
min_sources: 2,
..FederatedConfig::default()
};
let protocol = FederatedFingerprintProtocol::new(config);
let p1 = make_partial("a", 100, vec![1.0, 2.0], vec![0.1, 0.2], 1.0);
let result = protocol.aggregate(&[p1]);
assert!(result.is_err());
assert!(result
.as_ref()
.err()
.is_some_and(|e| e.contains("Need at least 2 sources")));
let result = protocol.aggregate(&[]);
assert!(result.is_err());
}
#[test]
fn test_zero_record_count_rejected() {
let config = FederatedConfig::default();
let protocol = FederatedFingerprintProtocol::new(config);
let p1 = make_partial("a", 100, vec![1.0, 2.0], vec![0.1, 0.2], 1.0);
let p2 = make_partial("b", 0, vec![3.0, 4.0], vec![0.3, 0.4], 1.0);
let result = protocol.aggregate(&[p1, p2]);
assert!(result.is_err());
assert!(result
.as_ref()
.err()
.is_some_and(|e| e.contains("zero records")));
}
#[test]
fn test_single_source_works() {
let config = FederatedConfig {
min_sources: 1,
..FederatedConfig::default()
};
let protocol = FederatedFingerprintProtocol::new(config);
let p1 = make_partial("only", 500, vec![42.0, 7.0], vec![5.0, 1.0], 1.0);
let result = protocol
.aggregate(&[p1])
.expect("single source should work");
assert_eq!(result.source_count, 1);
assert_eq!(result.total_record_count, 500);
assert!((result.means[0] - 42.0).abs() < 1e-10);
assert!((result.means[1] - 7.0).abs() < 1e-10);
assert!((result.total_epsilon - 1.0).abs() < 1e-10);
}
#[test]
fn test_epsilon_per_source_limit() {
let config = FederatedConfig {
min_sources: 2,
max_epsilon_per_source: 1.0,
..FederatedConfig::default()
};
let protocol = FederatedFingerprintProtocol::new(config);
let p1 = make_partial("a", 100, vec![1.0, 2.0], vec![0.1, 0.2], 0.5);
let p2 = make_partial("b", 200, vec![3.0, 4.0], vec![0.3, 0.4], 2.0);
let result = protocol.aggregate(&[p1, p2]);
assert!(result.is_err());
assert!(result
.as_ref()
.err()
.is_some_and(|e| e.contains("exceeds max")));
}
#[test]
fn test_column_mismatch_rejected() {
let config = FederatedConfig {
min_sources: 2,
..FederatedConfig::default()
};
let protocol = FederatedFingerprintProtocol::new(config);
let p1 = make_partial("a", 100, vec![1.0, 2.0], vec![0.1, 0.2], 1.0);
let mut p2 = make_partial("b", 200, vec![3.0, 4.0], vec![0.3, 0.4], 1.0);
p2.column_names = vec!["amount".to_string(), "price".to_string()];
let result = protocol.aggregate(&[p1, p2]);
assert!(result.is_err());
assert!(result
.as_ref()
.err()
.is_some_and(|e| e.contains("Column name mismatch")));
}
#[test]
fn test_median_aggregation() {
let config = FederatedConfig {
min_sources: 1,
max_epsilon_per_source: 5.0,
aggregation_method: AggregationMethod::Median,
};
let protocol = FederatedFingerprintProtocol::new(config);
let p1 = make_partial("a", 100, vec![10.0, 1.0], vec![1.0, 0.1], 1.0);
let p2 = make_partial("b", 100, vec![20.0, 2.0], vec![2.0, 0.2], 1.0);
let p3 = make_partial("c", 100, vec![30.0, 3.0], vec![3.0, 0.3], 1.0);
let result = protocol.aggregate(&[p1, p2, p3]).expect("should aggregate");
assert!((result.means[0] - 20.0).abs() < 1e-10);
assert!((result.means[1] - 2.0).abs() < 1e-10);
}
#[test]
fn test_serde_roundtrip() {
let partial = make_partial("test", 100, vec![1.0, 2.0], vec![0.1, 0.2], 1.0);
let json = serde_json::to_string(&partial).expect("serialize");
let deserialized: PartialFingerprint = serde_json::from_str(&json).expect("deserialize");
assert_eq!(deserialized.source_id, "test");
assert_eq!(deserialized.record_count, 100);
}
}