postrust_server/
app.rs

1//! Request handling.
2
3use crate::state::AppState;
4use axum::{
5    body::Body,
6    extract::{Request, State},
7    http::StatusCode,
8    response::{IntoResponse, Response},
9};
10use bytes::Bytes;
11use postrust_auth::authenticate;
12use postrust_core::{create_action_plan, parse_request, ActionPlan, ApiRequest};
13use postrust_response::{format_response, QueryResult, Response as PgrstResponse};
14use sqlx::Row;
15use std::sync::Arc;
16use tracing::{debug, error, info, warn};
17
18/// Main request handler.
19pub async fn handle_request(
20    State(state): State<Arc<AppState>>,
21    request: Request,
22) -> Response {
23    let method = request.method().clone();
24    let path = request.uri().path().to_string();
25
26    debug!("{} {}", method, path);
27
28    match process_request(state, request).await {
29        Ok(response) => response.into_response(),
30        Err(e) => error_response(e).into_response(),
31    }
32}
33
34/// Process a request and return a response.
35async fn process_request(
36    state: Arc<AppState>,
37    request: Request,
38) -> Result<Response, postrust_core::Error> {
39    // Extract auth header
40    let auth_header = request
41        .headers()
42        .get("authorization")
43        .and_then(|v| v.to_str().ok());
44
45    // Authenticate
46    let auth_result = authenticate(auth_header, &state.jwt_config)
47        .map_err(|e| postrust_core::Error::InvalidJwt(e.to_string()))?;
48
49    debug!("Authenticated as role: {}", auth_result.role);
50
51    // Parse request
52    let (parts, body) = request.into_parts();
53    let body_bytes = axum::body::to_bytes(body, 10 * 1024 * 1024)
54        .await
55        .map_err(|e| postrust_core::Error::InvalidBody(e.to_string()))?;
56
57    // Build HTTP request for parsing
58    let mut builder = http::Request::builder()
59        .method(parts.method.clone())
60        .uri(parts.uri.clone());
61
62    for (key, value) in &parts.headers {
63        builder = builder.header(key, value);
64    }
65
66    let http_request = builder
67        .body(body_bytes.clone())
68        .map_err(|e| postrust_core::Error::Internal(e.to_string()))?;
69
70    // Parse API request
71    let mut api_request = parse_request(
72        &http_request,
73        state.default_schema(),
74        state.schemas(),
75    )?;
76
77    // Parse payload
78    if !body_bytes.is_empty() {
79        let payload = postrust_core::api_request::payload::parse_payload(
80            body_bytes,
81            &api_request.content_media_type,
82        )?;
83        api_request.payload = payload;
84    }
85
86    // Get schema cache
87    let schema_cache = state.schema_cache().await;
88
89    // Create execution plan
90    let plan = create_action_plan(&api_request, &schema_cache)?;
91
92    // Execute plan
93    let result = execute_plan(&state, &api_request, &plan, &auth_result).await?;
94
95    // Format response
96    let response = format_response(&api_request, &result)
97        .map_err(|e| postrust_core::Error::Internal(e.to_string()))?;
98
99    Ok(build_response(response))
100}
101
102/// Execute an action plan.
103async fn execute_plan(
104    state: &AppState,
105    request: &ApiRequest,
106    plan: &ActionPlan,
107    auth: &postrust_auth::AuthResult,
108) -> Result<QueryResult, postrust_core::Error> {
109    match plan {
110        ActionPlan::Db(db_plan) => {
111            // Build SQL
112            let query = postrust_core::query::build_query(
113                &ActionPlan::Db(db_plan.clone()),
114                Some(&auth.role),
115            )?;
116
117            if !query.has_main() {
118                return Ok(QueryResult::default());
119            }
120
121            let (sql, params) = query.build_main();
122            debug!("Executing SQL: {}", sql);
123
124            // Execute query
125            let mut conn = state.pool.acquire().await
126                .map_err(|e| postrust_core::Error::ConnectionPool(e.to_string()))?;
127
128            // Set role
129            sqlx::query(&format!(
130                "SET LOCAL ROLE {}",
131                postrust_sql::escape_ident(&auth.role)
132            ))
133            .execute(&mut *conn)
134            .await
135            .map_err(|e| postrust_core::Error::Database(postrust_core::error::DatabaseError {
136                code: "42501".into(),
137                message: e.to_string(),
138                details: None,
139                hint: None,
140                constraint: None,
141                table: None,
142                column: None,
143            }))?;
144
145            // Set claims as GUC
146            for (key, value) in &auth.claims {
147                let guc_key = format!("request.jwt.claims.{}", key);
148                let guc_value = match value {
149                    serde_json::Value::String(s) => s.clone(),
150                    other => other.to_string(),
151                };
152
153                sqlx::query("SELECT set_config($1, $2, true)")
154                    .bind(&guc_key)
155                    .bind(&guc_value)
156                    .execute(&mut *conn)
157                    .await
158                    .ok(); // Ignore errors for individual claims
159            }
160
161            // Execute main query
162            let rows = sqlx::query(&sql)
163                .fetch_all(&mut *conn)
164                .await
165                .map_err(|e| {
166                    error!("Query error: {}", e);
167                    map_sqlx_error(e)
168                })?;
169
170            // Convert rows to JSON
171            let json_rows: Vec<serde_json::Value> = rows
172                .iter()
173                .map(|row| row_to_json(row))
174                .collect();
175
176            Ok(QueryResult {
177                status: StatusCode::OK,
178                rows: json_rows,
179                total_count: None,
180                content_range: None,
181                location: None,
182                guc_headers: None,
183                guc_status: None,
184            })
185        }
186        ActionPlan::Info(info_plan) => {
187            use postrust_core::plan::InfoPlan;
188
189            // Return appropriate metadata based on the info type
190            let response_data = match info_plan {
191                InfoPlan::OpenApiSpec => {
192                    // Return basic server info for root endpoint
193                    serde_json::json!({
194                        "name": "postrust",
195                        "version": env!("CARGO_PKG_VERSION"),
196                        "description": "PostgREST-compatible REST API for PostgreSQL"
197                    })
198                }
199                InfoPlan::RelationInfo(qi) => {
200                    serde_json::json!({
201                        "schema": qi.schema,
202                        "name": qi.name,
203                        "type": "relation"
204                    })
205                }
206                InfoPlan::RoutineInfo(qi) => {
207                    serde_json::json!({
208                        "schema": qi.schema,
209                        "name": qi.name,
210                        "type": "routine"
211                    })
212                }
213            };
214
215            Ok(QueryResult {
216                status: StatusCode::OK,
217                rows: vec![response_data],
218                ..Default::default()
219            })
220        }
221    }
222}
223
224/// Convert a sqlx row to JSON.
225fn row_to_json(row: &sqlx::postgres::PgRow) -> serde_json::Value {
226    use sqlx::{Column, Row, TypeInfo};
227
228    let mut map = serde_json::Map::new();
229
230    for column in row.columns() {
231        let name = column.name();
232        let type_name = column.type_info().name();
233
234        let value = match type_name {
235            "INT2" | "SMALLINT" => row
236                .try_get::<i16, _>(name)
237                .ok()
238                .map(|v| serde_json::Value::Number(v.into())),
239            "INT4" | "INT" | "INTEGER" => row
240                .try_get::<i32, _>(name)
241                .ok()
242                .map(|v| serde_json::Value::Number(v.into())),
243            "INT8" | "BIGINT" => row
244                .try_get::<i64, _>(name)
245                .ok()
246                .map(|v| serde_json::Value::Number(v.into())),
247            "FLOAT4" | "REAL" => row
248                .try_get::<f32, _>(name)
249                .ok()
250                .and_then(|v| serde_json::Number::from_f64(v as f64))
251                .map(serde_json::Value::Number),
252            "FLOAT8" | "DOUBLE PRECISION" => row
253                .try_get::<f64, _>(name)
254                .ok()
255                .and_then(|v| serde_json::Number::from_f64(v))
256                .map(serde_json::Value::Number),
257            "NUMERIC" | "DECIMAL" => row
258                .try_get::<sqlx::types::BigDecimal, _>(name)
259                .ok()
260                .map(|v| serde_json::Value::String(v.to_string())),
261            "BOOL" | "BOOLEAN" => row
262                .try_get::<bool, _>(name)
263                .ok()
264                .map(serde_json::Value::Bool),
265            "JSON" | "JSONB" => row.try_get::<serde_json::Value, _>(name).ok(),
266            "UUID" => row
267                .try_get::<sqlx::types::Uuid, _>(name)
268                .ok()
269                .map(|v| serde_json::Value::String(v.to_string())),
270            "TIMESTAMPTZ" | "TIMESTAMP WITH TIME ZONE" => row
271                .try_get::<chrono::DateTime<chrono::Utc>, _>(name)
272                .ok()
273                .map(|v| serde_json::Value::String(v.to_rfc3339())),
274            "TIMESTAMP" | "TIMESTAMP WITHOUT TIME ZONE" => row
275                .try_get::<chrono::NaiveDateTime, _>(name)
276                .ok()
277                .map(|v| serde_json::Value::String(v.to_string())),
278            "DATE" => row
279                .try_get::<chrono::NaiveDate, _>(name)
280                .ok()
281                .map(|v| serde_json::Value::String(v.to_string())),
282            "TIME" | "TIME WITHOUT TIME ZONE" => row
283                .try_get::<chrono::NaiveTime, _>(name)
284                .ok()
285                .map(|v| serde_json::Value::String(v.to_string())),
286            _ => row
287                .try_get::<String, _>(name)
288                .ok()
289                .map(serde_json::Value::String),
290        };
291
292        map.insert(name.to_string(), value.unwrap_or(serde_json::Value::Null));
293    }
294
295    serde_json::Value::Object(map)
296}
297
298/// Map sqlx error to our error type.
299fn map_sqlx_error(e: sqlx::Error) -> postrust_core::Error {
300    match e {
301        sqlx::Error::Database(db_err) => {
302            // Try to downcast to Postgres-specific error for additional details
303            let (details, hint) = db_err
304                .try_downcast_ref::<sqlx::postgres::PgDatabaseError>()
305                .map(|pg_err| (pg_err.detail().map(String::from), pg_err.hint().map(String::from)))
306                .unwrap_or((None, None));
307
308            postrust_core::Error::Database(postrust_core::error::DatabaseError {
309                code: db_err.code().map(|c| c.to_string()).unwrap_or_default(),
310                message: db_err.message().to_string(),
311                details,
312                hint,
313                constraint: db_err.constraint().map(|s| s.to_string()),
314                table: db_err.table().map(|s| s.to_string()),
315                column: None,
316            })
317        }
318        other => postrust_core::Error::Internal(other.to_string()),
319    }
320}
321
322/// Build an HTTP response from our response type.
323fn build_response(response: PgrstResponse) -> Response {
324    let mut builder = Response::builder().status(response.status);
325
326    for (key, value) in &response.headers {
327        builder = builder.header(key, value);
328    }
329
330    builder
331        .body(Body::from(response.body))
332        .unwrap_or_else(|_| Response::new(Body::empty()))
333}
334
335/// Build an error response.
336///
337/// In production mode (PGRST_DEBUG=false or unset), sensitive error details
338/// are hidden to prevent information leakage.
339fn error_response(error: postrust_core::Error) -> Response {
340    let status = error.status_code();
341
342    // Check if debug mode is enabled
343    let debug_mode = std::env::var("PGRST_DEBUG")
344        .map(|v| v == "true" || v == "1")
345        .unwrap_or(false);
346
347    let body = if debug_mode {
348        // Full error details in debug mode
349        serde_json::to_vec(&error.to_json()).unwrap_or_default()
350    } else {
351        // Sanitized error in production
352        let sanitized = serde_json::json!({
353            "code": error.code(),
354            "message": sanitize_error_message(&error),
355            "details": null,
356            "hint": null
357        });
358        serde_json::to_vec(&sanitized).unwrap_or_default()
359    };
360
361    Response::builder()
362        .status(status)
363        .header("content-type", "application/json")
364        .body(Body::from(body))
365        .unwrap_or_else(|_| Response::new(Body::empty()))
366}
367
368/// Sanitize error messages for production.
369fn sanitize_error_message(error: &postrust_core::Error) -> &'static str {
370    use postrust_core::Error;
371    match error {
372        Error::TableNotFound(_) | Error::NotFound(_) => "Resource not found",
373        Error::FunctionNotFound(_) => "Function not found",
374        Error::ColumnNotFound(_) | Error::UnknownColumn(_) => "Column not found",
375        Error::RelationshipNotFound(_) => "Relationship not found",
376        Error::InvalidPath(_) => "Invalid request path",
377        Error::InvalidBody(_) => "Invalid request body",
378        Error::InvalidJwt(_) | Error::JwtExpired | Error::MissingAuth => "Unauthorized",
379        Error::InsufficientPermissions(_) => "Forbidden",
380        Error::UnacceptableSchema(_) => "Invalid schema",
381        Error::InvalidHeader(_) | Error::InvalidQueryParam(_) => "Invalid request",
382        Error::Database(_) => "Database error",
383        Error::ConnectionPool(_) => "Service temporarily unavailable",
384        Error::Internal(_) => "Internal server error",
385        _ => "An error occurred",
386    }
387}