alopex_server/http/
mod.rs1pub mod admin;
2pub mod session;
3pub mod sql;
4pub mod vector;
5
6use std::sync::Arc;
7
8use axum::http::{HeaderValue, StatusCode};
9use axum::middleware;
10use axum::response::{IntoResponse, Response};
11use axum::{Json, Router};
12use serde::Serialize;
13use tower::ServiceBuilder;
14use tower_http::limit::RequestBodyLimitLayer;
15use tower_http::trace::TraceLayer;
16use tracing::Span;
17use uuid::Uuid;
18
19use crate::auth::AuthError;
20use crate::error::ServerError;
21use crate::server::ServerState;
22
23#[derive(Clone, Debug)]
24pub struct RequestContext {
25 pub correlation_id: String,
26 pub actor: Option<String>,
27}
28
29#[derive(Serialize)]
30struct ErrorBody {
31 code: String,
32 message: String,
33 correlation_id: String,
34}
35
36#[derive(Serialize)]
37struct ErrorResponse {
38 error: ErrorBody,
39}
40
41pub fn router(state: Arc<ServerState>) -> Router {
42 let api = Router::new()
43 .route("/sql", axum::routing::post(sql::handle))
44 .route("/vector/search", axum::routing::post(vector::search))
45 .route("/vector/upsert", axum::routing::post(vector::upsert))
46 .route("/vector/delete", axum::routing::post(vector::delete))
47 .route(
48 "/vector/index/create",
49 axum::routing::post(vector::index_create),
50 )
51 .route(
52 "/vector/index/update",
53 axum::routing::post(vector::index_update),
54 )
55 .route(
56 "/vector/index/delete",
57 axum::routing::post(vector::index_delete),
58 )
59 .route(
60 "/vector/index/compact",
61 axum::routing::post(vector::index_compact),
62 )
63 .route("/session/begin", axum::routing::post(session::begin))
64 .route("/session/:id/commit", axum::routing::post(session::commit))
65 .route(
66 "/session/:id/rollback",
67 axum::routing::post(session::rollback),
68 );
69
70 let api = if state.config.api_prefix.is_empty() {
71 api
72 } else {
73 Router::new().nest(&state.config.api_prefix, api)
74 };
75
76 let middleware = middleware::from_fn(context_middleware);
77 let connection_middleware = middleware::from_fn(connection_middleware);
78 api.layer(
79 ServiceBuilder::new()
80 .layer(RequestBodyLimitLayer::new(state.config.max_request_size))
81 .layer(tower::limit::ConcurrencyLimitLayer::new(
82 state.config.max_connections,
83 ))
84 .layer(TraceLayer::new_for_http().make_span_with(make_trace_span))
85 .layer(middleware)
86 .layer(connection_middleware),
87 )
88 .layer(axum::Extension(state))
89}
90
91pub fn admin_router(state: Arc<ServerState>) -> Router {
92 admin::router(state)
93}
94
95pub async fn context_middleware<B>(
96 axum::extract::Extension(state): axum::extract::Extension<Arc<ServerState>>,
97 mut req: axum::http::Request<B>,
98 next: middleware::Next<B>,
99) -> Response {
100 let correlation_id =
101 extract_correlation_id(req.headers()).unwrap_or_else(|| Uuid::new_v4().to_string());
102
103 let actor = match state.auth.validate_http(req.headers()) {
104 Ok(actor) => actor,
105 Err(err) => {
106 if state.config.audit_log_enabled {
107 state.audit.log(crate::audit::AuditLogEntry {
108 event_type: crate::audit::AuditEventType::AuthFailure,
109 actor: None,
110 target: "auth".into(),
111 correlation_id: correlation_id.clone(),
112 timestamp: chrono::Utc::now(),
113 details: serde_json::json!({ "error": err.to_string() }),
114 });
115 }
116 return auth_error_response(err, &correlation_id);
117 }
118 };
119
120 req.extensions_mut().insert(RequestContext {
121 correlation_id: correlation_id.clone(),
122 actor,
123 });
124
125 let mut res = next.run(req).await;
126 let _ = res.headers_mut().insert(
127 "x-correlation-id",
128 HeaderValue::from_str(&correlation_id).unwrap_or_else(|_| HeaderValue::from_static("")),
129 );
130 res
131}
132
133pub async fn connection_middleware<B>(
134 axum::extract::Extension(state): axum::extract::Extension<Arc<ServerState>>,
135 req: axum::http::Request<B>,
136 next: middleware::Next<B>,
137) -> Response {
138 state.metrics.record_connection(1);
139 let res = next.run(req).await;
140 state.metrics.record_connection(-1);
141 res
142}
143
144fn auth_error_response(err: AuthError, correlation_id: &str) -> Response {
145 let message = err.to_string();
146 let body = Json(ErrorResponse {
147 error: ErrorBody {
148 code: "UNAUTHORIZED".to_string(),
149 message,
150 correlation_id: correlation_id.to_string(),
151 },
152 });
153 (StatusCode::UNAUTHORIZED, body).into_response()
154}
155
156pub fn error_response(err: ServerError, ctx: &RequestContext) -> Response {
157 let body = Json(ErrorResponse {
158 error: ErrorBody {
159 code: err.error_code(),
160 message: err.to_string(),
161 correlation_id: ctx.correlation_id.clone(),
162 },
163 });
164 (err.status_code(), body).into_response()
165}
166
167fn make_trace_span<B>(request: &axum::http::Request<B>) -> Span {
168 let correlation_id = request
169 .extensions()
170 .get::<RequestContext>()
171 .map(|ctx| ctx.correlation_id.clone())
172 .or_else(|| extract_correlation_id(request.headers()))
173 .unwrap_or_else(|| Uuid::new_v4().to_string());
174 let traceparent = request
175 .headers()
176 .get("traceparent")
177 .and_then(|v| v.to_str().ok())
178 .unwrap_or("");
179 tracing::info_span!(
180 "http_request",
181 correlation_id = %correlation_id,
182 traceparent = %traceparent,
183 method = %request.method(),
184 path = %request.uri().path()
185 )
186}
187
188pub fn json_response<T: Serialize>(value: T, max_size: usize, ctx: &RequestContext) -> Response {
189 match serde_json::to_vec(&value) {
190 Ok(bytes) if bytes.len() <= max_size => (StatusCode::OK, Json(value)).into_response(),
191 Ok(_) => error_response(
192 ServerError::PayloadTooLarge("response size exceeds limit".into()),
193 ctx,
194 ),
195 Err(err) => error_response(ServerError::Internal(err.to_string()), ctx),
196 }
197}
198
199fn extract_correlation_id(headers: &axum::http::HeaderMap) -> Option<String> {
200 headers
201 .get("x-correlation-id")
202 .and_then(|v| v.to_str().ok())
203 .map(|v| v.to_string())
204 .or_else(|| {
205 headers
206 .get("x-request-id")
207 .and_then(|v| v.to_str().ok())
208 .map(|v| v.to_string())
209 })
210}