postrust_server/
app.rs

1//! Request handling.
2
3use crate::state::AppState;
4use axum::{
5    body::Body,
6    extract::{Request, State},
7    http::StatusCode,
8    response::{IntoResponse, Response},
9};
10use bytes::Bytes;
11use postrust_auth::authenticate;
12use postrust_core::{create_action_plan, parse_request, ActionPlan, ApiRequest};
13use postrust_response::{format_response, QueryResult, Response as PgrstResponse};
14use sqlx::Row;
15use std::sync::Arc;
16use tracing::{debug, error, info, warn};
17
18/// Main request handler.
19pub async fn handle_request(
20    State(state): State<Arc<AppState>>,
21    request: Request,
22) -> Response {
23    let method = request.method().clone();
24    let path = request.uri().path().to_string();
25
26    debug!("{} {}", method, path);
27
28    match process_request(state, request).await {
29        Ok(response) => response.into_response(),
30        Err(e) => error_response(e).into_response(),
31    }
32}
33
34/// Process a request and return a response.
35async fn process_request(
36    state: Arc<AppState>,
37    request: Request,
38) -> Result<Response, postrust_core::Error> {
39    // Extract auth header
40    let auth_header = request
41        .headers()
42        .get("authorization")
43        .and_then(|v| v.to_str().ok());
44
45    // Authenticate
46    let auth_result = authenticate(auth_header, &state.jwt_config)
47        .map_err(|e| postrust_core::Error::InvalidJwt(e.to_string()))?;
48
49    debug!("Authenticated as role: {}", auth_result.role);
50
51    // Parse request
52    let (parts, body) = request.into_parts();
53    let body_bytes = axum::body::to_bytes(body, 10 * 1024 * 1024)
54        .await
55        .map_err(|e| postrust_core::Error::InvalidBody(e.to_string()))?;
56
57    // Build HTTP request for parsing
58    let mut builder = http::Request::builder()
59        .method(parts.method.clone())
60        .uri(parts.uri.clone());
61
62    for (key, value) in &parts.headers {
63        builder = builder.header(key, value);
64    }
65
66    let http_request = builder
67        .body(body_bytes.clone())
68        .map_err(|e| postrust_core::Error::Internal(e.to_string()))?;
69
70    // Parse API request
71    let mut api_request = parse_request(
72        &http_request,
73        state.default_schema(),
74        state.schemas(),
75    )?;
76
77    // Parse payload
78    if !body_bytes.is_empty() {
79        let payload = postrust_core::api_request::payload::parse_payload(
80            body_bytes,
81            &api_request.content_media_type,
82        )?;
83        api_request.payload = payload;
84    }
85
86    // Get schema cache
87    let schema_cache = state.schema_cache().await;
88
89    // Create execution plan
90    let plan = create_action_plan(&api_request, &schema_cache)?;
91
92    // Execute plan
93    let result = execute_plan(&state, &api_request, &plan, &auth_result).await?;
94
95    // Format response
96    let response = format_response(&api_request, &result)
97        .map_err(|e| postrust_core::Error::Internal(e.to_string()))?;
98
99    Ok(build_response(response))
100}
101
102/// Execute an action plan.
103async fn execute_plan(
104    state: &AppState,
105    request: &ApiRequest,
106    plan: &ActionPlan,
107    auth: &postrust_auth::AuthResult,
108) -> Result<QueryResult, postrust_core::Error> {
109    match plan {
110        ActionPlan::Db(db_plan) => {
111            // Build SQL
112            let query = postrust_core::query::build_query(
113                &ActionPlan::Db(db_plan.clone()),
114                Some(&auth.role),
115            )?;
116
117            if !query.has_main() {
118                return Ok(QueryResult::default());
119            }
120
121            let (sql, params) = query.build_main();
122            debug!("Executing SQL: {}", sql);
123            debug!("With {} parameters", params.len());
124
125            // Execute query
126            let mut conn = state.pool.acquire().await
127                .map_err(|e| postrust_core::Error::ConnectionPool(e.to_string()))?;
128
129            // Set role
130            sqlx::query(&format!(
131                "SET LOCAL ROLE {}",
132                postrust_sql::escape_ident(&auth.role)
133            ))
134            .execute(&mut *conn)
135            .await
136            .map_err(|e| postrust_core::Error::Database(postrust_core::error::DatabaseError {
137                code: "42501".into(),
138                message: e.to_string(),
139                details: None,
140                hint: None,
141                constraint: None,
142                table: None,
143                column: None,
144            }))?;
145
146            // Set claims as GUC
147            for (key, value) in &auth.claims {
148                let guc_key = format!("request.jwt.claims.{}", key);
149                let guc_value = match value {
150                    serde_json::Value::String(s) => s.clone(),
151                    other => other.to_string(),
152                };
153
154                sqlx::query("SELECT set_config($1, $2, true)")
155                    .bind(&guc_key)
156                    .bind(&guc_value)
157                    .execute(&mut *conn)
158                    .await
159                    .ok(); // Ignore errors for individual claims
160            }
161
162            // Execute main query with bound parameters
163            let rows = bind_params(sqlx::query(&sql), &params)
164                .fetch_all(&mut *conn)
165                .await
166                .map_err(|e| {
167                    error!("Query error: {}", e);
168                    map_sqlx_error(e)
169                })?;
170
171            // Convert rows to JSON
172            let json_rows: Vec<serde_json::Value> = rows
173                .iter()
174                .map(|row| row_to_json(row))
175                .collect();
176
177            Ok(QueryResult {
178                status: StatusCode::OK,
179                rows: json_rows,
180                total_count: None,
181                content_range: None,
182                location: None,
183                guc_headers: None,
184                guc_status: None,
185            })
186        }
187        ActionPlan::Info(info_plan) => {
188            use postrust_core::plan::InfoPlan;
189
190            // Return appropriate metadata based on the info type
191            let response_data = match info_plan {
192                InfoPlan::OpenApiSpec => {
193                    // Return basic server info for root endpoint
194                    serde_json::json!({
195                        "name": "postrust",
196                        "version": env!("CARGO_PKG_VERSION"),
197                        "description": "PostgREST-compatible REST API for PostgreSQL"
198                    })
199                }
200                InfoPlan::RelationInfo(qi) => {
201                    serde_json::json!({
202                        "schema": qi.schema,
203                        "name": qi.name,
204                        "type": "relation"
205                    })
206                }
207                InfoPlan::RoutineInfo(qi) => {
208                    serde_json::json!({
209                        "schema": qi.schema,
210                        "name": qi.name,
211                        "type": "routine"
212                    })
213                }
214            };
215
216            Ok(QueryResult {
217                status: StatusCode::OK,
218                rows: vec![response_data],
219                ..Default::default()
220            })
221        }
222    }
223}
224
225/// Convert a sqlx row to JSON.
226fn row_to_json(row: &sqlx::postgres::PgRow) -> serde_json::Value {
227    use sqlx::{Column, Row, TypeInfo};
228
229    let mut map = serde_json::Map::new();
230
231    for column in row.columns() {
232        let name = column.name();
233        let type_name = column.type_info().name();
234
235        let value = match type_name {
236            "INT2" | "SMALLINT" => row
237                .try_get::<i16, _>(name)
238                .ok()
239                .map(|v| serde_json::Value::Number(v.into())),
240            "INT4" | "INT" | "INTEGER" => row
241                .try_get::<i32, _>(name)
242                .ok()
243                .map(|v| serde_json::Value::Number(v.into())),
244            "INT8" | "BIGINT" => row
245                .try_get::<i64, _>(name)
246                .ok()
247                .map(|v| serde_json::Value::Number(v.into())),
248            "FLOAT4" | "REAL" => row
249                .try_get::<f32, _>(name)
250                .ok()
251                .and_then(|v| serde_json::Number::from_f64(v as f64))
252                .map(serde_json::Value::Number),
253            "FLOAT8" | "DOUBLE PRECISION" => row
254                .try_get::<f64, _>(name)
255                .ok()
256                .and_then(|v| serde_json::Number::from_f64(v))
257                .map(serde_json::Value::Number),
258            "NUMERIC" | "DECIMAL" => row
259                .try_get::<sqlx::types::BigDecimal, _>(name)
260                .ok()
261                .map(|v| serde_json::Value::String(v.to_string())),
262            "BOOL" | "BOOLEAN" => row
263                .try_get::<bool, _>(name)
264                .ok()
265                .map(serde_json::Value::Bool),
266            "JSON" | "JSONB" => row.try_get::<serde_json::Value, _>(name).ok(),
267            "UUID" => row
268                .try_get::<sqlx::types::Uuid, _>(name)
269                .ok()
270                .map(|v| serde_json::Value::String(v.to_string())),
271            "TIMESTAMPTZ" | "TIMESTAMP WITH TIME ZONE" => row
272                .try_get::<chrono::DateTime<chrono::Utc>, _>(name)
273                .ok()
274                .map(|v| serde_json::Value::String(v.to_rfc3339())),
275            "TIMESTAMP" | "TIMESTAMP WITHOUT TIME ZONE" => row
276                .try_get::<chrono::NaiveDateTime, _>(name)
277                .ok()
278                .map(|v| serde_json::Value::String(v.to_string())),
279            "DATE" => row
280                .try_get::<chrono::NaiveDate, _>(name)
281                .ok()
282                .map(|v| serde_json::Value::String(v.to_string())),
283            "TIME" | "TIME WITHOUT TIME ZONE" => row
284                .try_get::<chrono::NaiveTime, _>(name)
285                .ok()
286                .map(|v| serde_json::Value::String(v.to_string())),
287            _ => row
288                .try_get::<String, _>(name)
289                .ok()
290                .map(serde_json::Value::String),
291        };
292
293        map.insert(name.to_string(), value.unwrap_or(serde_json::Value::Null));
294    }
295
296    serde_json::Value::Object(map)
297}
298
299/// Bind SqlParam values to a sqlx query.
300fn bind_params<'q>(
301    mut query: sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments>,
302    params: &'q [postrust_sql::SqlParam],
303) -> sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments> {
304    use postrust_sql::SqlParam;
305
306    for param in params {
307        query = match param {
308            SqlParam::Null => query.bind(None::<String>),
309            SqlParam::Bool(b) => query.bind(b),
310            SqlParam::Int(n) => query.bind(n),
311            SqlParam::Float(f) => query.bind(f),
312            SqlParam::Text(s) => query.bind(s),
313            SqlParam::Bytes(b) => query.bind(b),
314            SqlParam::Json(j) => query.bind(j),
315            SqlParam::Uuid(u) => query.bind(u),
316            SqlParam::Timestamp(t) => query.bind(t),
317            SqlParam::Array(arr) => {
318                // Convert array to Vec<String> for text arrays
319                let strings: Vec<String> = arr
320                    .iter()
321                    .map(|p| match p {
322                        SqlParam::Text(s) => s.clone(),
323                        SqlParam::Int(n) => n.to_string(),
324                        SqlParam::Bool(b) => b.to_string(),
325                        other => format!("{:?}", other),
326                    })
327                    .collect();
328                query.bind(strings)
329            }
330        };
331    }
332
333    query
334}
335
336/// Map sqlx error to our error type.
337fn map_sqlx_error(e: sqlx::Error) -> postrust_core::Error {
338    match e {
339        sqlx::Error::Database(db_err) => {
340            // Try to downcast to Postgres-specific error for additional details
341            let (details, hint) = db_err
342                .try_downcast_ref::<sqlx::postgres::PgDatabaseError>()
343                .map(|pg_err| (pg_err.detail().map(String::from), pg_err.hint().map(String::from)))
344                .unwrap_or((None, None));
345
346            postrust_core::Error::Database(postrust_core::error::DatabaseError {
347                code: db_err.code().map(|c| c.to_string()).unwrap_or_default(),
348                message: db_err.message().to_string(),
349                details,
350                hint,
351                constraint: db_err.constraint().map(|s| s.to_string()),
352                table: db_err.table().map(|s| s.to_string()),
353                column: None,
354            })
355        }
356        other => postrust_core::Error::Internal(other.to_string()),
357    }
358}
359
360/// Build an HTTP response from our response type.
361fn build_response(response: PgrstResponse) -> Response {
362    let mut builder = Response::builder().status(response.status);
363
364    for (key, value) in &response.headers {
365        builder = builder.header(key, value);
366    }
367
368    builder
369        .body(Body::from(response.body))
370        .unwrap_or_else(|_| Response::new(Body::empty()))
371}
372
373/// Build an error response.
374///
375/// In production mode (PGRST_DEBUG=false or unset), sensitive error details
376/// are hidden to prevent information leakage.
377fn error_response(error: postrust_core::Error) -> Response {
378    let status = error.status_code();
379
380    // Check if debug mode is enabled
381    let debug_mode = std::env::var("PGRST_DEBUG")
382        .map(|v| v == "true" || v == "1")
383        .unwrap_or(false);
384
385    let body = if debug_mode {
386        // Full error details in debug mode
387        serde_json::to_vec(&error.to_json()).unwrap_or_default()
388    } else {
389        // Sanitized error in production
390        let sanitized = serde_json::json!({
391            "code": error.code(),
392            "message": sanitize_error_message(&error),
393            "details": null,
394            "hint": null
395        });
396        serde_json::to_vec(&sanitized).unwrap_or_default()
397    };
398
399    Response::builder()
400        .status(status)
401        .header("content-type", "application/json")
402        .body(Body::from(body))
403        .unwrap_or_else(|_| Response::new(Body::empty()))
404}
405
406/// Sanitize error messages for production.
407fn sanitize_error_message(error: &postrust_core::Error) -> &'static str {
408    use postrust_core::Error;
409    match error {
410        Error::TableNotFound(_) | Error::NotFound(_) => "Resource not found",
411        Error::FunctionNotFound(_) => "Function not found",
412        Error::ColumnNotFound(_) | Error::UnknownColumn(_) => "Column not found",
413        Error::RelationshipNotFound(_) => "Relationship not found",
414        Error::InvalidPath(_) => "Invalid request path",
415        Error::InvalidBody(_) => "Invalid request body",
416        Error::InvalidJwt(_) | Error::JwtExpired | Error::MissingAuth => "Unauthorized",
417        Error::InsufficientPermissions(_) => "Forbidden",
418        Error::UnacceptableSchema(_) => "Invalid schema",
419        Error::InvalidHeader(_) | Error::InvalidQueryParam(_) => "Invalid request",
420        Error::Database(_) => "Database error",
421        Error::ConnectionPool(_) => "Service temporarily unavailable",
422        Error::Internal(_) => "Internal server error",
423        _ => "An error occurred",
424    }
425}