pub mod crypto;
pub mod privacy;
pub use crypto::*;
pub use privacy::*;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use trustformers_core::{Tensor};
use trustformers_core::error::{CoreError, Result};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FederatedLearningV2Config {
pub privacy_config: AdvancedPrivacyConfig,
pub crypto_config: CryptographicConfig,
pub aggregation_config: SecureAggregationConfig,
pub communication_config: CommunicationProtocolConfig,
pub training_config: FederatedTrainingConfig,
pub security_config: SecurityConfig,
pub accounting_config: PrivacyAccountingConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SecureAggregationConfig {
pub min_participants: u32,
pub max_participants: u32,
pub dropout_threshold: f64,
pub byzantine_fault_tolerance: bool,
pub aggregation_function: AggregationFunction,
pub verification_methods: Vec<VerificationMethod>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum AggregationFunction {
FederatedAveraging,
WeightedAveraging,
MedianAggregation,
TrimmedMean,
CoordinateWiseMedian,
GeometricMedian,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum VerificationMethod {
DigitalSignatures,
ZeroKnowledgeProofs,
CommitmentSchemes,
HashBased,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CommunicationProtocolConfig {
pub transport_security: TransportSecurity,
pub compression: CompressionConfig,
pub bandwidth_adaptation: BandwidthAdaptationConfig,
pub message_routing: MessageRoutingConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TransportSecurity {
pub use_tls: bool,
pub tls_version: String,
pub certificate_validation: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompressionConfig {
pub algorithm: CompressionAlgorithm,
pub compression_level: u8,
pub adaptive_compression: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum CompressionAlgorithm {
None,
Gzip,
Brotli,
Zstd,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BandwidthAdaptationConfig {
pub adaptive_batching: bool,
pub qos_priority: u8,
pub congestion_control: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessageRoutingConfig {
pub use_gossip: bool,
pub gossip_fanout: u32,
pub message_ttl: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FederatedTrainingConfig {
pub num_rounds: u32,
pub client_sampling: ClientSamplingStrategy,
pub update_frequency: u32,
pub personalization: PersonalizationConfig,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ClientSamplingStrategy {
Random,
RoundRobin,
PerformanceBased,
ResourceAware,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PersonalizationConfig {
pub enabled: bool,
pub strategy: PersonalizationStrategy,
pub local_update_ratio: f64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum PersonalizationStrategy {
LocalFineTuning,
MetaLearning,
MultiTaskLearning,
TransferLearning,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SecurityConfig {
pub attack_detection_enabled: bool,
pub byzantine_fault_tolerance: bool,
pub robust_aggregation_methods: Vec<RobustAggregationMethod>,
pub anomaly_detection_thresholds: AnomalyDetectionThresholds,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum RobustAggregationMethod {
TrimmedMean,
Median,
Krum,
Bulyan,
FoolsGold,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AnomalyDetectionThresholds {
pub gradient_norm_threshold: f64,
pub model_accuracy_threshold: f64,
pub participation_rate_threshold: f64,
}
impl Default for FederatedLearningV2Config {
fn default() -> Self {
Self {
privacy_config: AdvancedPrivacyConfig::default(),
crypto_config: CryptographicConfig::default(),
aggregation_config: SecureAggregationConfig::default(),
communication_config: CommunicationProtocolConfig::default(),
training_config: FederatedTrainingConfig::default(),
security_config: SecurityConfig::default(),
accounting_config: PrivacyAccountingConfig::default(),
}
}
}
impl Default for SecureAggregationConfig {
fn default() -> Self {
Self {
min_participants: 3,
max_participants: 1000,
dropout_threshold: 0.1,
byzantine_fault_tolerance: true,
aggregation_function: AggregationFunction::FederatedAveraging,
verification_methods: vec![VerificationMethod::DigitalSignatures],
}
}
}
impl Default for CommunicationProtocolConfig {
fn default() -> Self {
Self {
transport_security: TransportSecurity::default(),
compression: CompressionConfig::default(),
bandwidth_adaptation: BandwidthAdaptationConfig::default(),
message_routing: MessageRoutingConfig::default(),
}
}
}
impl Default for TransportSecurity {
fn default() -> Self {
Self {
use_tls: true,
tls_version: "1.3".to_string(),
certificate_validation: true,
}
}
}
impl Default for CompressionConfig {
fn default() -> Self {
Self {
algorithm: CompressionAlgorithm::Gzip,
compression_level: 6,
adaptive_compression: true,
}
}
}
impl Default for BandwidthAdaptationConfig {
fn default() -> Self {
Self {
adaptive_batching: true,
qos_priority: 3,
congestion_control: true,
}
}
}
impl Default for MessageRoutingConfig {
fn default() -> Self {
Self {
use_gossip: false,
gossip_fanout: 3,
message_ttl: 100,
}
}
}
impl Default for FederatedTrainingConfig {
fn default() -> Self {
Self {
num_rounds: 100,
client_sampling: ClientSamplingStrategy::Random,
update_frequency: 1,
personalization: PersonalizationConfig::default(),
}
}
}
impl Default for PersonalizationConfig {
fn default() -> Self {
Self {
enabled: false,
strategy: PersonalizationStrategy::LocalFineTuning,
local_update_ratio: 0.1,
}
}
}
impl Default for SecurityConfig {
fn default() -> Self {
Self {
attack_detection_enabled: true,
byzantine_fault_tolerance: true,
robust_aggregation_methods: vec![RobustAggregationMethod::TrimmedMean],
anomaly_detection_thresholds: AnomalyDetectionThresholds::default(),
}
}
}
impl Default for AnomalyDetectionThresholds {
fn default() -> Self {
Self {
gradient_norm_threshold: 10.0,
model_accuracy_threshold: 0.5,
participation_rate_threshold: 0.8,
}
}
}
pub struct FederatedLearningV2 {
config: FederatedLearningV2Config,
privacy_budget_tracker: PrivacyBudgetTracker,
key_manager: CryptographicKeyManager,
participants: HashMap<String, ParticipantInfo>,
round_number: u32,
}
#[derive(Debug, Clone)]
pub struct ParticipantInfo {
pub id: String,
pub public_key: Vec<u8>,
pub last_seen: std::time::SystemTime,
pub trust_score: f64,
pub contribution_history: Vec<f64>,
}
impl FederatedLearningV2 {
pub fn new(config: FederatedLearningV2Config) -> Self {
let privacy_budget_tracker = PrivacyBudgetTracker::new(
config.accounting_config.max_epsilon,
config.accounting_config.max_delta,
config.accounting_config.accounting_method,
);
Self {
config,
privacy_budget_tracker,
key_manager: CryptographicKeyManager::new(),
participants: HashMap::new(),
round_number: 0,
}
}
pub fn add_participant(&mut self, participant_info: ParticipantInfo) {
self.participants.insert(participant_info.id.clone(), participant_info);
}
pub fn remove_participant(&mut self, participant_id: &str) {
self.participants.remove(participant_id);
self.key_manager.remove_participant_key(participant_id);
}
pub fn start_training_round(&mut self) -> Result<()> {
self.round_number += 1;
let epsilon_per_round = self.config.privacy_config.epsilon / self.config.training_config.num_rounds as f64;
let delta_per_round = self.config.privacy_config.delta / self.config.training_config.num_rounds as f64;
if !self.privacy_budget_tracker.can_consume(epsilon_per_round, delta_per_round) {
return Err(TrustformersError::runtime_error("Privacy budget exceeded".into()).into());
}
self.privacy_budget_tracker.consume(epsilon_per_round, delta_per_round)
.map_err(|e| TrustformersError::runtime_error(e))?;
Ok(())
}
pub fn get_round_number(&self) -> u32 {
self.round_number
}
pub fn get_participant_count(&self) -> usize {
self.participants.len()
}
pub fn get_privacy_budget_status(&self) -> (f64, f64) {
self.privacy_budget_tracker.remaining_budget()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_federated_learning_v2_creation() {
let config = FederatedLearningV2Config::default();
let fl = FederatedLearningV2::new(config);
assert_eq!(fl.get_round_number(), 0);
assert_eq!(fl.get_participant_count(), 0);
}
#[test]
fn test_participant_management() {
let config = FederatedLearningV2Config::default();
let mut fl = FederatedLearningV2::new(config);
let participant = ParticipantInfo {
id: "participant1".to_string(),
public_key: vec![1, 2, 3, 4],
last_seen: std::time::SystemTime::now(),
trust_score: 0.9,
contribution_history: vec![0.8, 0.85, 0.9],
};
fl.add_participant(participant);
assert_eq!(fl.get_participant_count(), 1);
fl.remove_participant("participant1");
assert_eq!(fl.get_participant_count(), 0);
}
#[test]
fn test_training_round_privacy_budget() {
let config = FederatedLearningV2Config::default();
let mut fl = FederatedLearningV2::new(config);
for _ in 0..10 {
assert!(fl.start_training_round().is_ok());
}
assert_eq!(fl.get_round_number(), 10);
let (remaining_eps, remaining_delta) = fl.get_privacy_budget_status();
assert!(remaining_eps >= 0.0);
assert!(remaining_delta >= 0.0);
}
}