1use std::ops::{Add, Div, Mul, Sub};
2
3#[cfg(not(target_arch = "wasm32"))]
4use std::time::Instant;
5
6#[cfg(target_arch = "wasm32")]
8#[derive(Debug, Clone)]
9pub struct Instant;
10
11#[cfg(target_arch = "wasm32")]
12impl Instant {
13 #[inline]
14 pub fn now() -> Self {
15 Instant
16 }
17 #[inline]
18 pub fn elapsed(&self) -> std::time::Duration {
19 std::time::Duration::ZERO
20 }
21}
22
23pub trait CoordType:
25 Copy
26 + PartialOrd
27 + Send
28 + Sync
29 + 'static
30 + Add<Output = Self>
31 + Sub<Output = Self>
32 + Mul<Output = Self>
33 + Div<Output = Self>
34 + Into<f64>
35 + From<f32>
36{
37 fn zero() -> Self;
38 fn infinity() -> Self;
39 fn abs(self) -> Self;
40 fn sqrt(self) -> Self;
41}
42
43impl CoordType for f32 {
44 #[inline]
45 fn zero() -> Self {
46 0.0_f32
47 }
48 #[inline]
49 fn infinity() -> Self {
50 f32::INFINITY
51 }
52 #[inline]
53 fn abs(self) -> Self {
54 f32::abs(self)
55 }
56 #[inline]
57 fn sqrt(self) -> Self {
58 f32::sqrt(self)
59 }
60}
61
62impl CoordType for f64 {
63 #[inline]
64 fn zero() -> Self {
65 0.0_f64
66 }
67 #[inline]
68 fn infinity() -> Self {
69 f64::INFINITY
70 }
71 #[inline]
72 fn abs(self) -> Self {
73 f64::abs(self)
74 }
75 #[inline]
76 fn sqrt(self) -> Self {
77 f64::sqrt(self)
78 }
79}
80
81#[derive(Debug, Clone, Copy, PartialEq)]
83pub struct Point<C = f64, const D: usize = 2>([C; D]);
84
85impl<C: CoordType, const D: usize> Point<C, D> {
86 #[inline]
88 pub fn new(coords: [C; D]) -> Self {
89 Self(coords)
90 }
91
92 #[inline]
94 pub fn coords(&self) -> &[C; D] {
95 &self.0
96 }
97
98 #[inline]
100 pub fn coords_mut(&mut self) -> &mut [C; D] {
101 &mut self.0
102 }
103}
104
105#[derive(Debug, Clone, Copy, PartialEq)]
107pub struct BBox<C = f64, const D: usize = 2> {
108 pub min: Point<C, D>,
109 pub max: Point<C, D>,
110}
111
112impl<C: CoordType, const D: usize> BBox<C, D> {
113 #[inline]
115 pub fn new(min: Point<C, D>, max: Point<C, D>) -> Self {
116 Self { min, max }
117 }
118
119 pub fn contains_point(&self, point: &Point<C, D>) -> bool {
121 for d in 0..D {
122 let coord = point.coords()[d];
123 if coord < self.min.coords()[d] || coord > self.max.coords()[d] {
124 return false;
125 }
126 }
127 true
128 }
129
130 pub fn intersects(&self, other: &BBox<C, D>) -> bool {
132 for d in 0..D {
133 if self.max.coords()[d] < other.min.coords()[d]
134 || other.max.coords()[d] < self.min.coords()[d]
135 {
136 return false;
137 }
138 }
139 true
140 }
141}
142
143#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
145pub struct EntryId(pub u64);
146
147#[derive(Debug, Clone, Copy, PartialEq, Eq)]
149pub enum BackendKind {
150 RTree,
151 KDTree,
152 Quadtree,
153 Grid,
154}
155
156#[derive(Debug, Clone, Copy, PartialEq)]
158pub struct QueryMix {
159 pub range_frac: f64,
160 pub knn_frac: f64,
161 pub join_frac: f64,
162 pub mean_selectivity: f64,
163}
164
165impl Default for QueryMix {
166 fn default() -> Self {
167 Self {
168 range_frac: 1.0,
169 knn_frac: 0.0,
170 join_frac: 0.0,
171 mean_selectivity: 0.01,
172 }
173 }
174}
175
176#[derive(Debug, Clone)]
178pub struct DataShape<const D: usize> {
179 pub point_count: usize,
180 pub bbox: BBox<f64, D>,
181 pub skewness: [f64; D],
183 pub clustering_coef: f64,
185 pub overlap_ratio: f64,
187 pub effective_dim: f64,
189 pub query_mix: QueryMix,
190}
191
192#[derive(Debug, Clone)]
194pub struct Stats<const D: usize> {
195 pub backend: BackendKind,
196 pub point_count: usize,
197 pub migrations: u64,
198 pub last_migration_at: Option<Instant>,
199 pub query_count: u64,
200 pub data_shape: DataShape<D>,
201 pub migrating: bool,
202 pub dimensions: usize,
203}
204
205#[derive(Debug)]
207pub enum BonsaiError {
208 NotFound(EntryId),
209 Frozen,
210 MigrationInProgress,
211 Serialisation(String),
212 Config(String),
213 DimensionMismatch { expected: usize, got: usize },
214}
215
216impl std::fmt::Display for BonsaiError {
217 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
218 match self {
219 BonsaiError::NotFound(id) => write!(f, "entry {:?} not found", id),
220 BonsaiError::Frozen => {
221 write!(
222 f,
223 "index is frozen — call unfreeze() to re-enable adaptation"
224 )
225 }
226 BonsaiError::MigrationInProgress => write!(f, "migration already in progress"),
227 BonsaiError::Serialisation(msg) => write!(f, "serialisation error: {}", msg),
228 BonsaiError::Config(msg) => write!(f, "invalid configuration: {}", msg),
229 BonsaiError::DimensionMismatch { expected, got } => {
230 write!(f, "dimension mismatch: expected {}, got {}", expected, got)
231 }
232 }
233 }
234}
235
236impl std::error::Error for BonsaiError {}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241 use proptest::prelude::*;
242
243 proptest! {
245 #[test]
246 fn prop_point_coord_len_d2(coords in prop::array::uniform2(-1.0e9_f64..1.0e9)) {
247 let p = Point::<f64, 2>::new(coords);
248 prop_assert_eq!(p.coords().len(), 2);
249 }
250 }
251
252 proptest! {
253 #[test]
254 fn prop_point_coord_len_d3(
255 a in -1.0e9_f64..1.0e9,
256 b in -1.0e9_f64..1.0e9,
257 c in -1.0e9_f64..1.0e9,
258 ) {
259 let p = Point::<f64, 3>::new([a, b, c]);
260 prop_assert_eq!(p.coords().len(), 3);
261 }
262 }
263
264 proptest! {
266 #[test]
267 fn prop_datashape_skewness_len_d2(s0 in -10.0_f64..10.0, s1 in -10.0_f64..10.0) {
268 let shape = DataShape::<2> {
269 point_count: 0,
270 bbox: BBox::new(Point::new([0.0, 0.0]), Point::new([1.0, 1.0])),
271 skewness: [s0, s1],
272 clustering_coef: 1.0,
273 overlap_ratio: 0.0,
274 effective_dim: 2.0,
275 query_mix: QueryMix::default(),
276 };
277 prop_assert_eq!(shape.skewness.len(), 2);
278 }
279 }
280
281 proptest! {
282 #[test]
283 fn prop_datashape_skewness_len_d3(
284 s0 in -10.0_f64..10.0,
285 s1 in -10.0_f64..10.0,
286 s2 in -10.0_f64..10.0,
287 ) {
288 let shape = DataShape::<3> {
289 point_count: 0,
290 bbox: BBox::new(Point::new([0.0, 0.0, 0.0]), Point::new([1.0, 1.0, 1.0])),
291 skewness: [s0, s1, s2],
292 clustering_coef: 1.0,
293 overlap_ratio: 0.0,
294 effective_dim: 3.0,
295 query_mix: QueryMix::default(),
296 };
297 prop_assert_eq!(shape.skewness.len(), 3);
298 }
299 }
300
301 proptest! {
303 #[test]
304 fn prop_bbox_contains_point_d2(
305 min0 in -1.0e6_f64..0.0, min1 in -1.0e6_f64..0.0,
306 max0 in 0.0_f64..1.0e6, max1 in 0.0_f64..1.0e6,
307 px in -1.5e6_f64..1.5e6, py in -1.5e6_f64..1.5e6,
308 ) {
309 let bbox = BBox::<f64, 2>::new(Point::new([min0, min1]), Point::new([max0, max1]));
310 let point = Point::<f64, 2>::new([px, py]);
311 let expected = px >= min0 && px <= max0 && py >= min1 && py <= max1;
312 prop_assert_eq!(bbox.contains_point(&point), expected);
313 }
314 }
315
316 proptest! {
317 #[test]
318 fn prop_bbox_contains_point_d3(
319 min0 in -1.0e6_f64..0.0, min1 in -1.0e6_f64..0.0, min2 in -1.0e6_f64..0.0,
320 max0 in 0.0_f64..1.0e6, max1 in 0.0_f64..1.0e6, max2 in 0.0_f64..1.0e6,
321 px in -1.5e6_f64..1.5e6, py in -1.5e6_f64..1.5e6, pz in -1.5e6_f64..1.5e6,
322 ) {
323 let bbox = BBox::<f64, 3>::new(
324 Point::new([min0, min1, min2]),
325 Point::new([max0, max1, max2]),
326 );
327 let point = Point::<f64, 3>::new([px, py, pz]);
328 let expected = px >= min0 && px <= max0
329 && py >= min1 && py <= max1
330 && pz >= min2 && pz <= max2;
331 prop_assert_eq!(bbox.contains_point(&point), expected);
332 }
333 }
334
335 #[test]
336 fn bbox_contains_point_basic() {
337 let bbox = BBox::<f64, 2>::new(Point::new([0.0, 0.0]), Point::new([1.0, 1.0]));
338 assert!(bbox.contains_point(&Point::new([0.5, 0.5])));
339 assert!(bbox.contains_point(&Point::new([0.0, 0.0]))); assert!(bbox.contains_point(&Point::new([1.0, 1.0]))); assert!(!bbox.contains_point(&Point::new([1.1, 0.5])));
342 assert!(!bbox.contains_point(&Point::new([-0.1, 0.5])));
343 }
344
345 #[test]
346 fn bbox_intersects_basic() {
347 let a = BBox::<f64, 2>::new(Point::new([0.0, 0.0]), Point::new([2.0, 2.0]));
348 let b = BBox::<f64, 2>::new(Point::new([1.0, 1.0]), Point::new([3.0, 3.0]));
349 let c = BBox::<f64, 2>::new(Point::new([3.0, 3.0]), Point::new([4.0, 4.0]));
350 assert!(a.intersects(&b));
351 assert!(b.intersects(&a));
352 assert!(!a.intersects(&c));
353 }
354
355 #[test]
356 fn entry_id_hash_eq() {
357 use std::collections::HashSet;
358 let mut set = HashSet::new();
359 set.insert(EntryId(1));
360 set.insert(EntryId(2));
361 set.insert(EntryId(1));
362 assert_eq!(set.len(), 2);
363 }
364
365 #[test]
366 fn bonsai_error_display() {
367 let e = BonsaiError::NotFound(EntryId(42));
368 assert!(e.to_string().contains("42"));
369 let e2 = BonsaiError::DimensionMismatch {
370 expected: 3,
371 got: 2,
372 };
373 assert!(e2.to_string().contains('3'));
374 }
375}