1pub mod engine;
9pub mod grid;
10pub mod types;
11
12pub use engine::{CollectionSnapshot, CollectionStats, EngineSnapshot, GeoEngine};
13pub use grid::GridIndex;
14pub use types::{haversine_m, valid_coord, GeoError, GeoFeature, GeoHit, EARTH_RADIUS_M};
15
16#[cfg(test)]
17mod tests {
18 use super::*;
19
20 const NYC: (f64, f64) = (40.7128, -74.0060);
22 const CHICAGO: (f64, f64) = (41.8781, -87.6298);
23 const LA: (f64, f64) = (34.0522, -118.2437);
24 const LONDON: (f64, f64) = (51.5074, -0.1278);
25
26 fn seeded() -> GeoEngine {
27 let e = GeoEngine::new();
28 e.create_collection("cities").unwrap();
29 for (id, (lat, lon), country) in [
30 ("nyc", NYC, "us"),
31 ("chicago", CHICAGO, "us"),
32 ("la", LA, "us"),
33 ("london", LONDON, "uk"),
34 ] {
35 e.upsert(
36 "cities",
37 id,
38 lat,
39 lon,
40 serde_json::json!({ "country": country }),
41 )
42 .unwrap();
43 }
44 e
45 }
46
47 #[test]
48 fn haversine_known_distance() {
49 let d = haversine_m(NYC.0, NYC.1, LA.0, LA.1);
51 assert!((d - 3_936_000.0).abs() < 30_000.0, "got {d} m");
52 }
53
54 #[test]
55 fn within_radius_and_bbox() {
56 let e = seeded();
57 let hits = e
59 .within_radius(
60 "cities",
61 NYC.0,
62 NYC.1,
63 2_000_000.0,
64 &serde_json::Value::Null,
65 )
66 .unwrap();
67 let ids: Vec<&str> = hits.iter().map(|h| h.id.as_str()).collect();
68 assert_eq!(ids, vec!["nyc", "chicago"]);
69 assert!(hits[0].distance_m < hits[1].distance_m);
70
71 let bbox = e
73 .within_bbox(
74 "cities",
75 25.0,
76 -125.0,
77 50.0,
78 -65.0,
79 &serde_json::Value::Null,
80 )
81 .unwrap();
82 let mut ids: Vec<&str> = bbox.iter().map(|h| h.id.as_str()).collect();
83 ids.sort();
84 assert_eq!(ids, vec!["chicago", "la", "nyc"]);
85 }
86
87 #[test]
88 fn nearest_matches_bruteforce() {
89 let e = seeded();
90 let q = (39.0, -77.0); let hits = e
92 .nearest("cities", q.0, q.1, 3, &serde_json::Value::Null)
93 .unwrap();
94
95 let mut all = [
97 ("nyc", NYC),
98 ("chicago", CHICAGO),
99 ("la", LA),
100 ("london", LONDON),
101 ]
102 .map(|(id, c)| (id, haversine_m(q.0, q.1, c.0, c.1)));
103 all.sort_by(|a, b| a.1.total_cmp(&b.1));
104 let truth: Vec<&str> = all.iter().take(3).map(|(id, _)| *id).collect();
105
106 let got: Vec<&str> = hits.iter().map(|h| h.id.as_str()).collect();
107 assert_eq!(got, truth);
108 }
109
110 #[test]
111 fn metadata_filter() {
112 let e = seeded();
113 let hits = e
114 .nearest(
115 "cities",
116 NYC.0,
117 NYC.1,
118 5,
119 &serde_json::json!({"country": "uk"}),
120 )
121 .unwrap();
122 let ids: Vec<&str> = hits.iter().map(|h| h.id.as_str()).collect();
123 assert_eq!(ids, vec!["london"]);
124 }
125
126 #[test]
127 fn upsert_move_get_delete() {
128 let e = seeded();
129 assert_eq!(e.collection_stats("cities").unwrap().count, 4);
130 e.upsert(
132 "cities",
133 "nyc",
134 LONDON.0,
135 LONDON.1,
136 serde_json::json!({"country": "moved"}),
137 )
138 .unwrap();
139 assert_eq!(e.collection_stats("cities").unwrap().count, 4);
140 let f = e.get("cities", "nyc").unwrap().unwrap();
141 assert!((f.lat - LONDON.0).abs() < 1e-9);
142
143 assert!(e.delete("cities", "la").unwrap());
144 assert_eq!(e.collection_stats("cities").unwrap().count, 3);
145 assert!(e.get("cities", "la").unwrap().is_none());
146 }
147
148 #[test]
149 fn invalid_coord_and_missing_collection() {
150 let e = GeoEngine::new();
151 e.create_collection("c").unwrap();
152 assert!(matches!(
153 e.upsert("c", "x", 200.0, 0.0, serde_json::Value::Null),
154 Err(GeoError::InvalidCoordinate)
155 ));
156 assert!(matches!(
157 e.nearest("nope", 0.0, 0.0, 1, &serde_json::Value::Null),
158 Err(GeoError::CollectionNotFound(_))
159 ));
160 }
161
162 #[test]
163 fn snapshot_roundtrip() {
164 let e = seeded();
165 let bytes = serde_json::to_vec(&e.snapshot()).unwrap();
166 let restored = GeoEngine::new();
167 restored.load_snapshot(serde_json::from_slice(&bytes).unwrap());
168 assert_eq!(restored.collection_stats("cities").unwrap().count, 4);
169 let hits = restored
170 .nearest("cities", NYC.0, NYC.1, 1, &serde_json::Value::Null)
171 .unwrap();
172 assert_eq!(hits[0].id, "nyc");
173 }
174
175 #[test]
176 fn upsert_auto_creates_collection() {
177 let e = GeoEngine::new();
178 e.upsert("auto", "p", 1.0, 2.0, serde_json::Value::Null)
180 .unwrap();
181 assert_eq!(e.list_collections(), vec!["auto"]);
182 assert_eq!(e.collection_stats("auto").unwrap().count, 1);
183 assert!(matches!(
185 e.upsert("never", "x", 200.0, 0.0, serde_json::Value::Null),
186 Err(GeoError::InvalidCoordinate)
187 ));
188 assert!(!e.collection_exists("never"));
189 }
190
191 fn pacific() -> GeoEngine {
194 let e = GeoEngine::new();
195 e.create_collection("pac").unwrap();
196 e.upsert("pac", "east", 0.0, 179.9, serde_json::Value::Null)
198 .unwrap();
199 e.upsert("pac", "west", 0.0, -179.9, serde_json::Value::Null)
200 .unwrap();
201 e.upsert("pac", "far", 0.0, 100.0, serde_json::Value::Null)
202 .unwrap();
203 e
204 }
205
206 #[test]
207 fn antimeridian_radius_finds_both_sides() {
208 let e = pacific();
209 let hits = e
210 .within_radius("pac", 0.0, 180.0, 50_000.0, &serde_json::Value::Null)
211 .unwrap();
212 let ids: Vec<&str> = hits.iter().map(|h| h.id.as_str()).collect();
213 assert!(ids.contains(&"east"), "missed east of the date line");
214 assert!(ids.contains(&"west"), "missed west of the date line");
215 assert!(!ids.contains(&"far"));
216 }
217
218 #[test]
219 fn antimeridian_nearest_crosses_line() {
220 let e = pacific();
221 let hits = e
223 .nearest("pac", 0.0, -179.95, 2, &serde_json::Value::Null)
224 .unwrap();
225 let ids: Vec<&str> = hits.iter().map(|h| h.id.as_str()).collect();
226 assert!(
227 ids.contains(&"west") && ids.contains(&"east"),
228 "got {ids:?}"
229 );
230 assert_eq!(hits[0].id, "west"); }
232
233 #[test]
234 fn antimeridian_bbox_crossing() {
235 let e = pacific();
236 let hits = e
238 .within_bbox("pac", -10.0, 170.0, 10.0, -170.0, &serde_json::Value::Null)
239 .unwrap();
240 let mut ids: Vec<&str> = hits.iter().map(|h| h.id.as_str()).collect();
241 ids.sort();
242 assert_eq!(ids, vec!["east", "west"]);
243 }
244
245 #[test]
246 fn near_pole_query_finds_far_longitude_points() {
247 let e = GeoEngine::new();
250 e.create_collection("arctic").unwrap();
251 e.upsert("arctic", "a", 89.5, 0.0, serde_json::Value::Null)
253 .unwrap();
254 e.upsert("arctic", "b", 89.5, 90.0, serde_json::Value::Null)
255 .unwrap();
256 e.upsert("arctic", "c", 89.5, 180.0, serde_json::Value::Null)
257 .unwrap();
258 e.upsert("arctic", "d", 89.5, -90.0, serde_json::Value::Null)
259 .unwrap();
260 e.upsert("arctic", "equator", 0.0, 0.0, serde_json::Value::Null)
261 .unwrap();
262 let hits = e
264 .within_radius("arctic", 89.0, -179.0, 400_000.0, &serde_json::Value::Null)
265 .unwrap();
266 let mut ids: Vec<&str> = hits.iter().map(|h| h.id.as_str()).collect();
267 ids.sort();
268 assert_eq!(
269 ids,
270 vec!["a", "b", "c", "d"],
271 "near-pole radius missed far-longitude points"
272 );
273 let n = e
275 .nearest("arctic", 89.0, -179.0, 4, &serde_json::Value::Null)
276 .unwrap();
277 assert!(n.iter().all(|h| h.id != "equator"));
278 assert_eq!(n.len(), 4);
279 }
280
281 #[test]
284 fn nearest_with_sparse_filter_returns_k() {
285 let e = GeoEngine::new();
286 e.create_collection("c").unwrap();
287 for i in 0..60 {
289 e.upsert(
290 "c",
291 format!("r{i}"),
292 0.001 * i as f64,
293 0.0,
294 serde_json::json!({"color":"red"}),
295 )
296 .unwrap();
297 }
298 for (i, lat) in [80.0, 81.0, 82.0].iter().enumerate() {
299 e.upsert(
300 "c",
301 format!("b{i}"),
302 *lat,
303 0.0,
304 serde_json::json!({"color":"blue"}),
305 )
306 .unwrap();
307 }
308 let hits = e
310 .nearest("c", 0.0, 0.0, 3, &serde_json::json!({"color":"blue"}))
311 .unwrap();
312 assert_eq!(hits.len(), 3, "sparse filter under-returned");
313 assert!(hits.iter().all(|h| h.metadata["color"] == "blue"));
314 assert_eq!(hits[0].id, "b0");
316 }
317
318 #[test]
321 fn k_zero_and_k_over_size() {
322 let e = seeded();
323 assert!(e
324 .nearest("cities", NYC.0, NYC.1, 0, &serde_json::Value::Null)
325 .unwrap()
326 .is_empty());
327 let all = e
328 .nearest("cities", NYC.0, NYC.1, 100, &serde_json::Value::Null)
329 .unwrap();
330 assert_eq!(all.len(), 4); }
332
333 #[test]
334 fn radius_zero_matches_only_exact_point() {
335 let e = seeded();
336 let hits = e
337 .within_radius("cities", NYC.0, NYC.1, 0.0, &serde_json::Value::Null)
338 .unwrap();
339 assert_eq!(hits.len(), 1);
340 assert_eq!(hits[0].id, "nyc");
341 assert_eq!(hits[0].distance_m, 0.0);
342 }
343
344 #[test]
345 fn queries_on_empty_collection_are_empty() {
346 let e = GeoEngine::new();
347 e.create_collection("empty").unwrap();
348 let z = &serde_json::Value::Null;
349 assert!(e.nearest("empty", 0.0, 0.0, 5, z).unwrap().is_empty());
350 assert!(e
351 .within_radius("empty", 0.0, 0.0, 1e6, z)
352 .unwrap()
353 .is_empty());
354 assert!(e
355 .within_bbox("empty", -1.0, -1.0, 1.0, 1.0, z)
356 .unwrap()
357 .is_empty());
358 }
359
360 #[test]
361 fn coordinate_boundaries_are_validated() {
362 let e = GeoEngine::new();
363 e.create_collection("c").unwrap();
364 let z = serde_json::Value::Null;
365 assert!(e.upsert("c", "np", 90.0, 180.0, z.clone()).is_ok());
367 assert!(e.upsert("c", "sp", -90.0, -180.0, z.clone()).is_ok());
368 for (lat, lon) in [(90.1, 0.0), (-90.1, 0.0), (0.0, 180.1), (0.0, -180.1)] {
370 assert!(matches!(
371 e.upsert("c", "bad", lat, lon, z.clone()),
372 Err(GeoError::InvalidCoordinate)
373 ));
374 }
375 assert!(matches!(
377 e.upsert("c", "nan", f64::NAN, 0.0, z.clone()),
378 Err(GeoError::InvalidCoordinate)
379 ));
380 }
381
382 #[test]
383 fn collection_lifecycle() {
384 let e = GeoEngine::new();
385 e.create_collection("a").unwrap();
386 assert!(matches!(
387 e.create_collection("a"),
388 Err(GeoError::CollectionExists(_))
389 ));
390 e.create_collection("b").unwrap();
391 assert_eq!(e.list_collections(), vec!["a", "b"]);
392 assert!(e.collection_exists("a"));
393 e.drop_collection("a").unwrap();
394 assert!(!e.collection_exists("a"));
395 assert!(matches!(
396 e.drop_collection("a"),
397 Err(GeoError::CollectionNotFound(_))
398 ));
399 let z = &serde_json::Value::Null;
401 assert!(matches!(
402 e.within_radius("a", 0.0, 0.0, 1.0, z),
403 Err(GeoError::CollectionNotFound(_))
404 ));
405 assert!(matches!(
406 e.within_bbox("a", 0.0, 0.0, 1.0, 1.0, z),
407 Err(GeoError::CollectionNotFound(_))
408 ));
409 assert!(matches!(
410 e.get("a", "x"),
411 Err(GeoError::CollectionNotFound(_))
412 ));
413 assert!(matches!(
414 e.delete("a", "x"),
415 Err(GeoError::CollectionNotFound(_))
416 ));
417 }
418
419 #[test]
420 fn delete_removes_from_spatial_results() {
421 let e = seeded();
422 assert!(e.delete("cities", "chicago").unwrap());
423 let hits = e
424 .within_radius(
425 "cities",
426 NYC.0,
427 NYC.1,
428 2_000_000.0,
429 &serde_json::Value::Null,
430 )
431 .unwrap();
432 let ids: Vec<&str> = hits.iter().map(|h| h.id.as_str()).collect();
433 assert_eq!(ids, vec!["nyc"]); assert!(!e.delete("cities", "chicago").unwrap()); }
436
437 #[test]
438 fn haversine_symmetry_and_zero() {
439 assert_eq!(haversine_m(NYC.0, NYC.1, NYC.0, NYC.1), 0.0);
440 let ab = haversine_m(NYC.0, NYC.1, LONDON.0, LONDON.1);
441 let ba = haversine_m(LONDON.0, LONDON.1, NYC.0, NYC.1);
442 assert!((ab - ba).abs() < 1e-6);
443 assert!((ab - 5_570_000.0).abs() < 60_000.0, "got {ab}");
445 }
446
447 #[test]
450 fn nearest_matches_bruteforce_randomized() {
451 let e = GeoEngine::new();
452 e.create_collection("pts").unwrap();
453 let mut s: u64 = 0x1234_5678;
455 let mut next = || {
456 s = s
457 .wrapping_mul(6364136223846793005)
458 .wrapping_add(1442695040888963407);
459 (s >> 33) as f64 / (1u64 << 31) as f64 };
461 let mut pts: Vec<(f64, f64)> = (0..400)
462 .map(|_| (next() * 180.0 - 90.0, next() * 360.0 - 180.0))
463 .collect();
464 pts.extend([
466 (90.0, 0.0),
467 (-90.0, 0.0),
468 (89.7, 180.0),
469 (89.7, -180.0),
470 (88.0, 30.0),
471 (0.0, 180.0),
472 (0.0, -180.0),
473 ]);
474 for (i, (lat, lon)) in pts.iter().enumerate() {
475 e.upsert("pts", format!("p{i}"), *lat, *lon, serde_json::Value::Null)
476 .unwrap();
477 }
478 let z = &serde_json::Value::Null;
479 for q in [
481 (0.0, 179.99),
482 (89.0, -179.0),
483 (-45.0, 0.0),
484 (10.0, -179.5),
485 (89.9, 17.0),
486 (-89.5, -120.0),
487 (0.0, 180.0),
488 ] {
489 let got: Vec<String> = e
490 .nearest("pts", q.0, q.1, 5, z)
491 .unwrap()
492 .iter()
493 .map(|h| h.id.clone())
494 .collect();
495 let mut all: Vec<(String, f64)> = pts
497 .iter()
498 .enumerate()
499 .map(|(i, (la, lo))| (format!("p{i}"), haversine_m(q.0, q.1, *la, *lo)))
500 .collect();
501 all.sort_by(|a, b| a.1.total_cmp(&b.1).then_with(|| a.0.cmp(&b.0)));
502 let truth: Vec<String> = all.iter().take(5).map(|(id, _)| id.clone()).collect();
503 assert_eq!(got, truth, "nearest mismatch near {q:?}");
504 }
505 }
506}