oxirs_embed/federated_learning/
participant.rs

1//! Participant management for federated learning
2//!
3//! This module handles participant registration, capability assessment,
4//! trust scoring, and status management in federated learning systems.
5
6use chrono::{DateTime, Utc};
7use scirs2_core::ndarray_ext::Array2;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use uuid::Uuid;
11
12/// Federated participant information
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct Participant {
15    /// Participant ID
16    pub participant_id: Uuid,
17    /// Participant name
18    pub name: String,
19    /// Network endpoint
20    pub endpoint: String,
21    /// Public key for verification
22    pub public_key: String,
23    /// Data statistics
24    pub data_stats: DataStatistics,
25    /// Capability information
26    pub capabilities: ParticipantCapabilities,
27    /// Trust score
28    pub trust_score: f64,
29    /// Last communication time
30    pub last_communication: DateTime<Utc>,
31    /// Status
32    pub status: ParticipantStatus,
33}
34
35/// Data statistics for a participant
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct DataStatistics {
38    /// Number of samples
39    pub num_samples: usize,
40    /// Number of features
41    pub num_features: usize,
42    /// Data distribution summary
43    pub distribution_summary: HashMap<String, f64>,
44    /// Data quality metrics
45    pub quality_metrics: HashMap<String, f64>,
46    /// Privacy budget used
47    pub privacy_budget_used: f64,
48}
49
50/// Participant capabilities
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct ParticipantCapabilities {
53    /// Computational power
54    pub compute_power: ComputePower,
55    /// Available memory (GB)
56    pub available_memory_gb: f64,
57    /// Network bandwidth (Mbps)
58    pub network_bandwidth_mbps: f64,
59    /// Supported algorithms
60    pub supported_algorithms: Vec<String>,
61    /// Hardware accelerators
62    pub hardware_accelerators: Vec<HardwareAccelerator>,
63    /// Security features
64    pub security_features: Vec<SecurityFeature>,
65}
66
67/// Compute power levels
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub enum ComputePower {
70    /// Low computational resources
71    Low,
72    /// Medium computational resources
73    Medium,
74    /// High computational resources
75    High,
76    /// Very high computational resources
77    VeryHigh,
78}
79
80/// Hardware accelerators
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub enum HardwareAccelerator {
83    /// NVIDIA GPU
84    GPU,
85    /// Google TPU
86    TPU,
87    /// Intel Neural Compute Stick
88    NCS,
89    /// ARM Neural Processing Unit
90    NPU,
91    /// FPGA acceleration
92    FPGA,
93}
94
95/// Security features
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub enum SecurityFeature {
98    /// Trusted Execution Environment
99    TEE,
100    /// Hardware Security Module
101    HSM,
102    /// Secure Enclave
103    SecureEnclave,
104    /// Intel SGX
105    IntelSGX,
106    /// ARM TrustZone
107    ARMTrustZone,
108}
109
110/// Participant status
111#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
112pub enum ParticipantStatus {
113    /// Active and available
114    Active,
115    /// Temporarily inactive
116    Inactive,
117    /// Disconnected
118    Disconnected,
119    /// Suspended due to issues
120    Suspended,
121    /// Excluded from federation
122    Excluded,
123}
124
125/// Federated learning round information
126#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct FederatedRound {
128    /// Round number
129    pub round_number: usize,
130    /// Round start time
131    pub start_time: DateTime<Utc>,
132    /// Round end time
133    pub end_time: Option<DateTime<Utc>>,
134    /// Participating clients
135    pub participants: Vec<Uuid>,
136    /// Global model parameters
137    pub global_parameters: HashMap<String, Array2<f32>>,
138    /// Aggregated updates
139    pub aggregated_updates: HashMap<String, Array2<f32>>,
140    /// Round metrics
141    pub metrics: RoundMetrics,
142    /// Round status
143    pub status: RoundStatus,
144}
145
146/// Metrics for a federated learning round
147#[derive(Debug, Clone, Serialize, Deserialize)]
148pub struct RoundMetrics {
149    /// Number of participating clients
150    pub num_participants: usize,
151    /// Total training samples
152    pub total_samples: usize,
153    /// Average local loss
154    pub avg_local_loss: f64,
155    /// Global model accuracy
156    pub global_accuracy: f64,
157    /// Communication overhead (bytes)
158    pub communication_overhead: u64,
159    /// Round duration (seconds)
160    pub duration_seconds: f64,
161    /// Privacy budget consumed
162    pub privacy_budget_consumed: f64,
163    /// Convergence metrics
164    pub convergence_metrics: ConvergenceMetrics,
165}
166
167/// Convergence tracking metrics
168#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct ConvergenceMetrics {
170    /// Parameter change magnitude
171    pub parameter_change: f64,
172    /// Loss improvement
173    pub loss_improvement: f64,
174    /// Gradient norm
175    pub gradient_norm: f64,
176    /// Convergence status
177    pub convergence_status: ConvergenceStatus,
178    /// Estimated rounds to convergence
179    pub estimated_rounds_to_convergence: Option<usize>,
180}
181
182/// Convergence status
183#[derive(Debug, Clone, Serialize, Deserialize)]
184pub enum ConvergenceStatus {
185    /// Training is progressing
186    Progressing,
187    /// Converged to solution
188    Converged,
189    /// Diverging
190    Diverging,
191    /// Stagnated
192    Stagnated,
193    /// Oscillating
194    Oscillating,
195}
196
197/// Round status
198#[derive(Debug, Clone, Serialize, Deserialize)]
199pub enum RoundStatus {
200    /// Round is being initialized
201    Initializing,
202    /// Training in progress
203    Training,
204    /// Aggregating updates
205    Aggregating,
206    /// Round completed successfully
207    Completed,
208    /// Round failed
209    Failed,
210    /// Round was aborted
211    Aborted,
212}
213
214/// Local training statistics for a participant
215#[derive(Debug, Clone, Serialize, Deserialize)]
216pub struct LocalTrainingStats {
217    /// Local training epochs completed
218    pub epochs_completed: usize,
219    /// Local training time (seconds)
220    pub training_time_seconds: f64,
221    /// Local loss values
222    pub local_loss: f64,
223    /// Local accuracy
224    pub local_accuracy: f64,
225    /// Number of samples used
226    pub samples_used: usize,
227    /// Gradient norm
228    pub gradient_norm: f64,
229    /// Resource utilization
230    pub resource_utilization: ResourceUtilization,
231}
232
233/// Resource utilization metrics
234#[derive(Debug, Clone, Serialize, Deserialize)]
235pub struct ResourceUtilization {
236    /// CPU usage percentage
237    pub cpu_usage_percent: f64,
238    /// Memory usage (GB)
239    pub memory_usage_gb: f64,
240    /// GPU usage percentage (if available)
241    pub gpu_usage_percent: Option<f64>,
242    /// Network bandwidth used (Mbps)
243    pub network_bandwidth_used_mbps: f64,
244}
245
246/// Local model update from a participant
247#[derive(Debug, Clone, Serialize, Deserialize)]
248pub struct LocalUpdate {
249    /// Participant ID
250    pub participant_id: Uuid,
251    /// Round number
252    pub round_number: usize,
253    /// Model parameter updates
254    pub parameter_updates: HashMap<String, Array2<f32>>,
255    /// Number of local samples
256    pub num_samples: usize,
257    /// Local training statistics
258    pub training_stats: LocalTrainingStats,
259    /// Update timestamp
260    pub timestamp: DateTime<Utc>,
261    /// Data selection strategy used
262    pub data_selection: DataSelectionStrategy,
263}
264
265/// Data selection strategies for local training
266#[derive(Debug, Clone, Serialize, Deserialize)]
267pub enum DataSelectionStrategy {
268    /// Use all available data
269    AllData,
270    /// Random sampling
271    RandomSampling { sample_rate: f64 },
272    /// Stratified sampling
273    StratifiedSampling {
274        strata_proportions: HashMap<String, f64>,
275    },
276    /// Active learning selection
277    ActiveLearning { uncertainty_threshold: f64 },
278    /// Importance sampling
279    ImportanceSampling { importance_weights: Vec<f64> },
280}
281
282/// Global model state in federated learning
283#[derive(Debug, Clone, Serialize, Deserialize)]
284pub struct GlobalModelState {
285    /// Model parameters
286    pub parameters: HashMap<String, Array2<f32>>,
287    /// Global training round
288    pub global_round: usize,
289    /// Model version
290    pub model_version: String,
291    /// Last update timestamp
292    pub last_updated: DateTime<Utc>,
293    /// Model performance metrics
294    pub performance_metrics: HashMap<String, f64>,
295    /// Participant contributions
296    pub participant_contributions: HashMap<Uuid, f64>,
297}
298
299/// Local model state for a participant
300#[derive(Debug, Clone, Serialize, Deserialize)]
301pub struct LocalModelState {
302    /// Participant ID
303    pub participant_id: Uuid,
304    /// Local model parameters
305    pub parameters: HashMap<String, Array2<f32>>,
306    /// Personalized layers
307    pub personalized_parameters: HashMap<String, Array2<f32>>,
308    /// Global round synchronized to
309    pub synchronized_round: usize,
310    /// Local adaptation steps performed
311    pub local_adaptation_steps: usize,
312    /// Last synchronization time
313    pub last_sync_time: DateTime<Utc>,
314}
315
316/// Privacy metrics for federated learning
317#[derive(Debug, Clone, Serialize, Deserialize)]
318pub struct PrivacyMetrics {
319    /// Total privacy budget spent
320    pub total_budget_spent: f64,
321    /// Privacy budget per participant
322    pub participant_budget_usage: HashMap<Uuid, f64>,
323    /// Differential privacy guarantees
324    pub dp_guarantees: HashMap<String, f64>,
325    /// Privacy violations detected
326    pub privacy_violations: Vec<PrivacyViolation>,
327    /// Privacy risk assessment
328    pub privacy_risk_score: f64,
329}
330
331/// Privacy violation record
332#[derive(Debug, Clone, Serialize, Deserialize)]
333pub struct PrivacyViolation {
334    /// Violation type
335    pub violation_type: PrivacyViolationType,
336    /// Participant involved
337    pub participant_id: Option<Uuid>,
338    /// Violation timestamp
339    pub timestamp: DateTime<Utc>,
340    /// Severity level
341    pub severity: ViolationSeverity,
342    /// Description
343    pub description: String,
344    /// Mitigation action taken
345    pub mitigation_action: Option<String>,
346}
347
348/// Privacy violation types
349#[derive(Debug, Clone, Serialize, Deserialize)]
350pub enum PrivacyViolationType {
351    /// Budget exceeded
352    BudgetExceeded,
353    /// Information leakage detected
354    InformationLeakage,
355    /// Model inversion attack
356    ModelInversion,
357    /// Membership inference attack
358    MembershipInference,
359    /// Data reconstruction attack
360    DataReconstruction,
361}
362
363/// Violation severity levels
364#[derive(Debug, Clone, Serialize, Deserialize)]
365pub enum ViolationSeverity {
366    /// Low severity
367    Low,
368    /// Medium severity
369    Medium,
370    /// High severity
371    High,
372    /// Critical severity
373    Critical,
374}
375
376/// Federation statistics for monitoring
377#[derive(Debug, Clone, Serialize, Deserialize)]
378pub struct FederationStats {
379    /// Total number of participants
380    pub total_participants: usize,
381    /// Active participants
382    pub active_participants: usize,
383    /// Total rounds completed
384    pub rounds_completed: usize,
385    /// Average round duration
386    pub avg_round_duration_seconds: f64,
387    /// Total communication overhead
388    pub total_communication_overhead_bytes: u64,
389    /// Model convergence status
390    pub convergence_status: ConvergenceStatus,
391    /// Privacy metrics
392    pub privacy_metrics: PrivacyMetrics,
393    /// System uptime
394    pub system_uptime_seconds: u64,
395    /// Last activity timestamp
396    pub last_activity: DateTime<Utc>,
397}