Skip to main content

graph_core/
viewport.rs

1//! R-tree viewport spatial index.
2//!
3//! Wraps `rstar::RTree` with an incremental-update policy so the
4//! caller can stream position changes from the force simulation
5//! into the tree without paying full O(n log n) bulk-rebuild cost
6//! on every tick. Queries return ids filtered to a viewport
7//! axis-aligned bounding box, optionally truncated and ordered by
8//! score.
9//!
10//! ## Update policy
11//!
12//! `update_positions` appends changes to a dirty buffer and
13//! updates a per-id position cache. Once the dirty-buffer size
14//! exceeds `max(1000, n / 20)` the tree is rebuilt in bulk.
15//! Rebuild is O(n log n); per-update cost is amortised O(1)
16//! across a burst of updates. The threshold was picked to keep
17//! rebuild-induced pauses rare even on 100k-node graphs
18//! (5000 updates before rebuild = ~5% of node count) while still
19//! keeping queries close to up-to-date: the default
20//! force-simulation tick cadence moves ≤ 100 nodes per frame,
21//! so a 5000-update buffer holds ≈ 50 frames of drift.
22
23use glam::Vec2;
24use rstar::primitives::GeomWithData;
25use rstar::{RTree, AABB};
26#[cfg(feature = "serde_support")]
27use serde::{Deserialize, Serialize};
28use std::collections::HashMap;
29
30/// A single indexable point. `position` is an `[f32; 2]` so it
31/// matches rstar's `Point` impl for arrays directly; we convert
32/// to/from `glam::Vec2` at the API boundary only.
33#[cfg_attr(feature = "serde_support", derive(Serialize, Deserialize))]
34#[derive(Debug, Clone, Copy, PartialEq)]
35pub struct IndexPoint {
36    pub id: u32,
37    pub position: [f32; 2],
38    pub score: f32,
39}
40
41/// How to order the ids returned by [`ViewportIndex::query`].
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum ScoreKey {
44    /// Tree-traversal order. Cheapest; caller shouldn't rely on
45    /// any particular ordering.
46    Natural,
47    /// Highest score first.
48    Desc,
49    /// Lowest score first.
50    Asc,
51}
52
53// `GeomWithData<[f32; 2], (u32, f32)>`: payload is (id, score).
54type Entry = GeomWithData<[f32; 2], (u32, f32)>;
55
56/// R-tree spatial index with a lazy-rebuild incremental-update
57/// policy. Construct with [`ViewportIndex::new`] and populate with
58/// [`ViewportIndex::rebuild`].
59pub struct ViewportIndex {
60    tree: RTree<Entry>,
61    // Current position and score for every id known to the index.
62    // Used to (a) translate incremental updates into the final
63    // `IndexPoint` set that a rebuild needs, and (b) let callers
64    // update the same id multiple times before a rebuild hits.
65    state: HashMap<u32, IndexPoint>,
66    // Pending incremental position updates. Position-only — score
67    // doesn't drift between ticks so it comes in only via `rebuild`.
68    dirty_count: usize,
69}
70
71impl ViewportIndex {
72    pub fn new() -> Self {
73        Self {
74            tree: RTree::new(),
75            state: HashMap::new(),
76            dirty_count: 0,
77        }
78    }
79
80    /// Number of points currently indexed (or pending in the dirty
81    /// buffer — the state map is the source of truth).
82    #[inline]
83    pub fn len(&self) -> usize {
84        self.state.len()
85    }
86
87    #[inline]
88    pub fn is_empty(&self) -> bool {
89        self.state.is_empty()
90    }
91
92    /// Bulk-load the index from `points`. O(n log n) via rstar's
93    /// STR packing. Clears any pending updates.
94    pub fn rebuild(&mut self, points: &[IndexPoint]) {
95        self.state.clear();
96        self.state.reserve(points.len());
97        let entries: Vec<Entry> = points
98            .iter()
99            .map(|p| {
100                self.state.insert(p.id, *p);
101                GeomWithData::new(p.position, (p.id, p.score))
102            })
103            .collect();
104        self.tree = RTree::bulk_load(entries);
105        self.dirty_count = 0;
106    }
107
108    /// Apply a batch of `(id, new_position)` updates. Updates are
109    /// stored in an internal position cache immediately; the
110    /// R-tree itself is rebuilt only when the dirty count crosses
111    /// `max(1000, n / 20)`. Amortised O(1) per update.
112    pub fn update_positions(&mut self, changes: &[(u32, Vec2)]) {
113        if changes.is_empty() {
114            return;
115        }
116        for &(id, pos) in changes {
117            let entry = self.state.entry(id).or_insert(IndexPoint {
118                id,
119                position: [pos.x, pos.y],
120                score: 0.0,
121            });
122            entry.position = [pos.x, pos.y];
123        }
124        self.dirty_count += changes.len();
125
126        let threshold = 1000usize.max(self.state.len() / 20);
127        if self.dirty_count > threshold {
128            // Rebuild from the state map. Materialise once; rstar's
129            // bulk_load consumes the vec.
130            let entries: Vec<Entry> = self
131                .state
132                .values()
133                .map(|p| GeomWithData::new(p.position, (p.id, p.score)))
134                .collect();
135            self.tree = RTree::bulk_load(entries);
136            self.dirty_count = 0;
137        }
138    }
139
140    /// Return the ids whose current position falls inside the
141    /// axis-aligned box `[min, max]`, optionally truncated and
142    /// ordered.
143    ///
144    /// The tree may be slightly out-of-date relative to the state
145    /// map between rebuilds — callers who need point-in-time
146    /// accuracy should call `rebuild` before querying. For the
147    /// viewport use case a handful of frames of drift is fine
148    /// (the user's viewport is typically moving faster than the
149    /// simulation is drifting).
150    pub fn query(&self, min: Vec2, max: Vec2, limit: usize, order: ScoreKey) -> Vec<u32> {
151        if self.tree.size() == 0 {
152            return Vec::new();
153        }
154
155        let aabb = AABB::from_corners([min.x, min.y], [max.x, max.y]);
156        // Collect (id, score) inside the envelope. `score` only
157        // matters when `order != Natural`.
158        let mut hits: Vec<(u32, f32)> = self
159            .tree
160            .locate_in_envelope_intersecting(&aabb)
161            .map(|e| e.data)
162            .collect();
163
164        match order {
165            ScoreKey::Natural => {}
166            ScoreKey::Desc => {
167                hits.sort_by(|a, b| {
168                    b.1.partial_cmp(&a.1)
169                        .unwrap_or(std::cmp::Ordering::Equal)
170                        .then_with(|| a.0.cmp(&b.0))
171                });
172            }
173            ScoreKey::Asc => {
174                hits.sort_by(|a, b| {
175                    a.1.partial_cmp(&b.1)
176                        .unwrap_or(std::cmp::Ordering::Equal)
177                        .then_with(|| a.0.cmp(&b.0))
178                });
179            }
180        }
181
182        hits.truncate(limit);
183        hits.into_iter().map(|(id, _)| id).collect()
184    }
185
186    /// Return the `(id, score)` pairs inside the axis-aligned box,
187    /// optionally truncated and ordered.
188    ///
189    /// Added in Phase 223 Wave 3. Same semantics as [`Self::query`]
190    /// but exposes the associated score so the PyO3 surface can
191    /// return both without a second lookup. The id-only [`Self::query`]
192    /// stays as a convenience for callers (including the WASM crate)
193    /// that already ignore the score.
194    pub fn query_with_scores(
195        &self,
196        min: Vec2,
197        max: Vec2,
198        limit: usize,
199        order: ScoreKey,
200    ) -> Vec<(u32, f32)> {
201        if self.tree.size() == 0 {
202            return Vec::new();
203        }
204
205        let aabb = AABB::from_corners([min.x, min.y], [max.x, max.y]);
206        let mut hits: Vec<(u32, f32)> = self
207            .tree
208            .locate_in_envelope_intersecting(&aabb)
209            .map(|e| e.data)
210            .collect();
211
212        match order {
213            ScoreKey::Natural => {}
214            ScoreKey::Desc => {
215                hits.sort_by(|a, b| {
216                    b.1.partial_cmp(&a.1)
217                        .unwrap_or(std::cmp::Ordering::Equal)
218                        .then_with(|| a.0.cmp(&b.0))
219                });
220            }
221            ScoreKey::Asc => {
222                hits.sort_by(|a, b| {
223                    a.1.partial_cmp(&b.1)
224                        .unwrap_or(std::cmp::Ordering::Equal)
225                        .then_with(|| a.0.cmp(&b.0))
226                });
227            }
228        }
229
230        hits.truncate(limit);
231        hits
232    }
233
234    /// Snapshot the indexed points — a clone of the internal state
235    /// map's values, in unspecified order.
236    ///
237    /// Phase 223 Wave 3: used by the PyO3 surface to hand the point
238    /// set to bincode for Redis snapshotting. The tree itself isn't
239    /// serialised; callers rebuild it in O(n log n) from the returned
240    /// slice via [`Self::rebuild`].
241    pub fn snapshot(&self) -> Vec<IndexPoint> {
242        self.state.values().copied().collect()
243    }
244}
245
246impl Default for ViewportIndex {
247    fn default() -> Self {
248        Self::new()
249    }
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255    use std::collections::HashSet;
256
257    fn grid_points() -> Vec<IndexPoint> {
258        // 5x5 grid of points at integer coordinates 0..5, id = x*10 + y,
259        // score = x + y.
260        let mut out = Vec::new();
261        for x in 0..5 {
262            for y in 0..5 {
263                out.push(IndexPoint {
264                    id: (x * 10 + y) as u32,
265                    position: [x as f32, y as f32],
266                    score: (x + y) as f32,
267                });
268            }
269        }
270        out
271    }
272
273    #[test]
274    fn known_rect_returns_expected_set() {
275        let mut idx = ViewportIndex::new();
276        idx.rebuild(&grid_points());
277        // Top-left 3x3 = x in 0..=2, y in 0..=2.
278        let got: HashSet<u32> = idx
279            .query(Vec2::new(0.0, 0.0), Vec2::new(2.0, 2.0), 100, ScoreKey::Natural)
280            .into_iter()
281            .collect();
282        let want: HashSet<u32> = (0..3u32)
283            .flat_map(|x| (0..3u32).map(move |y| x * 10 + y))
284            .collect();
285        assert_eq!(got, want);
286    }
287
288    #[test]
289    fn empty_rect_returns_empty() {
290        let mut idx = ViewportIndex::new();
291        idx.rebuild(&grid_points());
292        let got = idx.query(
293            Vec2::new(100.0, 100.0),
294            Vec2::new(200.0, 200.0),
295            100,
296            ScoreKey::Natural,
297        );
298        assert!(got.is_empty());
299    }
300
301    #[test]
302    fn limit_and_order_respected() {
303        let mut idx = ViewportIndex::new();
304        let points: Vec<IndexPoint> = (0..20u32)
305            .map(|i| IndexPoint {
306                id: i,
307                position: [i as f32 % 10.0, (i / 10) as f32],
308                score: (i as f32) * 0.5,
309            })
310            .collect();
311        idx.rebuild(&points);
312
313        // Query a rect that covers everything.
314        let got = idx.query(
315            Vec2::new(-1.0, -1.0),
316            Vec2::new(10.0, 10.0),
317            5,
318            ScoreKey::Desc,
319        );
320        // Top 5 by score (descending) should be ids 19, 18, 17, 16, 15.
321        assert_eq!(got, vec![19, 18, 17, 16, 15]);
322    }
323
324    #[test]
325    fn incremental_update_matches_rebuild() {
326        let mut base_points: Vec<IndexPoint> = (0..100u32)
327            .map(|i| IndexPoint {
328                id: i,
329                position: [(i as f32) * 0.5, (i as f32) * 0.3],
330                score: i as f32,
331            })
332            .collect();
333
334        let mut incremental = ViewportIndex::new();
335        incremental.rebuild(&base_points);
336
337        // Apply 30 position changes. Threshold is max(1000, 100/20=5) = 1000,
338        // so these stay in the dirty buffer without triggering a rebuild —
339        // exactly the "below threshold" branch we want to exercise. To
340        // compare query results we then force a rebuild on the incremental
341        // index by pushing past the threshold with a no-op re-application
342        // of the same updates (position already matches, so no new state).
343        let mut updates = Vec::new();
344        for i in 0..30u32 {
345            let new_pos = Vec2::new(50.0 + i as f32, 50.0 + i as f32);
346            updates.push((i, new_pos));
347            base_points[i as usize].position = [new_pos.x, new_pos.y];
348        }
349        incremental.update_positions(&updates);
350
351        let mut rebuilt = ViewportIndex::new();
352        rebuilt.rebuild(&base_points);
353
354        // Cross the dirty threshold (1000) with repeated no-op updates so
355        // the tree picks up the 30 changes. The state is unchanged; only
356        // the tree gets refreshed.
357        let mut more_updates = Vec::new();
358        for _ in 0..35 {
359            more_updates.extend_from_slice(&updates);
360        }
361        incremental.update_positions(&more_updates);
362
363        // Query a rect covering all final positions and compare id sets.
364        let bbox_min = Vec2::new(-10.0, -10.0);
365        let bbox_max = Vec2::new(100.0, 100.0);
366        let a: HashSet<u32> = incremental
367            .query(bbox_min, bbox_max, 1000, ScoreKey::Natural)
368            .into_iter()
369            .collect();
370        let b: HashSet<u32> = rebuilt
371            .query(bbox_min, bbox_max, 1000, ScoreKey::Natural)
372            .into_iter()
373            .collect();
374        assert_eq!(a, b);
375    }
376
377    #[test]
378    fn incremental_update_triggers_rebuild() {
379        let points: Vec<IndexPoint> = (0..100u32)
380            .map(|i| IndexPoint {
381                id: i,
382                position: [i as f32, i as f32],
383                score: i as f32,
384            })
385            .collect();
386        let mut idx = ViewportIndex::new();
387        idx.rebuild(&points);
388
389        // Apply 5001 updates. Threshold is max(1000, 100/20=5) = 1000,
390        // so we must cross it. Sweep all 100 ids, updating each
391        // multiple times to ≈ position (1000+k, 1000+k).
392        let mut updates = Vec::with_capacity(5001);
393        for step in 0..51u32 {
394            for id in 0..100u32 {
395                updates.push((id, Vec2::new(1000.0 + step as f32, 1000.0 + step as f32)));
396                if updates.len() == 5001 {
397                    break;
398                }
399            }
400            if updates.len() == 5001 {
401                break;
402            }
403        }
404        assert_eq!(updates.len(), 5001);
405        idx.update_positions(&updates);
406
407        // Final positions all near (1050, 1050). Query a rect around
408        // that should return (nearly) all ids. The last update for
409        // each id wins.
410        let hits = idx.query(
411            Vec2::new(900.0, 900.0),
412            Vec2::new(1100.0, 1100.0),
413            200,
414            ScoreKey::Asc,
415        );
416        assert_eq!(hits.len(), 100, "expected all 100 ids inside viewport");
417        // Ascending by score: ids 0..100 have scores 0..100.
418        assert_eq!(hits[0], 0);
419        assert_eq!(hits[99], 99);
420
421        // Old positions should no longer match.
422        let old_hits = idx.query(
423            Vec2::new(-1.0, -1.0),
424            Vec2::new(100.0, 100.0),
425            200,
426            ScoreKey::Natural,
427        );
428        assert!(
429            old_hits.is_empty(),
430            "stale positions leaked: {old_hits:?}"
431        );
432    }
433}