Skip to main content

neumann_server/rest/
spatial.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! REST API handlers for spatial operations.
3//!
4//! The spatial index is global (shared across all collections via `QueryRouter`).
5//! The `{name}` path parameter is accepted for URL consistency but the index
6//! is not per-collection.  Future work can add per-collection spatial indexes.
7
8use std::sync::Arc;
9use std::time::Instant;
10
11use axum::extract::{Path, State};
12use axum::http::HeaderMap;
13use axum::Json;
14use serde::{Deserialize, Serialize};
15
16use crate::config::AuthConfig;
17use crate::rate_limit::{Operation, RateLimiter};
18use crate::rest::error::{ApiError, ApiResult};
19use crate::rest::VectorApiContext;
20
21// ---------------------------------------------------------------------------
22// Auth helpers (same pattern as points.rs / collections.rs)
23// ---------------------------------------------------------------------------
24
25fn extract_api_key(headers: &HeaderMap, auth_config: Option<&AuthConfig>) -> Option<String> {
26    let header_name = auth_config.map_or("x-api-key", |c| c.api_key_header.as_str());
27    headers
28        .get(header_name)
29        .and_then(|v| v.to_str().ok())
30        .map(String::from)
31}
32
33fn validate_auth(
34    headers: &HeaderMap,
35    auth_config: Option<&AuthConfig>,
36) -> Result<Option<String>, ApiError> {
37    let api_key = extract_api_key(headers, auth_config);
38
39    match (auth_config, api_key) {
40        (None, _) => Ok(None),
41        (Some(config), None) => {
42            if config.allow_anonymous {
43                Ok(None)
44            } else {
45                Err(ApiError::unauthorized("API key required"))
46            }
47        },
48        (Some(config), Some(key)) => config.validate_key(&key).map_or_else(
49            || Err(ApiError::unauthorized("Invalid API key")),
50            |identity| Ok(Some(identity.to_string())),
51        ),
52    }
53}
54
55fn check_rate_limit(
56    identity: Option<&String>,
57    rate_limiter: Option<&Arc<RateLimiter>>,
58    operation: &str,
59) -> Result<(), ApiError> {
60    if let Some(limiter) = rate_limiter {
61        if let Some(id) = identity {
62            if let Err(msg) = limiter.check_and_record(id, Operation::VectorOp) {
63                tracing::warn!("Rate limited: {id} for {operation}");
64                return Err(ApiError::rate_limited(msg));
65            }
66        }
67    }
68    Ok(())
69}
70
71// ---------------------------------------------------------------------------
72// Request / response types
73// ---------------------------------------------------------------------------
74
75/// Request to insert a spatial entry.
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct SpatialInsertRequest {
78    /// Key identifying the spatial entry.
79    pub key: String,
80    /// X coordinate of the bounding-box origin.
81    pub x: f32,
82    /// Y coordinate of the bounding-box origin.
83    pub y: f32,
84    /// Width of the bounding box (must be non-negative).
85    pub width: f32,
86    /// Height of the bounding box (must be non-negative).
87    pub height: f32,
88}
89
90/// Request to query entries within a radius.
91#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct SpatialQueryRequest {
93    /// X coordinate of the query centre.
94    pub x: f32,
95    /// Y coordinate of the query centre.
96    pub y: f32,
97    /// Search radius (must be non-negative and finite).
98    pub radius: f32,
99    /// Maximum number of results to return.
100    pub limit: Option<usize>,
101}
102
103/// Request to delete a spatial entry.
104#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct SpatialDeleteRequest {
106    /// Key of the entry to remove.
107    pub key: String,
108    /// X coordinate of the bounding-box origin.
109    pub x: f32,
110    /// Y coordinate of the bounding-box origin.
111    pub y: f32,
112    /// Width of the bounding box.
113    pub width: f32,
114    /// Height of the bounding box.
115    pub height: f32,
116}
117
118/// A single spatial result item.
119#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct SpatialResultItem {
121    /// Key of the spatial entry.
122    pub key: String,
123    /// Distance from query centre to closest point on the bounding box.
124    pub distance: f32,
125    /// X coordinate.
126    pub x: f32,
127    /// Y coordinate.
128    pub y: f32,
129    /// Width.
130    pub width: f32,
131    /// Height.
132    pub height: f32,
133}
134
135/// Response for a spatial radius query.
136#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct SpatialQueryResponse {
138    /// Matching entries.
139    pub result: Vec<SpatialResultItem>,
140    /// Query execution time in milliseconds.
141    pub time: f64,
142}
143
144/// Response for the spatial count endpoint.
145#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct SpatialCountResponse {
147    /// Number of entries in the spatial index.
148    pub count: usize,
149}
150
151// ---------------------------------------------------------------------------
152// Handlers
153// ---------------------------------------------------------------------------
154
155/// Insert a spatial entry.
156///
157/// # Errors
158///
159/// Returns an error if authentication fails, the spatial index is not
160/// configured, or the bounding box has invalid dimensions.
161pub async fn insert(
162    State(ctx): State<Arc<VectorApiContext>>,
163    headers: HeaderMap,
164    Path(_name): Path<String>,
165    Json(body): Json<SpatialInsertRequest>,
166) -> ApiResult<serde_json::Value> {
167    let identity = validate_auth(&headers, ctx.auth_config.as_ref())?;
168    check_rate_limit(
169        identity.as_ref(),
170        ctx.rate_limiter.as_ref(),
171        "spatial_insert",
172    )?;
173
174    let spatial = ctx
175        .spatial
176        .as_ref()
177        .ok_or_else(|| ApiError::internal("Spatial index not configured"))?;
178
179    let bounds = tensor_spatial::BoundingBox::new(body.x, body.y, body.width, body.height)
180        .map_err(|e| ApiError::bad_request(e.to_string()))?;
181    let entry = tensor_spatial::SpatialEntry {
182        data: body.key,
183        bounds,
184    };
185    spatial.write().insert(entry);
186
187    Ok(Json(serde_json::json!({"status": "ok"})))
188}
189
190/// Query entries within a radius.
191///
192/// # Errors
193///
194/// Returns an error if authentication fails or the spatial index is not
195/// configured.
196pub async fn query(
197    State(ctx): State<Arc<VectorApiContext>>,
198    headers: HeaderMap,
199    Path(_name): Path<String>,
200    Json(body): Json<SpatialQueryRequest>,
201) -> ApiResult<SpatialQueryResponse> {
202    let identity = validate_auth(&headers, ctx.auth_config.as_ref())?;
203    check_rate_limit(
204        identity.as_ref(),
205        ctx.rate_limiter.as_ref(),
206        "spatial_query",
207    )?;
208
209    let spatial = ctx
210        .spatial
211        .as_ref()
212        .ok_or_else(|| ApiError::internal("Spatial index not configured"))?;
213
214    let start = Instant::now();
215    let guard = spatial.read();
216    let mut results: Vec<SpatialResultItem> = guard
217        .query_within_radius_with_distances(body.x, body.y, body.radius)
218        .into_iter()
219        .map(|(e, dist)| SpatialResultItem {
220            key: e.data.clone(),
221            distance: dist,
222            x: e.bounds.x(),
223            y: e.bounds.y(),
224            width: e.bounds.width(),
225            height: e.bounds.height(),
226        })
227        .collect();
228    drop(guard);
229
230    if let Some(max) = body.limit {
231        results.truncate(max);
232    }
233
234    let elapsed = start.elapsed().as_secs_f64() * 1000.0;
235    Ok(Json(SpatialQueryResponse {
236        result: results,
237        time: elapsed,
238    }))
239}
240
241/// Delete a spatial entry.
242///
243/// # Errors
244///
245/// Returns an error if authentication fails, the spatial index is not
246/// configured, or the entry is not found.
247pub async fn delete(
248    State(ctx): State<Arc<VectorApiContext>>,
249    headers: HeaderMap,
250    Path(_name): Path<String>,
251    Json(body): Json<SpatialDeleteRequest>,
252) -> ApiResult<serde_json::Value> {
253    let identity = validate_auth(&headers, ctx.auth_config.as_ref())?;
254    check_rate_limit(
255        identity.as_ref(),
256        ctx.rate_limiter.as_ref(),
257        "spatial_delete",
258    )?;
259
260    let spatial = ctx
261        .spatial
262        .as_ref()
263        .ok_or_else(|| ApiError::internal("Spatial index not configured"))?;
264
265    let bounds = tensor_spatial::BoundingBox::new(body.x, body.y, body.width, body.height)
266        .map_err(|e| ApiError::bad_request(e.to_string()))?;
267    let key = body.key;
268    spatial
269        .write()
270        .remove(bounds, |e| e.data == key && e.bounds == bounds)
271        .map_err(|e| ApiError::not_found(e.to_string()))?;
272
273    Ok(Json(serde_json::json!({"status": "ok"})))
274}
275
276/// Get the number of entries in the spatial index.
277///
278/// # Errors
279///
280/// Returns an error if authentication fails or the spatial index is not
281/// configured.
282pub async fn count(
283    State(ctx): State<Arc<VectorApiContext>>,
284    headers: HeaderMap,
285    Path(_name): Path<String>,
286) -> ApiResult<SpatialCountResponse> {
287    let identity = validate_auth(&headers, ctx.auth_config.as_ref())?;
288    check_rate_limit(
289        identity.as_ref(),
290        ctx.rate_limiter.as_ref(),
291        "spatial_count",
292    )?;
293
294    let spatial = ctx
295        .spatial
296        .as_ref()
297        .ok_or_else(|| ApiError::internal("Spatial index not configured"))?;
298
299    let count = spatial.read().len();
300    Ok(Json(SpatialCountResponse { count }))
301}
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306
307    // ========== Serde Round-Trip Tests ==========
308
309    #[test]
310    fn test_serde_spatial_insert_request() {
311        let req = SpatialInsertRequest {
312            key: "building".to_string(),
313            x: 10.0,
314            y: 20.0,
315            width: 5.0,
316            height: 3.0,
317        };
318        let json = serde_json::to_string(&req).unwrap();
319        let decoded: SpatialInsertRequest = serde_json::from_str(&json).unwrap();
320        assert_eq!(decoded.key, "building");
321        assert!((decoded.x - 10.0).abs() < f32::EPSILON);
322    }
323
324    #[test]
325    fn test_serde_spatial_query_request() {
326        let req = SpatialQueryRequest {
327            x: 5.0,
328            y: 5.0,
329            radius: 10.0,
330            limit: Some(50),
331        };
332        let json = serde_json::to_string(&req).unwrap();
333        let decoded: SpatialQueryRequest = serde_json::from_str(&json).unwrap();
334        assert!((decoded.radius - 10.0).abs() < f32::EPSILON);
335        assert_eq!(decoded.limit, Some(50));
336    }
337
338    #[test]
339    fn test_serde_spatial_query_request_no_limit() {
340        let json = r#"{"x":1.0,"y":2.0,"radius":3.0}"#;
341        let decoded: SpatialQueryRequest = serde_json::from_str(json).unwrap();
342        assert!(decoded.limit.is_none());
343    }
344
345    #[test]
346    fn test_serde_spatial_delete_request() {
347        let req = SpatialDeleteRequest {
348            key: "park".to_string(),
349            x: 0.0,
350            y: 0.0,
351            width: 10.0,
352            height: 10.0,
353        };
354        let json = serde_json::to_string(&req).unwrap();
355        let decoded: SpatialDeleteRequest = serde_json::from_str(&json).unwrap();
356        assert_eq!(decoded.key, "park");
357    }
358
359    #[test]
360    fn test_serde_spatial_result_item() {
361        let item = SpatialResultItem {
362            key: "a".to_string(),
363            distance: 1.5,
364            x: 2.0,
365            y: 3.0,
366            width: 4.0,
367            height: 5.0,
368        };
369        let json = serde_json::to_string(&item).unwrap();
370        let decoded: SpatialResultItem = serde_json::from_str(&json).unwrap();
371        assert!((decoded.distance - 1.5).abs() < f32::EPSILON);
372    }
373
374    #[test]
375    fn test_serde_spatial_query_response() {
376        let resp = SpatialQueryResponse {
377            result: vec![SpatialResultItem {
378                key: "b".to_string(),
379                distance: 0.5,
380                x: 1.0,
381                y: 1.0,
382                width: 2.0,
383                height: 2.0,
384            }],
385            time: 1.234,
386        };
387        let json = serde_json::to_string(&resp).unwrap();
388        let decoded: SpatialQueryResponse = serde_json::from_str(&json).unwrap();
389        assert_eq!(decoded.result.len(), 1);
390    }
391
392    #[test]
393    fn test_serde_spatial_count_response() {
394        let resp = SpatialCountResponse { count: 42 };
395        let json = serde_json::to_string(&resp).unwrap();
396        let decoded: SpatialCountResponse = serde_json::from_str(&json).unwrap();
397        assert_eq!(decoded.count, 42);
398    }
399
400    // ========== Handler Unit Tests ==========
401
402    #[tokio::test]
403    async fn test_insert_no_spatial_configured() {
404        let engine = Arc::new(vector_engine::VectorEngine::new());
405        let ctx = Arc::new(VectorApiContext::new(engine));
406
407        let body = SpatialInsertRequest {
408            key: "test".to_string(),
409            x: 1.0,
410            y: 2.0,
411            width: 3.0,
412            height: 4.0,
413        };
414
415        let result = insert(
416            State(ctx),
417            HeaderMap::new(),
418            Path("default".to_string()),
419            Json(body),
420        )
421        .await;
422        assert!(result.is_err());
423    }
424
425    #[tokio::test]
426    async fn test_insert_and_count() {
427        let engine = Arc::new(vector_engine::VectorEngine::new());
428        let spatial = Arc::new(parking_lot::RwLock::new(tensor_spatial::SpatialIndex::<
429            String,
430        >::new()));
431        let ctx = Arc::new(VectorApiContext::new(engine).with_spatial(Some(Arc::clone(&spatial))));
432
433        let body = SpatialInsertRequest {
434            key: "obj1".to_string(),
435            x: 10.0,
436            y: 20.0,
437            width: 5.0,
438            height: 3.0,
439        };
440        let result = insert(
441            State(Arc::clone(&ctx)),
442            HeaderMap::new(),
443            Path("col".to_string()),
444            Json(body),
445        )
446        .await;
447        assert!(result.is_ok());
448
449        let count_result = count(
450            State(Arc::clone(&ctx)),
451            HeaderMap::new(),
452            Path("col".to_string()),
453        )
454        .await;
455        assert!(count_result.is_ok());
456        assert_eq!(count_result.unwrap().0.count, 1);
457    }
458
459    #[tokio::test]
460    async fn test_query_within_radius() {
461        let engine = Arc::new(vector_engine::VectorEngine::new());
462        let spatial = Arc::new(parking_lot::RwLock::new(tensor_spatial::SpatialIndex::<
463            String,
464        >::new()));
465        let ctx = Arc::new(VectorApiContext::new(engine).with_spatial(Some(Arc::clone(&spatial))));
466
467        // Insert two entries
468        for (key, x, y) in [("near", 1.0_f32, 1.0_f32), ("far", 100.0, 100.0)] {
469            let body = SpatialInsertRequest {
470                key: key.to_string(),
471                x,
472                y,
473                width: 1.0,
474                height: 1.0,
475            };
476            let _ = insert(
477                State(Arc::clone(&ctx)),
478                HeaderMap::new(),
479                Path("col".to_string()),
480                Json(body),
481            )
482            .await
483            .unwrap();
484        }
485
486        let q = SpatialQueryRequest {
487            x: 0.0,
488            y: 0.0,
489            radius: 10.0,
490            limit: None,
491        };
492        let result = query(
493            State(Arc::clone(&ctx)),
494            HeaderMap::new(),
495            Path("col".to_string()),
496            Json(q),
497        )
498        .await
499        .unwrap();
500        assert_eq!(result.0.result.len(), 1);
501        assert_eq!(result.0.result[0].key, "near");
502    }
503
504    #[tokio::test]
505    async fn test_delete_entry() {
506        let engine = Arc::new(vector_engine::VectorEngine::new());
507        let spatial = Arc::new(parking_lot::RwLock::new(tensor_spatial::SpatialIndex::<
508            String,
509        >::new()));
510        let ctx = Arc::new(VectorApiContext::new(engine).with_spatial(Some(Arc::clone(&spatial))));
511
512        // Insert
513        let body = SpatialInsertRequest {
514            key: "temp".to_string(),
515            x: 5.0,
516            y: 5.0,
517            width: 2.0,
518            height: 2.0,
519        };
520        let _ = insert(
521            State(Arc::clone(&ctx)),
522            HeaderMap::new(),
523            Path("col".to_string()),
524            Json(body),
525        )
526        .await
527        .unwrap();
528        assert_eq!(spatial.read().len(), 1);
529
530        // Delete
531        let del = SpatialDeleteRequest {
532            key: "temp".to_string(),
533            x: 5.0,
534            y: 5.0,
535            width: 2.0,
536            height: 2.0,
537        };
538        let _ = delete(
539            State(Arc::clone(&ctx)),
540            HeaderMap::new(),
541            Path("col".to_string()),
542            Json(del),
543        )
544        .await
545        .unwrap();
546        assert_eq!(spatial.read().len(), 0);
547    }
548
549    #[tokio::test]
550    async fn test_insert_invalid_bounds() {
551        let engine = Arc::new(vector_engine::VectorEngine::new());
552        let spatial = Arc::new(parking_lot::RwLock::new(tensor_spatial::SpatialIndex::<
553            String,
554        >::new()));
555        let ctx = Arc::new(VectorApiContext::new(engine).with_spatial(Some(spatial)));
556
557        let body = SpatialInsertRequest {
558            key: "bad".to_string(),
559            x: 0.0,
560            y: 0.0,
561            width: -1.0,
562            height: 5.0,
563        };
564        let result = insert(
565            State(ctx),
566            HeaderMap::new(),
567            Path("col".to_string()),
568            Json(body),
569        )
570        .await;
571        assert!(result.is_err());
572    }
573
574    #[tokio::test]
575    async fn test_count_no_spatial_configured() {
576        let engine = Arc::new(vector_engine::VectorEngine::new());
577        let ctx = Arc::new(VectorApiContext::new(engine));
578
579        let result = count(State(ctx), HeaderMap::new(), Path("default".to_string())).await;
580        assert!(result.is_err());
581    }
582
583    #[tokio::test]
584    async fn test_query_no_spatial_configured() {
585        let engine = Arc::new(vector_engine::VectorEngine::new());
586        let ctx = Arc::new(VectorApiContext::new(engine));
587
588        let body = SpatialQueryRequest {
589            x: 0.0,
590            y: 0.0,
591            radius: 10.0,
592            limit: None,
593        };
594        let result = query(
595            State(ctx),
596            HeaderMap::new(),
597            Path("default".to_string()),
598            Json(body),
599        )
600        .await;
601        assert!(result.is_err());
602    }
603
604    #[tokio::test]
605    async fn test_delete_no_spatial_configured() {
606        let engine = Arc::new(vector_engine::VectorEngine::new());
607        let ctx = Arc::new(VectorApiContext::new(engine));
608
609        let body = SpatialDeleteRequest {
610            key: "missing".to_string(),
611            x: 0.0,
612            y: 0.0,
613            width: 1.0,
614            height: 1.0,
615        };
616        let result = delete(
617            State(ctx),
618            HeaderMap::new(),
619            Path("default".to_string()),
620            Json(body),
621        )
622        .await;
623        assert!(result.is_err());
624    }
625}