1use std::sync::Arc;
9use std::time::Instant;
10
11use axum::extract::{Path, State};
12use axum::http::HeaderMap;
13use axum::Json;
14use serde::{Deserialize, Serialize};
15
16use crate::config::AuthConfig;
17use crate::rate_limit::{Operation, RateLimiter};
18use crate::rest::error::{ApiError, ApiResult};
19use crate::rest::VectorApiContext;
20
21fn extract_api_key(headers: &HeaderMap, auth_config: Option<&AuthConfig>) -> Option<String> {
26 let header_name = auth_config.map_or("x-api-key", |c| c.api_key_header.as_str());
27 headers
28 .get(header_name)
29 .and_then(|v| v.to_str().ok())
30 .map(String::from)
31}
32
33fn validate_auth(
34 headers: &HeaderMap,
35 auth_config: Option<&AuthConfig>,
36) -> Result<Option<String>, ApiError> {
37 let api_key = extract_api_key(headers, auth_config);
38
39 match (auth_config, api_key) {
40 (None, _) => Ok(None),
41 (Some(config), None) => {
42 if config.allow_anonymous {
43 Ok(None)
44 } else {
45 Err(ApiError::unauthorized("API key required"))
46 }
47 },
48 (Some(config), Some(key)) => config.validate_key(&key).map_or_else(
49 || Err(ApiError::unauthorized("Invalid API key")),
50 |identity| Ok(Some(identity.to_string())),
51 ),
52 }
53}
54
55fn check_rate_limit(
56 identity: Option<&String>,
57 rate_limiter: Option<&Arc<RateLimiter>>,
58 operation: &str,
59) -> Result<(), ApiError> {
60 if let Some(limiter) = rate_limiter {
61 if let Some(id) = identity {
62 if let Err(msg) = limiter.check_and_record(id, Operation::VectorOp) {
63 tracing::warn!("Rate limited: {id} for {operation}");
64 return Err(ApiError::rate_limited(msg));
65 }
66 }
67 }
68 Ok(())
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct SpatialInsertRequest {
78 pub key: String,
80 pub x: f32,
82 pub y: f32,
84 pub width: f32,
86 pub height: f32,
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct SpatialQueryRequest {
93 pub x: f32,
95 pub y: f32,
97 pub radius: f32,
99 pub limit: Option<usize>,
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct SpatialDeleteRequest {
106 pub key: String,
108 pub x: f32,
110 pub y: f32,
112 pub width: f32,
114 pub height: f32,
116}
117
118#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct SpatialResultItem {
121 pub key: String,
123 pub distance: f32,
125 pub x: f32,
127 pub y: f32,
129 pub width: f32,
131 pub height: f32,
133}
134
135#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct SpatialQueryResponse {
138 pub result: Vec<SpatialResultItem>,
140 pub time: f64,
142}
143
144#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct SpatialCountResponse {
147 pub count: usize,
149}
150
151pub async fn insert(
162 State(ctx): State<Arc<VectorApiContext>>,
163 headers: HeaderMap,
164 Path(_name): Path<String>,
165 Json(body): Json<SpatialInsertRequest>,
166) -> ApiResult<serde_json::Value> {
167 let identity = validate_auth(&headers, ctx.auth_config.as_ref())?;
168 check_rate_limit(
169 identity.as_ref(),
170 ctx.rate_limiter.as_ref(),
171 "spatial_insert",
172 )?;
173
174 let spatial = ctx
175 .spatial
176 .as_ref()
177 .ok_or_else(|| ApiError::internal("Spatial index not configured"))?;
178
179 let bounds = tensor_spatial::BoundingBox::new(body.x, body.y, body.width, body.height)
180 .map_err(|e| ApiError::bad_request(e.to_string()))?;
181 let entry = tensor_spatial::SpatialEntry {
182 data: body.key,
183 bounds,
184 };
185 spatial.write().insert(entry);
186
187 Ok(Json(serde_json::json!({"status": "ok"})))
188}
189
190pub async fn query(
197 State(ctx): State<Arc<VectorApiContext>>,
198 headers: HeaderMap,
199 Path(_name): Path<String>,
200 Json(body): Json<SpatialQueryRequest>,
201) -> ApiResult<SpatialQueryResponse> {
202 let identity = validate_auth(&headers, ctx.auth_config.as_ref())?;
203 check_rate_limit(
204 identity.as_ref(),
205 ctx.rate_limiter.as_ref(),
206 "spatial_query",
207 )?;
208
209 let spatial = ctx
210 .spatial
211 .as_ref()
212 .ok_or_else(|| ApiError::internal("Spatial index not configured"))?;
213
214 let start = Instant::now();
215 let guard = spatial.read();
216 let mut results: Vec<SpatialResultItem> = guard
217 .query_within_radius_with_distances(body.x, body.y, body.radius)
218 .into_iter()
219 .map(|(e, dist)| SpatialResultItem {
220 key: e.data.clone(),
221 distance: dist,
222 x: e.bounds.x(),
223 y: e.bounds.y(),
224 width: e.bounds.width(),
225 height: e.bounds.height(),
226 })
227 .collect();
228 drop(guard);
229
230 if let Some(max) = body.limit {
231 results.truncate(max);
232 }
233
234 let elapsed = start.elapsed().as_secs_f64() * 1000.0;
235 Ok(Json(SpatialQueryResponse {
236 result: results,
237 time: elapsed,
238 }))
239}
240
241pub async fn delete(
248 State(ctx): State<Arc<VectorApiContext>>,
249 headers: HeaderMap,
250 Path(_name): Path<String>,
251 Json(body): Json<SpatialDeleteRequest>,
252) -> ApiResult<serde_json::Value> {
253 let identity = validate_auth(&headers, ctx.auth_config.as_ref())?;
254 check_rate_limit(
255 identity.as_ref(),
256 ctx.rate_limiter.as_ref(),
257 "spatial_delete",
258 )?;
259
260 let spatial = ctx
261 .spatial
262 .as_ref()
263 .ok_or_else(|| ApiError::internal("Spatial index not configured"))?;
264
265 let bounds = tensor_spatial::BoundingBox::new(body.x, body.y, body.width, body.height)
266 .map_err(|e| ApiError::bad_request(e.to_string()))?;
267 let key = body.key;
268 spatial
269 .write()
270 .remove(bounds, |e| e.data == key && e.bounds == bounds)
271 .map_err(|e| ApiError::not_found(e.to_string()))?;
272
273 Ok(Json(serde_json::json!({"status": "ok"})))
274}
275
276pub async fn count(
283 State(ctx): State<Arc<VectorApiContext>>,
284 headers: HeaderMap,
285 Path(_name): Path<String>,
286) -> ApiResult<SpatialCountResponse> {
287 let identity = validate_auth(&headers, ctx.auth_config.as_ref())?;
288 check_rate_limit(
289 identity.as_ref(),
290 ctx.rate_limiter.as_ref(),
291 "spatial_count",
292 )?;
293
294 let spatial = ctx
295 .spatial
296 .as_ref()
297 .ok_or_else(|| ApiError::internal("Spatial index not configured"))?;
298
299 let count = spatial.read().len();
300 Ok(Json(SpatialCountResponse { count }))
301}
302
303#[cfg(test)]
304mod tests {
305 use super::*;
306
307 #[test]
310 fn test_serde_spatial_insert_request() {
311 let req = SpatialInsertRequest {
312 key: "building".to_string(),
313 x: 10.0,
314 y: 20.0,
315 width: 5.0,
316 height: 3.0,
317 };
318 let json = serde_json::to_string(&req).unwrap();
319 let decoded: SpatialInsertRequest = serde_json::from_str(&json).unwrap();
320 assert_eq!(decoded.key, "building");
321 assert!((decoded.x - 10.0).abs() < f32::EPSILON);
322 }
323
324 #[test]
325 fn test_serde_spatial_query_request() {
326 let req = SpatialQueryRequest {
327 x: 5.0,
328 y: 5.0,
329 radius: 10.0,
330 limit: Some(50),
331 };
332 let json = serde_json::to_string(&req).unwrap();
333 let decoded: SpatialQueryRequest = serde_json::from_str(&json).unwrap();
334 assert!((decoded.radius - 10.0).abs() < f32::EPSILON);
335 assert_eq!(decoded.limit, Some(50));
336 }
337
338 #[test]
339 fn test_serde_spatial_query_request_no_limit() {
340 let json = r#"{"x":1.0,"y":2.0,"radius":3.0}"#;
341 let decoded: SpatialQueryRequest = serde_json::from_str(json).unwrap();
342 assert!(decoded.limit.is_none());
343 }
344
345 #[test]
346 fn test_serde_spatial_delete_request() {
347 let req = SpatialDeleteRequest {
348 key: "park".to_string(),
349 x: 0.0,
350 y: 0.0,
351 width: 10.0,
352 height: 10.0,
353 };
354 let json = serde_json::to_string(&req).unwrap();
355 let decoded: SpatialDeleteRequest = serde_json::from_str(&json).unwrap();
356 assert_eq!(decoded.key, "park");
357 }
358
359 #[test]
360 fn test_serde_spatial_result_item() {
361 let item = SpatialResultItem {
362 key: "a".to_string(),
363 distance: 1.5,
364 x: 2.0,
365 y: 3.0,
366 width: 4.0,
367 height: 5.0,
368 };
369 let json = serde_json::to_string(&item).unwrap();
370 let decoded: SpatialResultItem = serde_json::from_str(&json).unwrap();
371 assert!((decoded.distance - 1.5).abs() < f32::EPSILON);
372 }
373
374 #[test]
375 fn test_serde_spatial_query_response() {
376 let resp = SpatialQueryResponse {
377 result: vec![SpatialResultItem {
378 key: "b".to_string(),
379 distance: 0.5,
380 x: 1.0,
381 y: 1.0,
382 width: 2.0,
383 height: 2.0,
384 }],
385 time: 1.234,
386 };
387 let json = serde_json::to_string(&resp).unwrap();
388 let decoded: SpatialQueryResponse = serde_json::from_str(&json).unwrap();
389 assert_eq!(decoded.result.len(), 1);
390 }
391
392 #[test]
393 fn test_serde_spatial_count_response() {
394 let resp = SpatialCountResponse { count: 42 };
395 let json = serde_json::to_string(&resp).unwrap();
396 let decoded: SpatialCountResponse = serde_json::from_str(&json).unwrap();
397 assert_eq!(decoded.count, 42);
398 }
399
400 #[tokio::test]
403 async fn test_insert_no_spatial_configured() {
404 let engine = Arc::new(vector_engine::VectorEngine::new());
405 let ctx = Arc::new(VectorApiContext::new(engine));
406
407 let body = SpatialInsertRequest {
408 key: "test".to_string(),
409 x: 1.0,
410 y: 2.0,
411 width: 3.0,
412 height: 4.0,
413 };
414
415 let result = insert(
416 State(ctx),
417 HeaderMap::new(),
418 Path("default".to_string()),
419 Json(body),
420 )
421 .await;
422 assert!(result.is_err());
423 }
424
425 #[tokio::test]
426 async fn test_insert_and_count() {
427 let engine = Arc::new(vector_engine::VectorEngine::new());
428 let spatial = Arc::new(parking_lot::RwLock::new(tensor_spatial::SpatialIndex::<
429 String,
430 >::new()));
431 let ctx = Arc::new(VectorApiContext::new(engine).with_spatial(Some(Arc::clone(&spatial))));
432
433 let body = SpatialInsertRequest {
434 key: "obj1".to_string(),
435 x: 10.0,
436 y: 20.0,
437 width: 5.0,
438 height: 3.0,
439 };
440 let result = insert(
441 State(Arc::clone(&ctx)),
442 HeaderMap::new(),
443 Path("col".to_string()),
444 Json(body),
445 )
446 .await;
447 assert!(result.is_ok());
448
449 let count_result = count(
450 State(Arc::clone(&ctx)),
451 HeaderMap::new(),
452 Path("col".to_string()),
453 )
454 .await;
455 assert!(count_result.is_ok());
456 assert_eq!(count_result.unwrap().0.count, 1);
457 }
458
459 #[tokio::test]
460 async fn test_query_within_radius() {
461 let engine = Arc::new(vector_engine::VectorEngine::new());
462 let spatial = Arc::new(parking_lot::RwLock::new(tensor_spatial::SpatialIndex::<
463 String,
464 >::new()));
465 let ctx = Arc::new(VectorApiContext::new(engine).with_spatial(Some(Arc::clone(&spatial))));
466
467 for (key, x, y) in [("near", 1.0_f32, 1.0_f32), ("far", 100.0, 100.0)] {
469 let body = SpatialInsertRequest {
470 key: key.to_string(),
471 x,
472 y,
473 width: 1.0,
474 height: 1.0,
475 };
476 let _ = insert(
477 State(Arc::clone(&ctx)),
478 HeaderMap::new(),
479 Path("col".to_string()),
480 Json(body),
481 )
482 .await
483 .unwrap();
484 }
485
486 let q = SpatialQueryRequest {
487 x: 0.0,
488 y: 0.0,
489 radius: 10.0,
490 limit: None,
491 };
492 let result = query(
493 State(Arc::clone(&ctx)),
494 HeaderMap::new(),
495 Path("col".to_string()),
496 Json(q),
497 )
498 .await
499 .unwrap();
500 assert_eq!(result.0.result.len(), 1);
501 assert_eq!(result.0.result[0].key, "near");
502 }
503
504 #[tokio::test]
505 async fn test_delete_entry() {
506 let engine = Arc::new(vector_engine::VectorEngine::new());
507 let spatial = Arc::new(parking_lot::RwLock::new(tensor_spatial::SpatialIndex::<
508 String,
509 >::new()));
510 let ctx = Arc::new(VectorApiContext::new(engine).with_spatial(Some(Arc::clone(&spatial))));
511
512 let body = SpatialInsertRequest {
514 key: "temp".to_string(),
515 x: 5.0,
516 y: 5.0,
517 width: 2.0,
518 height: 2.0,
519 };
520 let _ = insert(
521 State(Arc::clone(&ctx)),
522 HeaderMap::new(),
523 Path("col".to_string()),
524 Json(body),
525 )
526 .await
527 .unwrap();
528 assert_eq!(spatial.read().len(), 1);
529
530 let del = SpatialDeleteRequest {
532 key: "temp".to_string(),
533 x: 5.0,
534 y: 5.0,
535 width: 2.0,
536 height: 2.0,
537 };
538 let _ = delete(
539 State(Arc::clone(&ctx)),
540 HeaderMap::new(),
541 Path("col".to_string()),
542 Json(del),
543 )
544 .await
545 .unwrap();
546 assert_eq!(spatial.read().len(), 0);
547 }
548
549 #[tokio::test]
550 async fn test_insert_invalid_bounds() {
551 let engine = Arc::new(vector_engine::VectorEngine::new());
552 let spatial = Arc::new(parking_lot::RwLock::new(tensor_spatial::SpatialIndex::<
553 String,
554 >::new()));
555 let ctx = Arc::new(VectorApiContext::new(engine).with_spatial(Some(spatial)));
556
557 let body = SpatialInsertRequest {
558 key: "bad".to_string(),
559 x: 0.0,
560 y: 0.0,
561 width: -1.0,
562 height: 5.0,
563 };
564 let result = insert(
565 State(ctx),
566 HeaderMap::new(),
567 Path("col".to_string()),
568 Json(body),
569 )
570 .await;
571 assert!(result.is_err());
572 }
573
574 #[tokio::test]
575 async fn test_count_no_spatial_configured() {
576 let engine = Arc::new(vector_engine::VectorEngine::new());
577 let ctx = Arc::new(VectorApiContext::new(engine));
578
579 let result = count(State(ctx), HeaderMap::new(), Path("default".to_string())).await;
580 assert!(result.is_err());
581 }
582
583 #[tokio::test]
584 async fn test_query_no_spatial_configured() {
585 let engine = Arc::new(vector_engine::VectorEngine::new());
586 let ctx = Arc::new(VectorApiContext::new(engine));
587
588 let body = SpatialQueryRequest {
589 x: 0.0,
590 y: 0.0,
591 radius: 10.0,
592 limit: None,
593 };
594 let result = query(
595 State(ctx),
596 HeaderMap::new(),
597 Path("default".to_string()),
598 Json(body),
599 )
600 .await;
601 assert!(result.is_err());
602 }
603
604 #[tokio::test]
605 async fn test_delete_no_spatial_configured() {
606 let engine = Arc::new(vector_engine::VectorEngine::new());
607 let ctx = Arc::new(VectorApiContext::new(engine));
608
609 let body = SpatialDeleteRequest {
610 key: "missing".to_string(),
611 x: 0.0,
612 y: 0.0,
613 width: 1.0,
614 height: 1.0,
615 };
616 let result = delete(
617 State(ctx),
618 HeaderMap::new(),
619 Path("default".to_string()),
620 Json(body),
621 )
622 .await;
623 assert!(result.is_err());
624 }
625}