1use crate::error::{SpatialError, SpatialResult};
60use ndarray::{s, Array1, Array2, ArrayView1, ArrayView2};
61use std::collections::{BTreeMap, HashMap, VecDeque};
62use std::sync::Arc;
63use std::time::{Duration, Instant};
64use tokio::sync::{mpsc, RwLock as TokioRwLock};
65
66#[derive(Debug, Clone)]
68pub struct NodeConfig {
69 pub node_count: usize,
71 pub fault_tolerance: bool,
73 pub load_balancing: bool,
75 pub compression: bool,
77 pub communication_timeout_ms: u64,
79 pub heartbeat_interval_ms: u64,
81 pub max_retries: usize,
83 pub replication_factor: usize,
85}
86
87impl Default for NodeConfig {
88 fn default() -> Self {
89 Self::new()
90 }
91}
92
93impl NodeConfig {
94 pub fn new() -> Self {
96 Self {
97 node_count: 1,
98 fault_tolerance: false,
99 load_balancing: false,
100 compression: false,
101 communication_timeout_ms: 5000,
102 heartbeat_interval_ms: 1000,
103 max_retries: 3,
104 replication_factor: 1,
105 }
106 }
107
108 pub fn with_node_count(mut self, count: usize) -> Self {
110 self.node_count = count;
111 self
112 }
113
114 pub fn with_fault_tolerance(mut self, enabled: bool) -> Self {
116 self.fault_tolerance = enabled;
117 if enabled && self.replication_factor < 2 {
118 self.replication_factor = 2;
119 }
120 self
121 }
122
123 pub fn with_load_balancing(mut self, enabled: bool) -> Self {
125 self.load_balancing = enabled;
126 self
127 }
128
129 pub fn with_compression(mut self, enabled: bool) -> Self {
131 self.compression = enabled;
132 self
133 }
134}
135
136#[derive(Debug)]
138pub struct DistributedSpatialCluster {
139 config: NodeConfig,
141 nodes: Vec<Arc<TokioRwLock<NodeInstance>>>,
143 #[allow(dead_code)]
145 master_node_id: usize,
146 partitions: Arc<TokioRwLock<HashMap<usize, DataPartition>>>,
148 load_balancer: Arc<TokioRwLock<LoadBalancer>>,
150 #[allow(dead_code)]
152 fault_detector: Arc<TokioRwLock<FaultDetector>>,
153 communication: Arc<TokioRwLock<CommunicationLayer>>,
155 cluster_state: Arc<TokioRwLock<ClusterState>>,
157}
158
159#[derive(Debug)]
161pub struct NodeInstance {
162 pub node_id: usize,
164 pub status: NodeStatus,
166 pub local_data: Option<Array2<f64>>,
168 pub local_index: Option<DistributedSpatialIndex>,
170 pub load_metrics: LoadMetrics,
172 pub last_heartbeat: Instant,
174 pub assigned_partitions: Vec<usize>,
176}
177
178#[derive(Debug, Clone, PartialEq)]
180pub enum NodeStatus {
181 Active,
182 Inactive,
183 Failed,
184 Recovering,
185 Joining,
186 Leaving,
187}
188
189#[derive(Debug, Clone)]
191pub struct DataPartition {
192 pub partition_id: usize,
194 pub bounds: SpatialBounds,
196 pub data: Array2<f64>,
198 pub primary_node: usize,
200 pub replica_nodes: Vec<usize>,
202 pub size: usize,
204 pub last_modified: Instant,
206}
207
208#[derive(Debug, Clone)]
210pub struct SpatialBounds {
211 pub min_coords: Array1<f64>,
213 pub max_coords: Array1<f64>,
215}
216
217impl SpatialBounds {
218 pub fn contains(&self, point: &ArrayView1<f64>) -> bool {
220 point
221 .iter()
222 .zip(self.min_coords.iter())
223 .zip(self.max_coords.iter())
224 .all(|((&coord, &min_coord), &max_coord)| coord >= min_coord && coord <= max_coord)
225 }
226
227 pub fn volume(&self) -> f64 {
229 self.min_coords
230 .iter()
231 .zip(self.max_coords.iter())
232 .map(|(&min_coord, &max_coord)| max_coord - min_coord)
233 .product()
234 }
235}
236
237#[derive(Debug)]
239pub struct LoadBalancer {
240 #[allow(dead_code)]
242 node_loads: HashMap<usize, LoadMetrics>,
243 #[allow(dead_code)]
245 strategy: LoadBalancingStrategy,
246 #[allow(dead_code)]
248 last_rebalance: Instant,
249 #[allow(dead_code)]
251 load_threshold: f64,
252}
253
254#[derive(Debug, Clone)]
256pub enum LoadBalancingStrategy {
257 RoundRobin,
258 LeastLoaded,
259 ProportionalLoad,
260 AdaptiveLoad,
261}
262
263#[derive(Debug, Clone)]
265pub struct LoadMetrics {
266 pub cpu_utilization: f64,
268 pub memory_utilization: f64,
270 pub network_utilization: f64,
272 pub partition_count: usize,
274 pub operation_count: usize,
276 pub last_update: Instant,
278}
279
280impl LoadMetrics {
281 pub fn load_score(&self) -> f64 {
283 0.4 * self.cpu_utilization
284 + 0.3 * self.memory_utilization
285 + 0.2 * self.network_utilization
286 + 0.1 * (self.partition_count as f64 / 10.0).min(1.0)
287 }
288}
289
290#[derive(Debug)]
292pub struct FaultDetector {
293 #[allow(dead_code)]
295 node_health: HashMap<usize, NodeHealth>,
296 #[allow(dead_code)]
298 failure_threshold: Duration,
299 #[allow(dead_code)]
301 recovery_strategies: HashMap<FailureType, RecoveryStrategy>,
302}
303
304#[derive(Debug, Clone)]
306pub struct NodeHealth {
307 pub last_contact: Instant,
309 pub consecutive_failures: usize,
311 pub response_times: VecDeque<Duration>,
313 pub health_score: f64,
315}
316
317#[derive(Debug, Clone, Hash, PartialEq, Eq)]
319pub enum FailureType {
320 NodeUnresponsive,
321 HighLatency,
322 ResourceExhaustion,
323 PartialFailure,
324 NetworkPartition,
325}
326
327#[derive(Debug, Clone)]
329pub enum RecoveryStrategy {
330 Restart,
331 Migrate,
332 Replicate,
333 Isolate,
334 WaitAndRetry,
335}
336
337#[derive(Debug)]
339pub struct CommunicationLayer {
340 #[allow(dead_code)]
342 channels: HashMap<usize, mpsc::Sender<DistributedMessage>>,
343 #[allow(dead_code)]
345 compression_enabled: bool,
346 stats: CommunicationStats,
348}
349
350#[derive(Debug, Clone)]
352pub struct CommunicationStats {
353 pub messages_sent: u64,
355 pub messages_received: u64,
357 pub bytes_sent: u64,
359 pub bytes_received: u64,
361 pub average_latency_ms: f64,
363}
364
365#[derive(Debug, Clone)]
367pub enum DistributedMessage {
368 Heartbeat {
370 node_id: usize,
371 timestamp: Instant,
372 load_metrics: LoadMetrics,
373 },
374 DataDistribution {
376 partition_id: usize,
377 data: Array2<f64>,
378 bounds: SpatialBounds,
379 },
380 Query {
382 query_id: usize,
383 query_type: QueryType,
384 parameters: QueryParameters,
385 },
386 QueryResponse {
388 query_id: usize,
389 results: QueryResults,
390 node_id: usize,
391 },
392 LoadBalance { rebalance_plan: RebalancePlan },
394 FaultTolerance {
396 failure_type: FailureType,
397 affected_nodes: Vec<usize>,
398 recovery_plan: RecoveryPlan,
399 },
400}
401
402#[derive(Debug, Clone)]
404pub enum QueryType {
405 KNearestNeighbors,
406 RangeSearch,
407 Clustering,
408 DistanceMatrix,
409}
410
411#[derive(Debug, Clone)]
413pub struct QueryParameters {
414 pub query_point: Option<Array1<f64>>,
416 pub radius: Option<f64>,
418 pub k: Option<usize>,
420 pub num_clusters: Option<usize>,
422 pub extra_params: HashMap<String, f64>,
424}
425
426#[derive(Debug, Clone)]
428pub enum QueryResults {
429 NearestNeighbors {
430 indices: Vec<usize>,
431 distances: Vec<f64>,
432 },
433 RangeSearch {
434 indices: Vec<usize>,
435 points: Array2<f64>,
436 },
437 Clustering {
438 centroids: Array2<f64>,
439 assignments: Array1<usize>,
440 },
441 DistanceMatrix {
442 matrix: Array2<f64>,
443 },
444}
445
446#[derive(Debug, Clone)]
448pub struct RebalancePlan {
449 pub migrations: Vec<PartitionMigration>,
451 pub load_improvement: f64,
453 pub migration_cost: f64,
455}
456
457#[derive(Debug, Clone)]
459pub struct PartitionMigration {
460 pub partition_id: usize,
462 pub from_node: usize,
464 pub to_node: usize,
466 pub priority: f64,
468}
469
470#[derive(Debug, Clone)]
472pub struct RecoveryPlan {
473 pub actions: Vec<RecoveryAction>,
475 pub estimated_recovery_time: Duration,
477 pub success_probability: f64,
479}
480
481#[derive(Debug, Clone)]
483pub struct RecoveryAction {
484 pub action_type: RecoveryStrategy,
486 pub target_node: usize,
488 pub parameters: HashMap<String, String>,
490}
491
492#[derive(Debug)]
494pub struct ClusterState {
495 pub active_nodes: Vec<usize>,
497 pub total_data_points: usize,
499 pub total_partitions: usize,
501 pub health_score: f64,
503 pub performance_metrics: ClusterPerformanceMetrics,
505}
506
507#[derive(Debug, Clone)]
509pub struct ClusterPerformanceMetrics {
510 pub avg_query_latency_ms: f64,
512 pub throughput_qps: f64,
514 pub load_balance_score: f64,
516 pub fault_tolerance_level: f64,
518}
519
520#[derive(Debug)]
522pub struct DistributedSpatialIndex {
523 pub local_index: LocalSpatialIndex,
525 pub global_metadata: GlobalIndexMetadata,
527 pub routing_table: RoutingTable,
529}
530
531#[derive(Debug)]
533pub struct LocalSpatialIndex {
534 pub kdtree: Option<crate::KDTree<f64, crate::EuclideanDistance<f64>>>,
536 pub bounds: SpatialBounds,
538 pub stats: IndexStatistics,
540}
541
542#[derive(Debug, Clone)]
544pub struct GlobalIndexMetadata {
545 pub global_bounds: SpatialBounds,
547 pub partition_map: HashMap<usize, SpatialBounds>,
549 pub version: usize,
551}
552
553#[derive(Debug)]
555pub struct RoutingTable {
556 pub entries: BTreeMap<SpatialKey, Vec<usize>>,
558 pub cache: HashMap<SpatialKey, Vec<usize>>,
560}
561
562#[derive(Debug, Clone, PartialOrd, Ord, PartialEq, Eq)]
564pub struct SpatialKey {
565 pub z_order: u64,
567 pub level: usize,
569}
570
571#[derive(Debug, Clone)]
573pub struct IndexStatistics {
574 pub build_time_ms: f64,
576 pub memory_usage_bytes: usize,
578 pub query_count: u64,
580 pub avg_query_time_ms: f64,
582}
583
584impl DistributedSpatialCluster {
585 pub fn new(config: NodeConfig) -> SpatialResult<Self> {
587 let mut nodes = Vec::new();
588 let mut channels = HashMap::new();
589
590 for node_id in 0..config.node_count {
592 let (sender, receiver) = mpsc::channel(1000);
593 channels.insert(node_id, sender);
594
595 let node = NodeInstance {
596 node_id,
597 status: NodeStatus::Active,
598 local_data: None,
599 local_index: None,
600 load_metrics: LoadMetrics {
601 cpu_utilization: 0.0,
602 memory_utilization: 0.0,
603 network_utilization: 0.0,
604 partition_count: 0,
605 operation_count: 0,
606 last_update: Instant::now(),
607 },
608 last_heartbeat: Instant::now(),
609 assigned_partitions: Vec::new(),
610 };
611
612 nodes.push(Arc::new(TokioRwLock::new(node)));
613 }
614
615 let load_balancer = LoadBalancer {
616 node_loads: HashMap::new(),
617 strategy: LoadBalancingStrategy::AdaptiveLoad,
618 last_rebalance: Instant::now(),
619 load_threshold: 0.8,
620 };
621
622 let fault_detector = FaultDetector {
623 node_health: HashMap::new(),
624 failure_threshold: Duration::from_secs(10),
625 recovery_strategies: HashMap::new(),
626 };
627
628 let communication = CommunicationLayer {
629 channels,
630 compression_enabled: config.compression,
631 stats: CommunicationStats {
632 messages_sent: 0,
633 messages_received: 0,
634 bytes_sent: 0,
635 bytes_received: 0,
636 average_latency_ms: 0.0,
637 },
638 };
639
640 let cluster_state = ClusterState {
641 active_nodes: (0..config.node_count).collect(),
642 total_data_points: 0,
643 total_partitions: 0,
644 health_score: 1.0,
645 performance_metrics: ClusterPerformanceMetrics {
646 avg_query_latency_ms: 0.0,
647 throughput_qps: 0.0,
648 load_balance_score: 1.0,
649 fault_tolerance_level: if config.fault_tolerance { 0.8 } else { 0.0 },
650 },
651 };
652
653 Ok(Self {
654 config,
655 nodes,
656 master_node_id: 0,
657 partitions: Arc::new(TokioRwLock::new(HashMap::new())),
658 load_balancer: Arc::new(TokioRwLock::new(load_balancer)),
659 fault_detector: Arc::new(TokioRwLock::new(fault_detector)),
660 communication: Arc::new(TokioRwLock::new(communication)),
661 cluster_state: Arc::new(TokioRwLock::new(cluster_state)),
662 })
663 }
664
665 #[allow(dead_code)]
667 fn default_recovery_strategies(&self) -> HashMap<FailureType, RecoveryStrategy> {
668 let mut strategies = HashMap::new();
669 strategies.insert(FailureType::NodeUnresponsive, RecoveryStrategy::Restart);
670 strategies.insert(FailureType::HighLatency, RecoveryStrategy::WaitAndRetry);
671 strategies.insert(FailureType::ResourceExhaustion, RecoveryStrategy::Migrate);
672 strategies.insert(FailureType::PartialFailure, RecoveryStrategy::Replicate);
673 strategies.insert(FailureType::NetworkPartition, RecoveryStrategy::Isolate);
674 strategies
675 }
676
677 pub async fn distribute_data(&mut self, data: &ArrayView2<'_, f64>) -> SpatialResult<()> {
679 let (n_points, n_dims) = data.dim();
680
681 let partitions = self.create_spatial_partitions(data).await?;
683
684 self.assign_partitions_to_nodes(&partitions).await?;
686
687 self.build_distributed_indices().await?;
689
690 {
692 let mut state = self.cluster_state.write().await;
693 state.total_data_points = n_points;
694 state.total_partitions = partitions.len();
695 }
696
697 Ok(())
698 }
699
700 async fn create_spatial_partitions(
702 &self,
703 data: &ArrayView2<'_, f64>,
704 ) -> SpatialResult<Vec<DataPartition>> {
705 let (n_points, n_dims) = data.dim();
706 let target_partitions = self.config.node_count * 2; let mut min_coords = Array1::from_elem(n_dims, f64::INFINITY);
710 let mut max_coords = Array1::from_elem(n_dims, f64::NEG_INFINITY);
711
712 for point in data.outer_iter() {
713 for (i, &coord) in point.iter().enumerate() {
714 min_coords[i] = min_coords[i].min(coord);
715 max_coords[i] = max_coords[i].max(coord);
716 }
717 }
718
719 let global_bounds = SpatialBounds {
720 min_coords,
721 max_coords,
722 };
723
724 let mut point_z_orders = Vec::new();
726 for (i, point) in data.outer_iter().enumerate() {
727 let z_order = self.calculate_z_order(&point.to_owned(), &global_bounds, 16);
728 point_z_orders.push((i, z_order, point.to_owned()));
729 }
730
731 point_z_orders.sort_by_key(|(_, z_order_, _)| *z_order_);
733
734 let points_per_partition = n_points.div_ceil(target_partitions);
736 let mut partitions = Vec::new();
737
738 for partition_id in 0..target_partitions {
739 let start_idx = partition_id * points_per_partition;
740 let end_idx = ((partition_id + 1) * points_per_partition).min(n_points);
741
742 if start_idx >= n_points {
743 break;
744 }
745
746 let partition_size = end_idx - start_idx;
748 let mut partition_data = Array2::zeros((partition_size, n_dims));
749 let mut partition_min = Array1::from_elem(n_dims, f64::INFINITY);
750 let mut partition_max = Array1::from_elem(n_dims, f64::NEG_INFINITY);
751
752 for (i, (_, _, point)) in point_z_orders[start_idx..end_idx].iter().enumerate() {
753 partition_data.row_mut(i).assign(point);
754
755 for (j, &coord) in point.iter().enumerate() {
756 partition_min[j] = partition_min[j].min(coord);
757 partition_max[j] = partition_max[j].max(coord);
758 }
759 }
760
761 let partition_bounds = SpatialBounds {
762 min_coords: partition_min,
763 max_coords: partition_max,
764 };
765
766 let partition = DataPartition {
767 partition_id,
768 bounds: partition_bounds,
769 data: partition_data,
770 primary_node: partition_id % self.config.node_count,
771 replica_nodes: if self.config.fault_tolerance {
772 vec![(partition_id + 1) % self.config.node_count]
773 } else {
774 Vec::new()
775 },
776 size: partition_size,
777 last_modified: Instant::now(),
778 };
779
780 partitions.push(partition);
781 }
782
783 Ok(partitions)
784 }
785
786 fn calculate_z_order(
788 &self,
789 point: &Array1<f64>,
790 bounds: &SpatialBounds,
791 resolution: usize,
792 ) -> u64 {
793 let mut z_order = 0u64;
794
795 for bit in 0..resolution {
796 for (dim, ((&coord, &min_coord), &max_coord)) in point
797 .iter()
798 .zip(bounds.min_coords.iter())
799 .zip(bounds.max_coords.iter())
800 .enumerate()
801 {
802 if dim >= 3 {
803 break;
804 } let normalized = if max_coord > min_coord {
807 (coord - min_coord) / (max_coord - min_coord)
808 } else {
809 0.5
810 };
811
812 let bit_val = if normalized >= 0.5 { 1u64 } else { 0u64 };
813 let bit_pos = bit * 3 + dim; if bit_pos < 64 {
816 z_order |= bit_val << bit_pos;
817 }
818 }
819 }
820
821 z_order
822 }
823
824 async fn assign_partitions_to_nodes(
826 &mut self,
827 partitions: &[DataPartition],
828 ) -> SpatialResult<()> {
829 let mut partition_map = HashMap::new();
830
831 for partition in partitions {
832 partition_map.insert(partition.partition_id, partition.clone());
833
834 let primary_node = &self.nodes[partition.primary_node];
836 {
837 let mut node = primary_node.write().await;
838 node.assigned_partitions.push(partition.partition_id);
839
840 if let Some(ref existing_data) = node.local_data {
842 let (existing_rows, cols) = existing_data.dim();
844 let (new_rows_, _) = partition.data.dim();
845 let total_rows = existing_rows + new_rows_;
846
847 let mut combined_data = Array2::zeros((total_rows, cols));
848 combined_data
849 .slice_mut(s![..existing_rows, ..])
850 .assign(existing_data);
851 combined_data
852 .slice_mut(s![existing_rows.., ..])
853 .assign(&partition.data);
854 node.local_data = Some(combined_data);
855 } else {
856 node.local_data = Some(partition.data.clone());
857 }
858
859 node.load_metrics.partition_count += 1;
860 }
861
862 for &replica_node_id in &partition.replica_nodes {
864 let replica_node = &self.nodes[replica_node_id];
865 let mut node = replica_node.write().await;
866 node.assigned_partitions.push(partition.partition_id);
867
868 if let Some(ref existing_data) = node.local_data {
870 let (existing_rows, cols) = existing_data.dim();
872 let (new_rows_, _) = partition.data.dim();
873 let total_rows = existing_rows + new_rows_;
874
875 let mut combined_data = Array2::zeros((total_rows, cols));
876 combined_data
877 .slice_mut(s![..existing_rows, ..])
878 .assign(existing_data);
879 combined_data
880 .slice_mut(s![existing_rows.., ..])
881 .assign(&partition.data);
882 node.local_data = Some(combined_data);
883 } else {
884 node.local_data = Some(partition.data.clone());
885 }
886
887 node.load_metrics.partition_count += 1;
888 }
889 }
890
891 {
892 let mut partitions_lock = self.partitions.write().await;
893 *partitions_lock = partition_map;
894 }
895
896 Ok(())
897 }
898
899 async fn build_distributed_indices(&mut self) -> SpatialResult<()> {
901 for node_arc in &self.nodes {
903 let mut node = node_arc.write().await;
904
905 if let Some(ref local_data) = node.local_data {
906 let (n_points, n_dims) = local_data.dim();
908 let mut min_coords = Array1::from_elem(n_dims, f64::INFINITY);
909 let mut max_coords = Array1::from_elem(n_dims, f64::NEG_INFINITY);
910
911 for point in local_data.outer_iter() {
912 for (i, &coord) in point.iter().enumerate() {
913 min_coords[i] = min_coords[i].min(coord);
914 max_coords[i] = max_coords[i].max(coord);
915 }
916 }
917
918 let local_bounds = SpatialBounds {
919 min_coords,
920 max_coords,
921 };
922
923 let kdtree = crate::KDTree::new(local_data)?;
925
926 let local_index = LocalSpatialIndex {
927 kdtree: Some(kdtree),
928 bounds: local_bounds.clone(),
929 stats: IndexStatistics {
930 build_time_ms: 0.0, memory_usage_bytes: n_points * n_dims * 8, query_count: 0,
933 avg_query_time_ms: 0.0,
934 },
935 };
936
937 let routing_table = RoutingTable {
939 entries: BTreeMap::new(),
940 cache: HashMap::new(),
941 };
942
943 let global_metadata = GlobalIndexMetadata {
945 global_bounds: local_bounds.clone(), partition_map: HashMap::new(),
947 version: 1,
948 };
949
950 let distributed_index = DistributedSpatialIndex {
951 local_index,
952 global_metadata,
953 routing_table,
954 };
955
956 node.local_index = Some(distributed_index);
957 }
958 }
959
960 Ok(())
961 }
962
963 pub async fn distributed_kmeans(
965 &mut self,
966 k: usize,
967 max_iterations: usize,
968 ) -> SpatialResult<(Array2<f64>, Array1<usize>)> {
969 let initial_centroids = self.initialize_distributed_centroids(k).await?;
971 let mut centroids = initial_centroids;
972
973 for _iteration in 0..max_iterations {
974 let local_assignments = self.distributed_assignment_step(¢roids).await?;
976
977 let new_centroids = self
979 .distributed_centroid_update(&local_assignments, k)
980 .await?;
981
982 let centroid_change = self.calculate_centroid_change(¢roids, &new_centroids);
984 if centroid_change < 1e-6 {
985 break;
986 }
987
988 centroids = new_centroids;
989 }
990
991 let final_assignments = self.collect_final_assignments(¢roids).await?;
993
994 Ok((centroids, final_assignments))
995 }
996
997 async fn initialize_distributed_centroids(&self, k: usize) -> SpatialResult<Array2<f64>> {
999 let first_centroid = self.get_random_point_from_cluster().await?;
1001
1002 let n_dims = first_centroid.len();
1003 let mut centroids = Array2::zeros((k, n_dims));
1004 centroids.row_mut(0).assign(&first_centroid);
1005
1006 for i in 1..k {
1008 let distances = self
1009 .compute_distributed_distances(¢roids.slice(s![..i, ..]))
1010 .await?;
1011 let next_centroid = self.select_next_centroid_weighted(&distances).await?;
1012 centroids.row_mut(i).assign(&next_centroid);
1013 }
1014
1015 Ok(centroids)
1016 }
1017
1018 async fn get_random_point_from_cluster(&self) -> SpatialResult<Array1<f64>> {
1020 for node_arc in &self.nodes {
1021 let node = node_arc.read().await;
1022 if let Some(ref local_data) = node.local_data {
1023 if local_data.nrows() > 0 {
1024 let idx = (rand::random::<f64>() * local_data.nrows() as f64) as usize;
1025 return Ok(local_data.row(idx).to_owned());
1026 }
1027 }
1028 }
1029
1030 Err(SpatialError::InvalidInput(
1031 "No data found in cluster".to_string(),
1032 ))
1033 }
1034
1035 async fn compute_distributed_distances(
1037 &self,
1038 centroids: &ArrayView2<'_, f64>,
1039 ) -> SpatialResult<Vec<f64>> {
1040 let mut all_distances = Vec::new();
1041
1042 for node_arc in &self.nodes {
1043 let node = node_arc.read().await;
1044 if let Some(ref local_data) = node.local_data {
1045 for point in local_data.outer_iter() {
1046 let mut min_distance = f64::INFINITY;
1047
1048 for centroid in centroids.outer_iter() {
1049 let distance: f64 = point
1050 .iter()
1051 .zip(centroid.iter())
1052 .map(|(&a, &b)| (a - b).powi(2))
1053 .sum::<f64>()
1054 .sqrt();
1055
1056 min_distance = min_distance.min(distance);
1057 }
1058
1059 all_distances.push(min_distance);
1060 }
1061 }
1062 }
1063
1064 Ok(all_distances)
1065 }
1066
1067 async fn select_next_centroid_weighted(
1069 &self,
1070 _distances: &[f64],
1071 ) -> SpatialResult<Array1<f64>> {
1072 let total_distance: f64 = _distances.iter().sum();
1073 let target = rand::random::<f64>() * total_distance;
1074
1075 let mut cumulative = 0.0;
1076 let mut point_index = 0;
1077
1078 for &distance in _distances {
1079 cumulative += distance;
1080 if cumulative >= target {
1081 break;
1082 }
1083 point_index += 1;
1084 }
1085
1086 let mut current_index = 0;
1088 for node_arc in &self.nodes {
1089 let node = node_arc.read().await;
1090 if let Some(ref local_data) = node.local_data {
1091 if current_index + local_data.nrows() > point_index {
1092 let local_index = point_index - current_index;
1093 return Ok(local_data.row(local_index).to_owned());
1094 }
1095 current_index += local_data.nrows();
1096 }
1097 }
1098
1099 Err(SpatialError::InvalidInput(
1100 "Point index out of range".to_string(),
1101 ))
1102 }
1103
1104 async fn distributed_assignment_step(
1106 &self,
1107 centroids: &Array2<f64>,
1108 ) -> SpatialResult<Vec<(usize, Array1<usize>)>> {
1109 let mut local_assignments = Vec::new();
1110
1111 for (node_id, node_arc) in self.nodes.iter().enumerate() {
1112 let node = node_arc.read().await;
1113 if let Some(ref local_data) = node.local_data {
1114 let (n_points_, _) = local_data.dim();
1115 let mut assignments = Array1::zeros(n_points_);
1116
1117 for (i, point) in local_data.outer_iter().enumerate() {
1118 let mut best_cluster = 0;
1119 let mut best_distance = f64::INFINITY;
1120
1121 for (j, centroid) in centroids.outer_iter().enumerate() {
1122 let distance: f64 = point
1123 .iter()
1124 .zip(centroid.iter())
1125 .map(|(&a, &b)| (a - b).powi(2))
1126 .sum::<f64>()
1127 .sqrt();
1128
1129 if distance < best_distance {
1130 best_distance = distance;
1131 best_cluster = j;
1132 }
1133 }
1134
1135 assignments[i] = best_cluster;
1136 }
1137
1138 local_assignments.push((node_id, assignments));
1139 }
1140 }
1141
1142 Ok(local_assignments)
1143 }
1144
1145 async fn distributed_centroid_update(
1147 &self,
1148 local_assignments: &[(usize, Array1<usize>)],
1149 k: usize,
1150 ) -> SpatialResult<Array2<f64>> {
1151 let mut cluster_sums: HashMap<usize, Array1<f64>> = HashMap::new();
1153 let mut cluster_counts: HashMap<usize, usize> = HashMap::new();
1154
1155 for (node_id, assignments) in local_assignments {
1156 let node = self.nodes[*node_id].read().await;
1157 if let Some(ref local_data) = node.local_data {
1158 let (_, n_dims) = local_data.dim();
1159
1160 for (i, &cluster) in assignments.iter().enumerate() {
1161 let point = local_data.row(i);
1162
1163 let cluster_sum = cluster_sums
1164 .entry(cluster)
1165 .or_insert_with(|| Array1::zeros(n_dims));
1166 let cluster_count = cluster_counts.entry(cluster).or_insert(0);
1167
1168 for (j, &coord) in point.iter().enumerate() {
1169 cluster_sum[j] += coord;
1170 }
1171 *cluster_count += 1;
1172 }
1173 }
1174 }
1175
1176 let n_dims = cluster_sums
1178 .values()
1179 .next()
1180 .map(|sum| sum.len())
1181 .unwrap_or(2);
1182
1183 let mut new_centroids = Array2::zeros((k, n_dims));
1184
1185 for cluster in 0..k {
1186 if let (Some(sum), Some(&count)) =
1187 (cluster_sums.get(&cluster), cluster_counts.get(&cluster))
1188 {
1189 if count > 0 {
1190 for j in 0..n_dims {
1191 new_centroids[[cluster, j]] = sum[j] / count as f64;
1192 }
1193 }
1194 }
1195 }
1196
1197 Ok(new_centroids)
1198 }
1199
1200 fn calculate_centroid_change(
1202 &self,
1203 old_centroids: &Array2<f64>,
1204 new_centroids: &Array2<f64>,
1205 ) -> f64 {
1206 let mut total_change = 0.0;
1207
1208 for (old_row, new_row) in old_centroids.outer_iter().zip(new_centroids.outer_iter()) {
1209 let change: f64 = old_row
1210 .iter()
1211 .zip(new_row.iter())
1212 .map(|(&a, &b)| (a - b).powi(2))
1213 .sum::<f64>()
1214 .sqrt();
1215 total_change += change;
1216 }
1217
1218 total_change / old_centroids.nrows() as f64
1219 }
1220
1221 async fn collect_final_assignments(
1223 &self,
1224 centroids: &Array2<f64>,
1225 ) -> SpatialResult<Array1<usize>> {
1226 let mut all_assignments = Vec::new();
1227
1228 for node_arc in &self.nodes {
1229 let node = node_arc.read().await;
1230 if let Some(ref local_data) = node.local_data {
1231 for point in local_data.outer_iter() {
1232 let mut best_cluster = 0;
1233 let mut best_distance = f64::INFINITY;
1234
1235 for (j, centroid) in centroids.outer_iter().enumerate() {
1236 let distance: f64 = point
1237 .iter()
1238 .zip(centroid.iter())
1239 .map(|(&a, &b)| (a - b).powi(2))
1240 .sum::<f64>()
1241 .sqrt();
1242
1243 if distance < best_distance {
1244 best_distance = distance;
1245 best_cluster = j;
1246 }
1247 }
1248
1249 all_assignments.push(best_cluster);
1250 }
1251 }
1252 }
1253
1254 Ok(Array1::from(all_assignments))
1255 }
1256
1257 pub async fn distributed_knn_search(
1259 &self,
1260 query_point: &ArrayView1<'_, f64>,
1261 k: usize,
1262 ) -> SpatialResult<Vec<(usize, f64)>> {
1263 let mut all_neighbors = Vec::new();
1264
1265 for node_arc in &self.nodes {
1267 let node = node_arc.read().await;
1268 if let Some(ref local_index) = node.local_index {
1269 if let Some(ref kdtree) = local_index.local_index.kdtree {
1270 if local_index.local_index.bounds.contains(query_point) {
1272 let (indices, distances) =
1273 kdtree.query(query_point.as_slice().unwrap(), k)?;
1274
1275 for (idx, dist) in indices.iter().zip(distances.iter()) {
1276 all_neighbors.push((*idx, *dist));
1277 }
1278 }
1279 }
1280 }
1281 }
1282
1283 all_neighbors.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
1285 all_neighbors.truncate(k);
1286
1287 Ok(all_neighbors)
1288 }
1289
1290 pub async fn get_cluster_statistics(&self) -> SpatialResult<ClusterStatistics> {
1292 let state = self.cluster_state.read().await;
1293 let _load_balancer = self.load_balancer.read().await;
1294 let communication = self.communication.read().await;
1295
1296 let active_node_count = state.active_nodes.len();
1297 let total_partitions = state.total_partitions;
1298 let avg_partitions_per_node = if active_node_count > 0 {
1299 total_partitions as f64 / active_node_count as f64
1300 } else {
1301 0.0
1302 };
1303
1304 Ok(ClusterStatistics {
1305 active_nodes: active_node_count,
1306 total_data_points: state.total_data_points,
1307 total_partitions,
1308 avg_partitions_per_node,
1309 health_score: state.health_score,
1310 load_balance_score: state.performance_metrics.load_balance_score,
1311 avg_query_latency_ms: state.performance_metrics.avg_query_latency_ms,
1312 throughput_qps: state.performance_metrics.throughput_qps,
1313 total_messages_sent: communication.stats.messages_sent,
1314 total_bytes_sent: communication.stats.bytes_sent,
1315 avg_communication_latency_ms: communication.stats.average_latency_ms,
1316 })
1317 }
1318}
1319
1320#[derive(Debug, Clone)]
1322pub struct ClusterStatistics {
1323 pub active_nodes: usize,
1324 pub total_data_points: usize,
1325 pub total_partitions: usize,
1326 pub avg_partitions_per_node: f64,
1327 pub health_score: f64,
1328 pub load_balance_score: f64,
1329 pub avg_query_latency_ms: f64,
1330 pub throughput_qps: f64,
1331 pub total_messages_sent: u64,
1332 pub total_bytes_sent: u64,
1333 pub avg_communication_latency_ms: f64,
1334}
1335
1336#[cfg(test)]
1337mod tests {
1338 use super::*;
1339 use ndarray::array;
1340
1341 #[test]
1342 fn test_nodeconfig() {
1343 let config = NodeConfig::new()
1344 .with_node_count(4)
1345 .with_fault_tolerance(true)
1346 .with_load_balancing(true);
1347
1348 assert_eq!(config.node_count, 4);
1349 assert!(config.fault_tolerance);
1350 assert!(config.load_balancing);
1351 assert_eq!(config.replication_factor, 2);
1352 }
1353
1354 #[test]
1355 fn test_spatial_bounds() {
1356 let bounds = SpatialBounds {
1357 min_coords: array![0.0, 0.0],
1358 max_coords: array![1.0, 1.0],
1359 };
1360
1361 assert!(bounds.contains(&array![0.5, 0.5].view()));
1362 assert!(!bounds.contains(&array![1.5, 0.5].view()));
1363 assert_eq!(bounds.volume(), 1.0);
1364 }
1365
1366 #[test]
1367 fn test_load_metrics() {
1368 let metrics = LoadMetrics {
1369 cpu_utilization: 0.5,
1370 memory_utilization: 0.3,
1371 network_utilization: 0.2,
1372 partition_count: 2,
1373 operation_count: 100,
1374 last_update: Instant::now(),
1375 };
1376
1377 let load_score = metrics.load_score();
1378 assert!(load_score > 0.0 && load_score < 1.0);
1379 }
1380
1381 #[tokio::test]
1382 async fn test_distributed_cluster_creation() {
1383 let config = NodeConfig::new()
1384 .with_node_count(2)
1385 .with_fault_tolerance(false);
1386
1387 let cluster = DistributedSpatialCluster::new(config);
1388 assert!(cluster.is_ok());
1389
1390 let cluster = cluster.unwrap();
1391 assert_eq!(cluster.nodes.len(), 2);
1392 assert_eq!(cluster.master_node_id, 0);
1393 }
1394
1395 #[tokio::test]
1396 async fn test_data_distribution() {
1397 let config = NodeConfig::new()
1398 .with_node_count(2)
1399 .with_fault_tolerance(false);
1400
1401 let mut cluster = DistributedSpatialCluster::new(config).unwrap();
1402 let data = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]];
1403
1404 let result = cluster.distribute_data(&data.view()).await;
1405 assert!(result.is_ok());
1406
1407 let stats = cluster.get_cluster_statistics().await.unwrap();
1408 assert_eq!(stats.total_data_points, 4);
1409 assert!(stats.total_partitions > 0);
1410 }
1411
1412 #[tokio::test]
1413 async fn test_distributed_kmeans() {
1414 let config = NodeConfig::new().with_node_count(2);
1415 let mut cluster = DistributedSpatialCluster::new(config).unwrap();
1416
1417 let data = array![
1418 [0.0, 0.0],
1419 [1.0, 0.0],
1420 [0.0, 1.0],
1421 [1.0, 1.0],
1422 [10.0, 10.0],
1423 [11.0, 10.0]
1424 ];
1425 cluster.distribute_data(&data.view()).await.unwrap();
1426
1427 let result = cluster.distributed_kmeans(2, 10).await;
1428 assert!(result.is_ok());
1429
1430 let (centroids, assignments) = result.unwrap();
1431 assert_eq!(centroids.nrows(), 2);
1432 assert_eq!(assignments.len(), 6);
1433 }
1434}