1use crate::state::AppState;
4use axum::{
5 body::Body,
6 extract::{Request, State},
7 http::StatusCode,
8 response::{IntoResponse, Response},
9};
10use bytes::Bytes;
11use postrust_auth::authenticate;
12use postrust_core::{create_action_plan, parse_request, ActionPlan, ApiRequest};
13use postrust_response::{format_response, QueryResult, Response as PgrstResponse};
14use sqlx::Row;
15use std::sync::Arc;
16use tracing::{debug, error, info, warn};
17
18pub async fn handle_request(
20 State(state): State<Arc<AppState>>,
21 request: Request,
22) -> Response {
23 let method = request.method().clone();
24 let path = request.uri().path().to_string();
25
26 debug!("{} {}", method, path);
27
28 match process_request(state, request).await {
29 Ok(response) => response.into_response(),
30 Err(e) => error_response(e).into_response(),
31 }
32}
33
34async fn process_request(
36 state: Arc<AppState>,
37 request: Request,
38) -> Result<Response, postrust_core::Error> {
39 let auth_header = request
41 .headers()
42 .get("authorization")
43 .and_then(|v| v.to_str().ok());
44
45 let auth_result = authenticate(auth_header, &state.jwt_config)
47 .map_err(|e| postrust_core::Error::InvalidJwt(e.to_string()))?;
48
49 debug!("Authenticated as role: {}", auth_result.role);
50
51 let (parts, body) = request.into_parts();
53 let body_bytes = axum::body::to_bytes(body, 10 * 1024 * 1024)
54 .await
55 .map_err(|e| postrust_core::Error::InvalidBody(e.to_string()))?;
56
57 let mut builder = http::Request::builder()
59 .method(parts.method.clone())
60 .uri(parts.uri.clone());
61
62 for (key, value) in &parts.headers {
63 builder = builder.header(key, value);
64 }
65
66 let http_request = builder
67 .body(body_bytes.clone())
68 .map_err(|e| postrust_core::Error::Internal(e.to_string()))?;
69
70 let mut api_request = parse_request(
72 &http_request,
73 state.default_schema(),
74 state.schemas(),
75 )?;
76
77 if !body_bytes.is_empty() {
79 let payload = postrust_core::api_request::payload::parse_payload(
80 body_bytes,
81 &api_request.content_media_type,
82 )?;
83 api_request.payload = payload;
84 }
85
86 let schema_cache = state.schema_cache().await;
88
89 let plan = create_action_plan(&api_request, &schema_cache)?;
91
92 let result = execute_plan(&state, &api_request, &plan, &auth_result).await?;
94
95 let response = format_response(&api_request, &result)
97 .map_err(|e| postrust_core::Error::Internal(e.to_string()))?;
98
99 Ok(build_response(response))
100}
101
102async fn execute_plan(
104 state: &AppState,
105 request: &ApiRequest,
106 plan: &ActionPlan,
107 auth: &postrust_auth::AuthResult,
108) -> Result<QueryResult, postrust_core::Error> {
109 match plan {
110 ActionPlan::Db(db_plan) => {
111 let query = postrust_core::query::build_query(
113 &ActionPlan::Db(db_plan.clone()),
114 Some(&auth.role),
115 )?;
116
117 if !query.has_main() {
118 return Ok(QueryResult::default());
119 }
120
121 let (sql, params) = query.build_main();
122 debug!("Executing SQL: {}", sql);
123
124 let mut conn = state.pool.acquire().await
126 .map_err(|e| postrust_core::Error::ConnectionPool(e.to_string()))?;
127
128 sqlx::query(&format!(
130 "SET LOCAL ROLE {}",
131 postrust_sql::escape_ident(&auth.role)
132 ))
133 .execute(&mut *conn)
134 .await
135 .map_err(|e| postrust_core::Error::Database(postrust_core::error::DatabaseError {
136 code: "42501".into(),
137 message: e.to_string(),
138 details: None,
139 hint: None,
140 constraint: None,
141 table: None,
142 column: None,
143 }))?;
144
145 for (key, value) in &auth.claims {
147 let guc_key = format!("request.jwt.claims.{}", key);
148 let guc_value = match value {
149 serde_json::Value::String(s) => s.clone(),
150 other => other.to_string(),
151 };
152
153 sqlx::query("SELECT set_config($1, $2, true)")
154 .bind(&guc_key)
155 .bind(&guc_value)
156 .execute(&mut *conn)
157 .await
158 .ok(); }
160
161 let rows = sqlx::query(&sql)
163 .fetch_all(&mut *conn)
164 .await
165 .map_err(|e| {
166 error!("Query error: {}", e);
167 map_sqlx_error(e)
168 })?;
169
170 let json_rows: Vec<serde_json::Value> = rows
172 .iter()
173 .map(|row| row_to_json(row))
174 .collect();
175
176 Ok(QueryResult {
177 status: StatusCode::OK,
178 rows: json_rows,
179 total_count: None,
180 content_range: None,
181 location: None,
182 guc_headers: None,
183 guc_status: None,
184 })
185 }
186 ActionPlan::Info(info_plan) => {
187 use postrust_core::plan::InfoPlan;
188
189 let response_data = match info_plan {
191 InfoPlan::OpenApiSpec => {
192 serde_json::json!({
194 "name": "postrust",
195 "version": env!("CARGO_PKG_VERSION"),
196 "description": "PostgREST-compatible REST API for PostgreSQL"
197 })
198 }
199 InfoPlan::RelationInfo(qi) => {
200 serde_json::json!({
201 "schema": qi.schema,
202 "name": qi.name,
203 "type": "relation"
204 })
205 }
206 InfoPlan::RoutineInfo(qi) => {
207 serde_json::json!({
208 "schema": qi.schema,
209 "name": qi.name,
210 "type": "routine"
211 })
212 }
213 };
214
215 Ok(QueryResult {
216 status: StatusCode::OK,
217 rows: vec![response_data],
218 ..Default::default()
219 })
220 }
221 }
222}
223
224fn row_to_json(row: &sqlx::postgres::PgRow) -> serde_json::Value {
226 use sqlx::{Column, Row, TypeInfo};
227
228 let mut map = serde_json::Map::new();
229
230 for column in row.columns() {
231 let name = column.name();
232 let type_name = column.type_info().name();
233
234 let value = match type_name {
235 "INT2" | "SMALLINT" => row
236 .try_get::<i16, _>(name)
237 .ok()
238 .map(|v| serde_json::Value::Number(v.into())),
239 "INT4" | "INT" | "INTEGER" => row
240 .try_get::<i32, _>(name)
241 .ok()
242 .map(|v| serde_json::Value::Number(v.into())),
243 "INT8" | "BIGINT" => row
244 .try_get::<i64, _>(name)
245 .ok()
246 .map(|v| serde_json::Value::Number(v.into())),
247 "FLOAT4" | "REAL" => row
248 .try_get::<f32, _>(name)
249 .ok()
250 .and_then(|v| serde_json::Number::from_f64(v as f64))
251 .map(serde_json::Value::Number),
252 "FLOAT8" | "DOUBLE PRECISION" => row
253 .try_get::<f64, _>(name)
254 .ok()
255 .and_then(|v| serde_json::Number::from_f64(v))
256 .map(serde_json::Value::Number),
257 "NUMERIC" | "DECIMAL" => row
258 .try_get::<sqlx::types::BigDecimal, _>(name)
259 .ok()
260 .map(|v| serde_json::Value::String(v.to_string())),
261 "BOOL" | "BOOLEAN" => row
262 .try_get::<bool, _>(name)
263 .ok()
264 .map(serde_json::Value::Bool),
265 "JSON" | "JSONB" => row.try_get::<serde_json::Value, _>(name).ok(),
266 "UUID" => row
267 .try_get::<sqlx::types::Uuid, _>(name)
268 .ok()
269 .map(|v| serde_json::Value::String(v.to_string())),
270 "TIMESTAMPTZ" | "TIMESTAMP WITH TIME ZONE" => row
271 .try_get::<chrono::DateTime<chrono::Utc>, _>(name)
272 .ok()
273 .map(|v| serde_json::Value::String(v.to_rfc3339())),
274 "TIMESTAMP" | "TIMESTAMP WITHOUT TIME ZONE" => row
275 .try_get::<chrono::NaiveDateTime, _>(name)
276 .ok()
277 .map(|v| serde_json::Value::String(v.to_string())),
278 "DATE" => row
279 .try_get::<chrono::NaiveDate, _>(name)
280 .ok()
281 .map(|v| serde_json::Value::String(v.to_string())),
282 "TIME" | "TIME WITHOUT TIME ZONE" => row
283 .try_get::<chrono::NaiveTime, _>(name)
284 .ok()
285 .map(|v| serde_json::Value::String(v.to_string())),
286 _ => row
287 .try_get::<String, _>(name)
288 .ok()
289 .map(serde_json::Value::String),
290 };
291
292 map.insert(name.to_string(), value.unwrap_or(serde_json::Value::Null));
293 }
294
295 serde_json::Value::Object(map)
296}
297
298fn map_sqlx_error(e: sqlx::Error) -> postrust_core::Error {
300 match e {
301 sqlx::Error::Database(db_err) => {
302 let (details, hint) = db_err
304 .try_downcast_ref::<sqlx::postgres::PgDatabaseError>()
305 .map(|pg_err| (pg_err.detail().map(String::from), pg_err.hint().map(String::from)))
306 .unwrap_or((None, None));
307
308 postrust_core::Error::Database(postrust_core::error::DatabaseError {
309 code: db_err.code().map(|c| c.to_string()).unwrap_or_default(),
310 message: db_err.message().to_string(),
311 details,
312 hint,
313 constraint: db_err.constraint().map(|s| s.to_string()),
314 table: db_err.table().map(|s| s.to_string()),
315 column: None,
316 })
317 }
318 other => postrust_core::Error::Internal(other.to_string()),
319 }
320}
321
322fn build_response(response: PgrstResponse) -> Response {
324 let mut builder = Response::builder().status(response.status);
325
326 for (key, value) in &response.headers {
327 builder = builder.header(key, value);
328 }
329
330 builder
331 .body(Body::from(response.body))
332 .unwrap_or_else(|_| Response::new(Body::empty()))
333}
334
335fn error_response(error: postrust_core::Error) -> Response {
340 let status = error.status_code();
341
342 let debug_mode = std::env::var("PGRST_DEBUG")
344 .map(|v| v == "true" || v == "1")
345 .unwrap_or(false);
346
347 let body = if debug_mode {
348 serde_json::to_vec(&error.to_json()).unwrap_or_default()
350 } else {
351 let sanitized = serde_json::json!({
353 "code": error.code(),
354 "message": sanitize_error_message(&error),
355 "details": null,
356 "hint": null
357 });
358 serde_json::to_vec(&sanitized).unwrap_or_default()
359 };
360
361 Response::builder()
362 .status(status)
363 .header("content-type", "application/json")
364 .body(Body::from(body))
365 .unwrap_or_else(|_| Response::new(Body::empty()))
366}
367
368fn sanitize_error_message(error: &postrust_core::Error) -> &'static str {
370 use postrust_core::Error;
371 match error {
372 Error::TableNotFound(_) | Error::NotFound(_) => "Resource not found",
373 Error::FunctionNotFound(_) => "Function not found",
374 Error::ColumnNotFound(_) | Error::UnknownColumn(_) => "Column not found",
375 Error::RelationshipNotFound(_) => "Relationship not found",
376 Error::InvalidPath(_) => "Invalid request path",
377 Error::InvalidBody(_) => "Invalid request body",
378 Error::InvalidJwt(_) | Error::JwtExpired | Error::MissingAuth => "Unauthorized",
379 Error::InsufficientPermissions(_) => "Forbidden",
380 Error::UnacceptableSchema(_) => "Invalid schema",
381 Error::InvalidHeader(_) | Error::InvalidQueryParam(_) => "Invalid request",
382 Error::Database(_) => "Database error",
383 Error::ConnectionPool(_) => "Service temporarily unavailable",
384 Error::Internal(_) => "Internal server error",
385 _ => "An error occurred",
386 }
387}