use chrono::{DateTime, Utc};
use scirs2_core::ndarray_ext::Array2;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Participant {
pub participant_id: Uuid,
pub name: String,
pub endpoint: String,
pub public_key: String,
pub data_stats: DataStatistics,
pub capabilities: ParticipantCapabilities,
pub trust_score: f64,
pub last_communication: DateTime<Utc>,
pub status: ParticipantStatus,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DataStatistics {
pub num_samples: usize,
pub num_features: usize,
pub distribution_summary: HashMap<String, f64>,
pub quality_metrics: HashMap<String, f64>,
pub privacy_budget_used: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ParticipantCapabilities {
pub compute_power: ComputePower,
pub available_memory_gb: f64,
pub network_bandwidth_mbps: f64,
pub supported_algorithms: Vec<String>,
pub hardware_accelerators: Vec<HardwareAccelerator>,
pub security_features: Vec<SecurityFeature>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ComputePower {
Low,
Medium,
High,
VeryHigh,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum HardwareAccelerator {
GPU,
TPU,
NCS,
NPU,
FPGA,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SecurityFeature {
TEE,
HSM,
SecureEnclave,
IntelSGX,
ARMTrustZone,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum ParticipantStatus {
Active,
Inactive,
Disconnected,
Suspended,
Excluded,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FederatedRound {
pub round_number: usize,
pub start_time: DateTime<Utc>,
pub end_time: Option<DateTime<Utc>>,
pub participants: Vec<Uuid>,
pub global_parameters: HashMap<String, Array2<f32>>,
pub aggregated_updates: HashMap<String, Array2<f32>>,
pub metrics: RoundMetrics,
pub status: RoundStatus,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RoundMetrics {
pub num_participants: usize,
pub total_samples: usize,
pub avg_local_loss: f64,
pub global_accuracy: f64,
pub communication_overhead: u64,
pub duration_seconds: f64,
pub privacy_budget_consumed: f64,
pub convergence_metrics: ConvergenceMetrics,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConvergenceMetrics {
pub parameter_change: f64,
pub loss_improvement: f64,
pub gradient_norm: f64,
pub convergence_status: ConvergenceStatus,
pub estimated_rounds_to_convergence: Option<usize>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ConvergenceStatus {
Progressing,
Converged,
Diverging,
Stagnated,
Oscillating,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum RoundStatus {
Initializing,
Training,
Aggregating,
Completed,
Failed,
Aborted,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LocalTrainingStats {
pub epochs_completed: usize,
pub training_time_seconds: f64,
pub local_loss: f64,
pub local_accuracy: f64,
pub samples_used: usize,
pub gradient_norm: f64,
pub resource_utilization: ResourceUtilization,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResourceUtilization {
pub cpu_usage_percent: f64,
pub memory_usage_gb: f64,
pub gpu_usage_percent: Option<f64>,
pub network_bandwidth_used_mbps: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LocalUpdate {
pub participant_id: Uuid,
pub round_number: usize,
pub parameter_updates: HashMap<String, Array2<f32>>,
pub num_samples: usize,
pub training_stats: LocalTrainingStats,
pub timestamp: DateTime<Utc>,
pub data_selection: DataSelectionStrategy,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum DataSelectionStrategy {
AllData,
RandomSampling { sample_rate: f64 },
StratifiedSampling {
strata_proportions: HashMap<String, f64>,
},
ActiveLearning { uncertainty_threshold: f64 },
ImportanceSampling { importance_weights: Vec<f64> },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GlobalModelState {
pub parameters: HashMap<String, Array2<f32>>,
pub global_round: usize,
pub model_version: String,
pub last_updated: DateTime<Utc>,
pub performance_metrics: HashMap<String, f64>,
pub participant_contributions: HashMap<Uuid, f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LocalModelState {
pub participant_id: Uuid,
pub parameters: HashMap<String, Array2<f32>>,
pub personalized_parameters: HashMap<String, Array2<f32>>,
pub synchronized_round: usize,
pub local_adaptation_steps: usize,
pub last_sync_time: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PrivacyMetrics {
pub total_budget_spent: f64,
pub participant_budget_usage: HashMap<Uuid, f64>,
pub dp_guarantees: HashMap<String, f64>,
pub privacy_violations: Vec<PrivacyViolation>,
pub privacy_risk_score: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PrivacyViolation {
pub violation_type: PrivacyViolationType,
pub participant_id: Option<Uuid>,
pub timestamp: DateTime<Utc>,
pub severity: ViolationSeverity,
pub description: String,
pub mitigation_action: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum PrivacyViolationType {
BudgetExceeded,
InformationLeakage,
ModelInversion,
MembershipInference,
DataReconstruction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ViolationSeverity {
Low,
Medium,
High,
Critical,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FederationStats {
pub total_participants: usize,
pub active_participants: usize,
pub rounds_completed: usize,
pub avg_round_duration_seconds: f64,
pub total_communication_overhead_bytes: u64,
pub convergence_status: ConvergenceStatus,
pub privacy_metrics: PrivacyMetrics,
pub system_uptime_seconds: u64,
pub last_activity: DateTime<Utc>,
}