1use std::{collections::HashMap, sync::Arc};
12
13use axum::{
14 Json,
15 extract::{Path, Query, State},
16 http::{HeaderMap, HeaderValue, Method, StatusCode, header},
17 response::{IntoResponse, Response},
18};
19use reifydb_core::{
20 actors::server::{Operation, ServerAuthResponse, ServerLogoutResponse, ServerMessage},
21 interface::catalog::binding::{Binding, BindingFormat, BindingProtocol, HttpMethod},
22 metric::ExecutionMetrics,
23};
24use reifydb_runtime::actor::reply::reply_channel;
25use reifydb_sub_server::{
26 auth::{AuthError, extract_identity_from_auth_header},
27 binding::dispatch_binding,
28 dispatch::dispatch,
29 format::WireFormat,
30 interceptor::{Protocol, RequestContext, RequestMetadata},
31 response::{CONTENT_TYPE_FRAMES, CONTENT_TYPE_RBCF, encode_frames_rbcf, resolve_response_json},
32 wire::WireParams,
33};
34use reifydb_type::{
35 params::Params,
36 value::{Value, frame::frame::Frame, identity::IdentityId, r#type::Type},
37};
38use reifydb_wire_format::json::{to::convert_frames, types::ResponseFrame};
39use serde::{Deserialize, Serialize};
40use serde_json::to_string;
41
42use crate::{error::AppError, state::HttpServerState};
43
44#[derive(Debug, Deserialize)]
46pub struct StatementRequest {
47 pub rql: String,
49 #[serde(default)]
51 pub params: Option<WireParams>,
52}
53
54#[derive(Debug, Serialize)]
56pub struct QueryResponse {
57 pub frames: Vec<ResponseFrame>,
59}
60
61#[derive(Debug, Deserialize)]
63pub struct FormatParams {
64 #[serde(default)]
65 pub format: WireFormat,
66 pub unwrap: Option<bool>,
67}
68
69#[derive(Debug, Serialize)]
71pub struct HealthResponse {
72 pub status: &'static str,
73}
74
75pub async fn health() -> impl IntoResponse {
86 (
87 StatusCode::OK,
88 Json(HealthResponse {
89 status: "ok",
90 }),
91 )
92}
93
94#[derive(Debug, Serialize)]
96pub struct LogoutResponse {
97 pub status: String,
98}
99
100#[derive(Debug, Deserialize)]
102pub struct AuthenticateRequest {
103 pub method: String,
105 #[serde(default)]
107 pub credentials: HashMap<String, String>,
108}
109
110#[derive(Debug, Serialize)]
112pub struct AuthenticateResponse {
113 pub status: String,
115 #[serde(skip_serializing_if = "Option::is_none")]
117 pub token: Option<String>,
118 #[serde(skip_serializing_if = "Option::is_none")]
120 pub identity: Option<String>,
121 #[serde(skip_serializing_if = "Option::is_none")]
123 pub challenge_id: Option<String>,
124 #[serde(skip_serializing_if = "Option::is_none")]
126 pub payload: Option<HashMap<String, String>>,
127 #[serde(skip_serializing_if = "Option::is_none")]
129 pub reason: Option<String>,
130}
131
132pub async fn handle_authenticate(
133 State(state): State<HttpServerState>,
134 Json(request): Json<AuthenticateRequest>,
135) -> Result<Response, AppError> {
136 let (reply, receiver) = reply_channel();
137 let (actor_ref, _handle) = state.spawn_actor();
138 actor_ref
139 .send(ServerMessage::Authenticate {
140 method: request.method,
141 credentials: request.credentials,
142 reply,
143 })
144 .ok()
145 .ok_or_else(|| AppError::Internal("actor mailbox closed".into()))?;
146
147 let auth_response = receiver.recv().await.map_err(|_| AppError::Internal("actor stopped".into()))?;
148
149 match auth_response {
150 ServerAuthResponse::Authenticated {
151 identity,
152 token,
153 } => Ok((
154 StatusCode::OK,
155 Json(AuthenticateResponse {
156 status: "authenticated".to_string(),
157 token: Some(token),
158 identity: Some(identity.to_string()),
159 challenge_id: None,
160 payload: None,
161 reason: None,
162 }),
163 )
164 .into_response()),
165 ServerAuthResponse::Challenge {
166 challenge_id,
167 payload,
168 } => Ok((
169 StatusCode::OK,
170 Json(AuthenticateResponse {
171 status: "challenge".to_string(),
172 token: None,
173 identity: None,
174 challenge_id: Some(challenge_id),
175 payload: Some(payload),
176 reason: None,
177 }),
178 )
179 .into_response()),
180 ServerAuthResponse::Failed {
181 reason,
182 } => Ok((
183 StatusCode::UNAUTHORIZED,
184 Json(AuthenticateResponse {
185 status: "failed".to_string(),
186 token: None,
187 identity: None,
188 challenge_id: None,
189 payload: None,
190 reason: Some(reason),
191 }),
192 )
193 .into_response()),
194 ServerAuthResponse::Error(reason) => Ok((
195 StatusCode::INTERNAL_SERVER_ERROR,
196 Json(AuthenticateResponse {
197 status: "failed".to_string(),
198 token: None,
199 identity: None,
200 challenge_id: None,
201 payload: None,
202 reason: Some(reason),
203 }),
204 )
205 .into_response()),
206 }
207}
208
209pub async fn handle_logout(State(state): State<HttpServerState>, headers: HeaderMap) -> Result<Response, AppError> {
210 let auth_header = headers.get("authorization").ok_or(AppError::Auth(AuthError::MissingCredentials))?;
211 let auth_str = auth_header.to_str().map_err(|_| AppError::Auth(AuthError::InvalidHeader))?;
212 let token = auth_str.strip_prefix("Bearer ").ok_or(AppError::Auth(AuthError::InvalidHeader))?.trim();
213
214 if token.is_empty() {
215 return Err(AppError::Auth(AuthError::InvalidToken));
216 }
217
218 let (reply, receiver) = reply_channel();
219 let (actor_ref, _handle) = state.spawn_actor();
220 actor_ref
221 .send(ServerMessage::Logout {
222 token: token.to_string(),
223 reply,
224 })
225 .ok()
226 .ok_or_else(|| AppError::Internal("actor mailbox closed".into()))?;
227
228 let logout_response = receiver.recv().await.map_err(|_| AppError::Internal("actor stopped".into()))?;
229
230 match logout_response {
231 ServerLogoutResponse::Ok => Ok((
232 StatusCode::OK,
233 Json(LogoutResponse {
234 status: "ok".to_string(),
235 }),
236 )
237 .into_response()),
238 ServerLogoutResponse::InvalidToken => Err(AppError::Auth(AuthError::InvalidToken)),
239 ServerLogoutResponse::Error(reason) => Err(AppError::Internal(reason)),
240 }
241}
242
243fn build_metadata(headers: &HeaderMap) -> RequestMetadata {
245 let mut metadata = RequestMetadata::new(Protocol::Http);
246 for (name, value) in headers.iter() {
247 if let Ok(v) = value.to_str() {
248 metadata.insert(name.as_str(), v);
249 }
250 }
251 metadata
252}
253
254pub async fn handle_query(
256 State(state): State<HttpServerState>,
257 Query(format_params): Query<FormatParams>,
258 headers: HeaderMap,
259 Json(request): Json<StatementRequest>,
260) -> Result<Response, AppError> {
261 execute_and_respond(&state, Operation::Query, &headers, request, &format_params).await
262}
263
264pub async fn handle_admin(
266 State(state): State<HttpServerState>,
267 Query(format_params): Query<FormatParams>,
268 headers: HeaderMap,
269 Json(request): Json<StatementRequest>,
270) -> Result<Response, AppError> {
271 execute_and_respond(&state, Operation::Admin, &headers, request, &format_params).await
272}
273
274pub async fn handle_command(
276 State(state): State<HttpServerState>,
277 Query(format_params): Query<FormatParams>,
278 headers: HeaderMap,
279 Json(request): Json<StatementRequest>,
280) -> Result<Response, AppError> {
281 execute_and_respond(&state, Operation::Command, &headers, request, &format_params).await
282}
283
284async fn execute_and_respond(
290 state: &HttpServerState,
291 operation: Operation,
292 headers: &HeaderMap,
293 request: StatementRequest,
294 format_params: &FormatParams,
295) -> Result<Response, AppError> {
296 let identity = extract_identity(state, headers)?;
297 let metadata = build_metadata(headers);
298 let params = match request.params {
299 None => Params::None,
300 Some(wp) => wp.into_params().map_err(AppError::InvalidParams)?,
301 };
302 let ctx = RequestContext {
303 identity,
304 operation,
305 rql: request.rql,
306 params,
307 metadata,
308 };
309
310 let (frames, metrics) = dispatch(state, ctx).await?;
311
312 let mut response = match format_params.format {
313 WireFormat::Rbcf => match encode_frames_rbcf(&frames) {
314 Ok(bytes) => (StatusCode::OK, [(header::CONTENT_TYPE, CONTENT_TYPE_RBCF.to_string())], bytes)
315 .into_response(),
316 Err(e) => return Err(AppError::BadRequest(format!("RBCF encode error: {}", e))),
317 },
318 WireFormat::Json => {
319 let resolved = resolve_response_json(frames, format_params.unwrap.unwrap_or(false))
320 .map_err(AppError::BadRequest)?;
321 (StatusCode::OK, [(header::CONTENT_TYPE, resolved.content_type)], resolved.body).into_response()
322 }
323 WireFormat::Frames => {
324 let body = to_string(&QueryResponse {
325 frames: convert_frames(&frames),
326 })
327 .map_err(|e| AppError::BadRequest(format!("JSON encode error: {}", e)))?;
328 (StatusCode::OK, [(header::CONTENT_TYPE, CONTENT_TYPE_FRAMES.to_string())], body)
329 .into_response()
330 }
331 };
332 insert_meta_headers(response.headers_mut(), &metrics);
333 Ok(response)
334}
335
336fn extract_identity(state: &HttpServerState, headers: &HeaderMap) -> Result<IdentityId, AppError> {
342 if let Some(auth_header) = headers.get("authorization") {
344 let auth_str = auth_header.to_str().map_err(|_| AppError::Auth(AuthError::InvalidHeader))?;
345
346 return extract_identity_from_auth_header(state.auth_service(), auth_str).map_err(AppError::Auth);
347 }
348
349 Ok(IdentityId::anonymous())
351}
352
353pub async fn handle_binding(
360 State(state): State<HttpServerState>,
361 Path(path): Path<String>,
362 method: Method,
363 Query(query_params): Query<HashMap<String, String>>,
364 headers: HeaderMap,
365) -> Result<Response, AppError> {
366 let http_method = match method.as_str() {
367 "GET" => HttpMethod::Get,
368 "POST" => HttpMethod::Post,
369 "PUT" => HttpMethod::Put,
370 "PATCH" => HttpMethod::Patch,
371 "DELETE" => HttpMethod::Delete,
372 _ => return Err(AppError::MethodNotAllowed(format!("method `{}` is not supported", method))),
373 };
374 let request_path = format!("/{}", path);
375
376 let bindings = state.engine().materialized_catalog().list_http_bindings();
378 let mut any_path_match = false;
379 let mut matched: Option<(Binding, HashMap<String, String>)> = None;
380 for b in &bindings {
381 let BindingProtocol::Http {
382 method: binding_method,
383 path: binding_path,
384 } = &b.protocol
385 else {
386 unreachable!("list_http_bindings returns only HTTP bindings")
387 };
388 if let Some(captures) = match_http_path(binding_path, &request_path) {
389 any_path_match = true;
390 if binding_method == &http_method {
391 matched = Some((b.clone(), captures));
392 break;
393 }
394 }
395 }
396 let (binding, path_captures) = match matched {
397 Some(m) => m,
398 None if any_path_match => {
399 return Err(AppError::MethodNotAllowed(format!(
400 "no binding for method `{}` at `{}`",
401 method, request_path
402 )));
403 }
404 None => return Err(AppError::NotFound(format!("no binding for `{}`", request_path))),
405 };
406
407 let procedure =
409 state.engine().materialized_catalog().find_procedure(binding.procedure_id).ok_or_else(|| {
410 AppError::Internal(format!(
411 "binding references missing procedure id {:?}",
412 binding.procedure_id
413 ))
414 })?;
415 let namespace = state.engine().materialized_catalog().find_namespace(binding.namespace).ok_or_else(|| {
416 AppError::Internal(format!("binding references missing namespace id {:?}", binding.namespace))
417 })?;
418
419 let param_names: Vec<&str> = procedure.params().iter().map(|p| p.name.as_str()).collect();
420 for key in query_params.keys() {
421 if !param_names.contains(&key.as_str()) {
422 return Err(AppError::BadRequest(format!("unknown parameter `{}`", key)));
423 }
424 if path_captures.contains_key(key) {
425 return Err(AppError::BadRequest(format!("parameter `{}` given in both path and query", key)));
426 }
427 }
428
429 let mut params: HashMap<String, Value> = HashMap::with_capacity(procedure.params().len());
430 for p in procedure.params() {
431 let raw = match path_captures.get(&p.name).or_else(|| query_params.get(&p.name)) {
432 Some(v) => v,
433 None => {
434 return Err(AppError::BadRequest(format!("missing required parameter `{}`", p.name)));
435 }
436 };
437 let value = coerce_str_to_value(raw, p.param_type.get_type()).map_err(|e| {
438 AppError::BadRequest(format!(
439 "parameter `{}`: cannot coerce `{}` to {:?}: {}",
440 p.name,
441 raw,
442 p.param_type.get_type(),
443 e
444 ))
445 })?;
446 params.insert(p.name.clone(), value);
447 }
448 let params = if params.is_empty() {
449 Params::None
450 } else {
451 Params::Named(Arc::new(params))
452 };
453
454 let identity = extract_identity(&state, &headers)?;
455 let metadata = build_metadata(&headers);
456
457 let (frames, metrics) =
458 dispatch_binding(&state, namespace.name(), procedure.name(), params, identity, metadata).await?;
459
460 let mut response = encode_binding_response(frames, binding.format)?;
461 insert_meta_headers(response.headers_mut(), &metrics);
462 Ok(response)
463}
464
465fn insert_meta_headers(headers: &mut HeaderMap, metrics: &ExecutionMetrics) {
466 headers.insert("x-fingerprint", HeaderValue::from_str(&metrics.fingerprint.to_hex()).unwrap());
467 headers.insert("x-duration", HeaderValue::from_str(&metrics.total.to_string()).unwrap());
468}
469
470fn match_http_path(template: &str, request: &str) -> Option<HashMap<String, String>> {
473 let t_segments: Vec<&str> = template.split('/').filter(|s| !s.is_empty()).collect();
474 let r_segments: Vec<&str> = request.split('/').filter(|s| !s.is_empty()).collect();
475 if t_segments.len() != r_segments.len() {
476 return None;
477 }
478 let mut captures = HashMap::new();
479 for (t, r) in t_segments.iter().zip(r_segments.iter()) {
480 if t.starts_with('{') && t.ends_with('}') {
481 let var = &t[1..t.len() - 1];
482 captures.insert(var.to_string(), r.to_string());
483 } else if t != r {
484 return None;
485 }
486 }
487 Some(captures)
488}
489
490fn coerce_str_to_value(s: &str, ty: Type) -> Result<Value, String> {
491 match ty {
492 Type::Boolean => match s {
493 "true" | "1" => Ok(Value::Boolean(true)),
494 "false" | "0" => Ok(Value::Boolean(false)),
495 _ => Err("expected `true`/`false`".into()),
496 },
497 Type::Utf8 => Ok(Value::Utf8(s.to_string())),
498 Type::Int1 => s.parse::<i8>().map(Value::Int1).map_err(|e| e.to_string()),
499 Type::Int2 => s.parse::<i16>().map(Value::Int2).map_err(|e| e.to_string()),
500 Type::Int4 => s.parse::<i32>().map(Value::Int4).map_err(|e| e.to_string()),
501 Type::Int8 => s.parse::<i64>().map(Value::Int8).map_err(|e| e.to_string()),
502 Type::Int16 => s.parse::<i128>().map(Value::Int16).map_err(|e| e.to_string()),
503 Type::Uint1 => s.parse::<u8>().map(Value::Uint1).map_err(|e| e.to_string()),
504 Type::Uint2 => s.parse::<u16>().map(Value::Uint2).map_err(|e| e.to_string()),
505 Type::Uint4 => s.parse::<u32>().map(Value::Uint4).map_err(|e| e.to_string()),
506 Type::Uint8 => s.parse::<u64>().map(Value::Uint8).map_err(|e| e.to_string()),
507 Type::Uint16 => s.parse::<u128>().map(Value::Uint16).map_err(|e| e.to_string()),
508 Type::Float4 => s
509 .parse::<f32>()
510 .map_err(|e| e.to_string())
511 .and_then(|v| v.try_into().map(Value::Float4).map_err(|_| "invalid f32".to_string())),
512 Type::Float8 => s
513 .parse::<f64>()
514 .map_err(|e| e.to_string())
515 .and_then(|v| v.try_into().map(Value::Float8).map_err(|_| "invalid f64".to_string())),
516 other => Err(format!("coercion to {:?} not supported from URL strings", other)),
517 }
518}
519
520fn encode_binding_response(frames: Vec<Frame>, format: BindingFormat) -> Result<Response, AppError> {
521 match format {
522 BindingFormat::Rbcf => match encode_frames_rbcf(&frames) {
523 Ok(bytes) => {
524 Ok((StatusCode::OK, [(header::CONTENT_TYPE, CONTENT_TYPE_RBCF.to_string())], bytes)
525 .into_response())
526 }
527 Err(e) => Err(AppError::BadRequest(format!("RBCF encode error: {}", e))),
528 },
529 BindingFormat::Json => {
530 let resolved = resolve_response_json(frames, false).map_err(AppError::BadRequest)?;
531 Ok((StatusCode::OK, [(header::CONTENT_TYPE, resolved.content_type)], resolved.body)
532 .into_response())
533 }
534 BindingFormat::Frames => Ok(Json(QueryResponse {
535 frames: convert_frames(&frames),
536 })
537 .into_response()),
538 }
539}
540
541#[cfg(test)]
542pub mod tests {
543 use serde_json::from_str;
544
545 use super::*;
546
547 #[test]
548 fn test_match_http_path_static() {
549 assert_eq!(match_http_path("/users", "/users"), Some(HashMap::new()));
550 assert_eq!(match_http_path("/users", "/other"), None);
551 }
552
553 #[test]
554 fn test_match_http_path_capture() {
555 let caps = match_http_path("/users/{id}", "/users/42").unwrap();
556 assert_eq!(caps.get("id"), Some(&"42".to_string()));
557 }
558
559 #[test]
560 fn test_match_http_path_mismatch_length() {
561 assert!(match_http_path("/users/{id}", "/users").is_none());
562 assert!(match_http_path("/users/{id}", "/users/42/extra").is_none());
563 }
564
565 #[test]
566 fn test_coerce_numeric() {
567 assert_eq!(coerce_str_to_value("42", Type::Int8).unwrap(), Value::Int8(42));
568 assert!(coerce_str_to_value("xx", Type::Int8).is_err());
569 }
570
571 #[test]
572 fn test_coerce_bool() {
573 assert_eq!(coerce_str_to_value("true", Type::Boolean).unwrap(), Value::Boolean(true));
574 assert!(coerce_str_to_value("maybe", Type::Boolean).is_err());
575 }
576
577 #[test]
578 fn test_statement_request_deserialization() {
579 let json = r#"{"rql": "SELECT 1"}"#;
580 let request: StatementRequest = from_str(json).unwrap();
581 assert_eq!(request.rql, "SELECT 1");
582 assert!(request.params.is_none());
583 }
584
585 #[test]
586 fn test_query_response_serialization() {
587 let response = QueryResponse {
588 frames: Vec::new(),
589 };
590 let json = to_string(&response).unwrap();
591 assert!(json.contains("frames"));
592 }
593
594 #[test]
595 fn test_health_response_serialization() {
596 let response = HealthResponse {
597 status: "ok",
598 };
599 let json = to_string(&response).unwrap();
600 assert_eq!(json, r#"{"status":"ok"}"#);
601 }
602}