1use crate::state::AppState;
4use axum::{
5 body::Body,
6 extract::{Request, State},
7 http::StatusCode,
8 response::{IntoResponse, Response},
9};
10use bytes::Bytes;
11use postrust_auth::authenticate;
12use postrust_core::{create_action_plan, parse_request, ActionPlan, ApiRequest};
13use postrust_response::{format_response, QueryResult, Response as PgrstResponse};
14use sqlx::Row;
15use std::sync::Arc;
16use tracing::{debug, error, info, warn};
17
18pub async fn handle_request(
20 State(state): State<Arc<AppState>>,
21 request: Request,
22) -> Response {
23 let method = request.method().clone();
24 let path = request.uri().path().to_string();
25
26 debug!("{} {}", method, path);
27
28 match process_request(state, request).await {
29 Ok(response) => response.into_response(),
30 Err(e) => error_response(e).into_response(),
31 }
32}
33
34async fn process_request(
36 state: Arc<AppState>,
37 request: Request,
38) -> Result<Response, postrust_core::Error> {
39 let auth_header = request
41 .headers()
42 .get("authorization")
43 .and_then(|v| v.to_str().ok());
44
45 let auth_result = authenticate(auth_header, &state.jwt_config)
47 .map_err(|e| postrust_core::Error::InvalidJwt(e.to_string()))?;
48
49 debug!("Authenticated as role: {}", auth_result.role);
50
51 let (parts, body) = request.into_parts();
53 let body_bytes = axum::body::to_bytes(body, 10 * 1024 * 1024)
54 .await
55 .map_err(|e| postrust_core::Error::InvalidBody(e.to_string()))?;
56
57 let mut builder = http::Request::builder()
59 .method(parts.method.clone())
60 .uri(parts.uri.clone());
61
62 for (key, value) in &parts.headers {
63 builder = builder.header(key, value);
64 }
65
66 let http_request = builder
67 .body(body_bytes.clone())
68 .map_err(|e| postrust_core::Error::Internal(e.to_string()))?;
69
70 let mut api_request = parse_request(
72 &http_request,
73 state.default_schema(),
74 state.schemas(),
75 )?;
76
77 if !body_bytes.is_empty() {
79 let payload = postrust_core::api_request::payload::parse_payload(
80 body_bytes,
81 &api_request.content_media_type,
82 )?;
83 api_request.payload = payload;
84 }
85
86 let schema_cache = state.schema_cache().await;
88
89 let plan = create_action_plan(&api_request, &schema_cache)?;
91
92 let result = execute_plan(&state, &api_request, &plan, &auth_result).await?;
94
95 let response = format_response(&api_request, &result)
97 .map_err(|e| postrust_core::Error::Internal(e.to_string()))?;
98
99 Ok(build_response(response))
100}
101
102async fn execute_plan(
104 state: &AppState,
105 request: &ApiRequest,
106 plan: &ActionPlan,
107 auth: &postrust_auth::AuthResult,
108) -> Result<QueryResult, postrust_core::Error> {
109 match plan {
110 ActionPlan::Db(db_plan) => {
111 let query = postrust_core::query::build_query(
113 &ActionPlan::Db(db_plan.clone()),
114 Some(&auth.role),
115 )?;
116
117 if !query.has_main() {
118 return Ok(QueryResult::default());
119 }
120
121 let (sql, params) = query.build_main();
122 debug!("Executing SQL: {}", sql);
123 debug!("With {} parameters", params.len());
124
125 let mut conn = state.pool.acquire().await
127 .map_err(|e| postrust_core::Error::ConnectionPool(e.to_string()))?;
128
129 sqlx::query(&format!(
131 "SET LOCAL ROLE {}",
132 postrust_sql::escape_ident(&auth.role)
133 ))
134 .execute(&mut *conn)
135 .await
136 .map_err(|e| postrust_core::Error::Database(postrust_core::error::DatabaseError {
137 code: "42501".into(),
138 message: e.to_string(),
139 details: None,
140 hint: None,
141 constraint: None,
142 table: None,
143 column: None,
144 }))?;
145
146 for (key, value) in &auth.claims {
148 let guc_key = format!("request.jwt.claims.{}", key);
149 let guc_value = match value {
150 serde_json::Value::String(s) => s.clone(),
151 other => other.to_string(),
152 };
153
154 sqlx::query("SELECT set_config($1, $2, true)")
155 .bind(&guc_key)
156 .bind(&guc_value)
157 .execute(&mut *conn)
158 .await
159 .ok(); }
161
162 let rows = bind_params(sqlx::query(&sql), ¶ms)
164 .fetch_all(&mut *conn)
165 .await
166 .map_err(|e| {
167 error!("Query error: {}", e);
168 map_sqlx_error(e)
169 })?;
170
171 let json_rows: Vec<serde_json::Value> = rows
173 .iter()
174 .map(|row| row_to_json(row))
175 .collect();
176
177 Ok(QueryResult {
178 status: StatusCode::OK,
179 rows: json_rows,
180 total_count: None,
181 content_range: None,
182 location: None,
183 guc_headers: None,
184 guc_status: None,
185 })
186 }
187 ActionPlan::Info(info_plan) => {
188 use postrust_core::plan::InfoPlan;
189
190 let response_data = match info_plan {
192 InfoPlan::OpenApiSpec => {
193 serde_json::json!({
195 "name": "postrust",
196 "version": env!("CARGO_PKG_VERSION"),
197 "description": "PostgREST-compatible REST API for PostgreSQL"
198 })
199 }
200 InfoPlan::RelationInfo(qi) => {
201 serde_json::json!({
202 "schema": qi.schema,
203 "name": qi.name,
204 "type": "relation"
205 })
206 }
207 InfoPlan::RoutineInfo(qi) => {
208 serde_json::json!({
209 "schema": qi.schema,
210 "name": qi.name,
211 "type": "routine"
212 })
213 }
214 };
215
216 Ok(QueryResult {
217 status: StatusCode::OK,
218 rows: vec![response_data],
219 ..Default::default()
220 })
221 }
222 }
223}
224
225fn row_to_json(row: &sqlx::postgres::PgRow) -> serde_json::Value {
227 use sqlx::{Column, Row, TypeInfo};
228
229 let mut map = serde_json::Map::new();
230
231 for column in row.columns() {
232 let name = column.name();
233 let type_name = column.type_info().name();
234
235 let value = match type_name {
236 "INT2" | "SMALLINT" => row
237 .try_get::<i16, _>(name)
238 .ok()
239 .map(|v| serde_json::Value::Number(v.into())),
240 "INT4" | "INT" | "INTEGER" => row
241 .try_get::<i32, _>(name)
242 .ok()
243 .map(|v| serde_json::Value::Number(v.into())),
244 "INT8" | "BIGINT" => row
245 .try_get::<i64, _>(name)
246 .ok()
247 .map(|v| serde_json::Value::Number(v.into())),
248 "FLOAT4" | "REAL" => row
249 .try_get::<f32, _>(name)
250 .ok()
251 .and_then(|v| serde_json::Number::from_f64(v as f64))
252 .map(serde_json::Value::Number),
253 "FLOAT8" | "DOUBLE PRECISION" => row
254 .try_get::<f64, _>(name)
255 .ok()
256 .and_then(|v| serde_json::Number::from_f64(v))
257 .map(serde_json::Value::Number),
258 "NUMERIC" | "DECIMAL" => row
259 .try_get::<sqlx::types::BigDecimal, _>(name)
260 .ok()
261 .map(|v| serde_json::Value::String(v.to_string())),
262 "BOOL" | "BOOLEAN" => row
263 .try_get::<bool, _>(name)
264 .ok()
265 .map(serde_json::Value::Bool),
266 "JSON" | "JSONB" => row.try_get::<serde_json::Value, _>(name).ok(),
267 "UUID" => row
268 .try_get::<sqlx::types::Uuid, _>(name)
269 .ok()
270 .map(|v| serde_json::Value::String(v.to_string())),
271 "TIMESTAMPTZ" | "TIMESTAMP WITH TIME ZONE" => row
272 .try_get::<chrono::DateTime<chrono::Utc>, _>(name)
273 .ok()
274 .map(|v| serde_json::Value::String(v.to_rfc3339())),
275 "TIMESTAMP" | "TIMESTAMP WITHOUT TIME ZONE" => row
276 .try_get::<chrono::NaiveDateTime, _>(name)
277 .ok()
278 .map(|v| serde_json::Value::String(v.to_string())),
279 "DATE" => row
280 .try_get::<chrono::NaiveDate, _>(name)
281 .ok()
282 .map(|v| serde_json::Value::String(v.to_string())),
283 "TIME" | "TIME WITHOUT TIME ZONE" => row
284 .try_get::<chrono::NaiveTime, _>(name)
285 .ok()
286 .map(|v| serde_json::Value::String(v.to_string())),
287 _ => row
288 .try_get::<String, _>(name)
289 .ok()
290 .map(serde_json::Value::String),
291 };
292
293 map.insert(name.to_string(), value.unwrap_or(serde_json::Value::Null));
294 }
295
296 serde_json::Value::Object(map)
297}
298
299fn bind_params<'q>(
301 mut query: sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments>,
302 params: &'q [postrust_sql::SqlParam],
303) -> sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments> {
304 use postrust_sql::SqlParam;
305
306 for param in params {
307 query = match param {
308 SqlParam::Null => query.bind(None::<String>),
309 SqlParam::Bool(b) => query.bind(b),
310 SqlParam::Int(n) => query.bind(n),
311 SqlParam::Float(f) => query.bind(f),
312 SqlParam::Text(s) => query.bind(s),
313 SqlParam::Bytes(b) => query.bind(b),
314 SqlParam::Json(j) => query.bind(j),
315 SqlParam::Uuid(u) => query.bind(u),
316 SqlParam::Timestamp(t) => query.bind(t),
317 SqlParam::Array(arr) => {
318 let strings: Vec<String> = arr
320 .iter()
321 .map(|p| match p {
322 SqlParam::Text(s) => s.clone(),
323 SqlParam::Int(n) => n.to_string(),
324 SqlParam::Bool(b) => b.to_string(),
325 other => format!("{:?}", other),
326 })
327 .collect();
328 query.bind(strings)
329 }
330 };
331 }
332
333 query
334}
335
336fn map_sqlx_error(e: sqlx::Error) -> postrust_core::Error {
338 match e {
339 sqlx::Error::Database(db_err) => {
340 let (details, hint) = db_err
342 .try_downcast_ref::<sqlx::postgres::PgDatabaseError>()
343 .map(|pg_err| (pg_err.detail().map(String::from), pg_err.hint().map(String::from)))
344 .unwrap_or((None, None));
345
346 postrust_core::Error::Database(postrust_core::error::DatabaseError {
347 code: db_err.code().map(|c| c.to_string()).unwrap_or_default(),
348 message: db_err.message().to_string(),
349 details,
350 hint,
351 constraint: db_err.constraint().map(|s| s.to_string()),
352 table: db_err.table().map(|s| s.to_string()),
353 column: None,
354 })
355 }
356 other => postrust_core::Error::Internal(other.to_string()),
357 }
358}
359
360fn build_response(response: PgrstResponse) -> Response {
362 let mut builder = Response::builder().status(response.status);
363
364 for (key, value) in &response.headers {
365 builder = builder.header(key, value);
366 }
367
368 builder
369 .body(Body::from(response.body))
370 .unwrap_or_else(|_| Response::new(Body::empty()))
371}
372
373fn error_response(error: postrust_core::Error) -> Response {
378 let status = error.status_code();
379
380 let debug_mode = std::env::var("PGRST_DEBUG")
382 .map(|v| v == "true" || v == "1")
383 .unwrap_or(false);
384
385 let body = if debug_mode {
386 serde_json::to_vec(&error.to_json()).unwrap_or_default()
388 } else {
389 let sanitized = serde_json::json!({
391 "code": error.code(),
392 "message": sanitize_error_message(&error),
393 "details": null,
394 "hint": null
395 });
396 serde_json::to_vec(&sanitized).unwrap_or_default()
397 };
398
399 Response::builder()
400 .status(status)
401 .header("content-type", "application/json")
402 .body(Body::from(body))
403 .unwrap_or_else(|_| Response::new(Body::empty()))
404}
405
406fn sanitize_error_message(error: &postrust_core::Error) -> &'static str {
408 use postrust_core::Error;
409 match error {
410 Error::TableNotFound(_) | Error::NotFound(_) => "Resource not found",
411 Error::FunctionNotFound(_) => "Function not found",
412 Error::ColumnNotFound(_) | Error::UnknownColumn(_) => "Column not found",
413 Error::RelationshipNotFound(_) => "Relationship not found",
414 Error::InvalidPath(_) => "Invalid request path",
415 Error::InvalidBody(_) => "Invalid request body",
416 Error::InvalidJwt(_) | Error::JwtExpired | Error::MissingAuth => "Unauthorized",
417 Error::InsufficientPermissions(_) => "Forbidden",
418 Error::UnacceptableSchema(_) => "Invalid schema",
419 Error::InvalidHeader(_) | Error::InvalidQueryParam(_) => "Invalid request",
420 Error::Database(_) => "Database error",
421 Error::ConnectionPool(_) => "Service temporarily unavailable",
422 Error::Internal(_) => "Internal server error",
423 _ => "An error occurred",
424 }
425}