flow_rs_core/spatial/
index.rs

1//! Core spatial indexing implementation
2
3use std::collections::HashMap;
4
5use crate::error::Result;
6use crate::graph::Node;
7use crate::types::{NodeId, Position, Rect, Viewport};
8
9use super::grid::GridCell;
10
11/// Spatial index for efficient viewport and proximity queries
12pub struct SpatialIndex {
13    entries: HashMap<NodeId, SpatialEntry>,
14    cell_size: f64,
15    grid: HashMap<GridCell, Vec<NodeId>>,
16}
17
18/// Entry in the spatial index
19#[derive(Debug, Clone)]
20struct SpatialEntry {
21    #[allow(dead_code)]
22    node_id: NodeId,
23    bounds: Rect,
24    grid_cells: Vec<GridCell>,
25}
26
27impl Default for SpatialIndex {
28    fn default() -> Self {
29        Self::new()
30    }
31}
32
33impl SpatialIndex {
34    /// Create a new spatial index with default cell size
35    pub fn new() -> Self {
36        Self::with_cell_size(100.0) // Default 100x100 pixel cells
37    }
38
39    /// Create a new spatial index with specified cell size
40    pub fn with_cell_size(cell_size: f64) -> Self {
41        Self {
42            entries: HashMap::new(),
43            cell_size,
44            grid: HashMap::new(),
45        }
46    }
47
48    /// Insert a node into the spatial index
49    pub fn insert<T>(&mut self, node: &Node<T>) -> Result<()>
50    where
51        T: Clone,
52    {
53        // Remove existing entry if present
54        self.remove(&node.id);
55
56        let bounds = node.bounds();
57        let grid_cells = self.get_grid_cells_for_bounds(&bounds);
58
59        let entry = SpatialEntry {
60            node_id: node.id.clone(),
61            bounds,
62            grid_cells: grid_cells.clone(),
63        };
64
65        // Add to grid cells
66        for cell in &grid_cells {
67            self.grid.entry(*cell).or_default().push(node.id.clone());
68        }
69
70        // Store entry
71        self.entries.insert(node.id.clone(), entry);
72
73        Ok(())
74    }
75
76    /// Remove a node from the spatial index
77    pub fn remove(&mut self, node_id: &NodeId) -> bool {
78        if let Some(entry) = self.entries.remove(node_id) {
79            // Remove from all grid cells
80            for cell in &entry.grid_cells {
81                if let Some(cell_nodes) = self.grid.get_mut(cell) {
82                    cell_nodes.retain(|id| id != node_id);
83                    if cell_nodes.is_empty() {
84                        self.grid.remove(cell);
85                    }
86                }
87            }
88            true
89        } else {
90            false
91        }
92    }
93
94    /// Update a node's position in the spatial index
95    pub fn update<T>(&mut self, node: &Node<T>) -> Result<()>
96    where
97        T: Clone,
98    {
99        self.remove(&node.id);
100        self.insert(node)
101    }
102
103    /// Clear all entries from the spatial index
104    pub fn clear(&mut self) {
105        self.entries.clear();
106        self.grid.clear();
107    }
108
109    /// Query nodes within a rectangular area
110    pub fn query_rect(&self, bounds: &Rect) -> Vec<NodeId> {
111        let mut results = Vec::new();
112        let mut seen = std::collections::HashSet::new();
113
114        let grid_cells = self.get_grid_cells_for_bounds(bounds);
115
116        for cell in grid_cells {
117            if let Some(cell_nodes) = self.grid.get(&cell) {
118                for node_id in cell_nodes {
119                    if !seen.contains(node_id) {
120                        if let Some(entry) = self.entries.get(node_id) {
121                            if bounds.intersects(&entry.bounds) {
122                                results.push(node_id.clone());
123                                seen.insert(node_id.clone());
124                            }
125                        }
126                    }
127                }
128            }
129        }
130
131        results
132    }
133
134    /// Query nodes within a viewport
135    pub fn query_viewport(&self, viewport: &Viewport) -> Vec<NodeId> {
136        let bounds = viewport.bounds();
137        self.query_rect(&bounds)
138    }
139
140    /// Query nodes within a circular area
141    pub fn query_radius(&self, center: Position, radius: f64) -> Vec<NodeId> {
142        let mut results = Vec::new();
143        let mut seen = std::collections::HashSet::new();
144
145        // Create a bounding box for the circular area
146        let bounds = Rect::new(
147            center.x - radius,
148            center.y - radius,
149            radius * 2.0,
150            radius * 2.0,
151        );
152
153        let grid_cells = self.get_grid_cells_for_bounds(&bounds);
154
155        for cell in grid_cells {
156            if let Some(cell_nodes) = self.grid.get(&cell) {
157                for node_id in cell_nodes {
158                    if !seen.contains(node_id) {
159                        if let Some(entry) = self.entries.get(node_id) {
160                            // Check if the node's center is within the radius
161                            let node_center = entry.bounds.center();
162                            if center.distance_to(node_center) <= radius {
163                                results.push(node_id.clone());
164                                seen.insert(node_id.clone());
165                            }
166                        }
167                    }
168                }
169            }
170        }
171
172        results
173    }
174
175    /// Find the nearest node to a given point
176    pub fn nearest(&self, point: Position) -> Option<NodeId> {
177        let mut nearest_id = None;
178        let mut nearest_distance = f64::INFINITY;
179
180        for (node_id, entry) in &self.entries {
181            let node_center = entry.bounds.center();
182            let distance = point.distance_to(node_center);
183
184            if distance < nearest_distance {
185                nearest_distance = distance;
186                nearest_id = Some(node_id.clone());
187            }
188        }
189
190        nearest_id
191    }
192
193    /// Get all node IDs in the index
194    pub fn node_ids(&self) -> Vec<NodeId> {
195        self.entries.keys().cloned().collect()
196    }
197
198    /// Get the number of nodes in the index
199    pub fn len(&self) -> usize {
200        self.entries.len()
201    }
202
203    /// Check if the index is empty
204    pub fn is_empty(&self) -> bool {
205        self.entries.is_empty()
206    }
207
208    /// Bulk load multiple nodes into the index
209    pub fn bulk_load<T>(&mut self, nodes: &[Node<T>]) -> Result<()>
210    where
211        T: Clone,
212    {
213        for node in nodes {
214            self.insert(node)?;
215        }
216        Ok(())
217    }
218
219    /// Get the bounding rectangle of all nodes in the index
220    pub fn bounds(&self) -> Option<Rect> {
221        if self.entries.is_empty() {
222            return None;
223        }
224
225        let mut min_x = f64::INFINITY;
226        let mut min_y = f64::INFINITY;
227        let mut max_x = f64::NEG_INFINITY;
228        let mut max_y = f64::NEG_INFINITY;
229
230        for entry in self.entries.values() {
231            min_x = min_x.min(entry.bounds.x);
232            min_y = min_y.min(entry.bounds.y);
233            max_x = max_x.max(entry.bounds.x + entry.bounds.width);
234            max_y = max_y.max(entry.bounds.y + entry.bounds.height);
235        }
236
237        Some(Rect::new(min_x, min_y, max_x - min_x, max_y - min_y))
238    }
239
240    /// Get the cell size of the spatial index
241    pub fn cell_size(&self) -> f64 {
242        self.cell_size
243    }
244
245    /// Get grid cells that a bounds rectangle intersects
246    pub fn get_grid_cells_for_bounds(&self, bounds: &Rect) -> Vec<GridCell> {
247        // Handle invalid bounds
248        if !bounds.is_valid() || bounds.width < 0.0 || bounds.height < 0.0 {
249            return Vec::new();
250        }
251
252        // Handle NaN or infinity values
253        if bounds.x.is_nan() || bounds.y.is_nan() || bounds.width.is_nan() || bounds.height.is_nan() ||
254           bounds.x.is_infinite() || bounds.y.is_infinite() || bounds.width.is_infinite() || bounds.height.is_infinite() {
255            return Vec::new();
256        }
257
258        let min_cell_x = (bounds.x / self.cell_size).floor() as i32;
259        let min_cell_y = (bounds.y / self.cell_size).floor() as i32;
260        let max_cell_x = ((bounds.x + bounds.width) / self.cell_size).floor() as i32;
261        let max_cell_y = ((bounds.y + bounds.height) / self.cell_size).floor() as i32;
262
263        let mut cells = Vec::new();
264        for x in min_cell_x..=max_cell_x {
265            for y in min_cell_y..=max_cell_y {
266                cells.push(GridCell::new(x, y));
267            }
268        }
269        cells
270    }
271}