Skip to main content

fraiseql_db/dialect/
postgres.rs

1//! PostgreSQL SQL dialect implementation.
2
3use std::{borrow::Cow, fmt::Write};
4
5use super::trait_def::{RowViewColumnType, SqlDialect, UnsupportedOperator};
6
7/// PostgreSQL dialect for [`GenericWhereGenerator`].
8///
9/// [`GenericWhereGenerator`]: crate::where_generator::GenericWhereGenerator
10pub struct PostgresDialect;
11
12impl SqlDialect for PostgresDialect {
13    fn name(&self) -> &'static str {
14        "PostgreSQL"
15    }
16
17    fn quote_identifier(&self, name: &str) -> String {
18        format!("\"{}\"", name.replace('"', "\"\""))
19    }
20
21    fn json_extract_scalar(&self, column: &str, path: &[String]) -> String {
22        use crate::path_escape::{escape_postgres_jsonb_path, escape_postgres_jsonb_segment};
23
24        if path.len() == 1 {
25            let escaped = escape_postgres_jsonb_segment(&path[0]);
26            format!("{column}->>'{escaped}'")
27        } else {
28            let escaped_path = escape_postgres_jsonb_path(path);
29            let mut result = column.to_owned();
30            for (i, segment) in escaped_path.iter().enumerate() {
31                if i < escaped_path.len() - 1 {
32                    write!(result, "->'{segment}'").expect("write to String");
33                } else {
34                    write!(result, "->>'{segment}'").expect("write to String");
35                }
36            }
37            result
38        }
39    }
40
41    fn placeholder(&self, n: usize) -> String {
42        format!("${n}")
43    }
44
45    fn cast_to_numeric<'a>(&self, expr: &'a str) -> Cow<'a, str> {
46        Cow::Owned(format!("({expr})::numeric"))
47    }
48
49    fn cast_to_boolean<'a>(&self, expr: &'a str) -> Cow<'a, str> {
50        Cow::Owned(format!("({expr})::boolean"))
51    }
52
53    fn cast_param_numeric<'a>(&self, placeholder: &'a str) -> Cow<'a, str> {
54        Cow::Owned(format!("({placeholder}::text)::numeric"))
55    }
56
57    fn ilike_sql(&self, lhs: &str, rhs: &str) -> String {
58        format!("{lhs} ILIKE {rhs}")
59    }
60
61    fn json_array_length(&self, expr: &str) -> String {
62        format!("jsonb_array_length({expr}::jsonb)")
63    }
64
65    fn array_contains_sql(&self, lhs: &str, rhs: &str) -> Result<String, UnsupportedOperator> {
66        Ok(format!("{lhs}::jsonb @> {rhs}::jsonb"))
67    }
68
69    fn array_contained_by_sql(&self, lhs: &str, rhs: &str) -> Result<String, UnsupportedOperator> {
70        Ok(format!("{lhs}::jsonb <@ {rhs}::jsonb"))
71    }
72
73    fn array_overlaps_sql(&self, lhs: &str, rhs: &str) -> Result<String, UnsupportedOperator> {
74        Ok(format!("{lhs}::jsonb && {rhs}::jsonb"))
75    }
76
77    fn fts_matches_sql(&self, expr: &str, param: &str) -> Result<String, UnsupportedOperator> {
78        Ok(format!("to_tsvector({expr}) @@ to_tsquery({param})"))
79    }
80
81    fn fts_plain_query_sql(&self, expr: &str, param: &str) -> Result<String, UnsupportedOperator> {
82        Ok(format!("to_tsvector({expr}) @@ plainto_tsquery({param})"))
83    }
84
85    fn fts_phrase_query_sql(&self, expr: &str, param: &str) -> Result<String, UnsupportedOperator> {
86        Ok(format!("to_tsvector({expr}) @@ phraseto_tsquery({param})"))
87    }
88
89    fn fts_websearch_query_sql(
90        &self,
91        expr: &str,
92        param: &str,
93    ) -> Result<String, UnsupportedOperator> {
94        Ok(format!("to_tsvector({expr}) @@ websearch_to_tsquery({param})"))
95    }
96
97    fn regex_sql(
98        &self,
99        lhs: &str,
100        rhs: &str,
101        case_insensitive: bool,
102        negate: bool,
103    ) -> Result<String, UnsupportedOperator> {
104        let op = match (case_insensitive, negate) {
105            (false, false) => "~",
106            (true, false) => "~*",
107            (false, true) => "!~",
108            (true, true) => "!~*",
109        };
110        Ok(format!("{lhs} {op} {rhs}"))
111    }
112
113    // ── PostgreSQL-only operators ──────────────────────────────────────────────
114
115    fn vector_distance_sql(
116        &self,
117        pg_op: &str,
118        lhs: &str,
119        rhs: &str,
120    ) -> Result<String, UnsupportedOperator> {
121        Ok(format!("{lhs}::vector {pg_op} {rhs}::vector"))
122    }
123
124    fn jaccard_distance_sql(&self, lhs: &str, rhs: &str) -> Result<String, UnsupportedOperator> {
125        Ok(format!("({lhs})::text[] <%> ({rhs})::text[]"))
126    }
127
128    fn inet_check_sql(&self, lhs: &str, check_name: &str) -> Result<String, UnsupportedOperator> {
129        match check_name {
130            "IsIPv4" => Ok(format!("family({lhs}::inet) = 4")),
131            "IsIPv6" => Ok(format!("family({lhs}::inet) = 6")),
132            "IsPrivate" => Ok(format!(
133                "({lhs}::inet << '10.0.0.0/8'::inet OR {lhs}::inet << '172.16.0.0/12'::inet OR {lhs}::inet << '192.168.0.0/16'::inet OR {lhs}::inet << '169.254.0.0/16'::inet)"
134            )),
135            "IsPublic" => Ok(format!(
136                "NOT ({lhs}::inet << '10.0.0.0/8'::inet OR {lhs}::inet << '172.16.0.0/12'::inet OR {lhs}::inet << '192.168.0.0/16'::inet OR {lhs}::inet << '169.254.0.0/16'::inet)"
137            )),
138            "IsLoopback" => Ok(format!(
139                "(family({lhs}::inet) = 4 AND {lhs}::inet << '127.0.0.0/8'::inet) OR (family({lhs}::inet) = 6 AND {lhs}::inet << '::1/128'::inet)"
140            )),
141            _ => Err(UnsupportedOperator {
142                dialect:  self.name(),
143                operator: "InetCheck",
144            }),
145        }
146    }
147
148    fn inet_binary_sql(
149        &self,
150        pg_op: &str,
151        lhs: &str,
152        rhs: &str,
153    ) -> Result<String, UnsupportedOperator> {
154        Ok(format!("{lhs}::inet {pg_op} {rhs}::inet"))
155    }
156
157    fn ltree_binary_sql(
158        &self,
159        pg_op: &str,
160        lhs: &str,
161        rhs: &str,
162        rhs_type: &str,
163    ) -> Result<String, UnsupportedOperator> {
164        Ok(format!("{lhs}::ltree {pg_op} {rhs}::{rhs_type}"))
165    }
166
167    fn ltree_any_lquery_sql(
168        &self,
169        lhs: &str,
170        placeholders: &[String],
171    ) -> Result<String, UnsupportedOperator> {
172        Ok(format!("{lhs}::ltree ? ARRAY[{}]", placeholders.join(", ")))
173    }
174
175    fn ltree_depth_sql(
176        &self,
177        op: &str,
178        lhs: &str,
179        rhs: &str,
180    ) -> Result<String, UnsupportedOperator> {
181        Ok(format!("nlevel({lhs}::ltree) {op} {rhs}"))
182    }
183
184    fn ltree_lca_sql(
185        &self,
186        lhs: &str,
187        placeholders: &[String],
188    ) -> Result<String, UnsupportedOperator> {
189        Ok(format!("{lhs}::ltree = lca(ARRAY[{}])", placeholders.join(", ")))
190    }
191
192    fn row_view_column_expr(
193        &self,
194        json_column: &str,
195        field_name: &str,
196        col_type: &RowViewColumnType,
197    ) -> String {
198        let pg_type = match col_type {
199            RowViewColumnType::Text => "text",
200            RowViewColumnType::Int32 => "int",
201            RowViewColumnType::Int64 => "bigint",
202            RowViewColumnType::Float64 => "double precision",
203            RowViewColumnType::Boolean => "boolean",
204            RowViewColumnType::Uuid => "uuid",
205            RowViewColumnType::Timestamptz => "timestamptz",
206            RowViewColumnType::Date => "date",
207            RowViewColumnType::Json => "jsonb",
208        };
209        format!("({json_column}->>'{field_name}')::{pg_type}")
210    }
211
212    fn generate_extended_sql(
213        &self,
214        operator: &crate::filters::ExtendedOperator,
215        field_sql: &str,
216        params: &mut Vec<serde_json::Value>,
217    ) -> fraiseql_error::Result<String> {
218        use fraiseql_error::FraiseQLError;
219
220        use crate::filters::ExtendedOperator;
221        match operator {
222            ExtendedOperator::EmailDomainEq(domain) => {
223                params.push(serde_json::Value::String(domain.clone()));
224                let idx = params.len();
225                Ok(format!("SPLIT_PART({field_sql}, '@', 2) = ${idx}"))
226            },
227            ExtendedOperator::EmailDomainIn(domains) => {
228                let placeholders: Vec<_> = domains
229                    .iter()
230                    .map(|d| {
231                        params.push(serde_json::Value::String(d.clone()));
232                        format!("${}", params.len())
233                    })
234                    .collect();
235                Ok(format!("SPLIT_PART({field_sql}, '@', 2) IN ({})", placeholders.join(", ")))
236            },
237            ExtendedOperator::EmailDomainEndswith(suffix) => {
238                let escaped = crate::where_generator::generic::escape_like_literal(suffix);
239                params.push(serde_json::Value::String(escaped));
240                let idx = params.len();
241                Ok(format!("SPLIT_PART({field_sql}, '@', 2) LIKE '%' || ${idx}"))
242            },
243            ExtendedOperator::EmailLocalPartStartswith(prefix) => {
244                let escaped = crate::where_generator::generic::escape_like_literal(prefix);
245                params.push(serde_json::Value::String(escaped));
246                let idx = params.len();
247                Ok(format!("SPLIT_PART({field_sql}, '@', 1) LIKE ${idx} || '%'"))
248            },
249            ExtendedOperator::VinWmiEq(wmi) => {
250                params.push(serde_json::Value::String(wmi.clone()));
251                let idx = params.len();
252                Ok(format!("SUBSTRING({field_sql} FROM 1 FOR 3) = ${idx}"))
253            },
254            ExtendedOperator::IbanCountryEq(country) => {
255                params.push(serde_json::Value::String(country.clone()));
256                let idx = params.len();
257                Ok(format!("SUBSTRING({field_sql} FROM 1 FOR 2) = ${idx}"))
258            },
259            _ => Err(FraiseQLError::validation(format!(
260                "Extended operator not yet implemented for PostgreSQL: {operator}"
261            ))),
262        }
263    }
264}