1use serde_json::Value;
2
3use crate::{ast::*, renderer::Renderer};
4use std::fmt::{self, Write};
5
6#[cfg_attr(feature = "docs", doc(cfg(feature = "postgresql")))]
11pub struct Postgres {
12 query: String,
13 parameters: Vec<Value>,
14}
15
16impl<'a> Renderer<'a> for Postgres {
17 const C_BACKTICK_OPEN: &'static str = "\"";
18 const C_BACKTICK_CLOSE: &'static str = "\"";
19 const C_WILDCARD: &'static str = "%";
20
21 fn build<Q>(query: Q) -> (String, Vec<Value>)
22 where
23 Q: Into<Query<'a>>,
24 {
25 let mut postgres = Postgres {
26 query: String::with_capacity(4096),
27 parameters: Vec::with_capacity(128),
28 };
29
30 Postgres::visit_query(&mut postgres, query.into());
31
32 (postgres.query, postgres.parameters)
33 }
34
35 fn write<D: fmt::Display>(&mut self, s: D) {
36 write!(&mut self.query, "{s}")
37 .expect("we ran out of memory or something else why write failed");
38 }
39
40 fn add_parameter(&mut self, value: Value) {
41 self.parameters.push(value);
42 }
43
44 fn parameter_substitution(&mut self) {
45 self.write("$");
46 self.write(self.parameters.len())
47 }
48
49 fn visit_limit_and_offset(&mut self, limit: Option<u32>, offset: Option<u32>) {
50 match (limit, offset) {
51 (Some(limit), Some(offset)) => {
52 self.write(" LIMIT ");
53 self.visit_parameterized(Value::from(limit));
54
55 self.write(" OFFSET ");
56 self.visit_parameterized(Value::from(offset))
57 }
58 (None, Some(offset)) => {
59 self.write(" OFFSET ");
60 self.visit_parameterized(Value::from(offset))
61 }
62 (Some(limit), None) => {
63 self.write(" LIMIT ");
64 self.visit_parameterized(Value::from(limit))
65 }
66 (None, None) => (),
67 }
68 }
69
70 fn visit_insert(&mut self, insert: Insert<'a>) {
71 self.write("INSERT ");
72
73 if let Some(table) = insert.table.clone() {
74 self.write("INTO ");
75 self.visit_table(table, true);
76 }
77
78 match insert.values {
79 Expression {
80 kind: ExpressionKind::Row(row),
81 ..
82 } => {
83 if row.values.is_empty() {
84 self.write(" DEFAULT VALUES");
85 } else {
86 let columns = insert.columns.len();
87
88 self.write(" (");
89 for (i, c) in insert.columns.into_iter().enumerate() {
90 self.visit_column(c.name.into_owned().into());
91
92 if i < (columns - 1) {
93 self.write(",");
94 }
95 }
96
97 self.write(")");
98 self.write(" VALUES ");
99 self.visit_row(row);
100 }
101 }
102 Expression {
103 kind: ExpressionKind::Values(values),
104 ..
105 } => {
106 let columns = insert.columns.len();
107
108 self.write(" (");
109 for (i, c) in insert.columns.into_iter().enumerate() {
110 self.visit_column(c.name.into_owned().into());
111
112 if i < (columns - 1) {
113 self.write(",");
114 }
115 }
116
117 self.write(")");
118 self.write(" VALUES ");
119 let values_len = values.len();
120
121 for (i, row) in values.into_iter().enumerate() {
122 self.visit_row(row);
123
124 if i < (values_len - 1) {
125 self.write(", ");
126 }
127 }
128 }
129 expr => self.surround_with("(", ")", |ref mut s| s.visit_expression(expr)),
130 }
131
132 match insert.on_conflict {
133 Some(OnConflict::DoNothing) => self.write(" ON CONFLICT DO NOTHING"),
134 Some(OnConflict::Update(update, constraints)) => {
135 self.write(" ON CONFLICT");
136 self.columns_to_bracket_list(constraints);
137 self.write(" DO ");
138
139 self.visit_upsert(update);
140 }
141 None => (),
142 }
143
144 if let Some(returning) = insert.returning {
145 if !returning.is_empty() {
146 let values = returning.into_iter().map(|r| r.into()).collect();
147 self.write(" RETURNING ");
148 self.visit_columns(values);
149 }
150 };
151 }
152
153 fn visit_delete(&mut self, delete: Delete<'a>) {
154 self.write("DELETE FROM ");
155 self.visit_table(delete.table, true);
156
157 if let Some(conditions) = delete.conditions {
158 self.write(" WHERE ");
159 self.visit_conditions(conditions);
160 }
161
162 if let Some(returning) = delete.returning {
163 self.write(" RETURNING ");
164
165 let length = returning.len();
166
167 for (i, expression) in returning.into_iter().enumerate() {
168 self.visit_expression(expression);
169
170 if i < (length - 1) {
171 self.write(", ");
172 }
173 }
174 }
175 }
176
177 fn visit_aggregate_to_string(&mut self, value: Expression<'a>) {
178 self.write("ARRAY_TO_STRING");
179 self.write("(");
180 self.write("ARRAY_AGG");
181 self.write("(");
182 self.visit_expression(value);
183 self.write(")");
184 self.write("','");
185 self.write(")")
186 }
187
188 fn visit_equals(&mut self, left: Expression<'a>, right: Expression<'a>) {
189 self.visit_expression(left);
190 self.write(" = ");
191 self.visit_expression(right);
192 }
193
194 fn visit_not_equals(&mut self, left: Expression<'a>, right: Expression<'a>) {
195 self.visit_expression(left);
196 self.write(" <> ");
197 self.visit_expression(right);
198 }
199
200 fn visit_json_extract(&mut self, json_extract: JsonExtract<'a>) {
201 match json_extract.path {
202 #[cfg(feature = "mysql")]
203 JsonPath::String(_) => {
204 panic!("JSON path string notation is not supported for Postgres")
205 }
206 JsonPath::Array(json_path) => {
207 self.write("(");
208 self.visit_expression(*json_extract.column);
209
210 if json_extract.extract_as_string {
211 self.write("#>>");
212 } else {
213 self.write("#>");
214 }
215
216 self.surround_with("ARRAY[", "]::text[]", |s| {
220 let len = json_path.len();
221 for (index, path) in json_path.into_iter().enumerate() {
222 s.visit_parameterized(Value::String(path.to_string()));
223 if index < len - 1 {
224 s.write(", ");
225 }
226 }
227 });
228
229 self.write(")");
230
231 if !json_extract.extract_as_string {
232 self.write("::jsonb");
233 }
234 }
235 }
236 }
237
238 fn visit_json_unquote(&mut self, json_unquote: JsonUnquote<'a>) {
239 self.write("(");
240 self.visit_expression(*json_unquote.expr);
241 self.write("#>>ARRAY[]::text[]");
242 self.write(")");
243 }
244
245 fn visit_array_contains(&mut self, left: Expression<'a>, right: Expression<'a>, not: bool) {
246 if not {
247 self.write("( NOT ");
248 }
249
250 self.visit_expression(left);
251 self.write(" @> ");
252 self.visit_expression(right);
253
254 if not {
255 self.write(" )");
256 }
257 }
258
259 fn visit_array_contained(&mut self, left: Expression<'a>, right: Expression<'a>, not: bool) {
260 if not {
261 self.write("( NOT ");
262 }
263
264 self.visit_expression(left);
265 self.write(" <@ ");
266 self.visit_expression(right);
267
268 if not {
269 self.write(" )");
270 }
271 }
272
273 fn visit_array_overlaps(&mut self, left: Expression<'a>, right: Expression<'a>) {
274 self.visit_expression(left);
275 self.write(" && ");
276 self.visit_expression(right);
277 }
278
279 fn visit_json_extract_last_array_item(&mut self, extract: JsonExtractLastArrayElem<'a>) {
280 self.write("(");
281 self.visit_expression(*extract.expr);
282 self.write("->-1");
283 self.write(")");
284 }
285
286 fn visit_json_extract_first_array_item(&mut self, extract: JsonExtractFirstArrayElem<'a>) {
287 self.write("(");
288 self.visit_expression(*extract.expr);
289 self.write("->0");
290 self.write(")");
291 }
292
293 fn visit_json_type_equals(&mut self, left: Expression<'a>, json_type: JsonType<'a>, not: bool) {
294 self.write("JSONB_TYPEOF");
295 self.write("(");
296 self.visit_expression(left);
297 self.write(")");
298
299 if not {
300 self.write(" != ");
301 } else {
302 self.write(" = ");
303 }
304
305 match json_type {
306 JsonType::Array => self.visit_expression(Value::String("array".to_string()).into()),
307 JsonType::Boolean => self.visit_expression(Value::String("boolean".to_string()).into()),
308 JsonType::Number => self.visit_expression(Value::String("number".to_string()).into()),
309 JsonType::Object => self.visit_expression(Value::String("object".to_string()).into()),
310 JsonType::String => self.visit_expression(Value::String("string".to_string()).into()),
311 JsonType::Null => self.visit_expression(Value::String("null".to_string()).into()),
312 JsonType::ColumnRef(column) => {
313 self.write("JSONB_TYPEOF");
314 self.write("(");
315 self.visit_column(*column);
316 self.write("::jsonb)")
317 }
318 }
319 }
320
321 fn visit_like(&mut self, left: Expression<'a>, right: Expression<'a>) {
322 let need_cast = matches!(&left.kind, ExpressionKind::Column(_));
323 self.visit_expression(left);
324
325 if need_cast {
328 self.write("::text");
329 }
330
331 self.write(" LIKE ");
332 self.visit_expression(right);
333 }
334
335 fn visit_not_like(&mut self, left: Expression<'a>, right: Expression<'a>) {
336 let need_cast = matches!(&left.kind, ExpressionKind::Column(_));
337 self.visit_expression(left);
338
339 if need_cast {
342 self.write("::text");
343 }
344
345 self.write(" NOT LIKE ");
346 self.visit_expression(right);
347 }
348
349 fn visit_ordering(&mut self, ordering: Ordering<'a>) {
350 let len = ordering.0.len();
351
352 for (i, (value, ordering)) in ordering.0.into_iter().enumerate() {
353 let direction = ordering.map(|dir| match dir {
354 Order::Asc => " ASC",
355 Order::Desc => " DESC",
356 Order::AscNullsFirst => " ASC NULLS FIRST",
357 Order::AscNullsLast => " ASC NULLS LAST",
358 Order::DescNullsFirst => " DESC NULLS FIRST",
359 Order::DescNullsLast => " DESC NULLS LAST",
360 });
361
362 self.visit_expression(value);
363 self.write(direction.unwrap_or(""));
364
365 if i < (len - 1) {
366 self.write(", ");
367 }
368 }
369 }
370
371 fn visit_concat(&mut self, concat: Concat<'a>) {
372 let len = concat.exprs.len();
373
374 self.surround_with("(", ")", |s| {
375 for (i, expr) in concat.exprs.into_iter().enumerate() {
376 s.visit_expression(expr);
377
378 if i < (len - 1) {
379 s.write(" || ");
380 }
381 }
382 });
383 }
384
385 fn visit_to_jsonb(&mut self, to_jsonb: ToJsonb<'a>) {
386 self.write("to_jsonb(");
387 self.visit_table(to_jsonb.table, false);
388 self.write(".*)");
389 }
390
391 fn visit_json_build_object(&mut self, json_build_object: JsonBuildObject<'a>) {
392 let values_length = json_build_object.values.len();
393 self.write("json_build_object(");
394
395 for (i, (name, expression)) in json_build_object.values.into_iter().enumerate() {
396 self.surround_with("'", "'", |renderer| {
397 renderer.write(&name);
398 });
399
400 self.write(", ");
401 self.visit_expression(expression);
402
403 if i < (values_length - 1) {
404 self.write(",");
405 }
406 }
407
408 self.write(")");
409 }
410
411 fn visit_json_agg(&mut self, json_agg: JsonAgg<'a>) {
412 self.write("json_agg(");
413
414 if json_agg.distinct {
415 self.write("DISTINCT ");
416 }
417
418 self.visit_expression(json_agg.expression);
419
420 if let Some(ordering) = json_agg.order_by {
421 self.write(" ORDER BY ");
422 self.visit_ordering(ordering);
423 }
424
425 self.write(")");
426 }
427
428 fn visit_encode(&mut self, encode: Encode<'a>) {
429 self.write("encode(");
430 self.visit_expression(encode.expression);
431 self.write(", ");
432
433 match encode.format {
434 EncodeFormat::Base64 => self.write("'base64'"),
435 EncodeFormat::Escape => self.write("'escape'"),
436 EncodeFormat::Hex => self.write("'hex'"),
437 }
438
439 self.write(")");
440 }
441
442 fn visit_join_data(&mut self, data: JoinData<'a>) {
443 if data.lateral {
444 self.write(" LATERAL ");
445 }
446
447 self.visit_table(data.table, true);
448 self.write(" ON ");
449 self.visit_conditions(data.conditions)
450 }
451}
452
453#[cfg(test)]
454mod tests {
455 use crate::{ast::json_build_object, renderer::*};
456
457 fn expected_values<T>(sql: &'static str, params: Vec<T>) -> (String, Vec<Value>)
458 where
459 T: Into<Value>,
460 {
461 (
462 String::from(sql),
463 params.into_iter().map(|p| p.into()).collect(),
464 )
465 }
466
467 fn default_params(mut additional: Vec<Value>) -> Vec<Value> {
468 let mut result = Vec::new();
469
470 for param in additional.drain(0..) {
471 result.push(param)
472 }
473
474 result
475 }
476
477 #[test]
478 fn test_single_row_insert_default_values() {
479 let query = Insert::single_into("users");
480 let (sql, params) = Postgres::build(query);
481
482 assert_eq!("INSERT INTO \"users\" DEFAULT VALUES", sql);
483 assert_eq!(default_params(vec![]), params);
484 }
485
486 #[test]
487 fn test_single_row_insert() {
488 let expected = expected_values("INSERT INTO \"users\" (\"foo\") VALUES ($1)", vec![10]);
489
490 let mut insert = Insert::single_into("users");
491 insert.value("foo", 10);
492
493 let (sql, params) = Postgres::build(insert);
494
495 assert_eq!(expected.0, sql);
496 assert_eq!(expected.1, params);
497 }
498
499 #[test]
500 #[cfg(feature = "postgresql")]
501 fn test_returning_insert() {
502 let expected = expected_values(
503 "INSERT INTO \"users\" (\"foo\") VALUES ($1) RETURNING \"foo\"",
504 vec![10],
505 );
506
507 let mut query = Insert::single_into("users");
508 query.value("foo", 10);
509
510 let mut query = query.build();
511 query.returning(vec!["foo"]);
512
513 let (sql, params) = Postgres::build(query);
514
515 assert_eq!(expected.0, sql);
516 assert_eq!(expected.1, params);
517 }
518
519 #[test]
520 #[cfg(feature = "postgresql")]
521 fn test_insert_on_conflict_update() {
522 let expected = expected_values(
523 "INSERT INTO \"users\" (\"foo\") VALUES ($1) ON CONFLICT (\"foo\") DO UPDATE SET \"foo\" = $2 WHERE \"users\".\"foo\" = $3 RETURNING \"foo\"",
524 vec![10, 3, 1],
525 );
526
527 let mut update = Update::table("users");
528 update.set("foo", 3);
529 update.so_that(("users", "foo").equals(1));
530
531 let mut insert = Insert::single_into("users");
532 insert.value("foo", 10);
533
534 let mut insert = insert.build();
535 insert.returning(vec!["foo"]);
536 insert.on_conflict(OnConflict::Update(update, Vec::from(["foo".into()])));
537
538 let (sql, params) = Postgres::build(insert);
539
540 assert_eq!(expected.0, sql);
541 assert_eq!(expected.1, params);
542 }
543
544 #[test]
545 fn test_multi_row_insert() {
546 let expected = expected_values(
547 "INSERT INTO \"users\" (\"foo\") VALUES ($1), ($2)",
548 vec![10, 11],
549 );
550
551 let mut insert = Insert::multi_into("users", vec!["foo"]);
552 insert.values(vec![10]);
553 insert.values(vec![11]);
554
555 let (sql, params) = Postgres::build(insert);
556
557 assert_eq!(expected.0, sql);
558 assert_eq!(expected.1, params);
559 }
560
561 #[test]
562 fn test_limit_and_offset_when_both_are_set() {
563 let expected = expected_values(
564 "SELECT \"users\".* FROM \"users\" LIMIT $1 OFFSET $2",
565 vec![10_i64, 2_i64],
566 );
567
568 let mut query = Select::from_table("users");
569 query.limit(10);
570 query.offset(2);
571
572 let (sql, params) = Postgres::build(query);
573
574 assert_eq!(expected.0, sql);
575 assert_eq!(expected.1, params);
576 }
577
578 #[test]
579 fn test_limit_and_offset_when_only_offset_is_set() {
580 let expected = expected_values("SELECT \"users\".* FROM \"users\" OFFSET $1", vec![10_i64]);
581
582 let mut query = Select::from_table("users");
583 query.offset(10);
584
585 let (sql, params) = Postgres::build(query);
586
587 assert_eq!(expected.0, sql);
588 assert_eq!(expected.1, params);
589 }
590
591 #[test]
592 fn test_limit_and_offset_when_only_limit_is_set() {
593 let expected = expected_values("SELECT \"users\".* FROM \"users\" LIMIT $1", vec![10_i64]);
594
595 let mut query = Select::from_table("users");
596 query.limit(10);
597
598 let (sql, params) = Postgres::build(query);
599
600 assert_eq!(expected.0, sql);
601 assert_eq!(expected.1, params);
602 }
603
604 #[test]
605 fn test_distinct() {
606 let expected_sql = "SELECT DISTINCT \"bar\" FROM \"test\"";
607
608 let mut query = Select::from_table("test");
609 query.column(Column::new("bar"));
610 query.distinct();
611
612 let (sql, _) = Postgres::build(query);
613
614 assert_eq!(expected_sql, sql);
615 }
616
617 #[test]
618 fn test_distinct_with_subquery() {
619 let expected_sql = "SELECT DISTINCT (SELECT $1 FROM \"test2\"), \"bar\" FROM \"test\"";
620
621 let mut query = Select::from_table("test");
622
623 query.value({
624 let mut query = Select::from_table("test2");
625 query.value(1);
626
627 query
628 });
629
630 query.column(Column::new("bar"));
631 query.distinct();
632
633 let (sql, _) = Postgres::build(query);
634
635 assert_eq!(expected_sql, sql);
636 }
637
638 #[test]
639 fn test_from() {
640 let expected_sql =
641 "SELECT \"foo\".*, \"bar\".\"a\" FROM \"foo\", (SELECT \"a\" FROM \"baz\") AS \"bar\"";
642
643 let mut query = Select::default();
644 query.and_from("foo");
645
646 query.and_from(
647 Table::from({
648 let mut query = Select::from_table("baz");
649 query.column("a");
650 query
651 })
652 .alias("bar"),
653 );
654
655 query.value(Table::from("foo").asterisk());
656 query.column(("bar", "a"));
657
658 let (sql, _) = Postgres::build(query);
659 assert_eq!(expected_sql, sql);
660 }
661
662 #[test]
663 fn test_like_cast_to_string() {
664 let expected = expected_values(
665 r#"SELECT "test".* FROM "test" WHERE "jsonField"::text LIKE $1"#,
666 vec!["%foo%"],
667 );
668
669 let mut query = Select::from_table("test");
670 query.so_that(Column::from("jsonField").like("%foo%"));
671
672 let (sql, params) = Postgres::build(query);
673
674 assert_eq!(expected.0, sql);
675 assert_eq!(expected.1, params);
676 }
677
678 #[test]
679 fn test_not_like_cast_to_string() {
680 let expected = expected_values(
681 r#"SELECT "test".* FROM "test" WHERE "jsonField"::text NOT LIKE $1"#,
682 vec!["%foo%"],
683 );
684
685 let mut query = Select::from_table("test");
686 query.so_that(Column::from("jsonField").not_like("%foo%"));
687
688 let (sql, params) = Postgres::build(query);
689
690 assert_eq!(expected.0, sql);
691 assert_eq!(expected.1, params);
692 }
693
694 #[test]
695 fn test_begins_with_cast_to_string() {
696 let expected = expected_values(
697 r#"SELECT "test".* FROM "test" WHERE "jsonField"::text LIKE $1"#,
698 vec!["%foo"],
699 );
700
701 let mut query = Select::from_table("test");
702 query.so_that(Column::from("jsonField").like("%foo"));
703
704 let (sql, params) = Postgres::build(query);
705
706 assert_eq!(expected.0, sql);
707 assert_eq!(expected.1, params);
708 }
709
710 #[test]
711 fn test_not_begins_with_cast_to_string() {
712 let expected = expected_values(
713 r#"SELECT "test".* FROM "test" WHERE "jsonField"::text NOT LIKE $1"#,
714 vec!["%foo"],
715 );
716
717 let mut query = Select::from_table("test");
718 query.so_that(Column::from("jsonField").not_like("%foo"));
719
720 let (sql, params) = Postgres::build(query);
721
722 assert_eq!(expected.0, sql);
723 assert_eq!(expected.1, params);
724 }
725
726 #[test]
727 fn test_ends_with_cast_to_string() {
728 let expected = expected_values(
729 r#"SELECT "test".* FROM "test" WHERE "jsonField"::text LIKE $1"#,
730 vec!["foo%"],
731 );
732
733 let mut query = Select::from_table("test");
734 query.so_that(Column::from("jsonField").like("foo%"));
735
736 let (sql, params) = Postgres::build(query);
737
738 assert_eq!(expected.0, sql);
739 assert_eq!(expected.1, params);
740 }
741
742 #[test]
743 fn test_not_ends_with_cast_to_string() {
744 let expected = expected_values(
745 r#"SELECT "test".* FROM "test" WHERE "jsonField"::text NOT LIKE $1"#,
746 vec!["foo%"],
747 );
748
749 let mut query = Select::from_table("test");
750 query.so_that(Column::from("jsonField").not_like("foo%"));
751
752 let (sql, params) = Postgres::build(query);
753
754 assert_eq!(expected.0, sql);
755 assert_eq!(expected.1, params);
756 }
757
758 #[test]
759 fn test_default_insert() {
760 let mut insert = Insert::single_into("foo");
761 insert.value("foo", "bar");
762 insert.value("baz", default_value());
763
764 let (sql, _) = Postgres::build(insert);
765
766 assert_eq!(
767 "INSERT INTO \"foo\" (\"foo\",\"baz\") VALUES ($1,DEFAULT)",
768 sql
769 );
770 }
771
772 #[test]
773 fn join_is_inserted_positionally() {
774 let joined_table = Table::from("User").left_join(
775 Table::from("Post")
776 .alias("p")
777 .on(("p", "userId").equals(Column::from(("User", "id")))),
778 );
779
780 let mut q = Select::from_table(joined_table);
781 q.and_from("Toto");
782
783 let (sql, _) = Postgres::build(q);
784
785 assert_eq!("SELECT \"User\".*, \"Toto\".* FROM \"User\" LEFT JOIN \"Post\" AS \"p\" ON \"p\".\"userId\" = \"User\".\"id\", \"Toto\"", sql);
786 }
787
788 #[test]
789 fn test_json_build_object_raw_value() {
790 let mut select = Select::default();
791 select.value(json_build_object([("id", raw("1"))]));
792
793 let (sql, _) = Postgres::build(select);
794
795 assert_eq!(r#"SELECT json_build_object('id', 1)"#, sql);
796 }
797
798 #[test]
799 fn test_json_build_object_column() {
800 let mut select = Select::from_table("User");
801 select.value(json_build_object([("name", Column::from("name"))]));
802
803 let (sql, _) = Postgres::build(select);
804
805 assert_eq!(
806 r#"SELECT json_build_object('name', "name") FROM "User""#,
807 sql
808 );
809 }
810
811 #[test]
812 fn test_cte() {
813 let mut insert = Insert::single_into("User");
814 insert.value("name", "Musti");
815
816 let mut insert = insert.build();
817 insert.returning(["id", "name"]);
818
819 let mut select = Select::from_table("public_user");
820 select.with(CommonTableExpression::new("public_user", insert));
821 select.columns(["id", "name"]);
822
823 let (sql, _) = Postgres::build(select);
824
825 assert_eq!(
826 r#"WITH "public_user" AS (INSERT INTO "User" ("name") VALUES ($1) RETURNING "id", "name") SELECT "id", "name" FROM "public_user""#,
827 sql
828 );
829 }
830}