Skip to main content

dbrest_postgres/
executor.rs

1//! PostgreSQL backend executor — implements [`DatabaseBackend`] for `sqlx::PgPool`.
2
3use std::time::Duration;
4
5use async_trait::async_trait;
6use sqlx::Row;
7use sqlx::postgres::PgPoolOptions;
8
9use crate::introspector::SqlxIntrospector;
10use dbrest_core::backend::{DatabaseBackend, DbVersion, StatementResult};
11use dbrest_core::error::Error;
12use dbrest_core::query::sql_builder::{SqlBuilder, SqlParam};
13use dbrest_core::schema_cache::db::DbIntrospector;
14
15/// PostgreSQL backend backed by `sqlx::PgPool`.
16pub struct PgBackend {
17    pool: sqlx::PgPool,
18}
19
20impl PgBackend {
21    /// Get a reference to the underlying pool (for callers that still need it
22    /// during the migration period).
23    pub fn pool(&self) -> &sqlx::PgPool {
24        &self.pool
25    }
26
27    /// Create from an existing pool (useful for tests and migration).
28    pub fn from_pool(pool: sqlx::PgPool) -> Self {
29        Self { pool }
30    }
31}
32
33// --------------------------------------------------------------------------
34// Helper: bind SqlParam values to a sqlx query
35// --------------------------------------------------------------------------
36
37fn bind_params<'q>(
38    mut q: sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments>,
39    params: &'q [SqlParam],
40) -> sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments> {
41    for p in params {
42        match p {
43            SqlParam::Text(t) => q = q.bind(t.as_str()),
44            SqlParam::Json(j) => q = q.bind(j.to_vec()),
45            SqlParam::Binary(b) => q = q.bind(b.to_vec()),
46            SqlParam::Null => q = q.bind(Option::<String>::None),
47        }
48    }
49    q
50}
51
52/// Map a `sqlx::Error` to our `Error` type, detecting PostgreSQL-specific
53/// constraint violations and error codes.
54pub fn map_sqlx_error(e: sqlx::Error) -> Error {
55    let (code, message, detail, hint) = match &e {
56        sqlx::Error::Database(db_err) => {
57            let code = db_err.code().map(|c| c.to_string());
58            let message = db_err.message().to_string();
59            let detail = db_err.constraint().map(|c| c.to_string());
60
61            let hint = if let Some(pg_err) =
62                db_err.try_downcast_ref::<sqlx::postgres::PgDatabaseError>()
63            {
64                pg_err.hint().map(|s| s.to_string())
65            } else {
66                None
67            };
68
69            (code, message, detail, hint)
70        }
71        _ => {
72            return Error::Database {
73                code: None,
74                message: e.to_string(),
75                detail: None,
76                hint: None,
77            };
78        }
79    };
80
81    if code.is_some() || !message.is_empty() {
82        match code.as_deref() {
83            // Constraint violations
84            Some("23505") => return Error::UniqueViolation(message),
85            Some("23503") => return Error::ForeignKeyViolation(message),
86            Some("23514") => return Error::CheckViolation(message),
87            Some("23502") => return Error::NotNullViolation(message),
88            Some("23P01") => return Error::ExclusionViolation(message),
89
90            // Permission errors
91            Some("42501") => {
92                let role =
93                    extract_role_from_message(&message).unwrap_or_else(|| "unknown".to_string());
94                return Error::PermissionDenied { role };
95            }
96
97            // Not found errors
98            Some("42883") => {
99                if message.contains("operator") {
100                    return Error::Database {
101                        code: Some("42883".to_string()),
102                        message: message.clone(),
103                        detail: Some(
104                            "Operator error: The requested operator is not available for the given data types."
105                                .to_string(),
106                        ),
107                        hint: Some(
108                            "Check that the filter operator and column types are compatible."
109                                .to_string(),
110                        ),
111                    };
112                }
113                let func_name =
114                    extract_name_from_message(&message, "function").unwrap_or_else(|| {
115                        tracing::debug!(
116                            "Could not extract function name from PostgreSQL error: {}",
117                            message
118                        );
119                        "unknown".to_string()
120                    });
121                return Error::FunctionNotFound { name: func_name };
122            }
123            Some("42P01") => {
124                let table_name = extract_name_from_message(&message, "relation")
125                    .unwrap_or_else(|| "unknown".to_string());
126                return Error::TableNotFound {
127                    name: table_name,
128                    suggestion: None,
129                };
130            }
131            Some("42703") => {
132                if let Some(col_start) = message.find("column ")
133                    && let Some(after_col) = message.get(col_start + 7..)
134                {
135                    let col_end = after_col.find(" does").unwrap_or(after_col.len());
136                    let col_ref = &after_col[..col_end];
137                    let col_ref = col_ref.trim();
138
139                    let (table_name, col_name) = if let Some(dot_pos) = col_ref.find('.') {
140                        let table = col_ref[..dot_pos].trim_matches('"').to_string();
141                        let col = col_ref[dot_pos + 1..].trim_matches('"').to_string();
142                        (table, col)
143                    } else {
144                        let col = col_ref.trim_matches('"').to_string();
145                        ("unknown".to_string(), col)
146                    };
147                    return Error::ColumnNotFound {
148                        table: table_name,
149                        column: col_name,
150                    };
151                }
152                return Error::InvalidQueryParam {
153                    param: "column".to_string(),
154                    message,
155                };
156            }
157
158            // RAISE exceptions
159            Some("P0001") => {
160                return Error::RaisedException {
161                    message,
162                    status: None,
163                };
164            }
165
166            // PostgREST custom codes (PT***)
167            Some(code) if code.starts_with("PT") => {
168                if let Some(status_str) = code.strip_prefix("PT")
169                    && let Ok(status) = status_str.parse::<u16>()
170                {
171                    return Error::DbrstRaise { message, status };
172                }
173            }
174
175            _ => {}
176        }
177
178        return Error::Database {
179            code,
180            message,
181            detail,
182            hint,
183        };
184    }
185
186    Error::Database {
187        code: None,
188        message: e.to_string(),
189        detail: None,
190        hint: None,
191    }
192}
193
194fn extract_role_from_message(msg: &str) -> Option<String> {
195    if let Some(start) = msg.find("role ") {
196        let rest = &msg[start + 5..];
197        if let Some(end) = rest.find([' ', '\n', '\r']) {
198            return Some(rest[..end].to_string());
199        }
200        return Some(rest.to_string());
201    }
202    None
203}
204
205fn extract_name_from_message(msg: &str, keyword: &str) -> Option<String> {
206    if let Some(start) = msg.find(keyword) {
207        let rest = &msg[start + keyword.len()..];
208        let rest = rest.trim_start();
209        if let Some(end) = rest.find([' ', ',', '(', '\n', '\r']) {
210            let name = rest[..end].trim_matches('"').to_string();
211            if !name.is_empty() {
212                return Some(name);
213            }
214        }
215        let name = rest
216            .split_whitespace()
217            .next()?
218            .trim_matches('"')
219            .to_string();
220        if !name.is_empty() {
221            return Some(name);
222        }
223    }
224    None
225}
226
227// --------------------------------------------------------------------------
228// Parse the standard 5-column result set
229// --------------------------------------------------------------------------
230
231fn parse_statement_row(row: &sqlx::postgres::PgRow) -> StatementResult {
232    let total: Option<i64> = row
233        .try_get::<String, _>("total_result_set")
234        .ok()
235        .and_then(|s| s.parse::<i64>().ok());
236
237    let page_total: i64 = row.try_get("page_total").unwrap_or(0);
238
239    let body_str: String = row.try_get("body").unwrap_or_else(|_| "[]".to_string());
240
241    let response_headers: Option<serde_json::Value> = row
242        .try_get::<Option<String>, _>("response_headers")
243        .ok()
244        .flatten()
245        .and_then(|s| {
246            if s.is_empty() {
247                None
248            } else {
249                serde_json::from_str(&s).ok()
250            }
251        });
252
253    let response_status: Option<i32> = row
254        .try_get::<Option<String>, _>("response_status")
255        .ok()
256        .flatten()
257        .and_then(|s| {
258            if s.is_empty() {
259                None
260            } else {
261                s.parse::<i32>().ok()
262            }
263        });
264
265    StatementResult {
266        total,
267        page_total,
268        body: body_str,
269        response_headers,
270        response_status,
271    }
272}
273
274// --------------------------------------------------------------------------
275// DatabaseBackend implementation
276// --------------------------------------------------------------------------
277
278#[async_trait]
279impl DatabaseBackend for PgBackend {
280    async fn connect(
281        uri: &str,
282        pool_size: u32,
283        acquire_timeout_secs: u64,
284        max_lifetime_secs: u64,
285        idle_timeout_secs: u64,
286    ) -> Result<Self, Error> {
287        let pool = PgPoolOptions::new()
288            .max_connections(pool_size)
289            .acquire_timeout(Duration::from_secs(acquire_timeout_secs))
290            .max_lifetime(Duration::from_secs(max_lifetime_secs))
291            .idle_timeout(Duration::from_secs(idle_timeout_secs))
292            .connect(uri)
293            .await
294            .map_err(|e| Error::DbConnection(e.to_string()))?;
295
296        Ok(Self { pool })
297    }
298
299    async fn version(&self) -> Result<DbVersion, Error> {
300        let row: (String,) = sqlx::query_as("SHOW server_version")
301            .fetch_one(&self.pool)
302            .await
303            .map_err(|e| Error::DbConnection(format!("Failed to query PG version: {}", e)))?;
304
305        let version_str = &row.0;
306        let parts: Vec<&str> = version_str.split('.').collect();
307        Ok(DbVersion {
308            major: parts.first().and_then(|s| s.parse().ok()).unwrap_or(0),
309            minor: parts.get(1).and_then(|s| s.parse().ok()).unwrap_or(0),
310            patch: parts
311                .get(2)
312                .and_then(|s| s.split_whitespace().next().and_then(|v| v.parse().ok()))
313                .unwrap_or(0),
314            engine: "PostgreSQL".to_string(),
315        })
316    }
317
318    fn min_version(&self) -> (u32, u32) {
319        (12, 0)
320    }
321
322    async fn exec_raw(&self, sql: &str, params: &[SqlParam]) -> Result<(), Error> {
323        let q = sqlx::query(sql);
324        let q = bind_params(q, params);
325        q.execute(&self.pool).await.map_err(map_sqlx_error)?;
326        Ok(())
327    }
328
329    async fn exec_statement(
330        &self,
331        sql: &str,
332        params: &[SqlParam],
333    ) -> Result<StatementResult, Error> {
334        let q = sqlx::query(sql);
335        let q = bind_params(q, params);
336        let rows = q.fetch_all(&self.pool).await.map_err(map_sqlx_error)?;
337
338        if rows.is_empty() {
339            return Ok(StatementResult::empty());
340        }
341
342        Ok(parse_statement_row(&rows[0]))
343    }
344
345    async fn exec_in_transaction(
346        &self,
347        tx_vars: Option<&SqlBuilder>,
348        pre_req: Option<&SqlBuilder>,
349        _mutation: Option<&SqlBuilder>,
350        main: Option<&SqlBuilder>,
351    ) -> Result<StatementResult, Error> {
352        let mut tx = self.pool.begin().await.map_err(|e| Error::Database {
353            code: None,
354            message: e.to_string(),
355            detail: None,
356            hint: None,
357        })?;
358
359        // 1. Set session variables
360        if let Some(tv) = tx_vars {
361            let q = sqlx::query(tv.sql());
362            let q = bind_params(q, tv.params());
363            q.execute(&mut *tx).await.map_err(map_sqlx_error)?;
364        }
365
366        // 2. Call pre-request function
367        if let Some(pr) = pre_req {
368            let q = sqlx::query(pr.sql());
369            let q = bind_params(q, pr.params());
370            q.execute(&mut *tx).await.map_err(map_sqlx_error)?;
371        }
372
373        // 3. Execute the main query
374        let result = if let Some(main_q) = main {
375            let q = sqlx::query(main_q.sql());
376            let q = bind_params(q, main_q.params());
377            let rows = q.fetch_all(&mut *tx).await.map_err(map_sqlx_error)?;
378
379            if rows.is_empty() {
380                StatementResult::empty()
381            } else {
382                parse_statement_row(&rows[0])
383            }
384        } else {
385            StatementResult::empty()
386        };
387
388        tx.commit().await.map_err(|e| Error::Database {
389            code: None,
390            message: e.to_string(),
391            detail: None,
392            hint: None,
393        })?;
394
395        Ok(result)
396    }
397
398    fn introspector(&self) -> Box<dyn DbIntrospector + '_> {
399        Box::new(SqlxIntrospector::new(&self.pool))
400    }
401
402    async fn start_listener(
403        &self,
404        channel: &str,
405        cancel: tokio::sync::watch::Receiver<bool>,
406        on_event: std::sync::Arc<dyn Fn(String) + Send + Sync>,
407    ) -> Result<(), Error> {
408        let mut listener = sqlx::postgres::PgListener::connect_with(&self.pool)
409            .await
410            .map_err(|e| Error::Database {
411                code: None,
412                message: e.to_string(),
413                detail: None,
414                hint: None,
415            })?;
416
417        listener
418            .listen(channel)
419            .await
420            .map_err(|e| Error::Database {
421                code: None,
422                message: e.to_string(),
423                detail: None,
424                hint: None,
425            })?;
426
427        tracing::info!(channel = channel, "Subscribed to NOTIFY channel");
428
429        // Process events in a sub-function to avoid borrow checker issues
430        // with on_event's drop order vs notification payload lifetime.
431        loop {
432            if *cancel.borrow() {
433                return Ok(());
434            }
435
436            let notification = tokio::time::timeout(Duration::from_secs(30), listener.recv()).await;
437
438            // Extract payload as an owned String before calling on_event,
439            // so the PgNotification (which borrows from the listener) is
440            // dropped before the closure is invoked.
441            let maybe_payload: Option<Result<String, sqlx::Error>> = match notification {
442                Ok(Ok(msg)) => Some(Ok(msg.payload().to_string())),
443                Ok(Err(e)) => Some(Err(e)),
444                Err(_) => None,
445            };
446
447            match maybe_payload {
448                Some(Ok(payload)) => {
449                    tracing::info!(payload = %payload, "Received NOTIFY");
450                    on_event(payload);
451                }
452                Some(Err(e)) => {
453                    return Err(Error::Database {
454                        code: None,
455                        message: e.to_string(),
456                        detail: None,
457                        hint: None,
458                    });
459                }
460                None => continue,
461            }
462        }
463    }
464
465    fn map_error(&self, err: Box<dyn std::error::Error + Send + Sync>) -> Error {
466        if let Ok(sqlx_err) = err.downcast::<sqlx::Error>() {
467            map_sqlx_error(*sqlx_err)
468        } else {
469            Error::Internal("Unknown database error".to_string())
470        }
471    }
472}