alopex_server/http/
mod.rs

1pub mod admin;
2pub mod session;
3pub mod sql;
4pub mod vector;
5
6use std::sync::Arc;
7
8use axum::http::{HeaderValue, StatusCode};
9use axum::middleware;
10use axum::response::{IntoResponse, Response};
11use axum::{Json, Router};
12use serde::Serialize;
13use tower::ServiceBuilder;
14use tower_http::limit::RequestBodyLimitLayer;
15use tower_http::trace::TraceLayer;
16use tracing::Span;
17use uuid::Uuid;
18
19use crate::auth::AuthError;
20use crate::error::ServerError;
21use crate::server::ServerState;
22
23#[derive(Clone, Debug)]
24pub struct RequestContext {
25    pub correlation_id: String,
26    pub actor: Option<String>,
27}
28
29#[derive(Serialize)]
30struct ErrorBody {
31    code: String,
32    message: String,
33    correlation_id: String,
34}
35
36#[derive(Serialize)]
37struct ErrorResponse {
38    error: ErrorBody,
39}
40
41pub fn router(state: Arc<ServerState>) -> Router {
42    let api = Router::new()
43        .route("/sql", axum::routing::post(sql::handle))
44        .route("/vector/search", axum::routing::post(vector::search))
45        .route("/vector/upsert", axum::routing::post(vector::upsert))
46        .route("/vector/delete", axum::routing::post(vector::delete))
47        .route(
48            "/vector/index/create",
49            axum::routing::post(vector::index_create),
50        )
51        .route(
52            "/vector/index/update",
53            axum::routing::post(vector::index_update),
54        )
55        .route(
56            "/vector/index/delete",
57            axum::routing::post(vector::index_delete),
58        )
59        .route(
60            "/vector/index/compact",
61            axum::routing::post(vector::index_compact),
62        )
63        .route("/session/begin", axum::routing::post(session::begin))
64        .route("/session/:id/commit", axum::routing::post(session::commit))
65        .route(
66            "/session/:id/rollback",
67            axum::routing::post(session::rollback),
68        );
69
70    let api = if state.config.api_prefix.is_empty() {
71        api
72    } else {
73        Router::new().nest(&state.config.api_prefix, api)
74    };
75
76    let middleware = middleware::from_fn(context_middleware);
77    let connection_middleware = middleware::from_fn(connection_middleware);
78    api.layer(
79        ServiceBuilder::new()
80            .layer(RequestBodyLimitLayer::new(state.config.max_request_size))
81            .layer(tower::limit::ConcurrencyLimitLayer::new(
82                state.config.max_connections,
83            ))
84            .layer(TraceLayer::new_for_http().make_span_with(make_trace_span))
85            .layer(middleware)
86            .layer(connection_middleware),
87    )
88    .layer(axum::Extension(state))
89}
90
91pub fn admin_router(state: Arc<ServerState>) -> Router {
92    admin::router(state)
93}
94
95pub async fn context_middleware<B>(
96    axum::extract::Extension(state): axum::extract::Extension<Arc<ServerState>>,
97    mut req: axum::http::Request<B>,
98    next: middleware::Next<B>,
99) -> Response {
100    let correlation_id =
101        extract_correlation_id(req.headers()).unwrap_or_else(|| Uuid::new_v4().to_string());
102
103    let actor = match state.auth.validate_http(req.headers()) {
104        Ok(actor) => actor,
105        Err(err) => {
106            if state.config.audit_log_enabled {
107                state.audit.log(crate::audit::AuditLogEntry {
108                    event_type: crate::audit::AuditEventType::AuthFailure,
109                    actor: None,
110                    target: "auth".into(),
111                    correlation_id: correlation_id.clone(),
112                    timestamp: chrono::Utc::now(),
113                    details: serde_json::json!({ "error": err.to_string() }),
114                });
115            }
116            return auth_error_response(err, &correlation_id);
117        }
118    };
119
120    req.extensions_mut().insert(RequestContext {
121        correlation_id: correlation_id.clone(),
122        actor,
123    });
124
125    let mut res = next.run(req).await;
126    let _ = res.headers_mut().insert(
127        "x-correlation-id",
128        HeaderValue::from_str(&correlation_id).unwrap_or_else(|_| HeaderValue::from_static("")),
129    );
130    res
131}
132
133pub async fn connection_middleware<B>(
134    axum::extract::Extension(state): axum::extract::Extension<Arc<ServerState>>,
135    req: axum::http::Request<B>,
136    next: middleware::Next<B>,
137) -> Response {
138    state.metrics.record_connection(1);
139    let res = next.run(req).await;
140    state.metrics.record_connection(-1);
141    res
142}
143
144fn auth_error_response(err: AuthError, correlation_id: &str) -> Response {
145    let message = err.to_string();
146    let body = Json(ErrorResponse {
147        error: ErrorBody {
148            code: "UNAUTHORIZED".to_string(),
149            message,
150            correlation_id: correlation_id.to_string(),
151        },
152    });
153    (StatusCode::UNAUTHORIZED, body).into_response()
154}
155
156pub fn error_response(err: ServerError, ctx: &RequestContext) -> Response {
157    let body = Json(ErrorResponse {
158        error: ErrorBody {
159            code: err.error_code(),
160            message: err.to_string(),
161            correlation_id: ctx.correlation_id.clone(),
162        },
163    });
164    (err.status_code(), body).into_response()
165}
166
167fn make_trace_span<B>(request: &axum::http::Request<B>) -> Span {
168    let correlation_id = request
169        .extensions()
170        .get::<RequestContext>()
171        .map(|ctx| ctx.correlation_id.clone())
172        .or_else(|| extract_correlation_id(request.headers()))
173        .unwrap_or_else(|| Uuid::new_v4().to_string());
174    let traceparent = request
175        .headers()
176        .get("traceparent")
177        .and_then(|v| v.to_str().ok())
178        .unwrap_or("");
179    tracing::info_span!(
180        "http_request",
181        correlation_id = %correlation_id,
182        traceparent = %traceparent,
183        method = %request.method(),
184        path = %request.uri().path()
185    )
186}
187
188pub fn json_response<T: Serialize>(value: T, max_size: usize, ctx: &RequestContext) -> Response {
189    match serde_json::to_vec(&value) {
190        Ok(bytes) if bytes.len() <= max_size => (StatusCode::OK, Json(value)).into_response(),
191        Ok(_) => error_response(
192            ServerError::PayloadTooLarge("response size exceeds limit".into()),
193            ctx,
194        ),
195        Err(err) => error_response(ServerError::Internal(err.to_string()), ctx),
196    }
197}
198
199fn extract_correlation_id(headers: &axum::http::HeaderMap) -> Option<String> {
200    headers
201        .get("x-correlation-id")
202        .and_then(|v| v.to_str().ok())
203        .map(|v| v.to_string())
204        .or_else(|| {
205            headers
206                .get("x-request-id")
207                .and_then(|v| v.to_str().ok())
208                .map(|v| v.to_string())
209        })
210}