pglite_oxide/pglite/
templating.rs1use anyhow::{Result, anyhow};
2use regex::Regex;
3use serde_json::Value;
4use std::sync::LazyLock;
5
6use crate::pglite::client::Pglite;
7use crate::pglite::interface::QueryOptions;
8use crate::pglite::types::TEXT;
9
10#[derive(Debug, Clone)]
11pub struct TemplatedQuery {
12 pub query: String,
13 pub params: Vec<Value>,
14}
15
16#[derive(Debug, Default, Clone)]
17pub struct QueryTemplate {
18 sql: String,
19 params: Vec<Value>,
20}
21
22impl QueryTemplate {
23 pub fn new() -> Self {
24 Self::default()
25 }
26
27 pub fn push_sql(&mut self, sql: impl AsRef<str>) {
28 self.sql.push_str(sql.as_ref());
29 }
30
31 pub fn push_raw(&mut self, sql: impl AsRef<str>) {
32 self.push_sql(sql);
33 }
34
35 pub fn push_identifier(&mut self, identifier: &str) {
36 self.sql.push_str("e_identifier(identifier));
37 }
38
39 pub fn push_param(&mut self, value: Value) {
40 let placeholder = format!("${}", self.params.len() + 1);
41 self.sql.push_str(&placeholder);
42 self.params.push(value);
43 }
44
45 pub fn build(self) -> TemplatedQuery {
46 TemplatedQuery {
47 query: self.sql,
48 params: self.params,
49 }
50 }
51}
52
53static DOLLAR_RE: LazyLock<Regex> =
54 LazyLock::new(|| Regex::new(r"\$(\d+)").expect("invalid regex"));
55
56pub fn quote_identifier(ident: &str) -> String {
57 let escaped = ident.replace('"', "\"\"");
58 format!("\"{}\"", escaped)
59}
60
61pub fn format_query(pg: &mut Pglite, query: &str, params: &[Value]) -> Result<String> {
62 if params.is_empty() {
63 return Ok(query.to_string());
64 }
65
66 let described = pg.describe_query(query, None)?;
67 let data_type_ids = described
68 .query_params
69 .iter()
70 .map(|param| param.data_type_id)
71 .collect::<Vec<_>>();
72
73 let formatted = DOLLAR_RE
74 .replace_all(query, |caps: ®ex::Captures| format!("%{}L", &caps[1]))
75 .to_string();
76
77 let mut sql = String::from("SELECT format($1");
78 for idx in 0..params.len() {
79 sql.push_str(", $");
80 sql.push_str(&(idx as i32 + 2).to_string());
81 }
82 sql.push_str(") AS query");
83
84 let mut arguments: Vec<Value> = Vec::with_capacity(params.len() + 1);
85 arguments.push(Value::String(formatted));
86 arguments.extend(params.iter().cloned());
87
88 let mut param_types = Vec::with_capacity(arguments.len());
89 param_types.push(TEXT);
90 param_types
91 .extend((0..params.len()).map(|idx| data_type_ids.get(idx).copied().unwrap_or(TEXT)));
92 let options = QueryOptions {
93 param_types,
94 ..QueryOptions::default()
95 };
96
97 let results = pg.query(&sql, &arguments, Some(&options))?;
98 let row = results
99 .rows
100 .first()
101 .ok_or_else(|| anyhow!("format query returned no rows"))?;
102 if let Value::Object(map) = row
103 && let Some(Value::String(formatted)) = map.get("query")
104 {
105 return Ok(formatted.clone());
106 }
107
108 Err(anyhow!("unexpected format query result"))
109}
110
111#[cfg(test)]
112mod tests {
113 use super::{QueryTemplate, quote_identifier};
114 use serde_json::json;
115
116 #[test]
117 fn template_builder_adds_params() {
118 let mut tpl = QueryTemplate::new();
119 tpl.push_sql("SELECT ");
120 tpl.push_identifier("foo");
121 tpl.push_sql(" WHERE id = ");
122 tpl.push_param(json!(42));
123 let built = tpl.build();
124 assert_eq!(built.query, "SELECT \"foo\" WHERE id = $1");
125 assert_eq!(built.params.len(), 1);
126 }
127
128 #[test]
129 fn quote_identifier_escapes_quotes() {
130 assert_eq!(quote_identifier("Foo"), "\"Foo\"");
131 assert_eq!(quote_identifier("a\"b"), "\"a\"\"b\"");
132 }
133}