1pub mod admin;
2pub mod admin_api;
3pub mod admin_resources;
4pub mod columnar;
5pub mod hnsw;
6pub mod kv;
7pub mod session;
8pub mod sql;
9pub mod vector;
10
11use std::sync::Arc;
12
13use axum::http::{HeaderValue, StatusCode};
14use axum::middleware;
15use axum::response::{IntoResponse, Response};
16use axum::{Json, Router};
17use serde::Serialize;
18use tower::ServiceBuilder;
19use tower_http::limit::RequestBodyLimitLayer;
20use tower_http::trace::TraceLayer;
21use tracing::Span;
22use uuid::Uuid;
23
24use crate::auth::AuthError;
25use crate::error::ServerError;
26use crate::server::ServerState;
27
28#[derive(Clone, Debug)]
29pub struct RequestContext {
30 pub correlation_id: String,
31 pub actor: Option<String>,
32}
33
34#[derive(Serialize)]
35struct ErrorBody {
36 code: String,
37 message: String,
38 correlation_id: String,
39}
40
41#[derive(Serialize)]
42struct ErrorResponse {
43 error: ErrorBody,
44}
45
46pub fn router(state: Arc<ServerState>) -> Router {
47 let api = Router::new()
48 .route("/kv/get", axum::routing::post(kv::get))
49 .route("/kv/put", axum::routing::post(kv::put))
50 .route("/kv/delete", axum::routing::post(kv::delete))
51 .route("/kv/list", axum::routing::post(kv::list))
52 .route("/kv/txn/begin", axum::routing::post(kv::txn_begin))
53 .route("/kv/txn/get", axum::routing::post(kv::txn_get))
54 .route("/kv/txn/put", axum::routing::post(kv::txn_put))
55 .route("/kv/txn/delete", axum::routing::post(kv::txn_delete))
56 .route("/kv/txn/commit", axum::routing::post(kv::txn_commit))
57 .route("/kv/txn/rollback", axum::routing::post(kv::txn_rollback))
58 .route("/columnar/scan", axum::routing::post(columnar::scan))
59 .route("/columnar/stats", axum::routing::post(columnar::stats))
60 .route("/columnar/list", axum::routing::post(columnar::list))
61 .route("/columnar/ingest", axum::routing::post(columnar::ingest))
62 .route(
63 "/columnar/index/create",
64 axum::routing::post(columnar::index_create),
65 )
66 .route(
67 "/columnar/index/list",
68 axum::routing::post(columnar::index_list),
69 )
70 .route(
71 "/columnar/index/drop",
72 axum::routing::post(columnar::index_drop),
73 )
74 .route("/hnsw/search", axum::routing::post(hnsw::search))
75 .route("/hnsw/upsert", axum::routing::post(hnsw::upsert))
76 .route("/hnsw/delete", axum::routing::post(hnsw::delete))
77 .route("/hnsw/create", axum::routing::post(hnsw::create))
78 .route("/hnsw/drop", axum::routing::post(hnsw::drop))
79 .route("/hnsw/stats", axum::routing::post(hnsw::stats))
80 .route("/sql", axum::routing::post(sql::handle))
81 .route("/api/sql/query", axum::routing::post(sql::handle))
82 .route("/vector/search", axum::routing::post(vector::search))
83 .route("/vector/upsert", axum::routing::post(vector::upsert))
84 .route("/vector/delete", axum::routing::post(vector::delete))
85 .route(
86 "/vector/index/create",
87 axum::routing::post(vector::index_create),
88 )
89 .route(
90 "/vector/index/update",
91 axum::routing::post(vector::index_update),
92 )
93 .route(
94 "/vector/index/delete",
95 axum::routing::post(vector::index_delete),
96 )
97 .route(
98 "/vector/index/compact",
99 axum::routing::post(vector::index_compact),
100 )
101 .route(
102 "/api/admin/capabilities",
103 axum::routing::get(admin_api::capabilities),
104 )
105 .route(
106 "/api/admin/resources",
107 axum::routing::get(admin_resources::list),
108 )
109 .route("/api/admin/status", axum::routing::get(admin_api::status))
110 .route("/api/admin/metrics", axum::routing::get(admin_api::metrics))
111 .route("/api/admin/health", axum::routing::get(admin_api::health))
112 .route(
113 "/api/admin/backup",
114 axum::routing::post(admin_api::start_backup),
115 )
116 .route(
117 "/api/admin/backup/:id",
118 axum::routing::get(admin_api::backup_status),
119 )
120 .route(
121 "/api/admin/restore",
122 axum::routing::post(admin_api::start_restore),
123 )
124 .route(
125 "/api/admin/restore/:id",
126 axum::routing::get(admin_api::restore_status),
127 )
128 .route(
129 "/api/admin/lifecycle",
130 axum::routing::post(admin_api::lifecycle),
131 )
132 .route(
133 "/api/admin/compaction",
134 axum::routing::post(admin_api::compaction),
135 )
136 .route("/session/begin", axum::routing::post(session::begin))
137 .route("/session/:id/commit", axum::routing::post(session::commit))
138 .route(
139 "/session/:id/rollback",
140 axum::routing::post(session::rollback),
141 );
142
143 let api = if state.config.api_prefix.is_empty() {
144 api
145 } else {
146 Router::new().nest(&state.config.api_prefix, api)
147 };
148
149 let middleware = middleware::from_fn(context_middleware);
150 let connection_middleware = middleware::from_fn(connection_middleware);
151 api.layer(
152 ServiceBuilder::new()
153 .layer(RequestBodyLimitLayer::new(state.config.max_request_size))
154 .layer(tower::limit::ConcurrencyLimitLayer::new(
155 state.config.max_connections,
156 ))
157 .layer(TraceLayer::new_for_http().make_span_with(make_trace_span))
158 .layer(middleware)
159 .layer(connection_middleware),
160 )
161 .layer(axum::Extension(state))
162}
163
164pub fn admin_router(state: Arc<ServerState>) -> Router {
165 admin::router(state)
166}
167
168pub async fn context_middleware<B>(
169 axum::extract::Extension(state): axum::extract::Extension<Arc<ServerState>>,
170 mut req: axum::http::Request<B>,
171 next: middleware::Next<B>,
172) -> Response {
173 let correlation_id =
174 extract_correlation_id(req.headers()).unwrap_or_else(|| Uuid::new_v4().to_string());
175
176 let actor = match state.auth.validate_http(req.headers()) {
177 Ok(actor) => actor,
178 Err(err) => {
179 if state.config.audit_log_enabled {
180 state.audit.log(crate::audit::AuditLogEntry {
181 event_type: crate::audit::AuditEventType::AuthFailure,
182 actor: None,
183 target: "auth".into(),
184 correlation_id: correlation_id.clone(),
185 timestamp: chrono::Utc::now(),
186 details: serde_json::json!({ "error": err.to_string() }),
187 });
188 }
189 return auth_error_response(err, &correlation_id);
190 }
191 };
192
193 req.extensions_mut().insert(RequestContext {
194 correlation_id: correlation_id.clone(),
195 actor,
196 });
197
198 let mut res = next.run(req).await;
199 let _ = res.headers_mut().insert(
200 "x-correlation-id",
201 HeaderValue::from_str(&correlation_id).unwrap_or_else(|_| HeaderValue::from_static("")),
202 );
203 res
204}
205
206pub async fn connection_middleware<B>(
207 axum::extract::Extension(state): axum::extract::Extension<Arc<ServerState>>,
208 req: axum::http::Request<B>,
209 next: middleware::Next<B>,
210) -> Response {
211 state.metrics.record_connection(1);
212 let res = next.run(req).await;
213 state.metrics.record_connection(-1);
214 res
215}
216
217fn auth_error_response(err: AuthError, correlation_id: &str) -> Response {
218 let message = err.to_string();
219 let body = Json(ErrorResponse {
220 error: ErrorBody {
221 code: "UNAUTHORIZED".to_string(),
222 message,
223 correlation_id: correlation_id.to_string(),
224 },
225 });
226 (StatusCode::UNAUTHORIZED, body).into_response()
227}
228
229pub fn error_response(err: ServerError, ctx: &RequestContext) -> Response {
230 let body = Json(ErrorResponse {
231 error: ErrorBody {
232 code: err.error_code(),
233 message: err.to_string(),
234 correlation_id: ctx.correlation_id.clone(),
235 },
236 });
237 (err.status_code(), body).into_response()
238}
239
240fn make_trace_span<B>(request: &axum::http::Request<B>) -> Span {
241 let correlation_id = request
242 .extensions()
243 .get::<RequestContext>()
244 .map(|ctx| ctx.correlation_id.clone())
245 .or_else(|| extract_correlation_id(request.headers()))
246 .unwrap_or_else(|| Uuid::new_v4().to_string());
247 let traceparent = request
248 .headers()
249 .get("traceparent")
250 .and_then(|v| v.to_str().ok())
251 .unwrap_or("");
252 tracing::info_span!(
253 "http_request",
254 correlation_id = %correlation_id,
255 traceparent = %traceparent,
256 method = %request.method(),
257 path = %request.uri().path()
258 )
259}
260
261pub fn json_response<T: Serialize>(value: T, max_size: usize, ctx: &RequestContext) -> Response {
262 match serde_json::to_vec(&value) {
263 Ok(bytes) if bytes.len() <= max_size => (StatusCode::OK, Json(value)).into_response(),
264 Ok(_) => error_response(
265 ServerError::PayloadTooLarge("response size exceeds limit".into()),
266 ctx,
267 ),
268 Err(err) => error_response(ServerError::Internal(err.to_string()), ctx),
269 }
270}
271
272fn extract_correlation_id(headers: &axum::http::HeaderMap) -> Option<String> {
273 headers
274 .get("x-correlation-id")
275 .and_then(|v| v.to_str().ok())
276 .map(|v| v.to_string())
277 .or_else(|| {
278 headers
279 .get("x-request-id")
280 .and_then(|v| v.to_str().ok())
281 .map(|v| v.to_string())
282 })
283}