1use crate::api_request::preferences::PreferCount;
31use crate::backend::SqlDialect;
32use crate::plan::call_plan::CallPlan;
33use crate::plan::mutate_plan::MutatePlan;
34use crate::plan::read_plan::ReadPlanTree;
35
36use super::builder;
37use super::fragment;
38use super::sql_builder::SqlBuilder;
39
40pub fn main_read(
72 read_plan: &ReadPlanTree,
73 prefer_count: Option<PreferCount>,
74 max_rows: Option<i64>,
75 headers_only: bool,
76 handler: Option<&crate::schema_cache::media_handler::MediaHandler>,
77 dialect: &dyn SqlDialect,
78) -> SqlBuilder {
79 let inner = builder::read_plan_to_query(read_plan, dialect);
80 let mut b = SqlBuilder::new();
81
82 b.push("WITH dbrst_source AS (");
84 b.push_builder(&inner);
85 b.push(")");
86
87 let has_exact_count = matches!(prefer_count, Some(PreferCount::Exact));
89 if has_exact_count {
90 let count_q = builder::read_plan_to_count_query(read_plan, dialect);
91 b.push(", dbrst_count AS (");
92 b.push_builder(&count_q);
93 b.push(")");
94 }
95
96 b.push(" SELECT ");
98
99 if has_exact_count {
101 b.push("(SELECT ");
102 b.push_ident("dbrst_filtered_count");
103 b.push(" FROM dbrst_count)");
104 } else {
105 b.push("NULL");
106 }
107 b.push(" AS total_result_set");
108
109 b.push(", ");
111 dialect.count_expr(&mut b, "_dbrst_t");
112 b.push(" AS page_total");
113
114 let col_names = select_column_names(read_plan);
116 let col_refs: Vec<&str> = col_names.iter().map(|s| s.as_str()).collect();
117
118 if headers_only {
120 b.push(", NULL AS body");
121 } else {
122 b.push(", ");
123 if let Some(h) = handler {
124 fragment::handler_agg_with_media_cols(&mut b, h, false, dialect, &col_refs);
125 } else {
126 fragment::handler_agg_cols(&mut b, false, dialect, &col_refs);
127 }
128 b.push(" AS body");
129 }
130
131 b.push(", ");
133 dialect.get_session_var(&mut b, "response.headers", "response_headers");
134 b.push(", ");
135 dialect.get_session_var(&mut b, "response.status", "response_status");
136
137 b.push(" FROM (SELECT * FROM dbrst_source");
139
140 if let Some(max) = max_rows {
142 b.push(" LIMIT ");
143 b.push(&max.to_string());
144 }
145
146 b.push(") AS ");
147 b.push_ident("_dbrst_t");
148
149 b
150}
151
152fn select_column_names(tree: &ReadPlanTree) -> Vec<String> {
158 if tree.node.select.iter().any(|sf| sf.field.full_row) {
161 return Vec::new();
162 }
163 tree.node
164 .select
165 .iter()
166 .map(|sf| {
167 sf.alias
168 .as_ref()
169 .map(|a| a.to_string())
170 .unwrap_or_else(|| sf.field.name.to_string())
171 })
172 .collect()
173}
174
175pub fn main_write(
206 mutate_plan: &MutatePlan,
207 _read_plan: &ReadPlanTree,
208 return_representation: bool,
209 handler: Option<&crate::schema_cache::media_handler::MediaHandler>,
210 dialect: &dyn SqlDialect,
211) -> SqlBuilder {
212 let inner = builder::mutate_plan_to_query(mutate_plan, dialect);
213 let has_returning = !mutate_plan.returning().is_empty();
214 let mut b = SqlBuilder::new();
215
216 b.push("WITH dbrst_source AS (");
217 b.push_builder(&inner);
218 if !has_returning {
219 b.push(" RETURNING 1");
220 }
221 b.push(")");
222
223 b.push(" SELECT ");
225
226 b.push("'' AS total_result_set");
228
229 b.push(", ");
231 dialect.count_expr(&mut b, "_dbrst_t");
232 b.push(" AS page_total");
233
234 let col_names: Vec<String> = mutate_plan
236 .returning()
237 .iter()
238 .map(|sf| {
239 sf.alias
240 .as_ref()
241 .map(|a| a.to_string())
242 .unwrap_or_else(|| sf.field.name.to_string())
243 })
244 .collect();
245 let col_refs: Vec<&str> = col_names.iter().map(|s| s.as_str()).collect();
246
247 if return_representation && has_returning {
249 b.push(", ");
250 if let Some(h) = handler {
251 fragment::handler_agg_with_media_cols(&mut b, h, false, dialect, &col_refs);
252 } else {
253 fragment::handler_agg_cols(&mut b, false, dialect, &col_refs);
254 }
255 b.push(" AS body");
256 } else {
257 b.push(", NULL AS body");
258 }
259
260 b.push(", ");
262 dialect.get_session_var(&mut b, "response.headers", "response_headers");
263 b.push(", ");
264 dialect.get_session_var(&mut b, "response.status", "response_status");
265
266 b.push(" FROM (SELECT * FROM dbrst_source) AS ");
268 b.push_ident("_dbrst_t");
269
270 b
271}
272
273pub fn main_write_split(
283 mutate_plan: &MutatePlan,
284 _read_plan: &ReadPlanTree,
285 return_representation: bool,
286 handler: Option<&crate::schema_cache::media_handler::MediaHandler>,
287 dialect: &dyn SqlDialect,
288) -> (SqlBuilder, SqlBuilder) {
289 let mut mutation = builder::mutate_plan_to_query(mutate_plan, dialect);
290 let has_returning = !mutate_plan.returning().is_empty();
291 if !has_returning {
292 mutation.push(" RETURNING 1");
293 }
294
295 let mut b = SqlBuilder::new();
297 b.push("SELECT ");
298
299 b.push("'' AS total_result_set");
301
302 b.push(", ");
304 dialect.count_expr(&mut b, "_dbrst_t");
305 b.push(" AS page_total");
306
307 let col_names: Vec<String> = mutate_plan
309 .returning()
310 .iter()
311 .map(|sf| {
312 sf.alias
313 .as_ref()
314 .map(|a| a.to_string())
315 .unwrap_or_else(|| sf.field.name.to_string())
316 })
317 .collect();
318 let col_refs: Vec<&str> = col_names.iter().map(|s| s.as_str()).collect();
319
320 if return_representation && has_returning {
322 b.push(", ");
323 if let Some(h) = handler {
324 fragment::handler_agg_with_media_cols(&mut b, h, false, dialect, &col_refs);
325 } else {
326 fragment::handler_agg_cols(&mut b, false, dialect, &col_refs);
327 }
328 b.push(" AS body");
329 } else {
330 b.push(", NULL AS body");
331 }
332
333 b.push(", ");
335 dialect.get_session_var(&mut b, "response.headers", "response_headers");
336 b.push(", ");
337 dialect.get_session_var(&mut b, "response.status", "response_status");
338
339 b.push(" FROM ");
341 b.push_ident("_dbrst_mut");
342 b.push(" AS ");
343 b.push_ident("_dbrst_t");
344
345 (mutation, b)
346}
347
348pub fn main_call(
378 call_plan: &CallPlan,
379 prefer_count: Option<PreferCount>,
380 max_rows: Option<i64>,
381 handler: Option<&crate::schema_cache::media_handler::MediaHandler>,
382 dialect: &dyn SqlDialect,
383) -> SqlBuilder {
384 let inner = builder::call_plan_to_query(call_plan, dialect);
385 let mut b = SqlBuilder::new();
386
387 b.push("WITH dbrst_source AS (");
389 b.push_builder(&inner);
390 b.push(")");
391
392 let has_exact_count = matches!(prefer_count, Some(PreferCount::Exact));
393
394 b.push(" SELECT ");
396
397 if has_exact_count {
399 dialect.count_star_from(&mut b, "dbrst_source");
400 } else {
401 b.push("NULL");
402 }
403 b.push(" AS total_result_set");
404
405 if call_plan.scalar {
407 b.push(", 1 AS page_total");
408 } else {
409 b.push(", ");
410 dialect.count_expr(&mut b, "_dbrst_t");
411 b.push(" AS page_total");
412 }
413
414 b.push(", ");
416 if call_plan.scalar {
417 dialect.row_to_json_star(&mut b, "dbrst_source");
419 } else if let Some(h) = handler {
420 fragment::handler_agg_with_media(&mut b, h, false, dialect);
421 } else {
422 fragment::handler_agg(&mut b, false, dialect);
423 }
424 b.push(" AS body");
425
426 b.push(", ");
428 dialect.get_session_var(&mut b, "response.headers", "response_headers");
429 b.push(", ");
430 dialect.get_session_var(&mut b, "response.status", "response_status");
431
432 if call_plan.scalar {
434 b.push(" FROM dbrst_source");
435 } else {
436 b.push(" FROM (SELECT * FROM dbrst_source");
437
438 if let Some(max) = max_rows {
439 b.push(" LIMIT ");
440 b.push(&max.to_string());
441 }
442
443 b.push(") AS ");
444 b.push_ident("_dbrst_t");
445 }
446
447 b
448}
449
450#[cfg(test)]
455mod tests {
456 use super::*;
457 use crate::api_request::types::Payload;
458 use crate::plan::call_plan::{CallArgs, CallParams, CallPlan};
459 use crate::plan::mutate_plan::{InsertPlan, MutatePlan};
460 use crate::plan::read_plan::{ReadPlan, ReadPlanTree};
461 use crate::plan::types::*;
462 use crate::test_helpers::TestPgDialect;
463 use crate::types::identifiers::QualifiedIdentifier;
464 use bytes::Bytes;
465 use smallvec::SmallVec;
466
467 fn dialect() -> &'static dyn SqlDialect {
468 &TestPgDialect
469 }
470
471 fn test_qi() -> QualifiedIdentifier {
472 QualifiedIdentifier::new("public", "users")
473 }
474
475 fn select_field(name: &str) -> CoercibleSelectField {
476 CoercibleSelectField {
477 field: CoercibleField::unknown(name.into(), SmallVec::new()),
478 agg_function: None,
479 agg_cast: None,
480 cast: None,
481 alias: None,
482 }
483 }
484
485 fn typed_field(name: &str, base_type: &str) -> CoercibleField {
486 CoercibleField::from_column(name.into(), SmallVec::new(), base_type.into())
487 }
488
489 #[test]
494 fn test_main_read_basic() {
495 let mut plan = ReadPlan::root(test_qi());
496 plan.select = vec![select_field("id"), select_field("name")];
497 let tree = ReadPlanTree::leaf(plan);
498
499 let b = main_read(&tree, None, None, false, None, dialect());
500 let sql = b.sql();
501
502 assert!(sql.starts_with("WITH dbrst_source AS ("));
503 assert!(sql.contains("AS total_result_set"));
504 assert!(sql.contains("AS page_total"));
505 assert!(sql.contains("AS body"));
506 assert!(sql.contains("AS response_headers"));
507 assert!(sql.contains("AS response_status"));
508 }
509
510 #[test]
511 fn test_main_read_with_exact_count() {
512 let plan = ReadPlan::root(test_qi());
513 let tree = ReadPlanTree::leaf(plan);
514
515 let b = main_read(
516 &tree,
517 Some(PreferCount::Exact),
518 None,
519 false,
520 None,
521 dialect(),
522 );
523 let sql = b.sql();
524
525 assert!(sql.contains("dbrst_count"));
526 assert!(sql.contains("dbrst_filtered_count"));
527 }
528
529 #[test]
530 fn test_main_read_headers_only() {
531 let plan = ReadPlan::root(test_qi());
532 let tree = ReadPlanTree::leaf(plan);
533
534 let b = main_read(&tree, None, None, true, None, dialect());
535 let sql = b.sql();
536
537 assert!(sql.contains("NULL AS body"));
538 }
539
540 #[test]
541 fn test_main_read_with_max_rows() {
542 let plan = ReadPlan::root(test_qi());
543 let tree = ReadPlanTree::leaf(plan);
544
545 let b = main_read(&tree, None, Some(100), false, None, dialect());
546 let sql = b.sql();
547
548 assert!(sql.contains("LIMIT 100"));
549 }
550
551 #[test]
556 fn test_main_write_basic() {
557 let mutate = MutatePlan::Insert(InsertPlan {
558 into: test_qi(),
559 columns: vec![typed_field("name", "text")],
560 body: Payload::RawJSON(Bytes::from(r#"[{"name":"test"}]"#)),
561 on_conflict: None,
562 where_: vec![],
563 returning: vec![select_field("id")],
564 pk_cols: vec!["id".into()],
565 apply_defaults: false,
566 });
567 let read = ReadPlanTree::leaf(ReadPlan::root(test_qi()));
568
569 let b = main_write(&mutate, &read, true, None, dialect());
570 let sql = b.sql();
571
572 assert!(sql.starts_with("WITH dbrst_source AS ("));
573 assert!(sql.contains("INSERT INTO"));
574 assert!(sql.contains("AS body"));
575 }
576
577 #[test]
578 fn test_main_write_no_representation() {
579 let mutate = MutatePlan::Insert(InsertPlan {
580 into: test_qi(),
581 columns: vec![],
582 body: Payload::RawJSON(Bytes::from("{}")),
583 on_conflict: None,
584 where_: vec![],
585 returning: vec![],
586 pk_cols: vec![],
587 apply_defaults: false,
588 });
589 let read = ReadPlanTree::leaf(ReadPlan::root(test_qi()));
590
591 let b = main_write(&mutate, &read, false, None, dialect());
592 let sql = b.sql();
593
594 assert!(sql.contains("NULL AS body"));
595 }
596
597 #[test]
602 fn test_main_call_basic() {
603 let call = CallPlan {
604 qi: QualifiedIdentifier::new("public", "get_time"),
605 params: CallParams::KeyParams(vec![]),
606 args: CallArgs::JsonArgs(None),
607 scalar: false,
608 set_of_scalar: false,
609 filter_fields: vec![],
610 returning: vec![],
611 };
612
613 let b = main_call(&call, None, None, None, dialect());
614 let sql = b.sql();
615
616 assert!(sql.starts_with("WITH dbrst_source AS ("));
617 assert!(sql.contains("get_time"));
618 assert!(sql.contains("AS body"));
619 }
620
621 #[test]
622 fn test_main_call_scalar() {
623 let call = CallPlan {
624 qi: QualifiedIdentifier::new("public", "add_numbers"),
625 params: CallParams::KeyParams(vec![]),
626 args: CallArgs::JsonArgs(None),
627 scalar: true,
628 set_of_scalar: false,
629 filter_fields: vec![],
630 returning: vec![],
631 };
632
633 let b = main_call(&call, None, None, None, dialect());
634 let sql = b.sql();
635
636 assert!(sql.contains("row_to_json"));
638 }
639
640 #[test]
641 fn test_main_call_with_count() {
642 let call = CallPlan {
643 qi: QualifiedIdentifier::new("public", "get_data"),
644 params: CallParams::KeyParams(vec![]),
645 args: CallArgs::JsonArgs(None),
646 scalar: false,
647 set_of_scalar: false,
648 filter_fields: vec![],
649 returning: vec![],
650 };
651
652 let b = main_call(&call, Some(PreferCount::Exact), None, None, dialect());
653 let sql = b.sql();
654
655 assert!(sql.contains("pg_catalog.count(*)"));
656 }
657
658 #[test]
659 fn test_main_call_with_max_rows() {
660 let call = CallPlan {
661 qi: QualifiedIdentifier::new("public", "get_data"),
662 params: CallParams::KeyParams(vec![]),
663 args: CallArgs::JsonArgs(None),
664 scalar: false,
665 set_of_scalar: false,
666 filter_fields: vec![],
667 returning: vec![],
668 };
669
670 let b = main_call(&call, None, Some(50), None, dialect());
671 let sql = b.sql();
672
673 assert!(sql.contains("LIMIT 50"));
674 }
675}