dbrest_core/query/
pre_query.rs1use crate::backend::SqlDialect;
28use crate::config::AppConfig;
29use crate::types::identifiers::QualifiedIdentifier;
30
31use super::sql_builder::SqlBuilder;
32
33#[allow(clippy::too_many_arguments)]
68pub fn tx_var_query(
69 config: &AppConfig,
70 dialect: &dyn SqlDialect,
71 method: &str,
72 path: &str,
73 role: Option<&str>,
74 headers_json: Option<&str>,
75 cookies_json: Option<&str>,
76 claims_json: Option<&str>,
77) -> SqlBuilder {
78 let search_path = config
80 .db_schemas
81 .iter()
82 .map(|s| format!("\"{}\"", s))
83 .collect::<Vec<_>>()
84 .join(", ");
85
86 let mut vars: Vec<(&str, String)> = Vec::new();
87 vars.push(("search_path", search_path));
88
89 let effective_role = role
90 .map(|r| r.to_string())
91 .or_else(|| config.db_anon_role.clone());
92 if let Some(ref role_val) = effective_role {
93 vars.push(("role", role_val.clone()));
94 }
95
96 vars.push(("request.method", method.to_string()));
97 vars.push(("request.path", path.to_string()));
98
99 if let Some(headers) = headers_json {
100 vars.push(("request.headers", headers.to_string()));
101 }
102 if let Some(cookies) = cookies_json {
103 vars.push(("request.cookies", cookies.to_string()));
104 }
105 if let Some(claims) = claims_json {
106 vars.push(("request.jwt.claims", claims.to_string()));
107 }
108
109 let mut b = SqlBuilder::new();
110
111 if dialect.session_vars_are_select_exprs() {
112 b.push("SELECT ");
114 let mut first = true;
115 for (key, value) in &vars {
116 push_set_var(&mut b, dialect, key, value, &mut first);
117 }
118 } else {
119 let refs: Vec<(&str, &str)> = vars.iter().map(|(k, v)| (*k, v.as_str())).collect();
121 dialect.build_tx_vars_statement(&mut b, &refs);
122 }
123
124 b
125}
126
127fn push_set_var(
140 b: &mut SqlBuilder,
141 dialect: &dyn SqlDialect,
142 key: &str,
143 value: &str,
144 first: &mut bool,
145) {
146 if !*first {
147 b.push(", ");
148 }
149 *first = false;
150 dialect.set_session_var(b, key, value);
151}
152
153pub fn pre_req_query(pre_request: &QualifiedIdentifier) -> SqlBuilder {
172 let mut b = SqlBuilder::new();
173 b.push("SELECT ");
174 b.push_qi(pre_request);
175 b.push("()");
176 b
177}
178
179#[cfg(test)]
184mod tests {
185 use super::*;
186 use crate::test_helpers::TestPgDialect;
187
188 fn test_config() -> AppConfig {
189 let mut config = AppConfig::default();
190 config.db_schemas = vec!["test_api".to_string(), "public".to_string()];
191 config.db_anon_role = Some("web_anon".to_string());
192 config
193 }
194
195 fn dialect() -> &'static dyn SqlDialect {
196 &TestPgDialect
197 }
198
199 #[test]
200 fn test_tx_var_query_basic() {
201 let config = test_config();
202 let b = tx_var_query(&config, dialect(), "GET", "/users", None, None, None, None);
203 let sql = b.sql();
204
205 assert!(sql.starts_with("SELECT set_config("));
206 assert!(sql.contains("search_path"));
207 assert!(sql.contains("request.method"));
208 assert!(sql.contains("request.path"));
209 assert!(sql.contains("'GET'"));
210 assert!(sql.contains("'/users'"));
211 }
212
213 #[test]
214 fn test_tx_var_query_with_role() {
215 let config = test_config();
216 let b = tx_var_query(
217 &config,
218 dialect(),
219 "POST",
220 "/items",
221 Some("admin"),
222 None,
223 None,
224 None,
225 );
226 let sql = b.sql();
227
228 assert!(sql.contains("'role'"));
229 assert!(sql.contains("'admin'"));
230 }
231
232 #[test]
233 fn test_tx_var_query_with_headers() {
234 let config = test_config();
235 let b = tx_var_query(
236 &config,
237 dialect(),
238 "GET",
239 "/users",
240 None,
241 Some(r#"{"accept":"application/json"}"#),
242 None,
243 None,
244 );
245 let sql = b.sql();
246
247 assert!(sql.contains("request.headers"));
248 assert!(sql.contains("application/json"));
249 }
250
251 #[test]
252 fn test_tx_var_query_with_claims() {
253 let config = test_config();
254 let b = tx_var_query(
255 &config,
256 dialect(),
257 "GET",
258 "/users",
259 None,
260 None,
261 None,
262 Some(r#"{"sub":"user123"}"#),
263 );
264 let sql = b.sql();
265
266 assert!(sql.contains("request.jwt.claims"));
267 }
268
269 #[test]
270 fn test_pre_req_query() {
271 let qi = QualifiedIdentifier::new("my_schema", "check_request");
272 let b = pre_req_query(&qi);
273 assert_eq!(b.sql(), "SELECT \"my_schema\".\"check_request\"()");
274 }
275
276 #[test]
277 fn test_pre_req_query_unqualified() {
278 let qi = QualifiedIdentifier::unqualified("pre_request_check");
279 let b = pre_req_query(&qi);
280 assert_eq!(b.sql(), "SELECT \"pre_request_check\"()");
281 }
282
283 #[test]
284 fn test_tx_var_query_search_path_format() {
285 let config = test_config();
286 let b = tx_var_query(&config, dialect(), "GET", "/", None, None, None, None);
287 let sql = b.sql();
288
289 assert!(sql.contains("\"test_api\", \"public\""));
291 }
292}