1pub mod admin;
2pub mod admin_api;
3pub mod columnar;
4pub mod hnsw;
5pub mod kv;
6pub mod session;
7pub mod sql;
8pub mod vector;
9
10use std::sync::Arc;
11
12use axum::http::{HeaderValue, StatusCode};
13use axum::middleware;
14use axum::response::{IntoResponse, Response};
15use axum::{Json, Router};
16use serde::Serialize;
17use tower::ServiceBuilder;
18use tower_http::limit::RequestBodyLimitLayer;
19use tower_http::trace::TraceLayer;
20use tracing::Span;
21use uuid::Uuid;
22
23use crate::auth::AuthError;
24use crate::error::ServerError;
25use crate::server::ServerState;
26
27#[derive(Clone, Debug)]
28pub struct RequestContext {
29 pub correlation_id: String,
30 pub actor: Option<String>,
31}
32
33#[derive(Serialize)]
34struct ErrorBody {
35 code: String,
36 message: String,
37 correlation_id: String,
38}
39
40#[derive(Serialize)]
41struct ErrorResponse {
42 error: ErrorBody,
43}
44
45pub fn router(state: Arc<ServerState>) -> Router {
46 let api = Router::new()
47 .route("/kv/get", axum::routing::post(kv::get))
48 .route("/kv/put", axum::routing::post(kv::put))
49 .route("/kv/delete", axum::routing::post(kv::delete))
50 .route("/kv/list", axum::routing::post(kv::list))
51 .route("/kv/txn/begin", axum::routing::post(kv::txn_begin))
52 .route("/kv/txn/get", axum::routing::post(kv::txn_get))
53 .route("/kv/txn/put", axum::routing::post(kv::txn_put))
54 .route("/kv/txn/delete", axum::routing::post(kv::txn_delete))
55 .route("/kv/txn/commit", axum::routing::post(kv::txn_commit))
56 .route("/kv/txn/rollback", axum::routing::post(kv::txn_rollback))
57 .route("/columnar/scan", axum::routing::post(columnar::scan))
58 .route("/columnar/stats", axum::routing::post(columnar::stats))
59 .route("/columnar/list", axum::routing::post(columnar::list))
60 .route("/columnar/ingest", axum::routing::post(columnar::ingest))
61 .route(
62 "/columnar/index/create",
63 axum::routing::post(columnar::index_create),
64 )
65 .route(
66 "/columnar/index/list",
67 axum::routing::post(columnar::index_list),
68 )
69 .route(
70 "/columnar/index/drop",
71 axum::routing::post(columnar::index_drop),
72 )
73 .route("/hnsw/search", axum::routing::post(hnsw::search))
74 .route("/hnsw/upsert", axum::routing::post(hnsw::upsert))
75 .route("/hnsw/delete", axum::routing::post(hnsw::delete))
76 .route("/hnsw/create", axum::routing::post(hnsw::create))
77 .route("/hnsw/drop", axum::routing::post(hnsw::drop))
78 .route("/hnsw/stats", axum::routing::post(hnsw::stats))
79 .route("/sql", axum::routing::post(sql::handle))
80 .route("/api/sql/query", axum::routing::post(sql::handle))
81 .route("/vector/search", axum::routing::post(vector::search))
82 .route("/vector/upsert", axum::routing::post(vector::upsert))
83 .route("/vector/delete", axum::routing::post(vector::delete))
84 .route(
85 "/vector/index/create",
86 axum::routing::post(vector::index_create),
87 )
88 .route(
89 "/vector/index/update",
90 axum::routing::post(vector::index_update),
91 )
92 .route(
93 "/vector/index/delete",
94 axum::routing::post(vector::index_delete),
95 )
96 .route(
97 "/vector/index/compact",
98 axum::routing::post(vector::index_compact),
99 )
100 .route(
101 "/api/admin/capabilities",
102 axum::routing::get(admin_api::capabilities),
103 )
104 .route("/api/admin/status", axum::routing::get(admin_api::status))
105 .route("/api/admin/metrics", axum::routing::get(admin_api::metrics))
106 .route("/api/admin/health", axum::routing::get(admin_api::health))
107 .route(
108 "/api/admin/lifecycle",
109 axum::routing::post(admin_api::lifecycle),
110 )
111 .route(
112 "/api/admin/compaction",
113 axum::routing::post(admin_api::compaction),
114 )
115 .route("/session/begin", axum::routing::post(session::begin))
116 .route("/session/:id/commit", axum::routing::post(session::commit))
117 .route(
118 "/session/:id/rollback",
119 axum::routing::post(session::rollback),
120 );
121
122 let api = if state.config.api_prefix.is_empty() {
123 api
124 } else {
125 Router::new().nest(&state.config.api_prefix, api)
126 };
127
128 let middleware = middleware::from_fn(context_middleware);
129 let connection_middleware = middleware::from_fn(connection_middleware);
130 api.layer(
131 ServiceBuilder::new()
132 .layer(RequestBodyLimitLayer::new(state.config.max_request_size))
133 .layer(tower::limit::ConcurrencyLimitLayer::new(
134 state.config.max_connections,
135 ))
136 .layer(TraceLayer::new_for_http().make_span_with(make_trace_span))
137 .layer(middleware)
138 .layer(connection_middleware),
139 )
140 .layer(axum::Extension(state))
141}
142
143pub fn admin_router(state: Arc<ServerState>) -> Router {
144 admin::router(state)
145}
146
147pub async fn context_middleware<B>(
148 axum::extract::Extension(state): axum::extract::Extension<Arc<ServerState>>,
149 mut req: axum::http::Request<B>,
150 next: middleware::Next<B>,
151) -> Response {
152 let correlation_id =
153 extract_correlation_id(req.headers()).unwrap_or_else(|| Uuid::new_v4().to_string());
154
155 let actor = match state.auth.validate_http(req.headers()) {
156 Ok(actor) => actor,
157 Err(err) => {
158 if state.config.audit_log_enabled {
159 state.audit.log(crate::audit::AuditLogEntry {
160 event_type: crate::audit::AuditEventType::AuthFailure,
161 actor: None,
162 target: "auth".into(),
163 correlation_id: correlation_id.clone(),
164 timestamp: chrono::Utc::now(),
165 details: serde_json::json!({ "error": err.to_string() }),
166 });
167 }
168 return auth_error_response(err, &correlation_id);
169 }
170 };
171
172 req.extensions_mut().insert(RequestContext {
173 correlation_id: correlation_id.clone(),
174 actor,
175 });
176
177 let mut res = next.run(req).await;
178 let _ = res.headers_mut().insert(
179 "x-correlation-id",
180 HeaderValue::from_str(&correlation_id).unwrap_or_else(|_| HeaderValue::from_static("")),
181 );
182 res
183}
184
185pub async fn connection_middleware<B>(
186 axum::extract::Extension(state): axum::extract::Extension<Arc<ServerState>>,
187 req: axum::http::Request<B>,
188 next: middleware::Next<B>,
189) -> Response {
190 state.metrics.record_connection(1);
191 let res = next.run(req).await;
192 state.metrics.record_connection(-1);
193 res
194}
195
196fn auth_error_response(err: AuthError, correlation_id: &str) -> Response {
197 let message = err.to_string();
198 let body = Json(ErrorResponse {
199 error: ErrorBody {
200 code: "UNAUTHORIZED".to_string(),
201 message,
202 correlation_id: correlation_id.to_string(),
203 },
204 });
205 (StatusCode::UNAUTHORIZED, body).into_response()
206}
207
208pub fn error_response(err: ServerError, ctx: &RequestContext) -> Response {
209 let body = Json(ErrorResponse {
210 error: ErrorBody {
211 code: err.error_code(),
212 message: err.to_string(),
213 correlation_id: ctx.correlation_id.clone(),
214 },
215 });
216 (err.status_code(), body).into_response()
217}
218
219fn make_trace_span<B>(request: &axum::http::Request<B>) -> Span {
220 let correlation_id = request
221 .extensions()
222 .get::<RequestContext>()
223 .map(|ctx| ctx.correlation_id.clone())
224 .or_else(|| extract_correlation_id(request.headers()))
225 .unwrap_or_else(|| Uuid::new_v4().to_string());
226 let traceparent = request
227 .headers()
228 .get("traceparent")
229 .and_then(|v| v.to_str().ok())
230 .unwrap_or("");
231 tracing::info_span!(
232 "http_request",
233 correlation_id = %correlation_id,
234 traceparent = %traceparent,
235 method = %request.method(),
236 path = %request.uri().path()
237 )
238}
239
240pub fn json_response<T: Serialize>(value: T, max_size: usize, ctx: &RequestContext) -> Response {
241 match serde_json::to_vec(&value) {
242 Ok(bytes) if bytes.len() <= max_size => (StatusCode::OK, Json(value)).into_response(),
243 Ok(_) => error_response(
244 ServerError::PayloadTooLarge("response size exceeds limit".into()),
245 ctx,
246 ),
247 Err(err) => error_response(ServerError::Internal(err.to_string()), ctx),
248 }
249}
250
251fn extract_correlation_id(headers: &axum::http::HeaderMap) -> Option<String> {
252 headers
253 .get("x-correlation-id")
254 .and_then(|v| v.to_str().ok())
255 .map(|v| v.to_string())
256 .or_else(|| {
257 headers
258 .get("x-request-id")
259 .and_then(|v| v.to_str().ok())
260 .map(|v| v.to_string())
261 })
262}