1use axum::{
12 Json,
13 extract::{Query, State},
14 http::{HeaderMap, StatusCode, header},
15 response::{IntoResponse, Response},
16};
17use reifydb_sub_server::{
18 auth::{AuthError, extract_identity_from_api_key, extract_identity_from_auth_header},
19 execute::{execute_admin, execute_command, execute_query},
20 response::{ResponseFrame, convert_frames, resolve_response_json},
21 state::AppState,
22 wire::WireParams,
23};
24use reifydb_type::{params::Params, value::identity::IdentityId};
25use serde::{Deserialize, Serialize};
26
27use crate::error::AppError;
28
29#[derive(Debug, Deserialize)]
31pub struct StatementRequest {
32 pub statements: Vec<String>,
34 #[serde(default)]
36 pub params: Option<WireParams>,
37}
38
39#[derive(Debug, Serialize)]
41pub struct QueryResponse {
42 pub frames: Vec<ResponseFrame>,
44}
45
46#[derive(Debug, Deserialize)]
48pub struct FormatParams {
49 pub format: Option<String>,
50 pub unwrap: Option<bool>,
51}
52
53#[derive(Debug, Serialize)]
55pub struct HealthResponse {
56 pub status: &'static str,
57}
58
59pub async fn health() -> impl IntoResponse {
70 (
71 StatusCode::OK,
72 Json(HealthResponse {
73 status: "ok",
74 }),
75 )
76}
77
78pub async fn handle_query(
103 State(state): State<AppState>,
104 Query(format_params): Query<FormatParams>,
105 headers: HeaderMap,
106 Json(request): Json<StatementRequest>,
107) -> Result<Response, AppError> {
108 let identity = extract_identity(&headers)?;
110
111 let query = request.statements.join("; ");
113
114 let params = match request.params {
116 None => Params::None,
117 Some(wp) => wp.into_params().map_err(|e| AppError::InvalidParams(e))?,
118 };
119
120 let frames = execute_query(
122 state.actor_system(),
123 state.engine_clone(),
124 query,
125 identity,
126 params,
127 state.query_timeout(),
128 )
129 .await?;
130
131 if format_params.format.as_deref() == Some("json") {
132 let resolved = resolve_response_json(frames, format_params.unwrap.unwrap_or(false))
133 .map_err(|e| AppError::BadRequest(e))?;
134 Ok((StatusCode::OK, [(header::CONTENT_TYPE, resolved.content_type)], resolved.body).into_response())
135 } else {
136 Ok(Json(QueryResponse {
137 frames: convert_frames(frames),
138 })
139 .into_response())
140 }
141}
142
143pub async fn handle_admin(
154 State(state): State<AppState>,
155 Query(format_params): Query<FormatParams>,
156 headers: HeaderMap,
157 Json(request): Json<StatementRequest>,
158) -> Result<Response, AppError> {
159 let identity = extract_identity(&headers)?;
161
162 let params = match request.params {
164 None => Params::None,
165 Some(wp) => wp.into_params().map_err(|e| AppError::InvalidParams(e))?,
166 };
167
168 let frames = execute_admin(
170 state.actor_system(),
171 state.engine_clone(),
172 request.statements,
173 identity,
174 params,
175 state.query_timeout(),
176 )
177 .await?;
178
179 if format_params.format.as_deref() == Some("json") {
180 let resolved = resolve_response_json(frames, format_params.unwrap.unwrap_or(false))
181 .map_err(|e| AppError::BadRequest(e))?;
182 Ok((StatusCode::OK, [(header::CONTENT_TYPE, resolved.content_type)], resolved.body).into_response())
183 } else {
184 Ok(Json(QueryResponse {
185 frames: convert_frames(frames),
186 })
187 .into_response())
188 }
189}
190
191pub async fn handle_command(
218 State(state): State<AppState>,
219 Query(format_params): Query<FormatParams>,
220 headers: HeaderMap,
221 Json(request): Json<StatementRequest>,
222) -> Result<Response, AppError> {
223 let identity = extract_identity(&headers)?;
225
226 let params = match request.params {
228 None => Params::None,
229 Some(wp) => wp.into_params().map_err(|e| AppError::InvalidParams(e))?,
230 };
231
232 let frames = execute_command(
234 state.actor_system(),
235 state.engine_clone(),
236 request.statements,
237 identity,
238 params,
239 state.query_timeout(),
240 )
241 .await?;
242
243 if format_params.format.as_deref() == Some("json") {
244 let resolved = resolve_response_json(frames, format_params.unwrap.unwrap_or(false))
245 .map_err(|e| AppError::BadRequest(e))?;
246 Ok((StatusCode::OK, [(header::CONTENT_TYPE, resolved.content_type)], resolved.body).into_response())
247 } else {
248 Ok(Json(QueryResponse {
249 frames: convert_frames(frames),
250 })
251 .into_response())
252 }
253}
254
255fn extract_identity(headers: &HeaderMap) -> Result<IdentityId, AppError> {
261 if let Some(auth_header) = headers.get("authorization") {
263 let auth_str = auth_header.to_str().map_err(|_| AppError::Auth(AuthError::InvalidHeader))?;
264
265 return extract_identity_from_auth_header(auth_str).map_err(AppError::Auth);
266 }
267
268 if let Some(api_key) = headers.get("x-api-key") {
270 let key = api_key.to_str().map_err(|_| AppError::Auth(AuthError::InvalidHeader))?;
271
272 return extract_identity_from_api_key(key).map_err(AppError::Auth);
273 }
274
275 Err(AppError::Auth(AuthError::MissingCredentials))
277}
278
279#[cfg(test)]
280pub mod tests {
281 use serde_json::{from_str, to_string};
282
283 use super::*;
284
285 #[test]
286 fn test_statement_request_deserialization() {
287 let json = r#"{"statements": ["SELECT 1"]}"#;
288 let request: StatementRequest = from_str(json).unwrap();
289 assert_eq!(request.statements, vec!["SELECT 1"]);
290 assert!(request.params.is_none());
291 }
292
293 #[test]
294 fn test_query_response_serialization() {
295 let response = QueryResponse {
296 frames: Vec::new(),
297 };
298 let json = to_string(&response).unwrap();
299 assert!(json.contains("frames"));
300 }
301
302 #[test]
303 fn test_health_response_serialization() {
304 let response = HealthResponse {
305 status: "ok",
306 };
307 let json = to_string(&response).unwrap();
308 assert_eq!(json, r#"{"status":"ok"}"#);
309 }
310}