Skip to main content

fraiseql_db/dialect/
postgres.rs

1//! PostgreSQL SQL dialect implementation.
2
3use std::{borrow::Cow, fmt::Write};
4
5use super::trait_def::{RowViewColumnType, SqlDialect, UnsupportedOperator};
6
7/// PostgreSQL dialect for [`GenericWhereGenerator`].
8///
9/// [`GenericWhereGenerator`]: crate::where_generator::GenericWhereGenerator
10pub struct PostgresDialect;
11
12impl SqlDialect for PostgresDialect {
13    fn name(&self) -> &'static str {
14        "PostgreSQL"
15    }
16
17    fn quote_identifier(&self, name: &str) -> String {
18        format!("\"{}\"", name.replace('"', "\"\""))
19    }
20
21    fn json_extract_scalar(&self, column: &str, path: &[String]) -> String {
22        use crate::path_escape::{escape_postgres_jsonb_path, escape_postgres_jsonb_segment};
23
24        if path.len() == 1 {
25            let escaped = escape_postgres_jsonb_segment(&path[0]);
26            format!("{column}->>'{escaped}'")
27        } else {
28            let escaped_path = escape_postgres_jsonb_path(path);
29            let mut result = column.to_owned();
30            for (i, segment) in escaped_path.iter().enumerate() {
31                if i < escaped_path.len() - 1 {
32                    write!(result, "->'{segment}'").expect("write to String");
33                } else {
34                    write!(result, "->>'{segment}'").expect("write to String");
35                }
36            }
37            result
38        }
39    }
40
41    fn placeholder(&self, n: usize) -> String {
42        format!("${n}")
43    }
44
45    fn cast_to_numeric<'a>(&self, expr: &'a str) -> Cow<'a, str> {
46        Cow::Owned(format!("({expr})::numeric"))
47    }
48
49    fn cast_to_boolean<'a>(&self, expr: &'a str) -> Cow<'a, str> {
50        Cow::Owned(format!("({expr})::boolean"))
51    }
52
53    fn cast_param_numeric<'a>(&self, placeholder: &'a str) -> Cow<'a, str> {
54        Cow::Owned(format!("({placeholder}::text)::numeric"))
55    }
56
57    fn cast_native_param(&self, placeholder: &str, native_type: &str) -> String {
58        match native_type.to_lowercase().as_str() {
59            // bool uses QueryParam::Bool which encodes correctly in binary — no intermediate text.
60            "boolean" | "bool" => format!("{placeholder}::bool"),
61            // text/varchar/char(n) — no cast needed.
62            "text" | "varchar" | "character varying" | "char" | "bpchar" | "name" => {
63                placeholder.to_string()
64            },
65            // Everything else: two-step cast forces $N to be resolved as text by the
66            // server, avoiding binary-encoding mismatches for uuid, timestamps, ints, etc.
67            _ => format!("{placeholder}::text::{native_type}"),
68        }
69    }
70
71    fn ilike_sql(&self, lhs: &str, rhs: &str) -> String {
72        format!("{lhs} ILIKE {rhs}")
73    }
74
75    fn json_array_length(&self, expr: &str) -> String {
76        format!("jsonb_array_length({expr}::jsonb)")
77    }
78
79    fn array_contains_sql(&self, lhs: &str, rhs: &str) -> Result<String, UnsupportedOperator> {
80        Ok(format!("{lhs}::jsonb @> {rhs}::jsonb"))
81    }
82
83    fn array_contained_by_sql(&self, lhs: &str, rhs: &str) -> Result<String, UnsupportedOperator> {
84        Ok(format!("{lhs}::jsonb <@ {rhs}::jsonb"))
85    }
86
87    fn array_overlaps_sql(&self, lhs: &str, rhs: &str) -> Result<String, UnsupportedOperator> {
88        Ok(format!("{lhs}::jsonb && {rhs}::jsonb"))
89    }
90
91    fn fts_matches_sql(&self, expr: &str, param: &str) -> Result<String, UnsupportedOperator> {
92        Ok(format!("to_tsvector({expr}) @@ to_tsquery({param})"))
93    }
94
95    fn fts_plain_query_sql(&self, expr: &str, param: &str) -> Result<String, UnsupportedOperator> {
96        Ok(format!("to_tsvector({expr}) @@ plainto_tsquery({param})"))
97    }
98
99    fn fts_phrase_query_sql(&self, expr: &str, param: &str) -> Result<String, UnsupportedOperator> {
100        Ok(format!("to_tsvector({expr}) @@ phraseto_tsquery({param})"))
101    }
102
103    fn fts_websearch_query_sql(
104        &self,
105        expr: &str,
106        param: &str,
107    ) -> Result<String, UnsupportedOperator> {
108        Ok(format!("to_tsvector({expr}) @@ websearch_to_tsquery({param})"))
109    }
110
111    fn regex_sql(
112        &self,
113        lhs: &str,
114        rhs: &str,
115        case_insensitive: bool,
116        negate: bool,
117    ) -> Result<String, UnsupportedOperator> {
118        let op = match (case_insensitive, negate) {
119            (false, false) => "~",
120            (true, false) => "~*",
121            (false, true) => "!~",
122            (true, true) => "!~*",
123        };
124        Ok(format!("{lhs} {op} {rhs}"))
125    }
126
127    // ── PostgreSQL-only operators ──────────────────────────────────────────────
128
129    fn vector_distance_sql(
130        &self,
131        pg_op: &str,
132        lhs: &str,
133        rhs: &str,
134    ) -> Result<String, UnsupportedOperator> {
135        Ok(format!("{lhs}::vector {pg_op} {rhs}::vector"))
136    }
137
138    fn jaccard_distance_sql(&self, lhs: &str, rhs: &str) -> Result<String, UnsupportedOperator> {
139        Ok(format!("({lhs})::text[] <%> ({rhs})::text[]"))
140    }
141
142    fn inet_check_sql(&self, lhs: &str, check_name: &str) -> Result<String, UnsupportedOperator> {
143        match check_name {
144            "IsIPv4" => Ok(format!("family({lhs}::inet) = 4")),
145            "IsIPv6" => Ok(format!("family({lhs}::inet) = 6")),
146            "IsPrivate" => Ok(format!(
147                "({lhs}::inet << '10.0.0.0/8'::inet OR {lhs}::inet << '172.16.0.0/12'::inet OR {lhs}::inet << '192.168.0.0/16'::inet OR {lhs}::inet << '169.254.0.0/16'::inet)"
148            )),
149            "IsPublic" => Ok(format!(
150                "NOT ({lhs}::inet << '10.0.0.0/8'::inet OR {lhs}::inet << '172.16.0.0/12'::inet OR {lhs}::inet << '192.168.0.0/16'::inet OR {lhs}::inet << '169.254.0.0/16'::inet)"
151            )),
152            "IsLoopback" => Ok(format!(
153                "(family({lhs}::inet) = 4 AND {lhs}::inet << '127.0.0.0/8'::inet) OR (family({lhs}::inet) = 6 AND {lhs}::inet << '::1/128'::inet)"
154            )),
155            _ => Err(UnsupportedOperator {
156                dialect:  self.name(),
157                operator: "InetCheck",
158            }),
159        }
160    }
161
162    fn inet_binary_sql(
163        &self,
164        pg_op: &str,
165        lhs: &str,
166        rhs: &str,
167    ) -> Result<String, UnsupportedOperator> {
168        Ok(format!("{lhs}::inet {pg_op} {rhs}::inet"))
169    }
170
171    fn ltree_binary_sql(
172        &self,
173        pg_op: &str,
174        lhs: &str,
175        rhs: &str,
176        rhs_type: &str,
177    ) -> Result<String, UnsupportedOperator> {
178        Ok(format!("{lhs}::ltree {pg_op} {rhs}::{rhs_type}"))
179    }
180
181    fn ltree_any_lquery_sql(
182        &self,
183        lhs: &str,
184        placeholders: &[String],
185    ) -> Result<String, UnsupportedOperator> {
186        Ok(format!("{lhs}::ltree ? ARRAY[{}]", placeholders.join(", ")))
187    }
188
189    fn ltree_depth_sql(
190        &self,
191        op: &str,
192        lhs: &str,
193        rhs: &str,
194    ) -> Result<String, UnsupportedOperator> {
195        Ok(format!("nlevel({lhs}::ltree) {op} {rhs}"))
196    }
197
198    fn ltree_lca_sql(
199        &self,
200        lhs: &str,
201        placeholders: &[String],
202    ) -> Result<String, UnsupportedOperator> {
203        Ok(format!("{lhs}::ltree = lca(ARRAY[{}])", placeholders.join(", ")))
204    }
205
206    fn row_view_column_expr(
207        &self,
208        json_column: &str,
209        field_name: &str,
210        col_type: &RowViewColumnType,
211    ) -> String {
212        let pg_type = match col_type {
213            RowViewColumnType::Text => "text",
214            RowViewColumnType::Int32 => "int",
215            RowViewColumnType::Int64 => "bigint",
216            RowViewColumnType::Float64 => "double precision",
217            RowViewColumnType::Boolean => "boolean",
218            RowViewColumnType::Uuid => "uuid",
219            RowViewColumnType::Timestamptz => "timestamptz",
220            RowViewColumnType::Date => "date",
221            RowViewColumnType::Json => "jsonb",
222        };
223        format!("({json_column}->>'{field_name}')::{pg_type}")
224    }
225
226    fn generate_extended_sql(
227        &self,
228        operator: &crate::filters::ExtendedOperator,
229        field_sql: &str,
230        params: &mut Vec<serde_json::Value>,
231    ) -> fraiseql_error::Result<String> {
232        use fraiseql_error::FraiseQLError;
233
234        use crate::filters::ExtendedOperator;
235        match operator {
236            ExtendedOperator::EmailDomainEq(domain) => {
237                params.push(serde_json::Value::String(domain.clone()));
238                let idx = params.len();
239                Ok(format!("SPLIT_PART({field_sql}, '@', 2) = ${idx}"))
240            },
241            ExtendedOperator::EmailDomainIn(domains) => {
242                let placeholders: Vec<_> = domains
243                    .iter()
244                    .map(|d| {
245                        params.push(serde_json::Value::String(d.clone()));
246                        format!("${}", params.len())
247                    })
248                    .collect();
249                Ok(format!("SPLIT_PART({field_sql}, '@', 2) IN ({})", placeholders.join(", ")))
250            },
251            ExtendedOperator::EmailDomainEndswith(suffix) => {
252                let escaped = crate::where_generator::generic::escape_like_literal(suffix);
253                params.push(serde_json::Value::String(escaped));
254                let idx = params.len();
255                Ok(format!("SPLIT_PART({field_sql}, '@', 2) LIKE '%' || ${idx}"))
256            },
257            ExtendedOperator::EmailLocalPartStartswith(prefix) => {
258                let escaped = crate::where_generator::generic::escape_like_literal(prefix);
259                params.push(serde_json::Value::String(escaped));
260                let idx = params.len();
261                Ok(format!("SPLIT_PART({field_sql}, '@', 1) LIKE ${idx} || '%'"))
262            },
263            ExtendedOperator::VinWmiEq(wmi) => {
264                params.push(serde_json::Value::String(wmi.clone()));
265                let idx = params.len();
266                Ok(format!("SUBSTRING({field_sql} FROM 1 FOR 3) = ${idx}"))
267            },
268            ExtendedOperator::IbanCountryEq(country) => {
269                params.push(serde_json::Value::String(country.clone()));
270                let idx = params.len();
271                Ok(format!("SUBSTRING({field_sql} FROM 1 FOR 2) = ${idx}"))
272            },
273            _ => Err(FraiseQLError::validation(format!(
274                "Extended operator not yet implemented for PostgreSQL: {operator}"
275            ))),
276        }
277    }
278}