1use std::collections::HashMap;
12
13use axum::{
14 Json,
15 extract::{Query, State},
16 http::{HeaderMap, StatusCode, header},
17 response::{IntoResponse, Response},
18};
19use reifydb_core::{
20 actors::server::{Operation, ServerAuthResponse, ServerLogoutResponse, ServerMessage},
21 value::frame::response::{ResponseFrame, convert_frames},
22};
23use reifydb_runtime::actor::reply::reply_channel;
24use reifydb_sub_server::{
25 auth::{AuthError, extract_identity_from_auth_header},
26 dispatch::dispatch,
27 interceptor::{Protocol, RequestContext, RequestMetadata},
28 response::resolve_response_json,
29 wire::WireParams,
30};
31use reifydb_type::{params::Params, value::identity::IdentityId};
32use serde::{Deserialize, Serialize};
33
34use crate::{error::AppError, state::HttpServerState};
35
36#[derive(Debug, Deserialize)]
38pub struct StatementRequest {
39 pub statements: Vec<String>,
41 #[serde(default)]
43 pub params: Option<WireParams>,
44}
45
46#[derive(Debug, Serialize)]
48pub struct QueryResponse {
49 pub frames: Vec<ResponseFrame>,
51}
52
53#[derive(Debug, Deserialize)]
55pub struct FormatParams {
56 pub format: Option<String>,
57 pub unwrap: Option<bool>,
58}
59
60#[derive(Debug, Serialize)]
62pub struct HealthResponse {
63 pub status: &'static str,
64}
65
66pub async fn health() -> impl IntoResponse {
77 (
78 StatusCode::OK,
79 Json(HealthResponse {
80 status: "ok",
81 }),
82 )
83}
84
85#[derive(Debug, Serialize)]
87pub struct LogoutResponse {
88 pub status: String,
89}
90
91#[derive(Debug, Deserialize)]
93pub struct AuthenticateRequest {
94 pub method: String,
96 #[serde(default)]
98 pub credentials: HashMap<String, String>,
99}
100
101#[derive(Debug, Serialize)]
103pub struct AuthenticateResponse {
104 pub status: String,
106 #[serde(skip_serializing_if = "Option::is_none")]
108 pub token: Option<String>,
109 #[serde(skip_serializing_if = "Option::is_none")]
111 pub identity: Option<String>,
112 #[serde(skip_serializing_if = "Option::is_none")]
114 pub challenge_id: Option<String>,
115 #[serde(skip_serializing_if = "Option::is_none")]
117 pub payload: Option<HashMap<String, String>>,
118 #[serde(skip_serializing_if = "Option::is_none")]
120 pub reason: Option<String>,
121}
122
123pub async fn handle_authenticate(
124 State(state): State<HttpServerState>,
125 Json(request): Json<AuthenticateRequest>,
126) -> Result<Response, AppError> {
127 let (reply, receiver) = reply_channel();
128 let (actor_ref, _handle) = state.spawn_actor();
129 actor_ref
130 .send(ServerMessage::Authenticate {
131 method: request.method,
132 credentials: request.credentials,
133 reply,
134 })
135 .ok()
136 .ok_or_else(|| AppError::Internal("actor mailbox closed".into()))?;
137
138 let auth_response = receiver.recv().await.map_err(|_| AppError::Internal("actor stopped".into()))?;
139
140 match auth_response {
141 ServerAuthResponse::Authenticated {
142 identity,
143 token,
144 } => Ok((
145 StatusCode::OK,
146 Json(AuthenticateResponse {
147 status: "authenticated".to_string(),
148 token: Some(token),
149 identity: Some(identity.to_string()),
150 challenge_id: None,
151 payload: None,
152 reason: None,
153 }),
154 )
155 .into_response()),
156 ServerAuthResponse::Challenge {
157 challenge_id,
158 payload,
159 } => Ok((
160 StatusCode::OK,
161 Json(AuthenticateResponse {
162 status: "challenge".to_string(),
163 token: None,
164 identity: None,
165 challenge_id: Some(challenge_id),
166 payload: Some(payload),
167 reason: None,
168 }),
169 )
170 .into_response()),
171 ServerAuthResponse::Failed {
172 reason,
173 } => Ok((
174 StatusCode::UNAUTHORIZED,
175 Json(AuthenticateResponse {
176 status: "failed".to_string(),
177 token: None,
178 identity: None,
179 challenge_id: None,
180 payload: None,
181 reason: Some(reason),
182 }),
183 )
184 .into_response()),
185 ServerAuthResponse::Error(reason) => Ok((
186 StatusCode::INTERNAL_SERVER_ERROR,
187 Json(AuthenticateResponse {
188 status: "failed".to_string(),
189 token: None,
190 identity: None,
191 challenge_id: None,
192 payload: None,
193 reason: Some(reason),
194 }),
195 )
196 .into_response()),
197 }
198}
199
200pub async fn handle_logout(State(state): State<HttpServerState>, headers: HeaderMap) -> Result<Response, AppError> {
201 let auth_header = headers.get("authorization").ok_or(AppError::Auth(AuthError::MissingCredentials))?;
202 let auth_str = auth_header.to_str().map_err(|_| AppError::Auth(AuthError::InvalidHeader))?;
203 let token = auth_str.strip_prefix("Bearer ").ok_or(AppError::Auth(AuthError::InvalidHeader))?.trim();
204
205 if token.is_empty() {
206 return Err(AppError::Auth(AuthError::InvalidToken));
207 }
208
209 let (reply, receiver) = reply_channel();
210 let (actor_ref, _handle) = state.spawn_actor();
211 actor_ref
212 .send(ServerMessage::Logout {
213 token: token.to_string(),
214 reply,
215 })
216 .ok()
217 .ok_or_else(|| AppError::Internal("actor mailbox closed".into()))?;
218
219 let logout_response = receiver.recv().await.map_err(|_| AppError::Internal("actor stopped".into()))?;
220
221 match logout_response {
222 ServerLogoutResponse::Ok => Ok((
223 StatusCode::OK,
224 Json(LogoutResponse {
225 status: "ok".to_string(),
226 }),
227 )
228 .into_response()),
229 ServerLogoutResponse::InvalidToken => Err(AppError::Auth(AuthError::InvalidToken)),
230 ServerLogoutResponse::Error(reason) => Err(AppError::Internal(reason)),
231 }
232}
233
234fn build_metadata(headers: &HeaderMap) -> RequestMetadata {
236 let mut metadata = RequestMetadata::new(Protocol::Http);
237 for (name, value) in headers.iter() {
238 if let Ok(v) = value.to_str() {
239 metadata.insert(name.as_str(), v);
240 }
241 }
242 metadata
243}
244
245pub async fn handle_query(
247 State(state): State<HttpServerState>,
248 Query(format_params): Query<FormatParams>,
249 headers: HeaderMap,
250 Json(request): Json<StatementRequest>,
251) -> Result<Response, AppError> {
252 execute_and_respond(&state, Operation::Query, &headers, request, &format_params).await
253}
254
255pub async fn handle_admin(
257 State(state): State<HttpServerState>,
258 Query(format_params): Query<FormatParams>,
259 headers: HeaderMap,
260 Json(request): Json<StatementRequest>,
261) -> Result<Response, AppError> {
262 execute_and_respond(&state, Operation::Admin, &headers, request, &format_params).await
263}
264
265pub async fn handle_command(
267 State(state): State<HttpServerState>,
268 Query(format_params): Query<FormatParams>,
269 headers: HeaderMap,
270 Json(request): Json<StatementRequest>,
271) -> Result<Response, AppError> {
272 execute_and_respond(&state, Operation::Command, &headers, request, &format_params).await
273}
274
275async fn execute_and_respond(
281 state: &HttpServerState,
282 operation: Operation,
283 headers: &HeaderMap,
284 request: StatementRequest,
285 format_params: &FormatParams,
286) -> Result<Response, AppError> {
287 let identity = extract_identity(state, headers)?;
288 let metadata = build_metadata(headers);
289 let params = match request.params {
290 None => Params::None,
291 Some(wp) => wp.into_params().map_err(AppError::InvalidParams)?,
292 };
293 let ctx = RequestContext {
294 identity,
295 operation,
296 statements: request.statements,
297 params,
298 metadata,
299 };
300
301 let (frames, wall_duration) = dispatch(state, ctx).await?;
302
303 let mut response = if format_params.format.as_deref() == Some("json") {
305 let resolved = resolve_response_json(frames, format_params.unwrap.unwrap_or(false))
306 .map_err(AppError::BadRequest)?;
307 (StatusCode::OK, [(header::CONTENT_TYPE, resolved.content_type)], resolved.body).into_response()
308 } else {
309 Json(QueryResponse {
310 frames: convert_frames(&frames),
311 })
312 .into_response()
313 };
314 response.headers_mut().insert("x-duration-ms", wall_duration.as_millis().to_string().parse().unwrap());
315 Ok(response)
316}
317
318fn extract_identity(state: &HttpServerState, headers: &HeaderMap) -> Result<IdentityId, AppError> {
324 if let Some(auth_header) = headers.get("authorization") {
326 let auth_str = auth_header.to_str().map_err(|_| AppError::Auth(AuthError::InvalidHeader))?;
327
328 return extract_identity_from_auth_header(state.auth_service(), auth_str).map_err(AppError::Auth);
329 }
330
331 Ok(IdentityId::anonymous())
333}
334
335#[cfg(test)]
336pub mod tests {
337 use serde_json::{from_str, to_string};
338
339 use super::*;
340
341 #[test]
342 fn test_statement_request_deserialization() {
343 let json = r#"{"statements": ["SELECT 1"]}"#;
344 let request: StatementRequest = from_str(json).unwrap();
345 assert_eq!(request.statements, vec!["SELECT 1"]);
346 assert!(request.params.is_none());
347 }
348
349 #[test]
350 fn test_query_response_serialization() {
351 let response = QueryResponse {
352 frames: Vec::new(),
353 };
354 let json = to_string(&response).unwrap();
355 assert!(json.contains("frames"));
356 }
357
358 #[test]
359 fn test_health_response_serialization() {
360 let response = HealthResponse {
361 status: "ok",
362 };
363 let json = to_string(&response).unwrap();
364 assert_eq!(json, r#"{"status":"ok"}"#);
365 }
366}