Skip to main content

dbrest_core/query/
statements.rs

1//! Final SQL statement assembly.
2//!
3//! Wraps the SQL produced by [`super::builder`] into CTE-based final statements
4//! that return a uniform response shape: total count, page count, body,
5//! response headers, and response status. This module sits at the top of the
6//! SQL generation pipeline, right before the query is handed to the executor.
7//!
8//! # Pipeline
9//!
10//! ```text
11//! builder.rs (SELECT / INSERT / CALL …) ──▶ statements.rs (CTE wrapper) ──▶ executor
12//! ```
13//!
14//! # SQL Example
15//!
16//! ```sql
17//! WITH dbrst_source AS (
18//!   SELECT "public"."users"."id" AS "id", "public"."users"."name" AS "name"
19//!   FROM "public"."users"
20//! )
21//! SELECT
22//!   NULL AS total_result_set,
23//!   pg_catalog.count(_dbrst_t) AS page_total,
24//!   coalesce(json_agg(_dbrst_t), '[]')::text AS body,
25//!   nullif(current_setting('response.headers', true), '') AS response_headers,
26//!   nullif(current_setting('response.status', true), '') AS response_status
27//! FROM (SELECT * FROM dbrst_source) AS _dbrst_t
28//! ```
29
30use crate::api_request::preferences::PreferCount;
31use crate::backend::SqlDialect;
32use crate::plan::call_plan::CallPlan;
33use crate::plan::mutate_plan::MutatePlan;
34use crate::plan::read_plan::ReadPlanTree;
35
36use super::builder;
37use super::fragment;
38use super::sql_builder::SqlBuilder;
39
40// ==========================================================================
41// main_read — CTE wrapper for SELECT queries
42// ==========================================================================
43
44/// Build the final read statement with CTE wrapper.
45///
46/// Wraps a `ReadPlanTree` query in a CTE that returns the standard response
47/// shape: total count, page count, body, response headers, and response status.
48///
49/// # Behaviour
50///
51/// - If `prefer_count` is `Exact`, adds a count CTE for the total result set
52/// - If `prefer_count` is `Planned`, uses `EXPLAIN` row estimate
53/// - The `max_rows` config limit is applied as an additional cap
54/// - `headers_only` omits the body column value (for HEAD requests)
55/// - `handler` determines the output format (JSON, CSV, binary, etc.)
56///
57/// # SQL Example
58///
59/// ```sql
60/// WITH dbrst_source AS (
61///   SELECT … FROM "public"."users" WHERE …
62/// )
63/// SELECT
64///   NULL AS total_result_set,
65///   pg_catalog.count(_dbrst_t) AS page_total,
66///   coalesce(json_agg(_dbrst_t), '[]')::text AS body,
67///   nullif(current_setting('response.headers', true), '') AS response_headers,
68///   nullif(current_setting('response.status', true), '') AS response_status
69/// FROM (SELECT * FROM dbrst_source) AS _dbrst_t
70/// ```
71pub fn main_read(
72    read_plan: &ReadPlanTree,
73    prefer_count: Option<PreferCount>,
74    max_rows: Option<i64>,
75    headers_only: bool,
76    handler: Option<&crate::schema_cache::media_handler::MediaHandler>,
77    dialect: &dyn SqlDialect,
78) -> SqlBuilder {
79    let inner = builder::read_plan_to_query(read_plan, dialect);
80    let mut b = SqlBuilder::new();
81
82    // CTE: dbrst_source
83    b.push("WITH dbrst_source AS (");
84    b.push_builder(&inner);
85    b.push(")");
86
87    // Optional count CTE
88    let has_exact_count = matches!(prefer_count, Some(PreferCount::Exact));
89    if has_exact_count {
90        let count_q = builder::read_plan_to_count_query(read_plan, dialect);
91        b.push(", dbrst_count AS (");
92        b.push_builder(&count_q);
93        b.push(")");
94    }
95
96    // Main SELECT
97    b.push(" SELECT ");
98
99    // total_result_set
100    if has_exact_count {
101        b.push("(SELECT ");
102        b.push_ident("dbrst_filtered_count");
103        b.push(" FROM dbrst_count)");
104    } else {
105        b.push("NULL");
106    }
107    b.push(" AS total_result_set");
108
109    // page_total
110    b.push(", ");
111    dialect.count_expr(&mut b, "_dbrst_t");
112    b.push(" AS page_total");
113
114    // Extract column names for non-PG backends that need them
115    let col_names = select_column_names(read_plan);
116    let col_refs: Vec<&str> = col_names.iter().map(|s| s.as_str()).collect();
117
118    // body
119    if headers_only {
120        b.push(", NULL AS body");
121    } else {
122        b.push(", ");
123        if let Some(h) = handler {
124            fragment::handler_agg_with_media_cols(&mut b, h, false, dialect, &col_refs);
125        } else {
126            fragment::handler_agg_cols(&mut b, false, dialect, &col_refs);
127        }
128        b.push(" AS body");
129    }
130
131    // response_headers & response_status
132    b.push(", ");
133    dialect.get_session_var(&mut b, "response.headers", "response_headers");
134    b.push(", ");
135    dialect.get_session_var(&mut b, "response.status", "response_status");
136
137    // FROM dbrst_source
138    b.push(" FROM (SELECT * FROM dbrst_source");
139
140    // Apply max_rows if configured
141    if let Some(max) = max_rows {
142        b.push(" LIMIT ");
143        b.push(&max.to_string());
144    }
145
146    b.push(") AS ");
147    b.push_ident("_dbrst_t");
148
149    b
150}
151
152/// Extract the output column names from a read plan's select list.
153///
154/// These are the names that will appear in the `_dbrst_t` alias and
155/// are needed by backends that cannot aggregate a row alias (e.g. SQLite).
156/// Returns an empty vec for `SELECT *` plans (full_row / star selects).
157fn select_column_names(tree: &ReadPlanTree) -> Vec<String> {
158    // If any field is a full_row select (i.e. `*`), we can't enumerate
159    // individual column names — return empty to trigger fallback.
160    if tree.node.select.iter().any(|sf| sf.field.full_row) {
161        return Vec::new();
162    }
163    tree.node
164        .select
165        .iter()
166        .map(|sf| {
167            sf.alias
168                .as_ref()
169                .map(|a| a.to_string())
170                .unwrap_or_else(|| sf.field.name.to_string())
171        })
172        .collect()
173}
174
175// ==========================================================================
176// main_write — CTE wrapper for mutation queries
177// ==========================================================================
178
179/// Build the final mutation statement with CTE wrapper.
180///
181/// Wraps a `MutatePlan` query in a CTE, optionally adding a read sub-select
182/// for `Prefer: return=representation`.
183///
184/// # Behaviour
185///
186/// - The mutation CTE (`dbrst_source`) contains the INSERT/UPDATE/DELETE
187/// - If `return_representation` is true, the response body includes the
188///   returned rows as JSON
189/// - The location header expression is included for INSERT operations
190///
191/// # SQL Example
192///
193/// ```sql
194/// WITH dbrst_source AS (
195///   INSERT INTO "public"."users"("name") VALUES ($1) RETURNING "id", "name"
196/// )
197/// SELECT
198///   '' AS total_result_set,
199///   pg_catalog.count(_dbrst_t) AS page_total,
200///   coalesce(json_agg(_dbrst_t), '[]')::text AS body,
201///   nullif(current_setting('response.headers', true), '') AS response_headers,
202///   nullif(current_setting('response.status', true), '') AS response_status
203/// FROM (SELECT * FROM dbrst_source) AS _dbrst_t
204/// ```
205pub fn main_write(
206    mutate_plan: &MutatePlan,
207    _read_plan: &ReadPlanTree,
208    return_representation: bool,
209    handler: Option<&crate::schema_cache::media_handler::MediaHandler>,
210    dialect: &dyn SqlDialect,
211) -> SqlBuilder {
212    let inner = builder::mutate_plan_to_query(mutate_plan, dialect);
213    let has_returning = !mutate_plan.returning().is_empty();
214    let mut b = SqlBuilder::new();
215
216    b.push("WITH dbrst_source AS (");
217    b.push_builder(&inner);
218    if !has_returning {
219        b.push(" RETURNING 1");
220    }
221    b.push(")");
222
223    // Main SELECT
224    b.push(" SELECT ");
225
226    // total_result_set (mutations don't support count)
227    b.push("'' AS total_result_set");
228
229    // page_total
230    b.push(", ");
231    dialect.count_expr(&mut b, "_dbrst_t");
232    b.push(" AS page_total");
233
234    // Extract column names from the RETURNING clause for non-PG backends
235    let col_names: Vec<String> = mutate_plan
236        .returning()
237        .iter()
238        .map(|sf| {
239            sf.alias
240                .as_ref()
241                .map(|a| a.to_string())
242                .unwrap_or_else(|| sf.field.name.to_string())
243        })
244        .collect();
245    let col_refs: Vec<&str> = col_names.iter().map(|s| s.as_str()).collect();
246
247    // body
248    if return_representation && has_returning {
249        b.push(", ");
250        if let Some(h) = handler {
251            fragment::handler_agg_with_media_cols(&mut b, h, false, dialect, &col_refs);
252        } else {
253            fragment::handler_agg_cols(&mut b, false, dialect, &col_refs);
254        }
255        b.push(" AS body");
256    } else {
257        b.push(", NULL AS body");
258    }
259
260    // response_headers & response_status
261    b.push(", ");
262    dialect.get_session_var(&mut b, "response.headers", "response_headers");
263    b.push(", ");
264    dialect.get_session_var(&mut b, "response.status", "response_status");
265
266    // FROM dbrst_source
267    b.push(" FROM (SELECT * FROM dbrst_source) AS ");
268    b.push_ident("_dbrst_t");
269
270    b
271}
272
273/// Build split mutation + aggregation statements for backends without DML-in-CTE support.
274///
275/// Returns `(mutation, aggregation)` where:
276/// - `mutation` is the bare INSERT/UPDATE/DELETE with RETURNING
277/// - `aggregation` is a SELECT that aggregates rows from `_dbrst_mut` temp table
278///
279/// The executor is responsible for:
280/// 1. Creating `_dbrst_mut` temp table from the mutation RETURNING rows
281/// 2. Running the aggregation SELECT
282pub fn main_write_split(
283    mutate_plan: &MutatePlan,
284    _read_plan: &ReadPlanTree,
285    return_representation: bool,
286    handler: Option<&crate::schema_cache::media_handler::MediaHandler>,
287    dialect: &dyn SqlDialect,
288) -> (SqlBuilder, SqlBuilder) {
289    let mut mutation = builder::mutate_plan_to_query(mutate_plan, dialect);
290    let has_returning = !mutate_plan.returning().is_empty();
291    if !has_returning {
292        mutation.push(" RETURNING 1");
293    }
294
295    // Build the aggregation SELECT from _dbrst_mut
296    let mut b = SqlBuilder::new();
297    b.push("SELECT ");
298
299    // total_result_set (mutations don't support count)
300    b.push("'' AS total_result_set");
301
302    // page_total
303    b.push(", ");
304    dialect.count_expr(&mut b, "_dbrst_t");
305    b.push(" AS page_total");
306
307    // Extract column names from the RETURNING clause
308    let col_names: Vec<String> = mutate_plan
309        .returning()
310        .iter()
311        .map(|sf| {
312            sf.alias
313                .as_ref()
314                .map(|a| a.to_string())
315                .unwrap_or_else(|| sf.field.name.to_string())
316        })
317        .collect();
318    let col_refs: Vec<&str> = col_names.iter().map(|s| s.as_str()).collect();
319
320    // body
321    if return_representation && has_returning {
322        b.push(", ");
323        if let Some(h) = handler {
324            fragment::handler_agg_with_media_cols(&mut b, h, false, dialect, &col_refs);
325        } else {
326            fragment::handler_agg_cols(&mut b, false, dialect, &col_refs);
327        }
328        b.push(" AS body");
329    } else {
330        b.push(", NULL AS body");
331    }
332
333    // response_headers & response_status
334    b.push(", ");
335    dialect.get_session_var(&mut b, "response.headers", "response_headers");
336    b.push(", ");
337    dialect.get_session_var(&mut b, "response.status", "response_status");
338
339    // FROM _dbrst_mut (the temp table populated by the executor)
340    b.push(" FROM ");
341    b.push_ident("_dbrst_mut");
342    b.push(" AS ");
343    b.push_ident("_dbrst_t");
344
345    (mutation, b)
346}
347
348// ==========================================================================
349// main_call — CTE wrapper for function call queries
350// ==========================================================================
351
352/// Build the final function call statement with CTE wrapper.
353///
354/// Wraps a `CallPlan` query in a CTE. Handles both scalar and set-returning
355/// functions.
356///
357/// # Behaviour
358///
359/// - Scalar functions: body is a single JSON value
360/// - Set-returning functions: body is a JSON array
361/// - The count CTE is included when `prefer_count` is `Exact`
362///
363/// # SQL Example
364///
365/// ```sql
366/// WITH dbrst_source AS (
367///   SELECT * FROM "public"."get_users"()
368/// )
369/// SELECT
370///   NULL AS total_result_set,
371///   pg_catalog.count(_dbrst_t) AS page_total,
372///   coalesce(json_agg(_dbrst_t), '[]')::text AS body,
373///   nullif(current_setting('response.headers', true), '') AS response_headers,
374///   nullif(current_setting('response.status', true), '') AS response_status
375/// FROM (SELECT * FROM dbrst_source) AS _dbrst_t
376/// ```
377pub fn main_call(
378    call_plan: &CallPlan,
379    prefer_count: Option<PreferCount>,
380    max_rows: Option<i64>,
381    handler: Option<&crate::schema_cache::media_handler::MediaHandler>,
382    dialect: &dyn SqlDialect,
383) -> SqlBuilder {
384    let inner = builder::call_plan_to_query(call_plan, dialect);
385    let mut b = SqlBuilder::new();
386
387    // CTE: dbrst_source
388    b.push("WITH dbrst_source AS (");
389    b.push_builder(&inner);
390    b.push(")");
391
392    let has_exact_count = matches!(prefer_count, Some(PreferCount::Exact));
393
394    // Main SELECT
395    b.push(" SELECT ");
396
397    // total_result_set
398    if has_exact_count {
399        dialect.count_star_from(&mut b, "dbrst_source");
400    } else {
401        b.push("NULL");
402    }
403    b.push(" AS total_result_set");
404
405    // page_total
406    if call_plan.scalar {
407        b.push(", 1 AS page_total");
408    } else {
409        b.push(", ");
410        dialect.count_expr(&mut b, "_dbrst_t");
411        b.push(" AS page_total");
412    }
413
414    // body
415    b.push(", ");
416    if call_plan.scalar {
417        // Scalar function: convert the CTE row to JSON text
418        dialect.row_to_json_star(&mut b, "dbrst_source");
419    } else if let Some(h) = handler {
420        fragment::handler_agg_with_media(&mut b, h, false, dialect);
421    } else {
422        fragment::handler_agg(&mut b, false, dialect);
423    }
424    b.push(" AS body");
425
426    // response_headers & response_status
427    b.push(", ");
428    dialect.get_session_var(&mut b, "response.headers", "response_headers");
429    b.push(", ");
430    dialect.get_session_var(&mut b, "response.status", "response_status");
431
432    // FROM dbrst_source
433    if call_plan.scalar {
434        b.push(" FROM dbrst_source");
435    } else {
436        b.push(" FROM (SELECT * FROM dbrst_source");
437
438        if let Some(max) = max_rows {
439            b.push(" LIMIT ");
440            b.push(&max.to_string());
441        }
442
443        b.push(") AS ");
444        b.push_ident("_dbrst_t");
445    }
446
447    b
448}
449
450// ==========================================================================
451// Tests
452// ==========================================================================
453
454#[cfg(test)]
455mod tests {
456    use super::*;
457    use crate::api_request::types::Payload;
458    use crate::plan::call_plan::{CallArgs, CallParams, CallPlan};
459    use crate::plan::mutate_plan::{InsertPlan, MutatePlan};
460    use crate::plan::read_plan::{ReadPlan, ReadPlanTree};
461    use crate::plan::types::*;
462    use crate::test_helpers::TestPgDialect;
463    use crate::types::identifiers::QualifiedIdentifier;
464    use bytes::Bytes;
465    use smallvec::SmallVec;
466
467    fn dialect() -> &'static dyn SqlDialect {
468        &TestPgDialect
469    }
470
471    fn test_qi() -> QualifiedIdentifier {
472        QualifiedIdentifier::new("public", "users")
473    }
474
475    fn select_field(name: &str) -> CoercibleSelectField {
476        CoercibleSelectField {
477            field: CoercibleField::unknown(name.into(), SmallVec::new()),
478            agg_function: None,
479            agg_cast: None,
480            cast: None,
481            alias: None,
482        }
483    }
484
485    fn typed_field(name: &str, base_type: &str) -> CoercibleField {
486        CoercibleField::from_column(name.into(), SmallVec::new(), base_type.into())
487    }
488
489    // ------------------------------------------------------------------
490    // main_read tests
491    // ------------------------------------------------------------------
492
493    #[test]
494    fn test_main_read_basic() {
495        let mut plan = ReadPlan::root(test_qi());
496        plan.select = vec![select_field("id"), select_field("name")];
497        let tree = ReadPlanTree::leaf(plan);
498
499        let b = main_read(&tree, None, None, false, None, dialect());
500        let sql = b.sql();
501
502        assert!(sql.starts_with("WITH dbrst_source AS ("));
503        assert!(sql.contains("AS total_result_set"));
504        assert!(sql.contains("AS page_total"));
505        assert!(sql.contains("AS body"));
506        assert!(sql.contains("AS response_headers"));
507        assert!(sql.contains("AS response_status"));
508    }
509
510    #[test]
511    fn test_main_read_with_exact_count() {
512        let plan = ReadPlan::root(test_qi());
513        let tree = ReadPlanTree::leaf(plan);
514
515        let b = main_read(
516            &tree,
517            Some(PreferCount::Exact),
518            None,
519            false,
520            None,
521            dialect(),
522        );
523        let sql = b.sql();
524
525        assert!(sql.contains("dbrst_count"));
526        assert!(sql.contains("dbrst_filtered_count"));
527    }
528
529    #[test]
530    fn test_main_read_headers_only() {
531        let plan = ReadPlan::root(test_qi());
532        let tree = ReadPlanTree::leaf(plan);
533
534        let b = main_read(&tree, None, None, true, None, dialect());
535        let sql = b.sql();
536
537        assert!(sql.contains("NULL AS body"));
538    }
539
540    #[test]
541    fn test_main_read_with_max_rows() {
542        let plan = ReadPlan::root(test_qi());
543        let tree = ReadPlanTree::leaf(plan);
544
545        let b = main_read(&tree, None, Some(100), false, None, dialect());
546        let sql = b.sql();
547
548        assert!(sql.contains("LIMIT 100"));
549    }
550
551    // ------------------------------------------------------------------
552    // main_write tests
553    // ------------------------------------------------------------------
554
555    #[test]
556    fn test_main_write_basic() {
557        let mutate = MutatePlan::Insert(InsertPlan {
558            into: test_qi(),
559            columns: vec![typed_field("name", "text")],
560            body: Payload::RawJSON(Bytes::from(r#"[{"name":"test"}]"#)),
561            on_conflict: None,
562            where_: vec![],
563            returning: vec![select_field("id")],
564            pk_cols: vec!["id".into()],
565            apply_defaults: false,
566        });
567        let read = ReadPlanTree::leaf(ReadPlan::root(test_qi()));
568
569        let b = main_write(&mutate, &read, true, None, dialect());
570        let sql = b.sql();
571
572        assert!(sql.starts_with("WITH dbrst_source AS ("));
573        assert!(sql.contains("INSERT INTO"));
574        assert!(sql.contains("AS body"));
575    }
576
577    #[test]
578    fn test_main_write_no_representation() {
579        let mutate = MutatePlan::Insert(InsertPlan {
580            into: test_qi(),
581            columns: vec![],
582            body: Payload::RawJSON(Bytes::from("{}")),
583            on_conflict: None,
584            where_: vec![],
585            returning: vec![],
586            pk_cols: vec![],
587            apply_defaults: false,
588        });
589        let read = ReadPlanTree::leaf(ReadPlan::root(test_qi()));
590
591        let b = main_write(&mutate, &read, false, None, dialect());
592        let sql = b.sql();
593
594        assert!(sql.contains("NULL AS body"));
595    }
596
597    // ------------------------------------------------------------------
598    // main_call tests
599    // ------------------------------------------------------------------
600
601    #[test]
602    fn test_main_call_basic() {
603        let call = CallPlan {
604            qi: QualifiedIdentifier::new("public", "get_time"),
605            params: CallParams::KeyParams(vec![]),
606            args: CallArgs::JsonArgs(None),
607            scalar: false,
608            set_of_scalar: false,
609            filter_fields: vec![],
610            returning: vec![],
611        };
612
613        let b = main_call(&call, None, None, None, dialect());
614        let sql = b.sql();
615
616        assert!(sql.starts_with("WITH dbrst_source AS ("));
617        assert!(sql.contains("get_time"));
618        assert!(sql.contains("AS body"));
619    }
620
621    #[test]
622    fn test_main_call_scalar() {
623        let call = CallPlan {
624            qi: QualifiedIdentifier::new("public", "add_numbers"),
625            params: CallParams::KeyParams(vec![]),
626            args: CallArgs::JsonArgs(None),
627            scalar: true,
628            set_of_scalar: false,
629            filter_fields: vec![],
630            returning: vec![],
631        };
632
633        let b = main_call(&call, None, None, None, dialect());
634        let sql = b.sql();
635
636        // Scalar uses row_to_json instead of json_agg
637        assert!(sql.contains("row_to_json"));
638    }
639
640    #[test]
641    fn test_main_call_with_count() {
642        let call = CallPlan {
643            qi: QualifiedIdentifier::new("public", "get_data"),
644            params: CallParams::KeyParams(vec![]),
645            args: CallArgs::JsonArgs(None),
646            scalar: false,
647            set_of_scalar: false,
648            filter_fields: vec![],
649            returning: vec![],
650        };
651
652        let b = main_call(&call, Some(PreferCount::Exact), None, None, dialect());
653        let sql = b.sql();
654
655        assert!(sql.contains("pg_catalog.count(*)"));
656    }
657
658    #[test]
659    fn test_main_call_with_max_rows() {
660        let call = CallPlan {
661            qi: QualifiedIdentifier::new("public", "get_data"),
662            params: CallParams::KeyParams(vec![]),
663            args: CallArgs::JsonArgs(None),
664            scalar: false,
665            set_of_scalar: false,
666            filter_fields: vec![],
667            returning: vec![],
668        };
669
670        let b = main_call(&call, None, Some(50), None, dialect());
671        let sql = b.sql();
672
673        assert!(sql.contains("LIMIT 50"));
674    }
675}