use alloc::format;
use alloc::vec::Vec;
use core::sync::atomic::{AtomicU32, Ordering};
use crate::domain::point::ensure_finite;
use crate::error::{RcfError, RcfResult};
use crate::tree::PointAccessor;
#[derive(Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct PointStore<const D: usize> {
#[cfg_attr(feature = "serde", serde(with = "point_slots_serde"))]
points: Vec<Option<[f64; D]>>,
#[cfg_attr(feature = "serde", serde(with = "atomic_u32_vec_serde"))]
ref_counts: Vec<AtomicU32>,
free_list: Vec<usize>,
}
#[cfg(feature = "serde")]
mod point_slots_serde {
use alloc::format;
use alloc::vec::Vec;
use serde::{Deserialize, Deserializer, Serialize, Serializer, de::Error as _};
pub fn serialize<S, const D: usize>(
slots: &[Option<[f64; D]>],
serializer: S,
) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let snapshot: Vec<Option<Vec<f64>>> = slots
.iter()
.map(|slot| slot.as_ref().map(|arr| arr.to_vec()))
.collect();
snapshot.serialize(serializer)
}
pub fn deserialize<'de, D2, const D: usize>(
deserializer: D2,
) -> Result<Vec<Option<[f64; D]>>, D2::Error>
where
D2: Deserializer<'de>,
{
let raw: Vec<Option<Vec<f64>>> = Vec::deserialize(deserializer)?;
raw.into_iter()
.map(|slot| match slot {
None => Ok(None),
Some(v) => {
let arr: [f64; D] = v.try_into().map_err(|_v: Vec<f64>| {
D2::Error::custom(format!("PointStore slot length mismatch: expected {D}"))
})?;
Ok(Some(arr))
}
})
.collect()
}
}
#[cfg(feature = "serde")]
mod atomic_u32_vec_serde {
use alloc::vec::Vec;
use core::sync::atomic::{AtomicU32, Ordering};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
pub fn serialize<S: Serializer>(v: &[AtomicU32], serializer: S) -> Result<S::Ok, S::Error> {
let snapshot: Vec<u32> = v.iter().map(|a| a.load(Ordering::Relaxed)).collect();
snapshot.serialize(serializer)
}
pub fn deserialize<'de, D: Deserializer<'de>>(
deserializer: D,
) -> Result<Vec<AtomicU32>, D::Error> {
let raw = Vec::<u32>::deserialize(deserializer)?;
Ok(raw.into_iter().map(AtomicU32::new).collect())
}
}
impl<const D: usize> PointStore<D> {
pub fn new() -> RcfResult<Self> {
if D == 0 {
return Err(RcfError::InvalidConfig(
"PointStore dimension must be > 0".into(),
));
}
Ok(Self {
points: Vec::new(),
ref_counts: Vec::new(),
free_list: Vec::new(),
})
}
#[must_use]
#[inline]
pub const fn dimension(&self) -> usize {
D
}
#[must_use]
pub fn capacity(&self) -> usize {
self.points.len()
}
#[must_use]
pub fn live_count(&self) -> usize {
self.points.iter().filter(|slot| slot.is_some()).count()
}
pub fn add(&mut self, point: [f64; D]) -> RcfResult<usize> {
ensure_finite(&point)?;
if let Some(idx) = self.free_list.pop() {
self.points[idx] = Some(point);
self.ref_counts[idx].store(0, Ordering::Relaxed);
return Ok(idx);
}
let idx = self.points.len();
self.points.push(Some(point));
self.ref_counts.push(AtomicU32::new(0));
Ok(idx)
}
pub fn drop_unreferenced(&mut self, idx: usize) -> RcfResult<()> {
if idx >= self.points.len() || self.points[idx].is_none() {
return Err(RcfError::OutOfBounds {
index: idx,
len: self.points.len(),
});
}
let rc = self.ref_counts[idx].load(Ordering::Acquire);
if rc != 0 {
return Err(RcfError::InvalidConfig(
format!("PointStore::drop_unreferenced: slot {idx} still has refcount {rc}").into(),
));
}
self.points[idx] = None;
self.free_list.push(idx);
Ok(())
}
pub fn incr_ref(&self, idx: usize) -> RcfResult<()> {
if idx >= self.points.len() || self.points[idx].is_none() {
return Err(RcfError::OutOfBounds {
index: idx,
len: self.points.len(),
});
}
self.ref_counts[idx].fetch_add(1, Ordering::AcqRel);
Ok(())
}
pub fn decr_ref(&self, idx: usize) -> RcfResult<bool> {
if idx >= self.points.len() || self.points[idx].is_none() {
return Err(RcfError::OutOfBounds {
index: idx,
len: self.points.len(),
});
}
loop {
let current = self.ref_counts[idx].load(Ordering::Acquire);
if current == 0 {
return Err(RcfError::InvalidConfig(
format!("PointStore::decr_ref: slot {idx} already at zero refcount").into(),
));
}
let next = current - 1;
if self.ref_counts[idx]
.compare_exchange_weak(current, next, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
return Ok(next == 0);
}
}
}
pub fn set_free(&mut self, idx: usize) -> RcfResult<()> {
if idx >= self.points.len() || self.points[idx].is_none() {
return Err(RcfError::OutOfBounds {
index: idx,
len: self.points.len(),
});
}
let rc = self.ref_counts[idx].load(Ordering::Acquire);
if rc != 0 {
return Err(RcfError::InvalidConfig(
format!("PointStore::set_free: slot {idx} has refcount {rc}, not zero").into(),
));
}
self.points[idx] = None;
self.free_list.push(idx);
Ok(())
}
#[must_use]
pub fn ref_count(&self, idx: usize) -> u32 {
if idx >= self.ref_counts.len() {
return 0;
}
self.ref_counts[idx].load(Ordering::Acquire)
}
#[must_use]
pub fn memory_estimate(&self) -> usize {
self.points.len() * (D * core::mem::size_of::<f64>())
+ self.ref_counts.len() * core::mem::size_of::<u32>()
+ self.free_list.len() * core::mem::size_of::<usize>()
}
}
impl<const D: usize> PointAccessor<D> for PointStore<D> {
fn point(&self, idx: usize) -> Option<&[f64; D]> {
self.points.get(idx).and_then(|slot| slot.as_ref())
}
}
#[cfg(test)]
#[allow(clippy::float_cmp)] mod tests {
use super::*;
#[test]
fn new_rejects_zero_dimension() {
assert!(matches!(
PointStore::<0>::new().unwrap_err(),
RcfError::InvalidConfig(_)
));
}
#[test]
fn add_validates_finite() {
let mut s = PointStore::<2>::new().unwrap();
assert!(matches!(
s.add([1.0, f64::NAN]).unwrap_err(),
RcfError::NaNValue
));
}
#[test]
fn add_returns_increasing_indices_initially() {
let mut s = PointStore::<2>::new().unwrap();
assert_eq!(s.add([0.0, 0.0]).unwrap(), 0);
assert_eq!(s.add([1.0, 1.0]).unwrap(), 1);
assert_eq!(s.capacity(), 2);
}
#[test]
fn point_returns_inserted_value() {
let mut s = PointStore::<2>::new().unwrap();
let idx = s.add([1.5, 2.5]).unwrap();
assert_eq!(s.point(idx), Some(&[1.5, 2.5]));
}
#[test]
fn point_returns_none_for_free_or_oob() {
let mut s = PointStore::<2>::new().unwrap();
assert_eq!(s.point(99), None);
let idx = s.add([0.0, 0.0]).unwrap();
s.incr_ref(idx).unwrap();
let hit_zero = s.decr_ref(idx).unwrap();
assert!(hit_zero);
s.set_free(idx).unwrap();
assert_eq!(s.point(idx), None);
}
#[test]
fn ref_count_starts_at_zero() {
let mut s = PointStore::<2>::new().unwrap();
let idx = s.add([0.0, 0.0]).unwrap();
assert_eq!(s.ref_count(idx), 0);
}
#[test]
fn incr_decr_cycle_frees_slot() {
let mut s = PointStore::<2>::new().unwrap();
let a = s.add([1.0, 1.0]).unwrap();
s.incr_ref(a).unwrap();
s.incr_ref(a).unwrap();
assert_eq!(s.ref_count(a), 2);
let hit_zero = s.decr_ref(a).unwrap();
assert!(!hit_zero);
assert_eq!(s.ref_count(a), 1);
let hit_zero = s.decr_ref(a).unwrap();
assert!(hit_zero);
assert_eq!(s.ref_count(a), 0);
s.set_free(a).unwrap();
let b = s.add([2.0, 2.0]).unwrap();
assert_eq!(b, a);
}
#[test]
fn incr_oob_or_free_returns_err() {
let mut s = PointStore::<2>::new().unwrap();
assert!(matches!(
s.incr_ref(99).unwrap_err(),
RcfError::OutOfBounds { .. }
));
let idx = s.add([0.0, 0.0]).unwrap();
s.incr_ref(idx).unwrap();
assert!(s.decr_ref(idx).unwrap());
s.set_free(idx).unwrap();
assert!(matches!(
s.incr_ref(idx).unwrap_err(),
RcfError::OutOfBounds { .. }
));
}
#[test]
fn decr_at_zero_is_invalid_config() {
let mut s = PointStore::<2>::new().unwrap();
let idx = s.add([0.0, 0.0]).unwrap();
assert!(matches!(
s.decr_ref(idx).unwrap_err(),
RcfError::InvalidConfig(_)
));
}
#[test]
fn drop_unreferenced_frees_zero_refcount_slot() {
let mut s = PointStore::<2>::new().unwrap();
let idx = s.add([0.0, 0.0]).unwrap();
s.drop_unreferenced(idx).unwrap();
assert_eq!(s.point(idx), None);
assert_eq!(s.live_count(), 0);
let new_idx = s.add([1.0, 1.0]).unwrap();
assert_eq!(new_idx, idx);
}
#[test]
fn drop_unreferenced_rejects_live_slot() {
let mut s = PointStore::<2>::new().unwrap();
let idx = s.add([0.0, 0.0]).unwrap();
s.incr_ref(idx).unwrap();
assert!(matches!(
s.drop_unreferenced(idx).unwrap_err(),
RcfError::InvalidConfig(_)
));
}
#[test]
fn live_count_tracks_active_slots() {
let mut s = PointStore::<2>::new().unwrap();
let a = s.add([0.0, 0.0]).unwrap();
let b = s.add([1.0, 1.0]).unwrap();
assert_eq!(s.live_count(), 2);
s.incr_ref(a).unwrap();
assert!(s.decr_ref(a).unwrap());
s.set_free(a).unwrap();
assert_eq!(s.live_count(), 1);
s.incr_ref(b).unwrap();
assert!(s.decr_ref(b).unwrap());
s.set_free(b).unwrap();
assert_eq!(s.live_count(), 0);
}
#[test]
fn point_accessor_impl_works() {
let mut s = PointStore::<3>::new().unwrap();
let idx = s.add([7.0, 8.0, 9.0]).unwrap();
let acc: &dyn PointAccessor<3> = &s;
assert_eq!(acc.point(idx), Some(&[7.0, 8.0, 9.0]));
}
#[test]
fn memory_estimate_grows_with_capacity() {
let mut s = PointStore::<4>::new().unwrap();
let before = s.memory_estimate();
s.add([1.0; 4]).unwrap();
let after = s.memory_estimate();
assert!(after > before);
}
}