1use std::collections::HashMap;
12
13use axum::{
14 Json,
15 extract::{Query, State},
16 http::{HeaderMap, StatusCode, header},
17 response::{IntoResponse, Response},
18};
19use reifydb_auth::service::AuthResponse as EngineAuthResponse;
20use reifydb_core::value::frame::response::{ResponseFrame, convert_frames};
21use reifydb_sub_server::{
22 auth::{AuthError, extract_identity_from_auth_header},
23 execute::execute,
24 interceptor::{Operation, Protocol, RequestContext, RequestMetadata},
25 response::resolve_response_json,
26 state::AppState,
27 wire::WireParams,
28};
29use reifydb_type::{params::Params, value::identity::IdentityId};
30use serde::{Deserialize, Serialize};
31
32use crate::error::AppError;
33
34#[derive(Debug, Deserialize)]
36pub struct StatementRequest {
37 pub statements: Vec<String>,
39 #[serde(default)]
41 pub params: Option<WireParams>,
42}
43
44#[derive(Debug, Serialize)]
46pub struct QueryResponse {
47 pub frames: Vec<ResponseFrame>,
49}
50
51#[derive(Debug, Deserialize)]
53pub struct FormatParams {
54 pub format: Option<String>,
55 pub unwrap: Option<bool>,
56}
57
58#[derive(Debug, Serialize)]
60pub struct HealthResponse {
61 pub status: &'static str,
62}
63
64pub async fn health() -> impl IntoResponse {
75 (
76 StatusCode::OK,
77 Json(HealthResponse {
78 status: "ok",
79 }),
80 )
81}
82
83#[derive(Debug, Serialize)]
85pub struct LogoutResponse {
86 pub status: String,
87}
88
89#[derive(Debug, Deserialize)]
91pub struct AuthenticateRequest {
92 pub method: String,
94 #[serde(default)]
96 pub credentials: HashMap<String, String>,
97}
98
99#[derive(Debug, Serialize)]
101pub struct AuthenticateResponse {
102 pub status: String,
104 #[serde(skip_serializing_if = "Option::is_none")]
106 pub token: Option<String>,
107 #[serde(skip_serializing_if = "Option::is_none")]
109 pub identity: Option<String>,
110 #[serde(skip_serializing_if = "Option::is_none")]
112 pub challenge_id: Option<String>,
113 #[serde(skip_serializing_if = "Option::is_none")]
115 pub payload: Option<HashMap<String, String>>,
116 #[serde(skip_serializing_if = "Option::is_none")]
118 pub reason: Option<String>,
119}
120
121pub async fn handle_authenticate(
122 State(state): State<AppState>,
123 Json(request): Json<AuthenticateRequest>,
124) -> Result<Response, AppError> {
125 match state.auth_service().authenticate(&request.method, request.credentials) {
126 Ok(EngineAuthResponse::Authenticated {
127 identity,
128 token,
129 }) => Ok((
130 StatusCode::OK,
131 Json(AuthenticateResponse {
132 status: "authenticated".to_string(),
133 token: Some(token),
134 identity: Some(identity.to_string()),
135 challenge_id: None,
136 payload: None,
137 reason: None,
138 }),
139 )
140 .into_response()),
141 Ok(EngineAuthResponse::Challenge {
142 challenge_id,
143 payload,
144 }) => Ok((
145 StatusCode::OK,
146 Json(AuthenticateResponse {
147 status: "challenge".to_string(),
148 token: None,
149 identity: None,
150 challenge_id: Some(challenge_id),
151 payload: Some(payload),
152 reason: None,
153 }),
154 )
155 .into_response()),
156 Ok(EngineAuthResponse::Failed {
157 reason,
158 }) => Ok((
159 StatusCode::UNAUTHORIZED,
160 Json(AuthenticateResponse {
161 status: "failed".to_string(),
162 token: None,
163 identity: None,
164 challenge_id: None,
165 payload: None,
166 reason: Some(reason),
167 }),
168 )
169 .into_response()),
170 Err(e) => Ok((
171 StatusCode::INTERNAL_SERVER_ERROR,
172 Json(AuthenticateResponse {
173 status: "failed".to_string(),
174 token: None,
175 identity: None,
176 challenge_id: None,
177 payload: None,
178 reason: Some(e.to_string()),
179 }),
180 )
181 .into_response()),
182 }
183}
184
185pub async fn handle_logout(State(state): State<AppState>, headers: HeaderMap) -> Result<Response, AppError> {
186 let auth_header = headers.get("authorization").ok_or(AppError::Auth(AuthError::MissingCredentials))?;
187 let auth_str = auth_header.to_str().map_err(|_| AppError::Auth(AuthError::InvalidHeader))?;
188 let token = auth_str.strip_prefix("Bearer ").ok_or(AppError::Auth(AuthError::InvalidHeader))?.trim();
189
190 if token.is_empty() {
191 return Err(AppError::Auth(AuthError::InvalidToken));
192 }
193
194 let revoked = state.auth_service().revoke_token(token);
195
196 if revoked {
197 Ok((
198 StatusCode::OK,
199 Json(LogoutResponse {
200 status: "ok".to_string(),
201 }),
202 )
203 .into_response())
204 } else {
205 Err(AppError::Auth(AuthError::InvalidToken))
206 }
207}
208
209fn build_metadata(headers: &HeaderMap) -> RequestMetadata {
211 let mut metadata = RequestMetadata::new(Protocol::Http);
212 for (name, value) in headers.iter() {
213 if let Ok(v) = value.to_str() {
214 metadata.insert(name.as_str(), v);
215 }
216 }
217 metadata
218}
219
220pub async fn handle_query(
245 State(state): State<AppState>,
246 Query(format_params): Query<FormatParams>,
247 headers: HeaderMap,
248 Json(request): Json<StatementRequest>,
249) -> Result<Response, AppError> {
250 execute_and_respond(&state, Operation::Query, &headers, request, &format_params).await
251}
252
253pub async fn handle_admin(
264 State(state): State<AppState>,
265 Query(format_params): Query<FormatParams>,
266 headers: HeaderMap,
267 Json(request): Json<StatementRequest>,
268) -> Result<Response, AppError> {
269 execute_and_respond(&state, Operation::Admin, &headers, request, &format_params).await
270}
271
272pub async fn handle_command(
282 State(state): State<AppState>,
283 Query(format_params): Query<FormatParams>,
284 headers: HeaderMap,
285 Json(request): Json<StatementRequest>,
286) -> Result<Response, AppError> {
287 execute_and_respond(&state, Operation::Command, &headers, request, &format_params).await
288}
289
290async fn execute_and_respond(
292 state: &AppState,
293 operation: Operation,
294 headers: &HeaderMap,
295 request: StatementRequest,
296 format_params: &FormatParams,
297) -> Result<Response, AppError> {
298 let identity = extract_identity(state, headers)?;
299 let metadata = build_metadata(headers);
300 let params = match request.params {
301 None => Params::None,
302 Some(wp) => wp.into_params().map_err(|e| AppError::InvalidParams(e))?,
303 };
304
305 let ctx = RequestContext {
306 identity,
307 operation,
308 statements: request.statements,
309 params,
310 metadata,
311 };
312
313 let (frames, duration) = execute(
314 state.request_interceptors(),
315 state.actor_system(),
316 state.engine_clone(),
317 ctx,
318 state.query_timeout(),
319 state.clock(),
320 )
321 .await?;
322
323 let mut response = if format_params.format.as_deref() == Some("json") {
324 let resolved = resolve_response_json(frames, format_params.unwrap.unwrap_or(false))
325 .map_err(|e| AppError::BadRequest(e))?;
326 (StatusCode::OK, [(header::CONTENT_TYPE, resolved.content_type)], resolved.body).into_response()
327 } else {
328 Json(QueryResponse {
329 frames: convert_frames(&frames),
330 })
331 .into_response()
332 };
333 response.headers_mut().insert("x-duration-ms", duration.as_millis().to_string().parse().unwrap());
334 Ok(response)
335}
336
337fn extract_identity(state: &AppState, headers: &HeaderMap) -> Result<IdentityId, AppError> {
343 if let Some(auth_header) = headers.get("authorization") {
345 let auth_str = auth_header.to_str().map_err(|_| AppError::Auth(AuthError::InvalidHeader))?;
346
347 return extract_identity_from_auth_header(state.auth_service(), auth_str).map_err(AppError::Auth);
348 }
349
350 Ok(IdentityId::anonymous())
352}
353
354#[cfg(test)]
355pub mod tests {
356 use serde_json::{from_str, to_string};
357
358 use super::*;
359
360 #[test]
361 fn test_statement_request_deserialization() {
362 let json = r#"{"statements": ["SELECT 1"]}"#;
363 let request: StatementRequest = from_str(json).unwrap();
364 assert_eq!(request.statements, vec!["SELECT 1"]);
365 assert!(request.params.is_none());
366 }
367
368 #[test]
369 fn test_query_response_serialization() {
370 let response = QueryResponse {
371 frames: Vec::new(),
372 };
373 let json = to_string(&response).unwrap();
374 assert!(json.contains("frames"));
375 }
376
377 #[test]
378 fn test_health_response_serialization() {
379 let response = HealthResponse {
380 status: "ok",
381 };
382 let json = to_string(&response).unwrap();
383 assert_eq!(json, r#"{"status":"ok"}"#);
384 }
385}