use crate::ModelConfig;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FederatedConfig {
pub base_config: ModelConfig,
pub num_participants: usize,
pub communication_rounds: usize,
pub local_epochs: usize,
pub min_participants: usize,
pub privacy_config: PrivacyConfig,
pub aggregation_strategy: AggregationStrategy,
pub communication_config: CommunicationConfig,
pub security_config: SecurityConfig,
pub personalization_config: PersonalizationConfig,
}
impl Default for FederatedConfig {
fn default() -> Self {
Self {
base_config: ModelConfig::default(),
num_participants: 10,
communication_rounds: 100,
local_epochs: 5,
min_participants: 5,
privacy_config: PrivacyConfig::default(),
aggregation_strategy: AggregationStrategy::FederatedAveraging,
communication_config: CommunicationConfig::default(),
security_config: SecurityConfig::default(),
personalization_config: PersonalizationConfig::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PrivacyConfig {
pub enable_differential_privacy: bool,
pub epsilon: f64,
pub delta: f64,
pub noise_mechanism: NoiseMechanism,
pub clipping_threshold: f64,
pub local_epsilon: f64,
pub global_epsilon: f64,
}
impl Default for PrivacyConfig {
fn default() -> Self {
Self {
enable_differential_privacy: true,
epsilon: 1.0,
delta: 1e-5,
noise_mechanism: NoiseMechanism::Gaussian,
clipping_threshold: 1.0,
local_epsilon: 0.5,
global_epsilon: 0.5,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum NoiseMechanism {
Gaussian,
Laplace,
Exponential,
SparseVector,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum AggregationStrategy {
FederatedAveraging,
WeightedAveraging,
SecureAggregation,
RobustAggregation,
PersonalizedAggregation,
HierarchicalAggregation,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CommunicationConfig {
pub enable_compression: bool,
pub compression_ratio: f64,
pub quantization_bits: u8,
pub enable_sparsification: bool,
pub sparsity_threshold: f64,
pub protocol: CommunicationProtocol,
pub batch_communication: bool,
pub timeout_seconds: u64,
}
impl Default for CommunicationConfig {
fn default() -> Self {
Self {
enable_compression: true,
compression_ratio: 0.1,
quantization_bits: 8,
enable_sparsification: true,
sparsity_threshold: 0.01,
protocol: CommunicationProtocol::Synchronous,
batch_communication: true,
timeout_seconds: 300,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum CommunicationProtocol {
Synchronous,
Asynchronous,
SemiSynchronous { staleness_bound: usize },
PeerToPeer,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SecurityConfig {
pub enable_homomorphic_encryption: bool,
pub encryption_scheme: EncryptionScheme,
pub enable_secure_mpc: bool,
pub verification_mechanisms: Vec<VerificationMechanism>,
pub certificate_config: CertificateConfig,
pub authentication_config: AuthenticationConfig,
}
impl Default for SecurityConfig {
fn default() -> Self {
Self {
enable_homomorphic_encryption: false,
encryption_scheme: EncryptionScheme::CKKS,
enable_secure_mpc: false,
verification_mechanisms: vec![VerificationMechanism::DigitalSignature],
certificate_config: CertificateConfig::default(),
authentication_config: AuthenticationConfig::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum EncryptionScheme {
CKKS,
BFV,
SEAL,
HElib,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum VerificationMechanism {
DigitalSignature,
ZeroKnowledgeProof,
CommitmentScheme,
HashVerification,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CertificateConfig {
pub ca_endpoint: String,
pub validity_days: u32,
pub key_length: u32,
pub validate_chain: bool,
}
impl Default for CertificateConfig {
fn default() -> Self {
Self {
ca_endpoint: "https://ca.example.com".to_string(),
validity_days: 365,
key_length: 2048,
validate_chain: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthenticationConfig {
pub method: AuthenticationMethod,
pub token_expiry_hours: u32,
pub enable_mfa: bool,
pub identity_provider: String,
}
impl Default for AuthenticationConfig {
fn default() -> Self {
Self {
method: AuthenticationMethod::OAuth2,
token_expiry_hours: 24,
enable_mfa: false,
identity_provider: "https://idp.example.com".to_string(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum AuthenticationMethod {
OAuth2,
JWT,
SAML,
MTLS,
ApiKey,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PersonalizationConfig {
pub enable_personalization: bool,
pub strategy: PersonalizationStrategy,
pub local_adaptation_weight: f64,
pub global_model_weight: f64,
pub personalization_layers: Vec<String>,
pub meta_learning_config: MetaLearningConfig,
}
impl Default for PersonalizationConfig {
fn default() -> Self {
Self {
enable_personalization: true,
strategy: PersonalizationStrategy::LocalAdaptation,
local_adaptation_weight: 0.3,
global_model_weight: 0.7,
personalization_layers: vec!["embedding".to_string(), "output".to_string()],
meta_learning_config: MetaLearningConfig::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum PersonalizationStrategy {
LocalAdaptation,
MultiTaskLearning,
MetaLearning,
MixtureOfExperts,
PersonalizedLayers,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MetaLearningConfig {
pub algorithm: MetaLearningAlgorithm,
pub inner_learning_rate: f64,
pub outer_learning_rate: f64,
pub inner_steps: usize,
pub support_set_size: usize,
pub query_set_size: usize,
}
impl Default for MetaLearningConfig {
fn default() -> Self {
Self {
algorithm: MetaLearningAlgorithm::MAML,
inner_learning_rate: 0.01,
outer_learning_rate: 0.001,
inner_steps: 5,
support_set_size: 10,
query_set_size: 5,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum MetaLearningAlgorithm {
MAML,
Reptile,
PrototypicalNetworks,
MatchingNetworks,
MANN,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingConfig {
pub convergence_threshold: f64,
pub max_global_iterations: usize,
pub patience: usize,
pub learning_rate_decay: f64,
pub min_learning_rate: f64,
}
impl Default for TrainingConfig {
fn default() -> Self {
Self {
convergence_threshold: 1e-6,
max_global_iterations: 1000,
patience: 10,
learning_rate_decay: 0.95,
min_learning_rate: 1e-6,
}
}
}