Skip to main content

dbrest_postgres/
dialect.rs

1//! PostgreSQL SQL dialect implementation.
2
3use dbrest_core::backend::SqlDialect;
4use dbrest_core::plan::types::CoercibleField;
5use dbrest_core::query::sql_builder::{SqlBuilder, SqlParam};
6
7/// PostgreSQL dialect — generates PG-specific SQL syntax.
8#[derive(Debug, Clone, Copy)]
9pub struct PgDialect;
10
11impl SqlDialect for PgDialect {
12    fn json_agg_with_columns(&self, b: &mut SqlBuilder, alias: &str, _columns: &[&str]) {
13        // PostgreSQL can aggregate entire row aliases — columns not needed.
14        b.push("coalesce(json_agg(");
15        b.push_ident(alias);
16        b.push("), '[]')::text");
17    }
18
19    fn row_to_json_with_columns(&self, b: &mut SqlBuilder, alias: &str, _columns: &[&str]) {
20        b.push("row_to_json(");
21        b.push_ident(alias);
22        b.push(")::text");
23    }
24
25    fn count_expr(&self, b: &mut SqlBuilder, expr: &str) {
26        b.push("pg_catalog.count(");
27        b.push_ident(expr);
28        b.push(")");
29    }
30
31    fn count_star(&self, b: &mut SqlBuilder) {
32        b.push("SELECT COUNT(*) AS ");
33        b.push_ident("dbrst_filtered_count");
34    }
35
36    fn set_session_var(&self, b: &mut SqlBuilder, key: &str, value: &str) {
37        b.push("set_config(");
38        b.push_literal(key);
39        b.push(", ");
40        b.push_literal(value);
41        b.push(", true)");
42    }
43
44    fn get_session_var(&self, b: &mut SqlBuilder, key: &str, column_alias: &str) {
45        b.push("nullif(current_setting('");
46        b.push(key);
47        b.push("', true), '') AS ");
48        b.push(column_alias);
49    }
50
51    fn type_cast(&self, b: &mut SqlBuilder, expr: &str, ty: &str) {
52        b.push(expr);
53        b.push("::");
54        b.push(ty);
55    }
56
57    fn from_json_body(&self, b: &mut SqlBuilder, columns: &[CoercibleField], json_bytes: &[u8]) {
58        let is_array = json_bytes.first().map(|&c| c == b'[').unwrap_or(false);
59        let func = if is_array {
60            "json_to_recordset"
61        } else {
62            "json_to_record"
63        };
64        b.push(func);
65        b.push("(");
66        b.push_param(SqlParam::Text(
67            String::from_utf8_lossy(json_bytes).into_owned(),
68        ));
69        b.push("::json) AS _(");
70        b.push_separated(", ", columns, |b, col| {
71            b.push_ident(&col.name);
72            b.push(" ");
73            b.push(col.base_type.as_deref().unwrap_or("text"));
74        });
75        b.push(")");
76    }
77
78    fn push_type_cast_suffix(&self, b: &mut SqlBuilder, ty: &str) {
79        b.push("::");
80        b.push(ty);
81    }
82
83    fn push_array_type_cast_suffix(&self, b: &mut SqlBuilder, ty: &str) {
84        b.push("::");
85        b.push(ty);
86        b.push("[]");
87    }
88
89    fn quote_ident(&self, ident: &str) -> String {
90        format!("\"{}\"", ident.replace('"', "\"\""))
91    }
92
93    fn quote_literal(&self, lit: &str) -> String {
94        format!("'{}'", lit.replace('\'', "''"))
95    }
96
97    fn supports_fts(&self) -> bool {
98        true
99    }
100
101    fn fts_predicate(
102        &self,
103        b: &mut SqlBuilder,
104        config: Option<&str>,
105        column: &str,
106        operator: &str,
107    ) {
108        b.push("to_tsvector(");
109        if let Some(cfg) = config {
110            b.push_literal(cfg);
111            b.push(", ");
112        }
113        b.push_ident(column);
114        b.push(") @@ ");
115        b.push(operator);
116        b.push("(");
117        if let Some(cfg) = config {
118            b.push_literal(cfg);
119            b.push(", ");
120        }
121    }
122
123    fn row_to_json_star(&self, b: &mut SqlBuilder, source: &str) {
124        b.push("row_to_json(");
125        b.push(source);
126        b.push(".*)::text");
127    }
128
129    fn count_star_from(&self, b: &mut SqlBuilder, source: &str) {
130        b.push("(SELECT pg_catalog.count(*) FROM ");
131        b.push(source);
132        b.push(")");
133    }
134
135    fn push_literal(&self, b: &mut SqlBuilder, s: &str) {
136        let has_backslash = s.contains('\\');
137        if has_backslash {
138            b.push("E");
139        }
140        b.push("'");
141        for ch in s.chars() {
142            if ch == '\'' {
143                b.push("'");
144            }
145            b.push_char(ch);
146        }
147        b.push("'");
148    }
149
150    fn supports_lateral_join(&self) -> bool {
151        true
152    }
153
154    fn named_param_assign(&self) -> &str {
155        " := "
156    }
157}