use std::collections::HashMap;
use axum::{
Json,
extract::{Query, State},
http::{HeaderMap, StatusCode, header},
response::{IntoResponse, Response},
};
use reifydb_auth::service::AuthResponse as EngineAuthResponse;
use reifydb_core::value::frame::response::{ResponseFrame, convert_frames};
use reifydb_sub_server::{
auth::{AuthError, extract_identity_from_auth_header},
execute::execute,
interceptor::{Operation, Protocol, RequestContext, RequestMetadata},
response::resolve_response_json,
state::AppState,
wire::WireParams,
};
use reifydb_type::{params::Params, value::identity::IdentityId};
use serde::{Deserialize, Serialize};
use crate::error::AppError;
#[derive(Debug, Deserialize)]
pub struct StatementRequest {
pub statements: Vec<String>,
#[serde(default)]
pub params: Option<WireParams>,
}
#[derive(Debug, Serialize)]
pub struct QueryResponse {
pub frames: Vec<ResponseFrame>,
}
#[derive(Debug, Deserialize)]
pub struct FormatParams {
pub format: Option<String>,
pub unwrap: Option<bool>,
}
#[derive(Debug, Serialize)]
pub struct HealthResponse {
pub status: &'static str,
}
pub async fn health() -> impl IntoResponse {
(
StatusCode::OK,
Json(HealthResponse {
status: "ok",
}),
)
}
#[derive(Debug, Serialize)]
pub struct LogoutResponse {
pub status: String,
}
#[derive(Debug, Deserialize)]
pub struct AuthenticateRequest {
pub method: String,
#[serde(default)]
pub credentials: HashMap<String, String>,
}
#[derive(Debug, Serialize)]
pub struct AuthenticateResponse {
pub status: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub token: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub identity: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub challenge_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub payload: Option<HashMap<String, String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reason: Option<String>,
}
pub async fn handle_authenticate(
State(state): State<AppState>,
Json(request): Json<AuthenticateRequest>,
) -> Result<Response, AppError> {
match state.auth_service().authenticate(&request.method, request.credentials) {
Ok(EngineAuthResponse::Authenticated {
identity,
token,
}) => Ok((
StatusCode::OK,
Json(AuthenticateResponse {
status: "authenticated".to_string(),
token: Some(token),
identity: Some(identity.to_string()),
challenge_id: None,
payload: None,
reason: None,
}),
)
.into_response()),
Ok(EngineAuthResponse::Challenge {
challenge_id,
payload,
}) => Ok((
StatusCode::OK,
Json(AuthenticateResponse {
status: "challenge".to_string(),
token: None,
identity: None,
challenge_id: Some(challenge_id),
payload: Some(payload),
reason: None,
}),
)
.into_response()),
Ok(EngineAuthResponse::Failed {
reason,
}) => Ok((
StatusCode::UNAUTHORIZED,
Json(AuthenticateResponse {
status: "failed".to_string(),
token: None,
identity: None,
challenge_id: None,
payload: None,
reason: Some(reason),
}),
)
.into_response()),
Err(e) => Ok((
StatusCode::INTERNAL_SERVER_ERROR,
Json(AuthenticateResponse {
status: "failed".to_string(),
token: None,
identity: None,
challenge_id: None,
payload: None,
reason: Some(e.to_string()),
}),
)
.into_response()),
}
}
pub async fn handle_logout(State(state): State<AppState>, headers: HeaderMap) -> Result<Response, AppError> {
let auth_header = headers.get("authorization").ok_or(AppError::Auth(AuthError::MissingCredentials))?;
let auth_str = auth_header.to_str().map_err(|_| AppError::Auth(AuthError::InvalidHeader))?;
let token = auth_str.strip_prefix("Bearer ").ok_or(AppError::Auth(AuthError::InvalidHeader))?.trim();
if token.is_empty() {
return Err(AppError::Auth(AuthError::InvalidToken));
}
let revoked = state.auth_service().revoke_token(token);
if revoked {
Ok((
StatusCode::OK,
Json(LogoutResponse {
status: "ok".to_string(),
}),
)
.into_response())
} else {
Err(AppError::Auth(AuthError::InvalidToken))
}
}
fn build_metadata(headers: &HeaderMap) -> RequestMetadata {
let mut metadata = RequestMetadata::new(Protocol::Http);
for (name, value) in headers.iter() {
if let Ok(v) = value.to_str() {
metadata.insert(name.as_str(), v);
}
}
metadata
}
pub async fn handle_query(
State(state): State<AppState>,
Query(format_params): Query<FormatParams>,
headers: HeaderMap,
Json(request): Json<StatementRequest>,
) -> Result<Response, AppError> {
execute_and_respond(&state, Operation::Query, &headers, request, &format_params).await
}
pub async fn handle_admin(
State(state): State<AppState>,
Query(format_params): Query<FormatParams>,
headers: HeaderMap,
Json(request): Json<StatementRequest>,
) -> Result<Response, AppError> {
execute_and_respond(&state, Operation::Admin, &headers, request, &format_params).await
}
pub async fn handle_command(
State(state): State<AppState>,
Query(format_params): Query<FormatParams>,
headers: HeaderMap,
Json(request): Json<StatementRequest>,
) -> Result<Response, AppError> {
execute_and_respond(&state, Operation::Command, &headers, request, &format_params).await
}
async fn execute_and_respond(
state: &AppState,
operation: Operation,
headers: &HeaderMap,
request: StatementRequest,
format_params: &FormatParams,
) -> Result<Response, AppError> {
let identity = extract_identity(state, headers)?;
let metadata = build_metadata(headers);
let params = match request.params {
None => Params::None,
Some(wp) => wp.into_params().map_err(|e| AppError::InvalidParams(e))?,
};
let ctx = RequestContext {
identity,
operation,
statements: request.statements,
params,
metadata,
};
let (frames, duration) = execute(
state.request_interceptors(),
state.actor_system(),
state.engine_clone(),
ctx,
state.query_timeout(),
state.clock(),
)
.await?;
let mut response = if format_params.format.as_deref() == Some("json") {
let resolved = resolve_response_json(frames, format_params.unwrap.unwrap_or(false))
.map_err(|e| AppError::BadRequest(e))?;
(StatusCode::OK, [(header::CONTENT_TYPE, resolved.content_type)], resolved.body).into_response()
} else {
Json(QueryResponse {
frames: convert_frames(&frames),
})
.into_response()
};
response.headers_mut().insert("x-duration-ms", duration.as_millis().to_string().parse().unwrap());
Ok(response)
}
fn extract_identity(state: &AppState, headers: &HeaderMap) -> Result<IdentityId, AppError> {
if let Some(auth_header) = headers.get("authorization") {
let auth_str = auth_header.to_str().map_err(|_| AppError::Auth(AuthError::InvalidHeader))?;
return extract_identity_from_auth_header(state.auth_service(), auth_str).map_err(AppError::Auth);
}
Ok(IdentityId::anonymous())
}
#[cfg(test)]
pub mod tests {
use serde_json::{from_str, to_string};
use super::*;
#[test]
fn test_statement_request_deserialization() {
let json = r#"{"statements": ["SELECT 1"]}"#;
let request: StatementRequest = from_str(json).unwrap();
assert_eq!(request.statements, vec!["SELECT 1"]);
assert!(request.params.is_none());
}
#[test]
fn test_query_response_serialization() {
let response = QueryResponse {
frames: Vec::new(),
};
let json = to_string(&response).unwrap();
assert!(json.contains("frames"));
}
#[test]
fn test_health_response_serialization() {
let response = HealthResponse {
status: "ok",
};
let json = to_string(&response).unwrap();
assert_eq!(json, r#"{"status":"ok"}"#);
}
}