oxirs_embed/federated_learning/
config.rs

1//! Configuration structures for federated learning
2//!
3//! This module contains all configuration types for federated learning including
4//! privacy settings, communication protocols, security configurations, and
5//! personalization options.
6
7use crate::ModelConfig;
8use serde::{Deserialize, Serialize};
9
10/// Configuration for federated learning
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct FederatedConfig {
13    /// Base model configuration
14    pub base_config: ModelConfig,
15    /// Number of federated participants
16    pub num_participants: usize,
17    /// Communication rounds
18    pub communication_rounds: usize,
19    /// Local training epochs per round
20    pub local_epochs: usize,
21    /// Minimum participants required for aggregation
22    pub min_participants: usize,
23    /// Differential privacy configuration
24    pub privacy_config: PrivacyConfig,
25    /// Aggregation strategy
26    pub aggregation_strategy: AggregationStrategy,
27    /// Communication optimization
28    pub communication_config: CommunicationConfig,
29    /// Secure computation settings
30    pub security_config: SecurityConfig,
31    /// Personalization settings
32    pub personalization_config: PersonalizationConfig,
33}
34
35impl Default for FederatedConfig {
36    fn default() -> Self {
37        Self {
38            base_config: ModelConfig::default(),
39            num_participants: 10,
40            communication_rounds: 100,
41            local_epochs: 5,
42            min_participants: 5,
43            privacy_config: PrivacyConfig::default(),
44            aggregation_strategy: AggregationStrategy::FederatedAveraging,
45            communication_config: CommunicationConfig::default(),
46            security_config: SecurityConfig::default(),
47            personalization_config: PersonalizationConfig::default(),
48        }
49    }
50}
51
52/// Privacy-preserving configuration
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct PrivacyConfig {
55    /// Enable differential privacy
56    pub enable_differential_privacy: bool,
57    /// Privacy budget (epsilon)
58    pub epsilon: f64,
59    /// Privacy delta parameter
60    pub delta: f64,
61    /// Noise mechanism
62    pub noise_mechanism: NoiseMechanism,
63    /// Gradient clipping threshold
64    pub clipping_threshold: f64,
65    /// Local privacy budget
66    pub local_epsilon: f64,
67    /// Global privacy budget
68    pub global_epsilon: f64,
69}
70
71impl Default for PrivacyConfig {
72    fn default() -> Self {
73        Self {
74            enable_differential_privacy: true,
75            epsilon: 1.0,
76            delta: 1e-5,
77            noise_mechanism: NoiseMechanism::Gaussian,
78            clipping_threshold: 1.0,
79            local_epsilon: 0.5,
80            global_epsilon: 0.5,
81        }
82    }
83}
84
85/// Noise mechanisms for differential privacy
86#[derive(Debug, Clone, Serialize, Deserialize)]
87pub enum NoiseMechanism {
88    /// Gaussian noise mechanism
89    Gaussian,
90    /// Laplace noise mechanism
91    Laplace,
92    /// Exponential mechanism
93    Exponential,
94    /// Sparse vector technique
95    SparseVector,
96}
97
98/// Aggregation strategies for federated learning
99#[derive(Debug, Clone, Serialize, Deserialize)]
100pub enum AggregationStrategy {
101    /// Standard federated averaging
102    FederatedAveraging,
103    /// Weighted federated averaging
104    WeightedAveraging,
105    /// Secure aggregation
106    SecureAggregation,
107    /// Robust aggregation (Byzantine-resistant)
108    RobustAggregation,
109    /// Personalized aggregation
110    PersonalizedAggregation,
111    /// Hierarchical aggregation
112    HierarchicalAggregation,
113}
114
115/// Communication optimization configuration
116#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct CommunicationConfig {
118    /// Enable gradient compression
119    pub enable_compression: bool,
120    /// Compression ratio
121    pub compression_ratio: f64,
122    /// Quantization bits
123    pub quantization_bits: u8,
124    /// Enable sparsification
125    pub enable_sparsification: bool,
126    /// Sparsity threshold
127    pub sparsity_threshold: f64,
128    /// Communication protocol
129    pub protocol: CommunicationProtocol,
130    /// Batch communication
131    pub batch_communication: bool,
132    /// Communication timeout (seconds)
133    pub timeout_seconds: u64,
134}
135
136impl Default for CommunicationConfig {
137    fn default() -> Self {
138        Self {
139            enable_compression: true,
140            compression_ratio: 0.1,
141            quantization_bits: 8,
142            enable_sparsification: true,
143            sparsity_threshold: 0.01,
144            protocol: CommunicationProtocol::Synchronous,
145            batch_communication: true,
146            timeout_seconds: 300,
147        }
148    }
149}
150
151/// Communication protocols
152#[derive(Debug, Clone, Serialize, Deserialize)]
153pub enum CommunicationProtocol {
154    /// Synchronous communication
155    Synchronous,
156    /// Asynchronous communication
157    Asynchronous,
158    /// Semi-synchronous with staleness bounds
159    SemiSynchronous { staleness_bound: usize },
160    /// Peer-to-peer communication
161    PeerToPeer,
162}
163
164/// Security configuration for secure computation
165#[derive(Debug, Clone, Serialize, Deserialize)]
166pub struct SecurityConfig {
167    /// Enable homomorphic encryption
168    pub enable_homomorphic_encryption: bool,
169    /// Encryption scheme
170    pub encryption_scheme: EncryptionScheme,
171    /// Enable secure multi-party computation
172    pub enable_secure_mpc: bool,
173    /// Verification mechanisms
174    pub verification_mechanisms: Vec<VerificationMechanism>,
175    /// Certificate management
176    pub certificate_config: CertificateConfig,
177    /// Authentication settings
178    pub authentication_config: AuthenticationConfig,
179}
180
181impl Default for SecurityConfig {
182    fn default() -> Self {
183        Self {
184            enable_homomorphic_encryption: false,
185            encryption_scheme: EncryptionScheme::CKKS,
186            enable_secure_mpc: false,
187            verification_mechanisms: vec![VerificationMechanism::DigitalSignature],
188            certificate_config: CertificateConfig::default(),
189            authentication_config: AuthenticationConfig::default(),
190        }
191    }
192}
193
194/// Homomorphic encryption schemes
195#[derive(Debug, Clone, Serialize, Deserialize)]
196pub enum EncryptionScheme {
197    /// CKKS scheme for approximate arithmetic
198    CKKS,
199    /// BFV scheme for exact arithmetic
200    BFV,
201    /// SEAL implementation
202    SEAL,
203    /// HElib implementation
204    HElib,
205}
206
207/// Verification mechanisms
208#[derive(Debug, Clone, Serialize, Deserialize)]
209pub enum VerificationMechanism {
210    /// Digital signatures
211    DigitalSignature,
212    /// Zero-knowledge proofs
213    ZeroKnowledgeProof,
214    /// Commitment schemes
215    CommitmentScheme,
216    /// Hash-based verification
217    HashVerification,
218}
219
220/// Certificate management configuration
221#[derive(Debug, Clone, Serialize, Deserialize)]
222pub struct CertificateConfig {
223    /// Certificate authority endpoint
224    pub ca_endpoint: String,
225    /// Certificate validity period (days)
226    pub validity_days: u32,
227    /// Key length
228    pub key_length: u32,
229    /// Certificate chain validation
230    pub validate_chain: bool,
231}
232
233impl Default for CertificateConfig {
234    fn default() -> Self {
235        Self {
236            ca_endpoint: "https://ca.example.com".to_string(),
237            validity_days: 365,
238            key_length: 2048,
239            validate_chain: true,
240        }
241    }
242}
243
244/// Authentication configuration
245#[derive(Debug, Clone, Serialize, Deserialize)]
246pub struct AuthenticationConfig {
247    /// Authentication method
248    pub method: AuthenticationMethod,
249    /// Token expiry time (hours)
250    pub token_expiry_hours: u32,
251    /// Enable multi-factor authentication
252    pub enable_mfa: bool,
253    /// Identity provider endpoint
254    pub identity_provider: String,
255}
256
257impl Default for AuthenticationConfig {
258    fn default() -> Self {
259        Self {
260            method: AuthenticationMethod::OAuth2,
261            token_expiry_hours: 24,
262            enable_mfa: false,
263            identity_provider: "https://idp.example.com".to_string(),
264        }
265    }
266}
267
268/// Authentication methods
269#[derive(Debug, Clone, Serialize, Deserialize)]
270pub enum AuthenticationMethod {
271    /// OAuth 2.0
272    OAuth2,
273    /// JSON Web Tokens
274    JWT,
275    /// SAML
276    SAML,
277    /// Mutual TLS
278    MTLS,
279    /// API Keys
280    ApiKey,
281}
282
283/// Personalization configuration
284#[derive(Debug, Clone, Serialize, Deserialize)]
285pub struct PersonalizationConfig {
286    /// Enable personalized models
287    pub enable_personalization: bool,
288    /// Personalization strategy
289    pub strategy: PersonalizationStrategy,
290    /// Local adaptation weight
291    pub local_adaptation_weight: f64,
292    /// Global model weight
293    pub global_model_weight: f64,
294    /// Personalization layers
295    pub personalization_layers: Vec<String>,
296    /// Meta-learning configuration
297    pub meta_learning_config: MetaLearningConfig,
298}
299
300impl Default for PersonalizationConfig {
301    fn default() -> Self {
302        Self {
303            enable_personalization: true,
304            strategy: PersonalizationStrategy::LocalAdaptation,
305            local_adaptation_weight: 0.3,
306            global_model_weight: 0.7,
307            personalization_layers: vec!["embedding".to_string(), "output".to_string()],
308            meta_learning_config: MetaLearningConfig::default(),
309        }
310    }
311}
312
313/// Personalization strategies
314#[derive(Debug, Clone, Serialize, Deserialize)]
315pub enum PersonalizationStrategy {
316    /// Local adaptation of global model
317    LocalAdaptation,
318    /// Multi-task learning
319    MultiTaskLearning,
320    /// Meta-learning approach
321    MetaLearning,
322    /// Mixture of experts
323    MixtureOfExperts,
324    /// Personalized layers
325    PersonalizedLayers,
326}
327
328/// Meta-learning configuration
329#[derive(Debug, Clone, Serialize, Deserialize)]
330pub struct MetaLearningConfig {
331    /// Meta-learning algorithm
332    pub algorithm: MetaLearningAlgorithm,
333    /// Inner learning rate
334    pub inner_learning_rate: f64,
335    /// Outer learning rate
336    pub outer_learning_rate: f64,
337    /// Number of inner steps
338    pub inner_steps: usize,
339    /// Support set size
340    pub support_set_size: usize,
341    /// Query set size
342    pub query_set_size: usize,
343}
344
345impl Default for MetaLearningConfig {
346    fn default() -> Self {
347        Self {
348            algorithm: MetaLearningAlgorithm::MAML,
349            inner_learning_rate: 0.01,
350            outer_learning_rate: 0.001,
351            inner_steps: 5,
352            support_set_size: 10,
353            query_set_size: 5,
354        }
355    }
356}
357
358/// Meta-learning algorithms
359#[derive(Debug, Clone, Serialize, Deserialize)]
360pub enum MetaLearningAlgorithm {
361    /// Model-Agnostic Meta-Learning
362    MAML,
363    /// Reptile algorithm
364    Reptile,
365    /// Prototypical networks
366    PrototypicalNetworks,
367    /// Matching networks
368    MatchingNetworks,
369    /// Memory-augmented neural networks
370    MANN,
371}
372
373/// Training configuration for federated learning
374#[derive(Debug, Clone, Serialize, Deserialize)]
375pub struct TrainingConfig {
376    /// Convergence criteria
377    pub convergence_threshold: f64,
378    /// Maximum global iterations
379    pub max_global_iterations: usize,
380    /// Patience for early stopping
381    pub patience: usize,
382    /// Learning rate decay
383    pub learning_rate_decay: f64,
384    /// Minimum learning rate
385    pub min_learning_rate: f64,
386}
387
388impl Default for TrainingConfig {
389    fn default() -> Self {
390        Self {
391            convergence_threshold: 1e-6,
392            max_global_iterations: 1000,
393            patience: 10,
394            learning_rate_decay: 0.95,
395            min_learning_rate: 1e-6,
396        }
397    }
398}