use std::sync::Arc;
use axum::http::{header, HeaderName, HeaderValue, Method};
use axum::routing::{delete, get, post, put};
use axum::Router;
use tower_http::cors::CorsLayer;
use tower_http::limit::RequestBodyLimitLayer;
use tower_http::trace::TraceLayer;
use vector_engine::VectorEngine;
use crate::audit::AuditLogger;
use crate::config::AuthConfig;
use crate::metrics::ServerMetrics;
use crate::rate_limit::RateLimiter;
pub mod collections;
pub mod error;
pub mod points;
pub mod spatial;
pub mod spatial3d;
pub mod types;
pub use error::{ApiError, ApiResult};
pub use types::*;
const DEFAULT_MAX_BODY_SIZE: usize = 16 * 1024 * 1024;
pub struct VectorApiContext {
pub engine: Arc<VectorEngine>,
pub auth_config: Option<AuthConfig>,
pub rate_limiter: Option<Arc<RateLimiter>>,
pub audit_logger: Option<Arc<AuditLogger>>,
pub metrics: Option<Arc<ServerMetrics>>,
pub spatial: Option<Arc<parking_lot::RwLock<tensor_spatial::SpatialIndex<String>>>>,
pub spatial_3d: Option<Arc<parking_lot::RwLock<tensor_spatial::SpatialIndex3D<String>>>>,
}
impl VectorApiContext {
#[must_use]
pub const fn new(engine: Arc<VectorEngine>) -> Self {
Self {
engine,
auth_config: None,
rate_limiter: None,
audit_logger: None,
metrics: None,
spatial: None,
spatial_3d: None,
}
}
#[must_use]
pub fn with_auth(mut self, auth_config: Option<AuthConfig>) -> Self {
self.auth_config = auth_config;
self
}
#[must_use]
pub fn with_rate_limiter(mut self, rate_limiter: Option<Arc<RateLimiter>>) -> Self {
self.rate_limiter = rate_limiter;
self
}
#[must_use]
pub fn with_audit_logger(mut self, audit_logger: Option<Arc<AuditLogger>>) -> Self {
self.audit_logger = audit_logger;
self
}
#[must_use]
pub fn with_metrics(mut self, metrics: Option<Arc<ServerMetrics>>) -> Self {
self.metrics = metrics;
self
}
#[must_use]
pub fn with_spatial(
mut self,
spatial: Option<Arc<parking_lot::RwLock<tensor_spatial::SpatialIndex<String>>>>,
) -> Self {
self.spatial = spatial;
self
}
#[must_use]
pub fn with_spatial_3d(
mut self,
spatial_3d: Option<Arc<parking_lot::RwLock<tensor_spatial::SpatialIndex3D<String>>>>,
) -> Self {
self.spatial_3d = spatial_3d;
self
}
}
#[derive(Debug, Clone)]
pub struct RestConfig {
pub max_body_size: usize,
pub cors_enabled: bool,
pub cors_origins: Vec<String>,
}
impl Default for RestConfig {
fn default() -> Self {
Self {
max_body_size: DEFAULT_MAX_BODY_SIZE,
cors_enabled: false,
cors_origins: Vec::new(),
}
}
}
impl RestConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub const fn with_max_body_size(mut self, size: usize) -> Self {
self.max_body_size = size;
self
}
#[must_use]
pub const fn with_cors(mut self, enabled: bool) -> Self {
self.cors_enabled = enabled;
self
}
#[must_use]
pub fn with_cors_origins(mut self, origins: Vec<String>) -> Self {
self.cors_origins = origins;
self
}
}
pub fn router(ctx: Arc<VectorApiContext>) -> Router {
router_with_config(ctx, &RestConfig::default())
}
pub fn router_with_config(ctx: Arc<VectorApiContext>, config: &RestConfig) -> Router {
let mut router = Router::new()
.route(
"/collections/{name}/points",
put(points::upsert).post(points::upsert),
)
.route("/collections/{name}/points/get", post(points::get))
.route("/collections/{name}/points/delete", post(points::delete))
.route("/collections/{name}/points/query", post(points::query))
.route("/collections/{name}/points/scroll", post(points::scroll))
.route("/collections/{name}", put(collections::create))
.route("/collections/{name}", get(collections::get))
.route("/collections/{name}", delete(collections::delete))
.route("/collections", get(collections::list))
.route(
"/collections/{name}/spatial/insert",
post(spatial::insert),
)
.route("/collections/{name}/spatial/query", post(spatial::query))
.route(
"/collections/{name}/spatial/delete",
post(spatial::delete),
)
.route("/collections/{name}/spatial/count", get(spatial::count))
.route(
"/collections/{name}/spatial3d/insert",
post(spatial3d::insert_3d),
)
.route(
"/collections/{name}/spatial3d/query",
post(spatial3d::query_3d),
)
.route(
"/collections/{name}/spatial3d/nearest",
post(spatial3d::nearest_3d),
)
.route(
"/collections/{name}/spatial3d/region",
post(spatial3d::region_3d),
)
.route(
"/collections/{name}/spatial3d/delete",
post(spatial3d::delete_3d),
)
.route(
"/collections/{name}/spatial3d/count",
get(spatial3d::count_3d),
)
.layer(TraceLayer::new_for_http())
.layer(RequestBodyLimitLayer::new(config.max_body_size))
.with_state(ctx);
if config.cors_enabled {
let origins: Vec<HeaderValue> = config
.cors_origins
.iter()
.filter_map(|o| o.parse().ok())
.collect();
let cors = CorsLayer::new()
.allow_origin(origins)
.allow_methods([Method::GET, Method::POST, Method::OPTIONS])
.allow_headers([header::CONTENT_TYPE, HeaderName::from_static("x-api-key")]);
router = router.layer(cors);
}
router
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_vector_api_context_new() {
let engine = Arc::new(VectorEngine::new());
let ctx = VectorApiContext::new(engine);
assert!(ctx.auth_config.is_none());
assert!(ctx.rate_limiter.is_none());
assert!(ctx.audit_logger.is_none());
assert!(ctx.metrics.is_none());
}
#[test]
fn test_vector_api_context_with_auth() {
use crate::config::ApiKey;
let engine = Arc::new(VectorEngine::new());
let auth_config = AuthConfig::new().with_api_key(ApiKey::new(
"test-api-key-12345678".to_string(),
"user:test".to_string(),
));
let ctx = VectorApiContext::new(engine).with_auth(Some(auth_config));
assert!(ctx.auth_config.is_some());
}
#[test]
fn test_vector_api_context_with_rate_limiter() {
let engine = Arc::new(VectorEngine::new());
let rate_limiter = Arc::new(RateLimiter::default());
let ctx = VectorApiContext::new(engine).with_rate_limiter(Some(rate_limiter));
assert!(ctx.rate_limiter.is_some());
}
#[test]
fn test_vector_api_context_with_audit_logger() {
let engine = Arc::new(VectorEngine::new());
let audit_logger = Arc::new(AuditLogger::default());
let ctx = VectorApiContext::new(engine).with_audit_logger(Some(audit_logger));
assert!(ctx.audit_logger.is_some());
}
#[test]
fn test_rest_config_default() {
let config = RestConfig::default();
assert_eq!(config.max_body_size, DEFAULT_MAX_BODY_SIZE);
assert!(!config.cors_enabled);
assert!(config.cors_origins.is_empty());
}
#[test]
fn test_rest_config_builder() {
let config = RestConfig::new()
.with_max_body_size(32 * 1024 * 1024)
.with_cors(true)
.with_cors_origins(vec!["http://localhost:3000".to_string()]);
assert_eq!(config.max_body_size, 32 * 1024 * 1024);
assert!(config.cors_enabled);
assert_eq!(config.cors_origins.len(), 1);
}
#[test]
fn test_router_creation() {
let engine = Arc::new(VectorEngine::new());
let ctx = Arc::new(VectorApiContext::new(engine));
let _router = router(ctx);
}
#[test]
fn test_router_with_config_creation() {
let engine = Arc::new(VectorEngine::new());
let ctx = Arc::new(VectorApiContext::new(engine));
let config = RestConfig::new().with_max_body_size(8 * 1024 * 1024);
let _router = router_with_config(ctx, &config);
}
#[test]
fn test_vector_api_context_with_spatial_3d() {
let engine = Arc::new(VectorEngine::new());
let spatial_3d = Arc::new(parking_lot::RwLock::new(tensor_spatial::SpatialIndex3D::<
String,
>::new()));
let ctx = VectorApiContext::new(engine).with_spatial_3d(Some(spatial_3d));
assert!(ctx.spatial_3d.is_some());
}
#[test]
fn test_vector_api_context_with_metrics() {
use opentelemetry::metrics::MeterProvider;
use opentelemetry_sdk::metrics::SdkMeterProvider;
let engine = Arc::new(VectorEngine::new());
let provider = SdkMeterProvider::builder().build();
let meter = provider.meter("test");
let metrics = Arc::new(ServerMetrics::new(meter));
let ctx = VectorApiContext::new(engine).with_metrics(Some(metrics));
assert!(ctx.metrics.is_some());
}
#[test]
fn test_router_with_cors_enabled() {
let engine = Arc::new(VectorEngine::new());
let ctx = Arc::new(VectorApiContext::new(engine));
let config = RestConfig::new()
.with_cors(true)
.with_cors_origins(vec!["http://localhost:3000".to_string()]);
let _router = router_with_config(ctx, &config);
}
}