use crate::common::IntegrateFloat;
use crate::error::{IntegrateError, IntegrateResult};
use scirs2_core::ndarray::Array1;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct NodeId(pub u64);
impl NodeId {
pub fn new(id: u64) -> Self {
Self(id)
}
pub fn value(&self) -> u64 {
self.0
}
}
impl std::fmt::Display for NodeId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Node({})", self.0)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct JobId(pub u64);
impl JobId {
pub fn new(id: u64) -> Self {
Self(id)
}
pub fn value(&self) -> u64 {
self.0
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ChunkId(pub u64);
impl ChunkId {
pub fn new(id: u64) -> Self {
Self(id)
}
pub fn value(&self) -> u64 {
self.0
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NodeStatus {
Available,
Busy,
Failed,
Maintenance,
Initializing,
ShuttingDown,
}
#[derive(Debug, Clone)]
pub struct NodeCapabilities {
pub cpu_cores: usize,
pub memory_bytes: usize,
pub has_gpu: bool,
pub gpu_memory_bytes: Option<usize>,
pub network_bandwidth: usize,
pub latency_us: u64,
pub supported_precisions: Vec<FloatPrecision>,
pub simd_capabilities: SimdCapability,
}
impl Default for NodeCapabilities {
fn default() -> Self {
Self {
cpu_cores: 1,
memory_bytes: 1024 * 1024 * 1024, has_gpu: false,
gpu_memory_bytes: None,
network_bandwidth: 100 * 1024 * 1024, latency_us: 1000, supported_precisions: vec![FloatPrecision::F32, FloatPrecision::F64],
simd_capabilities: SimdCapability::default(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FloatPrecision {
F16,
F32,
F64,
}
#[derive(Debug, Clone, Default)]
pub struct SimdCapability {
pub has_sse: bool,
pub has_sse2: bool,
pub has_avx: bool,
pub has_avx2: bool,
pub has_avx512: bool,
pub has_neon: bool,
}
#[derive(Debug, Clone)]
pub struct NodeInfo {
pub id: NodeId,
pub address: SocketAddr,
pub status: NodeStatus,
pub capabilities: NodeCapabilities,
pub last_heartbeat: Instant,
pub jobs_completed: usize,
pub average_job_duration: Duration,
}
impl NodeInfo {
pub fn new(id: NodeId, address: SocketAddr) -> Self {
Self {
id,
address,
status: NodeStatus::Initializing,
capabilities: NodeCapabilities::default(),
last_heartbeat: Instant::now(),
jobs_completed: 0,
average_job_duration: Duration::ZERO,
}
}
pub fn is_healthy(&self, timeout: Duration) -> bool {
self.last_heartbeat.elapsed() < timeout
&& self.status != NodeStatus::Failed
&& self.status != NodeStatus::ShuttingDown
}
pub fn processing_score(&self) -> f64 {
let base_score = self.capabilities.cpu_cores as f64;
let gpu_bonus = if self.capabilities.has_gpu { 10.0 } else { 0.0 };
let latency_penalty = (self.capabilities.latency_us as f64 / 1000.0).min(5.0);
base_score + gpu_bonus - latency_penalty
}
}
#[derive(Debug, Clone)]
pub struct WorkChunk<F: IntegrateFloat> {
pub id: ChunkId,
pub job_id: JobId,
pub time_interval: (F, F),
pub initial_state: Array1<F>,
pub boundary_conditions: BoundaryConditions<F>,
pub priority: u32,
pub estimated_cost: f64,
pub retry_count: u32,
pub max_retries: u32,
}
impl<F: IntegrateFloat> WorkChunk<F> {
pub fn new(
id: ChunkId,
job_id: JobId,
time_interval: (F, F),
initial_state: Array1<F>,
) -> Self {
let estimated_cost = Self::estimate_cost(&time_interval, initial_state.len());
Self {
id,
job_id,
time_interval,
initial_state,
boundary_conditions: BoundaryConditions::default(),
priority: 0,
estimated_cost,
retry_count: 0,
max_retries: 3,
}
}
fn estimate_cost(time_interval: &(F, F), state_size: usize) -> f64 {
let dt = (time_interval.1 - time_interval.0).to_f64().unwrap_or(1.0);
dt * state_size as f64
}
pub fn can_retry(&self) -> bool {
self.retry_count < self.max_retries
}
pub fn increment_retry(&mut self) {
self.retry_count += 1;
}
}
#[derive(Debug, Clone)]
pub struct BoundaryConditions<F: IntegrateFloat> {
pub left_boundary: Option<BoundaryData<F>>,
pub right_boundary: Option<BoundaryData<F>>,
pub ghost_cells: Vec<F>,
pub coupling_data: HashMap<String, Array1<F>>,
}
impl<F: IntegrateFloat> Default for BoundaryConditions<F> {
fn default() -> Self {
Self {
left_boundary: None,
right_boundary: None,
ghost_cells: Vec::new(),
coupling_data: HashMap::new(),
}
}
}
#[derive(Debug, Clone)]
pub struct BoundaryData<F: IntegrateFloat> {
pub time: F,
pub state: Array1<F>,
pub derivative: Option<Array1<F>>,
pub source_chunk: ChunkId,
}
#[derive(Debug, Clone)]
pub struct ChunkResult<F: IntegrateFloat> {
pub chunk_id: ChunkId,
pub node_id: NodeId,
pub time_points: Vec<F>,
pub states: Vec<Array1<F>>,
pub final_state: Array1<F>,
pub final_derivative: Option<Array1<F>>,
pub error_estimate: F,
pub processing_time: Duration,
pub memory_used: usize,
pub status: ChunkResultStatus,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ChunkResultStatus {
Success,
Failed,
NeedsRefinement,
Cancelled,
}
#[derive(Debug, Clone)]
pub struct DistributedConfig<F: IntegrateFloat> {
pub min_chunk_size: F,
pub max_chunk_size: F,
pub chunks_per_node: usize,
pub tolerance: F,
pub max_iterations: usize,
pub checkpointing_enabled: bool,
pub checkpoint_interval: usize,
pub communication_timeout: Duration,
pub heartbeat_interval: Duration,
pub max_retries: u32,
pub load_balancing: LoadBalancingStrategy,
pub fault_tolerance: FaultToleranceMode,
}
impl<F: IntegrateFloat> Default for DistributedConfig<F> {
fn default() -> Self {
Self {
min_chunk_size: F::from(0.001).unwrap_or(F::epsilon()),
max_chunk_size: F::from(1.0).unwrap_or(F::one()),
chunks_per_node: 4,
tolerance: F::from(1e-6).unwrap_or(F::epsilon()),
max_iterations: 1000,
checkpointing_enabled: true,
checkpoint_interval: 10,
communication_timeout: Duration::from_secs(30),
heartbeat_interval: Duration::from_secs(5),
max_retries: 3,
load_balancing: LoadBalancingStrategy::Adaptive,
fault_tolerance: FaultToleranceMode::Standard,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LoadBalancingStrategy {
RoundRobin,
CapabilityBased,
WorkStealing,
Adaptive,
LocalityAware,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FaultToleranceMode {
None,
Standard,
HighAvailability,
CheckpointRecovery,
}
#[derive(Debug, Clone)]
pub enum DistributedMessage<F: IntegrateFloat> {
Heartbeat {
node_id: NodeId,
status: NodeStatus,
timestamp: u64,
},
WorkAssignment {
chunk: WorkChunk<F>,
deadline: Option<Duration>,
},
WorkResult { result: ChunkResult<F> },
BoundaryExchange {
source_chunk: ChunkId,
target_chunk: ChunkId,
boundary_data: BoundaryData<F>,
},
CheckpointRequest { job_id: JobId, checkpoint_id: u64 },
CheckpointData {
job_id: JobId,
checkpoint_id: u64,
node_id: NodeId,
data: Vec<u8>,
},
NodeRegister {
node_id: NodeId,
address: SocketAddr,
capabilities: NodeCapabilities,
},
NodeDeregister { node_id: NodeId, reason: String },
JobCancel { job_id: JobId, reason: String },
SyncBarrier { barrier_id: u64, node_id: NodeId },
Ack { message_id: u64, status: AckStatus },
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AckStatus {
Ok,
Error,
Unknown,
}
#[derive(Debug, Clone, Default)]
pub struct DistributedMetrics {
pub chunks_processed: usize,
pub chunks_failed: usize,
pub chunks_retried: usize,
pub total_processing_time: Duration,
pub total_communication_time: Duration,
pub average_chunk_time: Duration,
pub load_balance_efficiency: f64,
pub bytes_sent: usize,
pub bytes_received: usize,
pub checkpoints_created: usize,
pub recoveries: usize,
}
impl DistributedMetrics {
pub fn update_load_balance(&mut self, node_loads: &[f64]) {
if node_loads.is_empty() {
self.load_balance_efficiency = 1.0;
return;
}
let mean_load: f64 = node_loads.iter().sum::<f64>() / node_loads.len() as f64;
if mean_load <= 0.0 {
self.load_balance_efficiency = 1.0;
return;
}
let variance: f64 = node_loads
.iter()
.map(|&load| (load - mean_load).powi(2))
.sum::<f64>()
/ node_loads.len() as f64;
let cv = variance.sqrt() / mean_load; self.load_balance_efficiency = (1.0 - cv.min(1.0)).max(0.0);
}
}
#[derive(Debug, Clone)]
pub enum DistributedError {
CommunicationError(String),
NodeTimeout(NodeId),
NodeFailure(NodeId, String),
ChunkError(ChunkId, String),
SyncError(String),
CheckpointError(String),
ConfigError(String),
ResourceExhausted(String),
}
impl std::fmt::Display for DistributedError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::CommunicationError(msg) => write!(f, "Communication error: {}", msg),
Self::NodeTimeout(id) => write!(f, "Node {} timed out", id),
Self::NodeFailure(id, msg) => write!(f, "Node {} failed: {}", id, msg),
Self::ChunkError(id, msg) => write!(f, "Chunk {:?} error: {}", id, msg),
Self::SyncError(msg) => write!(f, "Synchronization error: {}", msg),
Self::CheckpointError(msg) => write!(f, "Checkpoint error: {}", msg),
Self::ConfigError(msg) => write!(f, "Configuration error: {}", msg),
Self::ResourceExhausted(msg) => write!(f, "Resource exhausted: {}", msg),
}
}
}
impl std::error::Error for DistributedError {}
impl From<DistributedError> for IntegrateError {
fn from(err: DistributedError) -> Self {
IntegrateError::ComputationError(err.to_string())
}
}
pub type DistributedResult<T> = std::result::Result<T, DistributedError>;
#[cfg(test)]
mod tests {
use super::*;
use std::net::{IpAddr, Ipv4Addr};
#[test]
fn test_node_id_display() {
let id = NodeId::new(42);
assert_eq!(format!("{}", id), "Node(42)");
}
#[test]
fn test_node_info_health_check() {
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080);
let mut node = NodeInfo::new(NodeId::new(1), addr);
node.status = NodeStatus::Available;
assert!(node.is_healthy(Duration::from_secs(60)));
node.last_heartbeat = Instant::now() - Duration::from_secs(120);
assert!(!node.is_healthy(Duration::from_secs(60)));
}
#[test]
fn test_work_chunk_retry() {
let chunk: WorkChunk<f64> =
WorkChunk::new(ChunkId::new(1), JobId::new(1), (0.0, 1.0), Array1::zeros(3));
assert!(chunk.can_retry());
let mut chunk = chunk;
for _ in 0..3 {
chunk.increment_retry();
}
assert!(!chunk.can_retry());
}
#[test]
fn test_distributed_metrics_load_balance() {
let mut metrics = DistributedMetrics::default();
metrics.update_load_balance(&[1.0, 1.0, 1.0, 1.0]);
assert!((metrics.load_balance_efficiency - 1.0).abs() < 0.01);
metrics.update_load_balance(&[0.1, 0.1, 0.1, 3.7]);
assert!(metrics.load_balance_efficiency < 0.5);
}
}