1use axum::{
12 Json,
13 extract::{Query, State},
14 http::{HeaderMap, StatusCode, header},
15 response::{IntoResponse, Response},
16};
17use reifydb_core::value::frame::response::{ResponseFrame, convert_frames};
18use reifydb_sub_server::{
19 auth::{AuthError, extract_identity_from_api_key, extract_identity_from_auth_header},
20 execute::{execute_admin, execute_command, execute_query},
21 response::resolve_response_json,
22 state::AppState,
23 wire::WireParams,
24};
25use reifydb_type::{params::Params, value::identity::IdentityId};
26use serde::{Deserialize, Serialize};
27
28use crate::error::AppError;
29
30#[derive(Debug, Deserialize)]
32pub struct StatementRequest {
33 pub statements: Vec<String>,
35 #[serde(default)]
37 pub params: Option<WireParams>,
38}
39
40#[derive(Debug, Serialize)]
42pub struct QueryResponse {
43 pub frames: Vec<ResponseFrame>,
45}
46
47#[derive(Debug, Deserialize)]
49pub struct FormatParams {
50 pub format: Option<String>,
51 pub unwrap: Option<bool>,
52}
53
54#[derive(Debug, Serialize)]
56pub struct HealthResponse {
57 pub status: &'static str,
58}
59
60pub async fn health() -> impl IntoResponse {
71 (
72 StatusCode::OK,
73 Json(HealthResponse {
74 status: "ok",
75 }),
76 )
77}
78
79pub async fn handle_query(
104 State(state): State<AppState>,
105 Query(format_params): Query<FormatParams>,
106 headers: HeaderMap,
107 Json(request): Json<StatementRequest>,
108) -> Result<Response, AppError> {
109 let identity = extract_identity(&headers)?;
111
112 let query = request.statements.join("; ");
114
115 let params = match request.params {
117 None => Params::None,
118 Some(wp) => wp.into_params().map_err(|e| AppError::InvalidParams(e))?,
119 };
120
121 let frames = execute_query(
123 state.actor_system(),
124 state.engine_clone(),
125 query,
126 identity,
127 params,
128 state.query_timeout(),
129 )
130 .await?;
131
132 if format_params.format.as_deref() == Some("json") {
133 let resolved = resolve_response_json(frames, format_params.unwrap.unwrap_or(false))
134 .map_err(|e| AppError::BadRequest(e))?;
135 Ok((StatusCode::OK, [(header::CONTENT_TYPE, resolved.content_type)], resolved.body).into_response())
136 } else {
137 Ok(Json(QueryResponse {
138 frames: convert_frames(&frames),
139 })
140 .into_response())
141 }
142}
143
144pub async fn handle_admin(
155 State(state): State<AppState>,
156 Query(format_params): Query<FormatParams>,
157 headers: HeaderMap,
158 Json(request): Json<StatementRequest>,
159) -> Result<Response, AppError> {
160 let identity = extract_identity(&headers)?;
162
163 let params = match request.params {
165 None => Params::None,
166 Some(wp) => wp.into_params().map_err(|e| AppError::InvalidParams(e))?,
167 };
168
169 let frames = execute_admin(
171 state.actor_system(),
172 state.engine_clone(),
173 request.statements,
174 identity,
175 params,
176 state.query_timeout(),
177 )
178 .await?;
179
180 if format_params.format.as_deref() == Some("json") {
181 let resolved = resolve_response_json(frames, format_params.unwrap.unwrap_or(false))
182 .map_err(|e| AppError::BadRequest(e))?;
183 Ok((StatusCode::OK, [(header::CONTENT_TYPE, resolved.content_type)], resolved.body).into_response())
184 } else {
185 Ok(Json(QueryResponse {
186 frames: convert_frames(&frames),
187 })
188 .into_response())
189 }
190}
191
192pub async fn handle_command(
219 State(state): State<AppState>,
220 Query(format_params): Query<FormatParams>,
221 headers: HeaderMap,
222 Json(request): Json<StatementRequest>,
223) -> Result<Response, AppError> {
224 let identity = extract_identity(&headers)?;
226
227 let params = match request.params {
229 None => Params::None,
230 Some(wp) => wp.into_params().map_err(|e| AppError::InvalidParams(e))?,
231 };
232
233 let frames = execute_command(
235 state.actor_system(),
236 state.engine_clone(),
237 request.statements,
238 identity,
239 params,
240 state.query_timeout(),
241 )
242 .await?;
243
244 if format_params.format.as_deref() == Some("json") {
245 let resolved = resolve_response_json(frames, format_params.unwrap.unwrap_or(false))
246 .map_err(|e| AppError::BadRequest(e))?;
247 Ok((StatusCode::OK, [(header::CONTENT_TYPE, resolved.content_type)], resolved.body).into_response())
248 } else {
249 Ok(Json(QueryResponse {
250 frames: convert_frames(&frames),
251 })
252 .into_response())
253 }
254}
255
256fn extract_identity(headers: &HeaderMap) -> Result<IdentityId, AppError> {
262 if let Some(auth_header) = headers.get("authorization") {
264 let auth_str = auth_header.to_str().map_err(|_| AppError::Auth(AuthError::InvalidHeader))?;
265
266 return extract_identity_from_auth_header(auth_str).map_err(AppError::Auth);
267 }
268
269 if let Some(api_key) = headers.get("x-api-key") {
271 let key = api_key.to_str().map_err(|_| AppError::Auth(AuthError::InvalidHeader))?;
272
273 return extract_identity_from_api_key(key).map_err(AppError::Auth);
274 }
275
276 Err(AppError::Auth(AuthError::MissingCredentials))
278}
279
280#[cfg(test)]
281pub mod tests {
282 use serde_json::{from_str, to_string};
283
284 use super::*;
285
286 #[test]
287 fn test_statement_request_deserialization() {
288 let json = r#"{"statements": ["SELECT 1"]}"#;
289 let request: StatementRequest = from_str(json).unwrap();
290 assert_eq!(request.statements, vec!["SELECT 1"]);
291 assert!(request.params.is_none());
292 }
293
294 #[test]
295 fn test_query_response_serialization() {
296 let response = QueryResponse {
297 frames: Vec::new(),
298 };
299 let json = to_string(&response).unwrap();
300 assert!(json.contains("frames"));
301 }
302
303 #[test]
304 fn test_health_response_serialization() {
305 let response = HealthResponse {
306 status: "ok",
307 };
308 let json = to_string(&response).unwrap();
309 assert_eq!(json, r#"{"status":"ok"}"#);
310 }
311}