use std::collections::{HashMap, HashSet};
use std::io;
use std::time::{Duration, Instant};
use crate::geom::sphere::radec_to_xyz;
use crate::kdtree::{KdForest, KdTree};
use crate::quads::{Code, DIMCODES, DIMQUADS, Quad};
use crate::solver::SkyRegion;
use super::source::{HealpixCell, IndexFragment, IndexSource};
use super::{Index, IndexStar};
struct LoadedCell {
stars: Vec<IndexStar>,
quads: Vec<Quad>,
codes: Vec<Code>,
#[allow(dead_code)]
last_used: Instant,
}
pub struct LiveIndex<S: IndexSource> {
source: S,
loaded: HashMap<HealpixCell, LoadedCell>,
star_forest: KdForest<3>,
code_forest: KdForest<{ DIMCODES }>,
build_generation: u64,
}
#[derive(Debug, Clone)]
pub struct EnsureReport {
pub cells_added: usize,
pub stars_added: usize,
pub elapsed: Duration,
}
#[derive(Debug, Clone)]
pub struct DropReport {
pub cells_dropped: usize,
pub stars_dropped: usize,
pub elapsed: Duration,
}
impl<S: IndexSource> LiveIndex<S> {
pub fn open(source: S) -> Self {
Self {
source,
loaded: HashMap::new(),
star_forest: KdForest::new(),
code_forest: KdForest::new(),
build_generation: 0,
}
}
pub fn source(&self) -> &S {
&self.source
}
pub fn build_generation(&self) -> u64 {
self.build_generation
}
pub fn loaded_cells(&self) -> impl Iterator<Item = &HealpixCell> {
self.loaded.keys()
}
pub fn loaded_star_count(&self) -> usize {
self.loaded.values().map(|c| c.stars.len()).sum()
}
pub fn loaded_quad_count(&self) -> usize {
self.loaded.values().map(|c| c.quads.len()).sum()
}
pub fn loaded_cell_count(&self) -> usize {
self.loaded.len()
}
pub fn star_forest(&self) -> &KdForest<3> {
&self.star_forest
}
pub fn code_forest(&self) -> &KdForest<{ DIMCODES }> {
&self.code_forest
}
pub fn ensure_region(&mut self, region: &SkyRegion) -> io::Result<EnsureReport> {
let start = Instant::now();
let wanted = self.source.cells_intersecting(region);
let to_add: Vec<HealpixCell> = wanted
.into_iter()
.filter(|c| !self.loaded.contains_key(c))
.collect();
if to_add.is_empty() {
return Ok(EnsureReport {
cells_added: 0,
stars_added: 0,
elapsed: start.elapsed(),
});
}
let stars_added = self.add_cells(&to_add)?;
Ok(EnsureReport {
cells_added: to_add.len(),
stars_added,
elapsed: start.elapsed(),
})
}
pub fn drop_outside(&mut self, region: &SkyRegion) -> DropReport {
let start = Instant::now();
let keep: HashSet<HealpixCell> =
self.source.cells_intersecting(region).into_iter().collect();
let to_drop: Vec<HealpixCell> = self
.loaded
.keys()
.filter(|c| !keep.contains(c))
.copied()
.collect();
let mut stars_dropped = 0;
for cell in &to_drop {
stars_dropped += self.remove_cell(cell);
}
DropReport {
cells_dropped: to_drop.len(),
stars_dropped,
elapsed: start.elapsed(),
}
}
pub fn drop_cells(&mut self, cells: &[HealpixCell]) -> DropReport {
let start = Instant::now();
let mut dropped_count = 0;
let mut stars_dropped = 0;
for cell in cells {
let removed = self.remove_cell(cell);
if removed > 0 {
dropped_count += 1;
stars_dropped += removed;
}
}
DropReport {
cells_dropped: dropped_count,
stars_dropped,
elapsed: start.elapsed(),
}
}
pub fn set_region(&mut self, region: &SkyRegion) -> io::Result<EnsureReport> {
let start = Instant::now();
let wanted: HashSet<HealpixCell> =
self.source.cells_intersecting(region).into_iter().collect();
let to_add: Vec<HealpixCell> = wanted
.iter()
.filter(|c| !self.loaded.contains_key(c))
.copied()
.collect();
let to_drop: Vec<HealpixCell> = self
.loaded
.keys()
.filter(|c| !wanted.contains(c))
.copied()
.collect();
let stars_added = if to_add.is_empty() {
0
} else {
self.add_cells(&to_add)?
};
for cell in &to_drop {
self.remove_cell(cell);
}
Ok(EnsureReport {
cells_added: to_add.len(),
stars_added,
elapsed: start.elapsed(),
})
}
fn add_cells(&mut self, cells: &[HealpixCell]) -> io::Result<usize> {
let mut staged: Vec<(HealpixCell, LoadedCell, KdTree<3>, KdTree<{ DIMCODES }>)> =
Vec::with_capacity(cells.len());
let mut total_stars_added = 0;
for &cell in cells {
let frag: IndexFragment = self.source.load_cells(std::slice::from_ref(&cell))?;
let n_stars = frag.stars.len();
if n_stars == 0 && frag.quads.is_empty() {
continue;
}
let star_points: Vec<[f64; 3]> = frag
.stars
.iter()
.map(|s| radec_to_xyz(s.ra, s.dec))
.collect();
let star_indices: Vec<usize> = (0..n_stars).collect();
let star_tree = KdTree::<3>::build(star_points, star_indices);
let code_indices: Vec<usize> = (0..frag.codes.len()).collect();
let code_tree = KdTree::<{ DIMCODES }>::build(frag.codes.clone(), code_indices);
staged.push((
cell,
LoadedCell {
stars: frag.stars,
quads: frag.quads,
codes: frag.codes,
last_used: Instant::now(),
},
star_tree,
code_tree,
));
total_stars_added += n_stars;
}
for (cell, payload, star_tree, code_tree) in staged {
self.star_forest.insert(cell.id, star_tree);
self.code_forest.insert(cell.id, code_tree);
self.loaded.insert(cell, payload);
}
if total_stars_added > 0 {
self.build_generation += 1;
}
Ok(total_stars_added)
}
fn remove_cell(&mut self, cell: &HealpixCell) -> usize {
let stars_removed = match self.loaded.remove(cell) {
Some(c) => c.stars.len(),
None => return 0,
};
self.star_forest.remove(cell.id);
self.code_forest.remove(cell.id);
self.build_generation += 1;
stars_removed
}
pub fn as_index(&self) -> Index {
let mut stars: Vec<IndexStar> = Vec::with_capacity(self.loaded_star_count());
let mut quads: Vec<Quad> = Vec::with_capacity(self.loaded_quad_count());
let mut codes: Vec<Code> = Vec::with_capacity(self.loaded_quad_count());
let mut cells_sorted: Vec<&HealpixCell> = self.loaded.keys().collect();
cells_sorted.sort_by_key(|c| (c.depth, c.id));
for key in cells_sorted {
let cell = &self.loaded[key];
let base = stars.len();
for s in &cell.stars {
stars.push(s.clone());
}
for q in &cell.quads {
let mut new_ids = [0usize; DIMQUADS];
for (i, &sid) in q.star_ids.iter().enumerate() {
new_ids[i] = sid + base;
}
quads.push(Quad { star_ids: new_ids });
}
for c in &cell.codes {
codes.push(*c);
}
}
let star_points: Vec<[f64; 3]> = stars.iter().map(|s| radec_to_xyz(s.ra, s.dec)).collect();
let star_idx: Vec<usize> = (0..stars.len()).collect();
let star_tree = KdTree::<3>::build(star_points, star_idx);
let code_idx: Vec<usize> = (0..codes.len()).collect();
let code_tree = KdTree::<{ DIMCODES }>::build(codes, code_idx);
let (scale_lower, scale_upper) = self.source.scale_range();
Index {
star_tree,
stars,
code_tree,
quads,
scale_lower,
scale_upper,
metadata: self.source.metadata().cloned(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::index::{HealpixCell, IndexFragment, IndexMetadata, IndexSource};
use crate::kdtree::KdQueryable;
use crate::quads::{DIMQUADS, Quad};
use std::sync::Mutex;
struct MockSource {
cells: Vec<MockCell>,
scale_lower: f64,
scale_upper: f64,
load_count: Mutex<usize>,
}
struct MockCell {
cell: HealpixCell,
center_ra: f64,
center_dec: f64,
stars: Vec<IndexStar>,
quads: Vec<Quad>,
codes: Vec<Code>,
}
impl IndexSource for MockSource {
fn cells_intersecting(&self, region: &SkyRegion) -> Vec<HealpixCell> {
self.cells
.iter()
.filter(|c| region.contains(c.center_ra, c.center_dec))
.map(|c| c.cell)
.collect()
}
fn load_cells(&self, cells: &[HealpixCell]) -> io::Result<IndexFragment> {
let mut count = self.load_count.lock().unwrap();
*count += 1;
drop(count);
let mut stars = Vec::new();
let mut quads = Vec::new();
let mut codes = Vec::new();
for cell in cells {
if let Some(mc) = self.cells.iter().find(|c| c.cell == *cell) {
let base = stars.len();
stars.extend(mc.stars.iter().cloned());
for q in &mc.quads {
let mut new_ids = [0usize; DIMQUADS];
for (i, &sid) in q.star_ids.iter().enumerate() {
new_ids[i] = sid + base;
}
quads.push(Quad { star_ids: new_ids });
}
codes.extend(mc.codes.iter().copied());
}
}
Ok(IndexFragment {
stars,
quads,
codes,
scale_lower: self.scale_lower,
scale_upper: self.scale_upper,
metadata: None,
})
}
fn cell_depth(&self) -> u8 {
5
}
fn metadata(&self) -> Option<&IndexMetadata> {
None
}
fn star_count(&self) -> usize {
self.cells.iter().map(|c| c.stars.len()).sum()
}
fn quad_count(&self) -> usize {
self.cells.iter().map(|c| c.quads.len()).sum()
}
fn scale_range(&self) -> (f64, f64) {
(self.scale_lower, self.scale_upper)
}
}
fn make_mock_source() -> MockSource {
let mut cells = Vec::new();
for (cell_id, (ra, dec)) in [
(0u64, (0.5_f64, 0.3_f64)),
(1u64, (1.5, -0.1)),
(2u64, (3.0, 0.5)),
] {
let mut stars = Vec::new();
for i in 0..6 {
let frac = i as f64 / 6.0;
stars.push(IndexStar {
catalog_id: cell_id * 1000 + i as u64,
ra: ra + frac * 0.005,
dec: dec + frac * 0.005,
mag: 5.0 + frac,
});
}
cells.push(MockCell {
cell: HealpixCell {
depth: 5,
id: cell_id,
},
center_ra: ra,
center_dec: dec,
stars,
quads: Vec::new(),
codes: Vec::new(),
});
}
MockSource {
cells,
scale_lower: 0.001,
scale_upper: 0.05,
load_count: Mutex::new(0),
}
}
#[test]
fn open_starts_empty() {
let live = LiveIndex::open(make_mock_source());
assert_eq!(live.loaded_cell_count(), 0);
assert_eq!(live.loaded_star_count(), 0);
assert_eq!(live.build_generation(), 0);
}
#[test]
fn ensure_region_loads_intersecting_cells() {
let mut live = LiveIndex::open(make_mock_source());
let region = SkyRegion::from_radians(starfield::Equatorial::new(1.0, 0.1), 0.6);
let report = live.ensure_region(®ion).unwrap();
assert_eq!(report.cells_added, 2);
assert_eq!(report.stars_added, 12);
assert_eq!(live.loaded_star_count(), 12);
assert_eq!(live.build_generation(), 1);
}
#[test]
fn ensure_region_idempotent() {
let mut live = LiveIndex::open(make_mock_source());
let region = SkyRegion::from_radians(starfield::Equatorial::new(0.5, 0.3), 0.05);
let r1 = live.ensure_region(®ion).unwrap();
let gen_after_first = live.build_generation();
let r2 = live.ensure_region(®ion).unwrap();
assert!(r1.cells_added > 0);
assert_eq!(r2.cells_added, 0);
assert_eq!(r2.stars_added, 0);
assert_eq!(live.build_generation(), gen_after_first);
}
#[test]
fn drop_outside_compacts() {
let mut live = LiveIndex::open(make_mock_source());
let all_sky =
SkyRegion::from_radians(starfield::Equatorial::new(0.0, 0.0), std::f64::consts::PI);
live.ensure_region(&all_sky).unwrap();
assert_eq!(live.loaded_cell_count(), 3);
let tight = SkyRegion::from_radians(starfield::Equatorial::new(0.5, 0.3), 0.05);
let report = live.drop_outside(&tight);
assert!(report.cells_dropped >= 2);
assert!(live.loaded_cell_count() <= 1);
}
#[test]
fn set_region_replaces_membership() {
let mut live = LiveIndex::open(make_mock_source());
let region_a = SkyRegion::from_radians(starfield::Equatorial::new(0.5, 0.3), 0.05);
live.set_region(®ion_a).unwrap();
let cells_before: HashSet<HealpixCell> = live.loaded_cells().copied().collect();
let region_b = SkyRegion::from_radians(starfield::Equatorial::new(3.0, 0.5), 0.05);
live.set_region(®ion_b).unwrap();
let cells_after: HashSet<HealpixCell> = live.loaded_cells().copied().collect();
assert_ne!(cells_before, cells_after);
assert!(!cells_after.is_empty());
}
#[test]
fn build_generation_increments_on_change() {
let mut live = LiveIndex::open(make_mock_source());
let g0 = live.build_generation();
let region = SkyRegion::from_radians(starfield::Equatorial::new(0.5, 0.3), 0.05);
live.ensure_region(®ion).unwrap();
let g1 = live.build_generation();
assert!(g1 > g0);
live.ensure_region(®ion).unwrap();
assert_eq!(g1, live.build_generation());
let nowhere = SkyRegion::from_radians(starfield::Equatorial::new(5.0, 1.5), 0.001);
live.drop_outside(&nowhere);
assert!(live.build_generation() > g1);
}
#[test]
fn star_forest_query_unions_subtrees() {
let mut live = LiveIndex::open(make_mock_source());
let all_sky =
SkyRegion::from_radians(starfield::Equatorial::new(0.0, 0.0), std::f64::consts::PI);
live.ensure_region(&all_sky).unwrap();
let total_via_forest = live.star_forest().len();
assert_eq!(total_via_forest, live.loaded_star_count());
let center = radec_to_xyz(0.5, 0.3);
let hit = live.star_forest().nearest(¢er);
assert!(hit.is_some());
}
struct FailingSource {
inner: MockSource,
fail_on_cell_id: u64,
}
impl IndexSource for FailingSource {
fn cells_intersecting(&self, region: &SkyRegion) -> Vec<HealpixCell> {
self.inner.cells_intersecting(region)
}
fn load_cells(&self, cells: &[HealpixCell]) -> io::Result<IndexFragment> {
for c in cells {
if c.id == self.fail_on_cell_id {
return Err(io::Error::other(format!(
"simulated failure on cell {}",
c.id
)));
}
}
self.inner.load_cells(cells)
}
fn cell_depth(&self) -> u8 {
self.inner.cell_depth()
}
fn metadata(&self) -> Option<&IndexMetadata> {
self.inner.metadata()
}
fn star_count(&self) -> usize {
self.inner.star_count()
}
fn quad_count(&self) -> usize {
self.inner.quad_count()
}
fn scale_range(&self) -> (f64, f64) {
self.inner.scale_range()
}
}
#[test]
fn add_cells_failure_leaves_state_untouched() {
let source = FailingSource {
inner: make_mock_source(),
fail_on_cell_id: 2,
};
let mut live = LiveIndex::open(source);
let all_sky =
SkyRegion::from_radians(starfield::Equatorial::new(0.0, 0.0), std::f64::consts::PI);
let cells = live.source().cells_intersecting(&all_sky);
assert!(
cells.iter().any(|c| c.id == 2),
"test fixture must include the failing cell"
);
let result = live.ensure_region(&all_sky);
assert!(result.is_err(), "load should fail");
assert_eq!(live.loaded_cell_count(), 0);
assert_eq!(live.loaded_star_count(), 0);
assert_eq!(live.build_generation(), 0);
}
#[test]
fn set_region_failure_preserves_prior_state() {
let source = FailingSource {
inner: make_mock_source(),
fail_on_cell_id: 2,
};
let mut live = LiveIndex::open(source);
let region_a = SkyRegion::from_radians(starfield::Equatorial::new(0.5, 0.3), 0.05);
live.set_region(®ion_a).unwrap();
let gen_before = live.build_generation();
let cells_before: HashSet<HealpixCell> = live.loaded_cells().copied().collect();
let region_b = SkyRegion::from_radians(starfield::Equatorial::new(3.0, 0.5), 0.05);
let result = live.set_region(®ion_b);
assert!(result.is_err(), "load on cell 2 should fail");
let cells_after: HashSet<HealpixCell> = live.loaded_cells().copied().collect();
assert_eq!(cells_before, cells_after);
assert_eq!(live.build_generation(), gen_before);
}
#[test]
fn drop_cells_removes_only_the_named_ones() {
let mut live = LiveIndex::open(make_mock_source());
let all_sky =
SkyRegion::from_radians(starfield::Equatorial::new(0.0, 0.0), std::f64::consts::PI);
live.ensure_region(&all_sky).unwrap();
let before = live.loaded_cell_count();
let target = HealpixCell { depth: 5, id: 1 };
let report = live.drop_cells(&[target]);
assert_eq!(report.cells_dropped, 1);
assert_eq!(live.loaded_cell_count(), before - 1);
assert!(live.loaded_cells().all(|c| *c != target));
}
#[test]
fn drop_cells_ignores_unknown_cells() {
let mut live = LiveIndex::open(make_mock_source());
let region = SkyRegion::from_radians(starfield::Equatorial::new(0.5, 0.3), 0.05);
live.ensure_region(®ion).unwrap();
let before = live.loaded_cell_count();
let unknown = HealpixCell { depth: 5, id: 9999 };
let report = live.drop_cells(&[unknown]);
assert_eq!(report.cells_dropped, 0);
assert_eq!(report.stars_dropped, 0);
assert_eq!(live.loaded_cell_count(), before);
}
#[test]
fn kd_forest_insert_replaces_existing_tag() {
let pts_a: Vec<[f64; 2]> = vec![[0.0, 0.0], [1.0, 0.0], [2.0, 0.0]];
let pts_b: Vec<[f64; 2]> = vec![[10.0, 10.0]];
let tree_a = KdTree::<2>::build(pts_a, vec![0, 1, 2]);
let tree_b = KdTree::<2>::build(pts_b, vec![0]);
let mut forest: KdForest<2> = KdForest::new();
forest.insert(42, tree_a);
assert_eq!(forest.len(), 3);
forest.insert(42, tree_b); assert_eq!(forest.len(), 1);
assert_eq!(forest.sub_tree_count(), 1);
let near_origin = forest.range_search(&[0.0, 0.0], 0.5);
assert!(near_origin.is_empty());
let near_b = forest.range_search(&[10.0, 10.0], 0.5);
assert_eq!(near_b.len(), 1);
}
#[test]
fn as_index_flattens_loaded_set() {
let mut live = LiveIndex::open(make_mock_source());
let region =
SkyRegion::from_radians(starfield::Equatorial::new(0.0, 0.0), std::f64::consts::PI);
live.ensure_region(®ion).unwrap();
let idx = live.as_index();
assert_eq!(idx.stars.len(), live.loaded_star_count());
assert_eq!(idx.star_tree.len(), idx.stars.len());
for q in &idx.quads {
for &sid in &q.star_ids {
assert!(sid < idx.stars.len());
}
}
}
}