use glam::Vec2;
use rstar::primitives::GeomWithData;
use rstar::{RTree, AABB};
#[cfg(feature = "serde_support")]
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[cfg_attr(feature = "serde_support", derive(Serialize, Deserialize))]
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct IndexPoint {
pub id: u32,
pub position: [f32; 2],
pub score: f32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ScoreKey {
Natural,
Desc,
Asc,
}
type Entry = GeomWithData<[f32; 2], (u32, f32)>;
pub struct ViewportIndex {
tree: RTree<Entry>,
state: HashMap<u32, IndexPoint>,
dirty_count: usize,
}
impl ViewportIndex {
pub fn new() -> Self {
Self {
tree: RTree::new(),
state: HashMap::new(),
dirty_count: 0,
}
}
#[inline]
pub fn len(&self) -> usize {
self.state.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.state.is_empty()
}
pub fn rebuild(&mut self, points: &[IndexPoint]) {
self.state.clear();
self.state.reserve(points.len());
let entries: Vec<Entry> = points
.iter()
.map(|p| {
self.state.insert(p.id, *p);
GeomWithData::new(p.position, (p.id, p.score))
})
.collect();
self.tree = RTree::bulk_load(entries);
self.dirty_count = 0;
}
pub fn update_positions(&mut self, changes: &[(u32, Vec2)]) {
if changes.is_empty() {
return;
}
for &(id, pos) in changes {
let entry = self.state.entry(id).or_insert(IndexPoint {
id,
position: [pos.x, pos.y],
score: 0.0,
});
entry.position = [pos.x, pos.y];
}
self.dirty_count += changes.len();
let threshold = 1000usize.max(self.state.len() / 20);
if self.dirty_count > threshold {
let entries: Vec<Entry> = self
.state
.values()
.map(|p| GeomWithData::new(p.position, (p.id, p.score)))
.collect();
self.tree = RTree::bulk_load(entries);
self.dirty_count = 0;
}
}
pub fn query(&self, min: Vec2, max: Vec2, limit: usize, order: ScoreKey) -> Vec<u32> {
if self.tree.size() == 0 {
return Vec::new();
}
let aabb = AABB::from_corners([min.x, min.y], [max.x, max.y]);
let mut hits: Vec<(u32, f32)> = self
.tree
.locate_in_envelope_intersecting(&aabb)
.map(|e| e.data)
.collect();
match order {
ScoreKey::Natural => {}
ScoreKey::Desc => {
hits.sort_by(|a, b| {
b.1.partial_cmp(&a.1)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.0.cmp(&b.0))
});
}
ScoreKey::Asc => {
hits.sort_by(|a, b| {
a.1.partial_cmp(&b.1)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.0.cmp(&b.0))
});
}
}
hits.truncate(limit);
hits.into_iter().map(|(id, _)| id).collect()
}
pub fn query_with_scores(
&self,
min: Vec2,
max: Vec2,
limit: usize,
order: ScoreKey,
) -> Vec<(u32, f32)> {
if self.tree.size() == 0 {
return Vec::new();
}
let aabb = AABB::from_corners([min.x, min.y], [max.x, max.y]);
let mut hits: Vec<(u32, f32)> = self
.tree
.locate_in_envelope_intersecting(&aabb)
.map(|e| e.data)
.collect();
match order {
ScoreKey::Natural => {}
ScoreKey::Desc => {
hits.sort_by(|a, b| {
b.1.partial_cmp(&a.1)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.0.cmp(&b.0))
});
}
ScoreKey::Asc => {
hits.sort_by(|a, b| {
a.1.partial_cmp(&b.1)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.0.cmp(&b.0))
});
}
}
hits.truncate(limit);
hits
}
pub fn snapshot(&self) -> Vec<IndexPoint> {
self.state.values().copied().collect()
}
}
impl Default for ViewportIndex {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashSet;
fn grid_points() -> Vec<IndexPoint> {
let mut out = Vec::new();
for x in 0..5 {
for y in 0..5 {
out.push(IndexPoint {
id: (x * 10 + y) as u32,
position: [x as f32, y as f32],
score: (x + y) as f32,
});
}
}
out
}
#[test]
fn known_rect_returns_expected_set() {
let mut idx = ViewportIndex::new();
idx.rebuild(&grid_points());
let got: HashSet<u32> = idx
.query(Vec2::new(0.0, 0.0), Vec2::new(2.0, 2.0), 100, ScoreKey::Natural)
.into_iter()
.collect();
let want: HashSet<u32> = (0..3u32)
.flat_map(|x| (0..3u32).map(move |y| x * 10 + y))
.collect();
assert_eq!(got, want);
}
#[test]
fn empty_rect_returns_empty() {
let mut idx = ViewportIndex::new();
idx.rebuild(&grid_points());
let got = idx.query(
Vec2::new(100.0, 100.0),
Vec2::new(200.0, 200.0),
100,
ScoreKey::Natural,
);
assert!(got.is_empty());
}
#[test]
fn limit_and_order_respected() {
let mut idx = ViewportIndex::new();
let points: Vec<IndexPoint> = (0..20u32)
.map(|i| IndexPoint {
id: i,
position: [i as f32 % 10.0, (i / 10) as f32],
score: (i as f32) * 0.5,
})
.collect();
idx.rebuild(&points);
let got = idx.query(
Vec2::new(-1.0, -1.0),
Vec2::new(10.0, 10.0),
5,
ScoreKey::Desc,
);
assert_eq!(got, vec![19, 18, 17, 16, 15]);
}
#[test]
fn incremental_update_matches_rebuild() {
let mut base_points: Vec<IndexPoint> = (0..100u32)
.map(|i| IndexPoint {
id: i,
position: [(i as f32) * 0.5, (i as f32) * 0.3],
score: i as f32,
})
.collect();
let mut incremental = ViewportIndex::new();
incremental.rebuild(&base_points);
let mut updates = Vec::new();
for i in 0..30u32 {
let new_pos = Vec2::new(50.0 + i as f32, 50.0 + i as f32);
updates.push((i, new_pos));
base_points[i as usize].position = [new_pos.x, new_pos.y];
}
incremental.update_positions(&updates);
let mut rebuilt = ViewportIndex::new();
rebuilt.rebuild(&base_points);
let mut more_updates = Vec::new();
for _ in 0..35 {
more_updates.extend_from_slice(&updates);
}
incremental.update_positions(&more_updates);
let bbox_min = Vec2::new(-10.0, -10.0);
let bbox_max = Vec2::new(100.0, 100.0);
let a: HashSet<u32> = incremental
.query(bbox_min, bbox_max, 1000, ScoreKey::Natural)
.into_iter()
.collect();
let b: HashSet<u32> = rebuilt
.query(bbox_min, bbox_max, 1000, ScoreKey::Natural)
.into_iter()
.collect();
assert_eq!(a, b);
}
#[test]
fn incremental_update_triggers_rebuild() {
let points: Vec<IndexPoint> = (0..100u32)
.map(|i| IndexPoint {
id: i,
position: [i as f32, i as f32],
score: i as f32,
})
.collect();
let mut idx = ViewportIndex::new();
idx.rebuild(&points);
let mut updates = Vec::with_capacity(5001);
for step in 0..51u32 {
for id in 0..100u32 {
updates.push((id, Vec2::new(1000.0 + step as f32, 1000.0 + step as f32)));
if updates.len() == 5001 {
break;
}
}
if updates.len() == 5001 {
break;
}
}
assert_eq!(updates.len(), 5001);
idx.update_positions(&updates);
let hits = idx.query(
Vec2::new(900.0, 900.0),
Vec2::new(1100.0, 1100.0),
200,
ScoreKey::Asc,
);
assert_eq!(hits.len(), 100, "expected all 100 ids inside viewport");
assert_eq!(hits[0], 0);
assert_eq!(hits[99], 99);
let old_hits = idx.query(
Vec2::new(-1.0, -1.0),
Vec2::new(100.0, 100.0),
200,
ScoreKey::Natural,
);
assert!(
old_hits.is_empty(),
"stale positions leaked: {old_hits:?}"
);
}
}