Skip to main content

dbrest_core/query/
pre_query.rs

1//! Transaction and session variable setup queries.
2//!
3//! Before executing the main SQL query, dbrest sets PostgreSQL session
4//! variables via `set_config()` to communicate request context to
5//! database functions and triggers. This module generates those setup
6//! queries.
7//!
8//! # Pipeline
9//!
10//! ```text
11//! HTTP request ──▶ pre_req_query() ──▶ SET search_path, role, claims …
12//!               ──▶ tx_var_query()  ──▶ SET method, path, headers, cookies …
13//! ```
14//!
15//! # SQL Example
16//!
17//! ```sql
18//! SELECT
19//!   set_config('search_path', '"test_api", "public"', true),
20//!   set_config('role', 'web_anon', true),
21//!   set_config('request.method', 'GET', true),
22//!   set_config('request.path', '/users', true),
23//!   set_config('request.headers', '{"accept":"application/json"}', true),
24//!   set_config('request.cookies', '{}', true)
25//! ```
26
27use crate::backend::SqlDialect;
28use crate::config::AppConfig;
29use crate::types::identifiers::QualifiedIdentifier;
30
31use super::sql_builder::SqlBuilder;
32
33// ==========================================================================
34// tx_var_query — session variable setup
35// ==========================================================================
36
37/// Build the session variable setup query.
38///
39/// Generates a `SELECT set_config(...)` call for each session variable that
40/// must be set before executing the main query. These variables are available
41/// to PostgreSQL functions, triggers, and RLS policies via
42/// `current_setting('variable.name')`.
43///
44/// # Behaviour
45///
46/// The following variables are set:
47/// - `search_path` — from the configured schemas
48/// - `role` — the anonymous role or authenticated role
49/// - `request.method` — HTTP method (GET, POST, etc.)
50/// - `request.path` — URL path
51/// - `request.headers` — serialized request headers as JSON
52/// - `request.cookies` — serialized cookies as JSON
53/// - `request.jwt.claims` — JWT claims as JSON
54///
55/// All values are set with `is_local = true` so they only apply to the
56/// current transaction.
57///
58/// # SQL Example
59///
60/// ```sql
61/// SELECT
62///   set_config('search_path', '"test_api", "public"', true),
63///   set_config('role', 'web_anon', true),
64///   set_config('request.method', 'GET', true),
65///   set_config('request.path', '/users', true)
66/// ```
67#[allow(clippy::too_many_arguments)]
68pub fn tx_var_query(
69    config: &AppConfig,
70    dialect: &dyn SqlDialect,
71    method: &str,
72    path: &str,
73    role: Option<&str>,
74    headers_json: Option<&str>,
75    cookies_json: Option<&str>,
76    claims_json: Option<&str>,
77) -> SqlBuilder {
78    // Collect all key-value pairs
79    let search_path = config
80        .db_schemas
81        .iter()
82        .map(|s| format!("\"{}\"", s))
83        .collect::<Vec<_>>()
84        .join(", ");
85
86    let mut vars: Vec<(&str, String)> = Vec::new();
87    vars.push(("search_path", search_path));
88
89    let effective_role = role
90        .map(|r| r.to_string())
91        .or_else(|| config.db_anon_role.clone());
92    if let Some(ref role_val) = effective_role {
93        vars.push(("role", role_val.clone()));
94    }
95
96    vars.push(("request.method", method.to_string()));
97    vars.push(("request.path", path.to_string()));
98
99    if let Some(headers) = headers_json {
100        vars.push(("request.headers", headers.to_string()));
101    }
102    if let Some(cookies) = cookies_json {
103        vars.push(("request.cookies", cookies.to_string()));
104    }
105    if let Some(claims) = claims_json {
106        vars.push(("request.jwt.claims", claims.to_string()));
107    }
108
109    let mut b = SqlBuilder::new();
110
111    if dialect.session_vars_are_select_exprs() {
112        // PostgreSQL-style: SELECT set_config(...), set_config(...), ...
113        b.push("SELECT ");
114        let mut first = true;
115        for (key, value) in &vars {
116            push_set_var(&mut b, dialect, key, value, &mut first);
117        }
118    } else {
119        // Batch-style: dialect produces a single statement for all vars
120        let refs: Vec<(&str, &str)> = vars.iter().map(|(k, v)| (*k, v.as_str())).collect();
121        dialect.build_tx_vars_statement(&mut b, &refs);
122    }
123
124    b
125}
126
127/// Append a session variable assignment expression via the dialect.
128///
129/// # Behaviour
130///
131/// - Prepends a comma separator when this is not the first call
132/// - Delegates to `dialect.set_session_var()` for database-specific syntax
133///
134/// # SQL Example (PostgreSQL)
135///
136/// ```sql
137/// set_config('request.method', 'GET', true)
138/// ```
139fn push_set_var(
140    b: &mut SqlBuilder,
141    dialect: &dyn SqlDialect,
142    key: &str,
143    value: &str,
144    first: &mut bool,
145) {
146    if !*first {
147        b.push(", ");
148    }
149    *first = false;
150    dialect.set_session_var(b, key, value);
151}
152
153// ==========================================================================
154// pre_req_query — pre-request function call
155// ==========================================================================
156
157/// Build the pre-request function call query.
158///
159/// If the configuration specifies a `db_pre_request` function, this generates
160/// a `SELECT` call to that function. The function is invoked after session
161/// variables are set but before the main query, allowing it to perform
162/// custom authorization checks or request validation.
163///
164/// Returns `None` if no pre-request function is configured.
165///
166/// # SQL Example
167///
168/// ```sql
169/// SELECT "my_schema"."check_request"()
170/// ```
171pub fn pre_req_query(pre_request: &QualifiedIdentifier) -> SqlBuilder {
172    let mut b = SqlBuilder::new();
173    b.push("SELECT ");
174    b.push_qi(pre_request);
175    b.push("()");
176    b
177}
178
179// ==========================================================================
180// Tests
181// ==========================================================================
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186    use crate::test_helpers::TestPgDialect;
187
188    fn test_config() -> AppConfig {
189        let mut config = AppConfig::default();
190        config.db_schemas = vec!["test_api".to_string(), "public".to_string()];
191        config.db_anon_role = Some("web_anon".to_string());
192        config
193    }
194
195    fn dialect() -> &'static dyn SqlDialect {
196        &TestPgDialect
197    }
198
199    #[test]
200    fn test_tx_var_query_basic() {
201        let config = test_config();
202        let b = tx_var_query(&config, dialect(), "GET", "/users", None, None, None, None);
203        let sql = b.sql();
204
205        assert!(sql.starts_with("SELECT set_config("));
206        assert!(sql.contains("search_path"));
207        assert!(sql.contains("request.method"));
208        assert!(sql.contains("request.path"));
209        assert!(sql.contains("'GET'"));
210        assert!(sql.contains("'/users'"));
211    }
212
213    #[test]
214    fn test_tx_var_query_with_role() {
215        let config = test_config();
216        let b = tx_var_query(
217            &config,
218            dialect(),
219            "POST",
220            "/items",
221            Some("admin"),
222            None,
223            None,
224            None,
225        );
226        let sql = b.sql();
227
228        assert!(sql.contains("'role'"));
229        assert!(sql.contains("'admin'"));
230    }
231
232    #[test]
233    fn test_tx_var_query_with_headers() {
234        let config = test_config();
235        let b = tx_var_query(
236            &config,
237            dialect(),
238            "GET",
239            "/users",
240            None,
241            Some(r#"{"accept":"application/json"}"#),
242            None,
243            None,
244        );
245        let sql = b.sql();
246
247        assert!(sql.contains("request.headers"));
248        assert!(sql.contains("application/json"));
249    }
250
251    #[test]
252    fn test_tx_var_query_with_claims() {
253        let config = test_config();
254        let b = tx_var_query(
255            &config,
256            dialect(),
257            "GET",
258            "/users",
259            None,
260            None,
261            None,
262            Some(r#"{"sub":"user123"}"#),
263        );
264        let sql = b.sql();
265
266        assert!(sql.contains("request.jwt.claims"));
267    }
268
269    #[test]
270    fn test_pre_req_query() {
271        let qi = QualifiedIdentifier::new("my_schema", "check_request");
272        let b = pre_req_query(&qi);
273        assert_eq!(b.sql(), "SELECT \"my_schema\".\"check_request\"()");
274    }
275
276    #[test]
277    fn test_pre_req_query_unqualified() {
278        let qi = QualifiedIdentifier::unqualified("pre_request_check");
279        let b = pre_req_query(&qi);
280        assert_eq!(b.sql(), "SELECT \"pre_request_check\"()");
281    }
282
283    #[test]
284    fn test_tx_var_query_search_path_format() {
285        let config = test_config();
286        let b = tx_var_query(&config, dialect(), "GET", "/", None, None, None, None);
287        let sql = b.sql();
288
289        // Should contain quoted schema names
290        assert!(sql.contains("\"test_api\", \"public\""));
291    }
292}