use std::sync::Arc;
use std::time::Instant;
use axum::extract::{Path, State};
use axum::http::HeaderMap;
use axum::Json;
use serde::{Deserialize, Serialize};
use crate::config::AuthConfig;
use crate::rate_limit::{Operation, RateLimiter};
use crate::rest::error::{ApiError, ApiResult};
use crate::rest::VectorApiContext;
fn extract_api_key(headers: &HeaderMap, auth_config: Option<&AuthConfig>) -> Option<String> {
let header_name = auth_config.map_or("x-api-key", |c| c.api_key_header.as_str());
headers
.get(header_name)
.and_then(|v| v.to_str().ok())
.map(String::from)
}
fn validate_auth(
headers: &HeaderMap,
auth_config: Option<&AuthConfig>,
) -> Result<Option<String>, ApiError> {
let api_key = extract_api_key(headers, auth_config);
match (auth_config, api_key) {
(None, _) => Ok(None),
(Some(config), None) => {
if config.allow_anonymous {
Ok(None)
} else {
Err(ApiError::unauthorized("API key required"))
}
},
(Some(config), Some(key)) => config.validate_key(&key).map_or_else(
|| Err(ApiError::unauthorized("Invalid API key")),
|identity| Ok(Some(identity.to_string())),
),
}
}
fn check_rate_limit(
identity: Option<&String>,
rate_limiter: Option<&Arc<RateLimiter>>,
operation: &str,
) -> Result<(), ApiError> {
if let Some(limiter) = rate_limiter {
if let Some(id) = identity {
if let Err(msg) = limiter.check_and_record(id, Operation::VectorOp) {
tracing::warn!("Rate limited: {id} for {operation}");
return Err(ApiError::rate_limited(msg));
}
}
}
Ok(())
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Spatial3DInsertRequest {
pub key: String,
pub x: f32,
pub y: f32,
pub z: f32,
#[serde(default = "default_extent")]
pub w: f32,
#[serde(default = "default_extent")]
pub h: f32,
#[serde(default = "default_extent")]
pub d: f32,
}
const fn default_extent() -> f32 {
1.0
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Spatial3DQueryRequest {
pub x: f32,
pub y: f32,
pub z: f32,
pub radius: Option<f32>,
pub limit: Option<usize>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Spatial3DRegionRequest {
pub min: [f32; 3],
pub max: [f32; 3],
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Spatial3DDeleteRequest {
pub key: String,
pub x: f32,
pub y: f32,
pub z: f32,
#[serde(default = "default_extent")]
pub w: f32,
#[serde(default = "default_extent")]
pub h: f32,
#[serde(default = "default_extent")]
pub d: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Spatial3DResultItem {
pub key: String,
pub distance: f32,
pub x: f32,
pub y: f32,
pub z: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Spatial3DQueryResponse {
pub results: Vec<Spatial3DResultItem>,
pub time: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Spatial3DCountResponse {
pub count: usize,
}
pub async fn insert_3d(
State(ctx): State<Arc<VectorApiContext>>,
headers: HeaderMap,
Path(_name): Path<String>,
Json(body): Json<Spatial3DInsertRequest>,
) -> ApiResult<serde_json::Value> {
let identity = validate_auth(&headers, ctx.auth_config.as_ref())?;
check_rate_limit(
identity.as_ref(),
ctx.rate_limiter.as_ref(),
"spatial3d_insert",
)?;
let spatial = ctx
.spatial_3d
.as_ref()
.ok_or_else(|| ApiError::internal("3D spatial index not configured"))?;
let bounds = tensor_spatial::BoundingBox3D::new(body.x, body.y, body.z, body.w, body.h, body.d)
.map_err(|e| ApiError::bad_request(e.to_string()))?;
let entry = tensor_spatial::SpatialEntry3D {
data: body.key,
bounds,
};
spatial.write().insert(entry);
Ok(Json(serde_json::json!({"status": "ok"})))
}
pub async fn query_3d(
State(ctx): State<Arc<VectorApiContext>>,
headers: HeaderMap,
Path(_name): Path<String>,
Json(body): Json<Spatial3DQueryRequest>,
) -> ApiResult<Spatial3DQueryResponse> {
let identity = validate_auth(&headers, ctx.auth_config.as_ref())?;
check_rate_limit(
identity.as_ref(),
ctx.rate_limiter.as_ref(),
"spatial3d_query",
)?;
let spatial = ctx
.spatial_3d
.as_ref()
.ok_or_else(|| ApiError::internal("3D spatial index not configured"))?;
let radius = body.radius.unwrap_or(100.0);
let start = Instant::now();
let guard = spatial.read();
let mut results: Vec<Spatial3DResultItem> = guard
.query_within_radius_with_distances(body.x, body.y, body.z, radius)
.into_iter()
.map(|(e, dist)| {
let (cx, cy, cz) = e.bounds.center();
Spatial3DResultItem {
key: e.data.clone(),
distance: dist,
x: cx,
y: cy,
z: cz,
}
})
.collect();
drop(guard);
if let Some(max) = body.limit {
results.truncate(max);
}
let elapsed = start.elapsed().as_secs_f64() * 1000.0;
Ok(Json(Spatial3DQueryResponse {
results,
time: elapsed,
}))
}
pub async fn nearest_3d(
State(ctx): State<Arc<VectorApiContext>>,
headers: HeaderMap,
Path(_name): Path<String>,
Json(body): Json<Spatial3DQueryRequest>,
) -> ApiResult<Spatial3DQueryResponse> {
let identity = validate_auth(&headers, ctx.auth_config.as_ref())?;
check_rate_limit(
identity.as_ref(),
ctx.rate_limiter.as_ref(),
"spatial3d_nearest",
)?;
let spatial = ctx
.spatial_3d
.as_ref()
.ok_or_else(|| ApiError::internal("3D spatial index not configured"))?;
let k = body.limit.unwrap_or(10);
let start = Instant::now();
let guard = spatial.read();
let results: Vec<Spatial3DResultItem> = guard
.query_nearest_by_centroid(body.x, body.y, body.z, k)
.into_iter()
.map(|e| {
let (cx, cy, cz) = e.bounds.center();
let dx = body.x - cx;
let dy = body.y - cy;
let dz = body.z - cz;
Spatial3DResultItem {
key: e.data.clone(),
distance: dz.mul_add(dz, dx.mul_add(dx, dy * dy)).sqrt(),
x: cx,
y: cy,
z: cz,
}
})
.collect();
drop(guard);
let elapsed = start.elapsed().as_secs_f64() * 1000.0;
Ok(Json(Spatial3DQueryResponse {
results,
time: elapsed,
}))
}
pub async fn region_3d(
State(ctx): State<Arc<VectorApiContext>>,
headers: HeaderMap,
Path(_name): Path<String>,
Json(body): Json<Spatial3DRegionRequest>,
) -> ApiResult<Spatial3DQueryResponse> {
let identity = validate_auth(&headers, ctx.auth_config.as_ref())?;
check_rate_limit(
identity.as_ref(),
ctx.rate_limiter.as_ref(),
"spatial3d_region",
)?;
let spatial = ctx
.spatial_3d
.as_ref()
.ok_or_else(|| ApiError::internal("3D spatial index not configured"))?;
let width = body.max[0] - body.min[0];
let height = body.max[1] - body.min[1];
let depth = body.max[2] - body.min[2];
let region = tensor_spatial::BoundingBox3D::new(
body.min[0],
body.min[1],
body.min[2],
width,
height,
depth,
)
.map_err(|e| ApiError::bad_request(e.to_string()))?;
let start = Instant::now();
let guard = spatial.read();
let results: Vec<Spatial3DResultItem> = guard
.query_region(region)
.into_iter()
.map(|e| {
let (cx, cy, cz) = e.bounds.center();
Spatial3DResultItem {
key: e.data.clone(),
distance: 0.0,
x: cx,
y: cy,
z: cz,
}
})
.collect();
drop(guard);
let elapsed = start.elapsed().as_secs_f64() * 1000.0;
Ok(Json(Spatial3DQueryResponse {
results,
time: elapsed,
}))
}
pub async fn delete_3d(
State(ctx): State<Arc<VectorApiContext>>,
headers: HeaderMap,
Path(_name): Path<String>,
Json(body): Json<Spatial3DDeleteRequest>,
) -> ApiResult<serde_json::Value> {
let identity = validate_auth(&headers, ctx.auth_config.as_ref())?;
check_rate_limit(
identity.as_ref(),
ctx.rate_limiter.as_ref(),
"spatial3d_delete",
)?;
let spatial = ctx
.spatial_3d
.as_ref()
.ok_or_else(|| ApiError::internal("3D spatial index not configured"))?;
let bounds = tensor_spatial::BoundingBox3D::new(body.x, body.y, body.z, body.w, body.h, body.d)
.map_err(|e| ApiError::bad_request(e.to_string()))?;
let key = body.key;
spatial
.write()
.remove(bounds, |e| e.data == key && e.bounds == bounds)
.map_err(|e| ApiError::not_found(e.to_string()))?;
Ok(Json(serde_json::json!({"status": "ok"})))
}
pub async fn count_3d(
State(ctx): State<Arc<VectorApiContext>>,
headers: HeaderMap,
Path(_name): Path<String>,
) -> ApiResult<Spatial3DCountResponse> {
let identity = validate_auth(&headers, ctx.auth_config.as_ref())?;
check_rate_limit(
identity.as_ref(),
ctx.rate_limiter.as_ref(),
"spatial3d_count",
)?;
let spatial = ctx
.spatial_3d
.as_ref()
.ok_or_else(|| ApiError::internal("3D spatial index not configured"))?;
let count = spatial.read().len();
Ok(Json(Spatial3DCountResponse { count }))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serde_insert_request() {
let req = Spatial3DInsertRequest {
key: "paper:W123".to_string(),
x: 10.0,
y: 20.0,
z: 30.0,
w: 1.0,
h: 1.0,
d: 1.0,
};
let json = serde_json::to_string(&req).unwrap();
let decoded: Spatial3DInsertRequest = serde_json::from_str(&json).unwrap();
assert_eq!(decoded.key, "paper:W123");
assert!((decoded.z - 30.0).abs() < f32::EPSILON);
}
#[test]
fn test_serde_insert_request_defaults() {
let json = r#"{"key":"p1","x":1.0,"y":2.0,"z":3.0}"#;
let decoded: Spatial3DInsertRequest = serde_json::from_str(json).unwrap();
assert!((decoded.w - 1.0).abs() < f32::EPSILON);
assert!((decoded.h - 1.0).abs() < f32::EPSILON);
assert!((decoded.d - 1.0).abs() < f32::EPSILON);
}
#[test]
fn test_serde_query_request() {
let req = Spatial3DQueryRequest {
x: 5.0,
y: 5.0,
z: 5.0,
radius: Some(10.0),
limit: Some(50),
};
let json = serde_json::to_string(&req).unwrap();
let decoded: Spatial3DQueryRequest = serde_json::from_str(&json).unwrap();
assert!((decoded.z - 5.0).abs() < f32::EPSILON);
assert_eq!(decoded.limit, Some(50));
}
#[test]
fn test_serde_region_request() {
let req = Spatial3DRegionRequest {
min: [-10.0, -10.0, -10.0],
max: [10.0, 10.0, 10.0],
};
let json = serde_json::to_string(&req).unwrap();
let decoded: Spatial3DRegionRequest = serde_json::from_str(&json).unwrap();
assert!((decoded.min[2] - (-10.0)).abs() < f32::EPSILON);
}
#[test]
fn test_serde_result_item() {
let item = Spatial3DResultItem {
key: "a".to_string(),
distance: 1.5,
x: 2.0,
y: 3.0,
z: 4.0,
};
let json = serde_json::to_string(&item).unwrap();
let decoded: Spatial3DResultItem = serde_json::from_str(&json).unwrap();
assert!((decoded.distance - 1.5).abs() < f32::EPSILON);
assert!((decoded.z - 4.0).abs() < f32::EPSILON);
}
#[test]
fn test_serde_count_response() {
let resp = Spatial3DCountResponse { count: 42 };
let json = serde_json::to_string(&resp).unwrap();
let decoded: Spatial3DCountResponse = serde_json::from_str(&json).unwrap();
assert_eq!(decoded.count, 42);
}
fn make_ctx_with_spatial_3d() -> Arc<VectorApiContext> {
let engine = Arc::new(vector_engine::VectorEngine::new());
let spatial_3d = Arc::new(parking_lot::RwLock::new(tensor_spatial::SpatialIndex3D::<
String,
>::new()));
Arc::new(VectorApiContext::new(engine).with_spatial_3d(Some(spatial_3d)))
}
#[tokio::test]
async fn test_insert_3d_no_spatial() {
let engine = Arc::new(vector_engine::VectorEngine::new());
let ctx = Arc::new(VectorApiContext::new(engine));
let body = Spatial3DInsertRequest {
key: "test".to_string(),
x: 1.0,
y: 2.0,
z: 3.0,
w: 1.0,
h: 1.0,
d: 1.0,
};
let result = insert_3d(
State(ctx),
HeaderMap::new(),
Path("default".to_string()),
Json(body),
)
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_insert_3d_and_count() {
let ctx = make_ctx_with_spatial_3d();
let body = Spatial3DInsertRequest {
key: "p1".to_string(),
x: 10.0,
y: 20.0,
z: 30.0,
w: 1.0,
h: 1.0,
d: 1.0,
};
let result = insert_3d(
State(Arc::clone(&ctx)),
HeaderMap::new(),
Path("col".to_string()),
Json(body),
)
.await;
assert!(result.is_ok());
let count_result = count_3d(
State(Arc::clone(&ctx)),
HeaderMap::new(),
Path("col".to_string()),
)
.await;
assert!(count_result.is_ok());
assert_eq!(count_result.unwrap().0.count, 1);
}
#[tokio::test]
async fn test_query_3d_within_radius() {
let ctx = make_ctx_with_spatial_3d();
for (key, x, y, z) in [
("near", 1.0_f32, 1.0_f32, 1.0_f32),
("far", 100.0, 100.0, 100.0),
] {
let body = Spatial3DInsertRequest {
key: key.to_string(),
x,
y,
z,
w: 1.0,
h: 1.0,
d: 1.0,
};
let _ = insert_3d(
State(Arc::clone(&ctx)),
HeaderMap::new(),
Path("col".to_string()),
Json(body),
)
.await
.unwrap();
}
let q = Spatial3DQueryRequest {
x: 0.0,
y: 0.0,
z: 0.0,
radius: Some(10.0),
limit: None,
};
let result = query_3d(
State(Arc::clone(&ctx)),
HeaderMap::new(),
Path("col".to_string()),
Json(q),
)
.await
.unwrap();
assert_eq!(result.0.results.len(), 1);
assert_eq!(result.0.results[0].key, "near");
}
#[tokio::test]
async fn test_nearest_3d() {
let ctx = make_ctx_with_spatial_3d();
for (key, x, y, z) in [
("a", 1.0_f32, 0.0_f32, 0.0_f32),
("b", 5.0, 0.0, 0.0),
("c", 10.0, 0.0, 0.0),
] {
let body = Spatial3DInsertRequest {
key: key.to_string(),
x,
y,
z,
w: 1.0,
h: 1.0,
d: 1.0,
};
let _ = insert_3d(
State(Arc::clone(&ctx)),
HeaderMap::new(),
Path("col".to_string()),
Json(body),
)
.await
.unwrap();
}
let q = Spatial3DQueryRequest {
x: 0.0,
y: 0.0,
z: 0.0,
radius: None,
limit: Some(2),
};
let result = nearest_3d(
State(Arc::clone(&ctx)),
HeaderMap::new(),
Path("col".to_string()),
Json(q),
)
.await
.unwrap();
assert_eq!(result.0.results.len(), 2);
assert_eq!(result.0.results[0].key, "a");
}
#[tokio::test]
async fn test_region_3d() {
let ctx = make_ctx_with_spatial_3d();
for (key, x, y, z) in [
("inside", 5.0_f32, 5.0_f32, 5.0_f32),
("outside", 50.0, 50.0, 50.0),
] {
let body = Spatial3DInsertRequest {
key: key.to_string(),
x,
y,
z,
w: 1.0,
h: 1.0,
d: 1.0,
};
let _ = insert_3d(
State(Arc::clone(&ctx)),
HeaderMap::new(),
Path("col".to_string()),
Json(body),
)
.await
.unwrap();
}
let region = Spatial3DRegionRequest {
min: [0.0, 0.0, 0.0],
max: [10.0, 10.0, 10.0],
};
let result = region_3d(
State(Arc::clone(&ctx)),
HeaderMap::new(),
Path("col".to_string()),
Json(region),
)
.await
.unwrap();
assert_eq!(result.0.results.len(), 1);
assert_eq!(result.0.results[0].key, "inside");
}
#[tokio::test]
async fn test_delete_3d() {
let ctx = make_ctx_with_spatial_3d();
let body = Spatial3DInsertRequest {
key: "temp".to_string(),
x: 5.0,
y: 5.0,
z: 5.0,
w: 2.0,
h: 2.0,
d: 2.0,
};
let _ = insert_3d(
State(Arc::clone(&ctx)),
HeaderMap::new(),
Path("col".to_string()),
Json(body),
)
.await
.unwrap();
let count = count_3d(
State(Arc::clone(&ctx)),
HeaderMap::new(),
Path("col".to_string()),
)
.await
.unwrap();
assert_eq!(count.0.count, 1);
let del = Spatial3DDeleteRequest {
key: "temp".to_string(),
x: 5.0,
y: 5.0,
z: 5.0,
w: 2.0,
h: 2.0,
d: 2.0,
};
let _ = delete_3d(
State(Arc::clone(&ctx)),
HeaderMap::new(),
Path("col".to_string()),
Json(del),
)
.await
.unwrap();
let count = count_3d(
State(Arc::clone(&ctx)),
HeaderMap::new(),
Path("col".to_string()),
)
.await
.unwrap();
assert_eq!(count.0.count, 0);
}
#[tokio::test]
async fn test_query_3d_with_limit() {
let ctx = make_ctx_with_spatial_3d();
for (key, x) in [("a", 1.0_f32), ("b", 2.0), ("c", 3.0)] {
let body = Spatial3DInsertRequest {
key: key.to_string(),
x,
y: 0.0,
z: 0.0,
w: 1.0,
h: 1.0,
d: 1.0,
};
let _ = insert_3d(
State(Arc::clone(&ctx)),
HeaderMap::new(),
Path("col".to_string()),
Json(body),
)
.await
.unwrap();
}
let q = Spatial3DQueryRequest {
x: 0.0,
y: 0.0,
z: 0.0,
radius: Some(100.0),
limit: Some(2),
};
let result = query_3d(
State(Arc::clone(&ctx)),
HeaderMap::new(),
Path("col".to_string()),
Json(q),
)
.await
.unwrap();
assert!(result.0.results.len() <= 2);
}
#[test]
fn test_validate_auth_no_config() {
let result = validate_auth(&HeaderMap::new(), None);
assert!(result.is_ok());
assert!(result.unwrap().is_none());
}
#[test]
fn test_validate_auth_required_but_missing() {
use crate::config::{ApiKey, AuthConfig};
let auth = AuthConfig::new().with_api_key(ApiKey::new(
"test-api-key-12345678".to_string(),
"user:test".to_string(),
));
let result = validate_auth(&HeaderMap::new(), Some(&auth));
assert!(result.is_err());
}
#[test]
fn test_validate_auth_anonymous_allowed() {
use crate::config::AuthConfig;
let auth = AuthConfig::new().with_anonymous(true);
let result = validate_auth(&HeaderMap::new(), Some(&auth));
assert!(result.is_ok());
assert!(result.unwrap().is_none());
}
#[test]
fn test_validate_auth_valid_key() {
use crate::config::{ApiKey, AuthConfig};
let auth = AuthConfig::new().with_api_key(ApiKey::new(
"test-api-key-12345678".to_string(),
"user:alice".to_string(),
));
let mut headers = HeaderMap::new();
headers.insert("x-api-key", "test-api-key-12345678".parse().unwrap());
let result = validate_auth(&headers, Some(&auth));
assert!(result.is_ok());
assert_eq!(result.unwrap().unwrap(), "user:alice");
}
#[test]
fn test_validate_auth_invalid_key() {
use crate::config::{ApiKey, AuthConfig};
let auth = AuthConfig::new().with_api_key(ApiKey::new(
"test-api-key-12345678".to_string(),
"user:alice".to_string(),
));
let mut headers = HeaderMap::new();
headers.insert("x-api-key", "wrong-key-value".parse().unwrap());
let result = validate_auth(&headers, Some(&auth));
assert!(result.is_err());
}
#[test]
fn test_extract_api_key_default_header() {
let mut headers = HeaderMap::new();
headers.insert("x-api-key", "mykey".parse().unwrap());
let key = extract_api_key(&headers, None);
assert_eq!(key.unwrap(), "mykey");
}
#[test]
fn test_extract_api_key_custom_header() {
use crate::config::AuthConfig;
let auth = AuthConfig::new().with_header("authorization".to_string());
let mut headers = HeaderMap::new();
headers.insert("authorization", "bearer-token".parse().unwrap());
let key = extract_api_key(&headers, Some(&auth));
assert_eq!(key.unwrap(), "bearer-token");
}
#[test]
fn test_check_rate_limit_no_limiter() {
let result = check_rate_limit(None, None, "test");
assert!(result.is_ok());
}
#[test]
fn test_check_rate_limit_no_identity() {
let limiter = Arc::new(RateLimiter::default());
let result = check_rate_limit(None, Some(&limiter), "test");
assert!(result.is_ok());
}
#[tokio::test]
async fn test_insert_3d_invalid_bounds() {
let ctx = make_ctx_with_spatial_3d();
let body = Spatial3DInsertRequest {
key: "bad".to_string(),
x: 0.0,
y: 0.0,
z: 0.0,
w: -1.0,
h: 5.0,
d: 5.0,
};
let result = insert_3d(
State(ctx),
HeaderMap::new(),
Path("col".to_string()),
Json(body),
)
.await;
assert!(result.is_err());
}
}