1use std::collections::HashMap;
12
13use axum::{
14 Json,
15 extract::{Query, State},
16 http::{HeaderMap, HeaderValue, StatusCode, header},
17 response::{IntoResponse, Response},
18};
19use reifydb_core::actors::server::{Operation, ServerAuthResponse, ServerLogoutResponse, ServerMessage};
20use reifydb_runtime::actor::reply::reply_channel;
21use reifydb_sub_server::{
22 auth::{AuthError, extract_identity_from_auth_header},
23 dispatch::dispatch,
24 format::WireFormat,
25 interceptor::{Protocol, RequestContext, RequestMetadata},
26 response::{encode_frames_rbcf, resolve_response_json},
27 wire::WireParams,
28};
29use reifydb_type::{
30 params::Params,
31 value::{duration::Duration as ReifyDuration, identity::IdentityId},
32};
33use reifydb_wire_format::json::{to::convert_frames, types::ResponseFrame};
34use serde::{Deserialize, Serialize};
35
36use crate::{error::AppError, state::HttpServerState};
37
38const CONTENT_TYPE_RBCF: &str = "application/vnd.reifydb.rbcf";
39
40#[derive(Debug, Deserialize)]
42pub struct StatementRequest {
43 pub rql: String,
45 #[serde(default)]
47 pub params: Option<WireParams>,
48}
49
50#[derive(Debug, Serialize)]
52pub struct QueryResponse {
53 pub frames: Vec<ResponseFrame>,
55}
56
57#[derive(Debug, Deserialize)]
59pub struct FormatParams {
60 #[serde(default)]
61 pub format: WireFormat,
62 pub unwrap: Option<bool>,
63}
64
65#[derive(Debug, Serialize)]
67pub struct HealthResponse {
68 pub status: &'static str,
69}
70
71pub async fn health() -> impl IntoResponse {
82 (
83 StatusCode::OK,
84 Json(HealthResponse {
85 status: "ok",
86 }),
87 )
88}
89
90#[derive(Debug, Serialize)]
92pub struct LogoutResponse {
93 pub status: String,
94}
95
96#[derive(Debug, Deserialize)]
98pub struct AuthenticateRequest {
99 pub method: String,
101 #[serde(default)]
103 pub credentials: HashMap<String, String>,
104}
105
106#[derive(Debug, Serialize)]
108pub struct AuthenticateResponse {
109 pub status: String,
111 #[serde(skip_serializing_if = "Option::is_none")]
113 pub token: Option<String>,
114 #[serde(skip_serializing_if = "Option::is_none")]
116 pub identity: Option<String>,
117 #[serde(skip_serializing_if = "Option::is_none")]
119 pub challenge_id: Option<String>,
120 #[serde(skip_serializing_if = "Option::is_none")]
122 pub payload: Option<HashMap<String, String>>,
123 #[serde(skip_serializing_if = "Option::is_none")]
125 pub reason: Option<String>,
126}
127
128pub async fn handle_authenticate(
129 State(state): State<HttpServerState>,
130 Json(request): Json<AuthenticateRequest>,
131) -> Result<Response, AppError> {
132 let (reply, receiver) = reply_channel();
133 let (actor_ref, _handle) = state.spawn_actor();
134 actor_ref
135 .send(ServerMessage::Authenticate {
136 method: request.method,
137 credentials: request.credentials,
138 reply,
139 })
140 .ok()
141 .ok_or_else(|| AppError::Internal("actor mailbox closed".into()))?;
142
143 let auth_response = receiver.recv().await.map_err(|_| AppError::Internal("actor stopped".into()))?;
144
145 match auth_response {
146 ServerAuthResponse::Authenticated {
147 identity,
148 token,
149 } => Ok((
150 StatusCode::OK,
151 Json(AuthenticateResponse {
152 status: "authenticated".to_string(),
153 token: Some(token),
154 identity: Some(identity.to_string()),
155 challenge_id: None,
156 payload: None,
157 reason: None,
158 }),
159 )
160 .into_response()),
161 ServerAuthResponse::Challenge {
162 challenge_id,
163 payload,
164 } => Ok((
165 StatusCode::OK,
166 Json(AuthenticateResponse {
167 status: "challenge".to_string(),
168 token: None,
169 identity: None,
170 challenge_id: Some(challenge_id),
171 payload: Some(payload),
172 reason: None,
173 }),
174 )
175 .into_response()),
176 ServerAuthResponse::Failed {
177 reason,
178 } => Ok((
179 StatusCode::UNAUTHORIZED,
180 Json(AuthenticateResponse {
181 status: "failed".to_string(),
182 token: None,
183 identity: None,
184 challenge_id: None,
185 payload: None,
186 reason: Some(reason),
187 }),
188 )
189 .into_response()),
190 ServerAuthResponse::Error(reason) => Ok((
191 StatusCode::INTERNAL_SERVER_ERROR,
192 Json(AuthenticateResponse {
193 status: "failed".to_string(),
194 token: None,
195 identity: None,
196 challenge_id: None,
197 payload: None,
198 reason: Some(reason),
199 }),
200 )
201 .into_response()),
202 }
203}
204
205pub async fn handle_logout(State(state): State<HttpServerState>, headers: HeaderMap) -> Result<Response, AppError> {
206 let auth_header = headers.get("authorization").ok_or(AppError::Auth(AuthError::MissingCredentials))?;
207 let auth_str = auth_header.to_str().map_err(|_| AppError::Auth(AuthError::InvalidHeader))?;
208 let token = auth_str.strip_prefix("Bearer ").ok_or(AppError::Auth(AuthError::InvalidHeader))?.trim();
209
210 if token.is_empty() {
211 return Err(AppError::Auth(AuthError::InvalidToken));
212 }
213
214 let (reply, receiver) = reply_channel();
215 let (actor_ref, _handle) = state.spawn_actor();
216 actor_ref
217 .send(ServerMessage::Logout {
218 token: token.to_string(),
219 reply,
220 })
221 .ok()
222 .ok_or_else(|| AppError::Internal("actor mailbox closed".into()))?;
223
224 let logout_response = receiver.recv().await.map_err(|_| AppError::Internal("actor stopped".into()))?;
225
226 match logout_response {
227 ServerLogoutResponse::Ok => Ok((
228 StatusCode::OK,
229 Json(LogoutResponse {
230 status: "ok".to_string(),
231 }),
232 )
233 .into_response()),
234 ServerLogoutResponse::InvalidToken => Err(AppError::Auth(AuthError::InvalidToken)),
235 ServerLogoutResponse::Error(reason) => Err(AppError::Internal(reason)),
236 }
237}
238
239fn build_metadata(headers: &HeaderMap) -> RequestMetadata {
241 let mut metadata = RequestMetadata::new(Protocol::Http);
242 for (name, value) in headers.iter() {
243 if let Ok(v) = value.to_str() {
244 metadata.insert(name.as_str(), v);
245 }
246 }
247 metadata
248}
249
250pub async fn handle_query(
252 State(state): State<HttpServerState>,
253 Query(format_params): Query<FormatParams>,
254 headers: HeaderMap,
255 Json(request): Json<StatementRequest>,
256) -> Result<Response, AppError> {
257 execute_and_respond(&state, Operation::Query, &headers, request, &format_params).await
258}
259
260pub async fn handle_admin(
262 State(state): State<HttpServerState>,
263 Query(format_params): Query<FormatParams>,
264 headers: HeaderMap,
265 Json(request): Json<StatementRequest>,
266) -> Result<Response, AppError> {
267 execute_and_respond(&state, Operation::Admin, &headers, request, &format_params).await
268}
269
270pub async fn handle_command(
272 State(state): State<HttpServerState>,
273 Query(format_params): Query<FormatParams>,
274 headers: HeaderMap,
275 Json(request): Json<StatementRequest>,
276) -> Result<Response, AppError> {
277 execute_and_respond(&state, Operation::Command, &headers, request, &format_params).await
278}
279
280async fn execute_and_respond(
286 state: &HttpServerState,
287 operation: Operation,
288 headers: &HeaderMap,
289 request: StatementRequest,
290 format_params: &FormatParams,
291) -> Result<Response, AppError> {
292 let identity = extract_identity(state, headers)?;
293 let metadata = build_metadata(headers);
294 let params = match request.params {
295 None => Params::None,
296 Some(wp) => wp.into_params().map_err(AppError::InvalidParams)?,
297 };
298 let ctx = RequestContext {
299 identity,
300 operation,
301 rql: request.rql,
302 params,
303 metadata,
304 };
305
306 let (frames, wall_duration, metrics) = dispatch(state, ctx).await?;
307
308 let mut response = match format_params.format {
309 WireFormat::Rbcf => match encode_frames_rbcf(&frames) {
310 Ok(bytes) => (StatusCode::OK, [(header::CONTENT_TYPE, CONTENT_TYPE_RBCF.to_string())], bytes)
311 .into_response(),
312 Err(e) => return Err(AppError::BadRequest(format!("RBCF encode error: {}", e))),
313 },
314 WireFormat::Json => {
315 let resolved = resolve_response_json(frames, format_params.unwrap.unwrap_or(false))
316 .map_err(AppError::BadRequest)?;
317 (StatusCode::OK, [(header::CONTENT_TYPE, resolved.content_type)], resolved.body).into_response()
318 }
319 WireFormat::Frames => Json(QueryResponse {
320 frames: convert_frames(&frames),
321 })
322 .into_response(),
323 };
324 let duration = ReifyDuration::from_nanoseconds(wall_duration.as_nanos() as i64).unwrap_or_default();
325 response.headers_mut().insert("x-fingerprint", HeaderValue::from_str(&metrics.fingerprint.to_hex()).unwrap());
326 response.headers_mut().insert("x-duration", HeaderValue::from_str(&duration.to_string()).unwrap());
327 Ok(response)
328}
329
330fn extract_identity(state: &HttpServerState, headers: &HeaderMap) -> Result<IdentityId, AppError> {
336 if let Some(auth_header) = headers.get("authorization") {
338 let auth_str = auth_header.to_str().map_err(|_| AppError::Auth(AuthError::InvalidHeader))?;
339
340 return extract_identity_from_auth_header(state.auth_service(), auth_str).map_err(AppError::Auth);
341 }
342
343 Ok(IdentityId::anonymous())
345}
346
347#[cfg(test)]
348pub mod tests {
349 use serde_json::{from_str, to_string};
350
351 use super::*;
352
353 #[test]
354 fn test_statement_request_deserialization() {
355 let json = r#"{"rql": "SELECT 1"}"#;
356 let request: StatementRequest = from_str(json).unwrap();
357 assert_eq!(request.rql, "SELECT 1");
358 assert!(request.params.is_none());
359 }
360
361 #[test]
362 fn test_query_response_serialization() {
363 let response = QueryResponse {
364 frames: Vec::new(),
365 };
366 let json = to_string(&response).unwrap();
367 assert!(json.contains("frames"));
368 }
369
370 #[test]
371 fn test_health_response_serialization() {
372 let response = HealthResponse {
373 status: "ok",
374 };
375 let json = to_string(&response).unwrap();
376 assert_eq!(json, r#"{"status":"ok"}"#);
377 }
378}