Skip to main content

fraiseql_db/dialect/
postgres.rs

1//! PostgreSQL SQL dialect implementation.
2
3use std::{borrow::Cow, fmt::Write};
4
5use super::trait_def::{RowViewColumnType, SqlDialect, UnsupportedOperator};
6
7/// PostgreSQL dialect for [`GenericWhereGenerator`].
8///
9/// [`GenericWhereGenerator`]: crate::where_generator::GenericWhereGenerator
10pub struct PostgresDialect;
11
12impl SqlDialect for PostgresDialect {
13    fn name(&self) -> &'static str {
14        "PostgreSQL"
15    }
16
17    fn quote_identifier(&self, name: &str) -> String {
18        format!("\"{}\"", name.replace('"', "\"\""))
19    }
20
21    fn json_extract_scalar(&self, column: &str, path: &[String]) -> String {
22        use crate::path_escape::{escape_postgres_jsonb_path, escape_postgres_jsonb_segment};
23
24        if path.len() == 1 {
25            let escaped = escape_postgres_jsonb_segment(&path[0]);
26            format!("{column}->>'{escaped}'")
27        } else {
28            let escaped_path = escape_postgres_jsonb_path(path);
29            let mut result = column.to_owned();
30            for (i, segment) in escaped_path.iter().enumerate() {
31                if i < escaped_path.len() - 1 {
32                    write!(result, "->'{segment}'").expect("write to String");
33                } else {
34                    write!(result, "->>'{segment}'").expect("write to String");
35                }
36            }
37            result
38        }
39    }
40
41    fn placeholder(&self, n: usize) -> String {
42        format!("${n}")
43    }
44
45    fn cast_to_numeric<'a>(&self, expr: &'a str) -> Cow<'a, str> {
46        Cow::Owned(format!("({expr})::numeric"))
47    }
48
49    fn cast_to_boolean<'a>(&self, expr: &'a str) -> Cow<'a, str> {
50        Cow::Owned(format!("({expr})::boolean"))
51    }
52
53    fn cast_param_numeric<'a>(&self, placeholder: &'a str) -> Cow<'a, str> {
54        Cow::Owned(format!("({placeholder}::text)::numeric"))
55    }
56
57    fn cast_native_param(&self, placeholder: &str, native_type: &str) -> String {
58        match native_type.to_lowercase().as_str() {
59            // bool uses QueryParam::Bool which encodes correctly in binary — no intermediate text.
60            "boolean" | "bool" => format!("{placeholder}::bool"),
61            // text/varchar/char(n) — no cast needed.
62            "text" | "varchar" | "character varying" | "char" | "bpchar" | "name" => {
63                placeholder.to_string()
64            },
65            // Everything else: two-step cast forces $N to be resolved as text by the
66            // server, avoiding binary-encoding mismatches for uuid, timestamps, ints, etc.
67            _ => format!("{placeholder}::text::{native_type}"),
68        }
69    }
70
71    fn ilike_sql(&self, lhs: &str, rhs: &str) -> String {
72        format!("{lhs} ILIKE {rhs}")
73    }
74
75    fn json_array_length(&self, expr: &str) -> String {
76        format!("jsonb_array_length({expr}::jsonb)")
77    }
78
79    fn array_contains_sql(&self, lhs: &str, rhs: &str) -> Result<String, UnsupportedOperator> {
80        Ok(format!("{lhs}::jsonb @> {rhs}::jsonb"))
81    }
82
83    fn array_contained_by_sql(&self, lhs: &str, rhs: &str) -> Result<String, UnsupportedOperator> {
84        Ok(format!("{lhs}::jsonb <@ {rhs}::jsonb"))
85    }
86
87    fn array_overlaps_sql(&self, lhs: &str, rhs: &str) -> Result<String, UnsupportedOperator> {
88        Ok(format!("{lhs}::jsonb && {rhs}::jsonb"))
89    }
90
91    fn fts_matches_sql(&self, expr: &str, param: &str) -> Result<String, UnsupportedOperator> {
92        Ok(format!("to_tsvector({expr}) @@ to_tsquery({param})"))
93    }
94
95    fn fts_plain_query_sql(&self, expr: &str, param: &str) -> Result<String, UnsupportedOperator> {
96        Ok(format!("to_tsvector({expr}) @@ plainto_tsquery({param})"))
97    }
98
99    fn fts_phrase_query_sql(&self, expr: &str, param: &str) -> Result<String, UnsupportedOperator> {
100        Ok(format!("to_tsvector({expr}) @@ phraseto_tsquery({param})"))
101    }
102
103    fn fts_websearch_query_sql(
104        &self,
105        expr: &str,
106        param: &str,
107    ) -> Result<String, UnsupportedOperator> {
108        Ok(format!("to_tsvector({expr}) @@ websearch_to_tsquery({param})"))
109    }
110
111    fn regex_sql(
112        &self,
113        lhs: &str,
114        rhs: &str,
115        case_insensitive: bool,
116        negate: bool,
117    ) -> Result<String, UnsupportedOperator> {
118        let op = match (case_insensitive, negate) {
119            (false, false) => "~",
120            (true, false) => "~*",
121            (false, true) => "!~",
122            (true, true) => "!~*",
123        };
124        Ok(format!("{lhs} {op} {rhs}"))
125    }
126
127    // ── PostgreSQL-only operators ──────────────────────────────────────────────
128
129    fn vector_distance_sql(
130        &self,
131        pg_op: &str,
132        lhs: &str,
133        rhs: &str,
134    ) -> Result<String, UnsupportedOperator> {
135        Ok(format!("{lhs}::vector {pg_op} {rhs}::vector"))
136    }
137
138    fn jaccard_distance_sql(&self, lhs: &str, rhs: &str) -> Result<String, UnsupportedOperator> {
139        Ok(format!("({lhs})::text[] <%> ({rhs})::text[]"))
140    }
141
142    fn inet_check_sql(&self, lhs: &str, check_name: &str) -> Result<String, UnsupportedOperator> {
143        match check_name {
144            "IsIPv4" => Ok(format!("family({lhs}::inet) = 4")),
145            "IsIPv6" => Ok(format!("family({lhs}::inet) = 6")),
146            "IsPrivate" => Ok(format!(
147                "({lhs}::inet << '10.0.0.0/8'::inet OR {lhs}::inet << '172.16.0.0/12'::inet OR {lhs}::inet << '192.168.0.0/16'::inet OR {lhs}::inet << 'fc00::/7'::inet)"
148            )),
149            "IsPublic" => Ok(format!(
150                "NOT ({lhs}::inet << '10.0.0.0/8'::inet OR {lhs}::inet << '172.16.0.0/12'::inet OR {lhs}::inet << '192.168.0.0/16'::inet OR {lhs}::inet << 'fc00::/7'::inet)"
151            )),
152            "IsLoopback" => Ok(format!(
153                "({lhs}::inet << '127.0.0.0/8'::inet OR {lhs}::inet << '::1/128'::inet)"
154            )),
155            "IsNotLoopback" => Ok(format!(
156                "NOT ({lhs}::inet << '127.0.0.0/8'::inet OR {lhs}::inet << '::1/128'::inet)"
157            )),
158            "IsMulticast" => Ok(format!(
159                "({lhs}::inet << '224.0.0.0/4'::inet OR {lhs}::inet << 'ff00::/8'::inet)"
160            )),
161            "IsNotMulticast" => Ok(format!(
162                "NOT ({lhs}::inet << '224.0.0.0/4'::inet OR {lhs}::inet << 'ff00::/8'::inet)"
163            )),
164            "IsLinkLocal" => Ok(format!(
165                "({lhs}::inet << '169.254.0.0/16'::inet OR {lhs}::inet << 'fe80::/10'::inet)"
166            )),
167            "IsNotLinkLocal" => Ok(format!(
168                "NOT ({lhs}::inet << '169.254.0.0/16'::inet OR {lhs}::inet << 'fe80::/10'::inet)"
169            )),
170            "IsDocumentation" => Ok(format!(
171                "({lhs}::inet << '192.0.2.0/24'::inet OR {lhs}::inet << '198.51.100.0/24'::inet OR {lhs}::inet << '203.0.113.0/24'::inet OR {lhs}::inet << '2001:db8::/32'::inet)"
172            )),
173            "IsNotDocumentation" => Ok(format!(
174                "NOT ({lhs}::inet << '192.0.2.0/24'::inet OR {lhs}::inet << '198.51.100.0/24'::inet OR {lhs}::inet << '203.0.113.0/24'::inet OR {lhs}::inet << '2001:db8::/32'::inet)"
175            )),
176            "IsCarrierGrade" => Ok(format!("({lhs}::inet << '100.64.0.0/10'::inet)")),
177            "IsNotCarrierGrade" => Ok(format!("NOT ({lhs}::inet << '100.64.0.0/10'::inet)")),
178            _ => Err(UnsupportedOperator {
179                dialect:  self.name(),
180                operator: "InetCheck",
181            }),
182        }
183    }
184
185    fn inet_binary_sql(
186        &self,
187        pg_op: &str,
188        lhs: &str,
189        rhs: &str,
190    ) -> Result<String, UnsupportedOperator> {
191        Ok(format!("{lhs}::inet {pg_op} {rhs}::inet"))
192    }
193
194    fn ltree_binary_sql(
195        &self,
196        pg_op: &str,
197        lhs: &str,
198        rhs: &str,
199        rhs_type: &str,
200    ) -> Result<String, UnsupportedOperator> {
201        Ok(format!("{lhs}::ltree {pg_op} {rhs}::{rhs_type}"))
202    }
203
204    fn ltree_any_lquery_sql(
205        &self,
206        lhs: &str,
207        placeholders: &[String],
208    ) -> Result<String, UnsupportedOperator> {
209        Ok(format!("{lhs}::ltree ? ARRAY[{}]", placeholders.join(", ")))
210    }
211
212    fn ltree_depth_sql(
213        &self,
214        op: &str,
215        lhs: &str,
216        rhs: &str,
217    ) -> Result<String, UnsupportedOperator> {
218        Ok(format!("nlevel({lhs}::ltree) {op} {rhs}"))
219    }
220
221    fn ltree_lca_sql(
222        &self,
223        lhs: &str,
224        placeholders: &[String],
225    ) -> Result<String, UnsupportedOperator> {
226        Ok(format!("{lhs}::ltree = lca(ARRAY[{}])", placeholders.join(", ")))
227    }
228
229    fn ltree_id_subquery_sql(
230        &self,
231        pg_op: &str,
232        field_expr: &str,
233        table: &str,
234        path_column: &str,
235        fk_column: Option<&str>,
236        param: &str,
237    ) -> Result<String, UnsupportedOperator> {
238        let qt = self.quote_identifier(table);
239        let qp = self.quote_identifier(path_column);
240        let qi = self.quote_identifier("id");
241        let path_subquery = format!("SELECT {qp} FROM {qt} WHERE {qi} = {param}");
242
243        if let Some(fk) = fk_column {
244            // Cross-table: fk IN (SELECT id FROM t WHERE path <op> (subquery))
245            let qfk = self.quote_identifier(fk);
246            Ok(format!("{qfk} IN (SELECT {qi} FROM {qt} WHERE {qp} {pg_op} ({path_subquery}))"))
247        } else {
248            // Self-referencing: field_expr <op> (SELECT path FROM t WHERE id = $N)
249            Ok(format!("{field_expr}::ltree {pg_op} ({path_subquery})"))
250        }
251    }
252
253    fn row_view_column_expr(
254        &self,
255        json_column: &str,
256        field_name: &str,
257        col_type: &RowViewColumnType,
258    ) -> String {
259        let pg_type = match col_type {
260            RowViewColumnType::Text => "text",
261            RowViewColumnType::Int32 => "int",
262            RowViewColumnType::Int64 => "bigint",
263            RowViewColumnType::Float64 => "double precision",
264            RowViewColumnType::Boolean => "boolean",
265            RowViewColumnType::Uuid => "uuid",
266            RowViewColumnType::Timestamptz => "timestamptz",
267            RowViewColumnType::Date => "date",
268            RowViewColumnType::Json => "jsonb",
269        };
270        format!("({json_column}->>'{field_name}')::{pg_type}")
271    }
272
273    fn generate_extended_sql(
274        &self,
275        operator: &crate::filters::ExtendedOperator,
276        field_sql: &str,
277        params: &mut Vec<serde_json::Value>,
278    ) -> fraiseql_error::Result<String> {
279        use fraiseql_error::FraiseQLError;
280
281        use crate::filters::ExtendedOperator;
282        match operator {
283            ExtendedOperator::EmailDomainEq(domain) => {
284                params.push(serde_json::Value::String(domain.clone()));
285                let idx = params.len();
286                Ok(format!("SPLIT_PART({field_sql}, '@', 2) = ${idx}"))
287            },
288            ExtendedOperator::EmailDomainIn(domains) => {
289                let placeholders: Vec<_> = domains
290                    .iter()
291                    .map(|d| {
292                        params.push(serde_json::Value::String(d.clone()));
293                        format!("${}", params.len())
294                    })
295                    .collect();
296                Ok(format!("SPLIT_PART({field_sql}, '@', 2) IN ({})", placeholders.join(", ")))
297            },
298            ExtendedOperator::EmailDomainEndswith(suffix) => {
299                let escaped = crate::where_generator::generic::escape_like_literal(suffix);
300                params.push(serde_json::Value::String(escaped));
301                let idx = params.len();
302                Ok(format!("SPLIT_PART({field_sql}, '@', 2) LIKE '%' || ${idx}"))
303            },
304            ExtendedOperator::EmailLocalPartStartswith(prefix) => {
305                let escaped = crate::where_generator::generic::escape_like_literal(prefix);
306                params.push(serde_json::Value::String(escaped));
307                let idx = params.len();
308                Ok(format!("SPLIT_PART({field_sql}, '@', 1) LIKE ${idx} || '%'"))
309            },
310            ExtendedOperator::VinWmiEq(wmi) => {
311                params.push(serde_json::Value::String(wmi.clone()));
312                let idx = params.len();
313                Ok(format!("SUBSTRING({field_sql} FROM 1 FOR 3) = ${idx}"))
314            },
315            ExtendedOperator::IbanCountryEq(country) => {
316                params.push(serde_json::Value::String(country.clone()));
317                let idx = params.len();
318                Ok(format!("SUBSTRING({field_sql} FROM 1 FOR 2) = ${idx}"))
319            },
320            _ => Err(FraiseQLError::validation(format!(
321                "Extended operator not yet implemented for PostgreSQL: {operator}"
322            ))),
323        }
324    }
325}