Skip to main content

pglite_oxide/pglite/
templating.rs

1use anyhow::{Result, anyhow};
2use regex::Regex;
3use serde_json::Value;
4use std::sync::LazyLock;
5
6use crate::pglite::client::Pglite;
7use crate::pglite::interface::QueryOptions;
8use crate::pglite::types::TEXT;
9
10#[derive(Debug, Clone)]
11pub struct TemplatedQuery {
12    pub query: String,
13    pub params: Vec<Value>,
14}
15
16#[derive(Debug, Default, Clone)]
17pub struct QueryTemplate {
18    sql: String,
19    params: Vec<Value>,
20}
21
22impl QueryTemplate {
23    pub fn new() -> Self {
24        Self::default()
25    }
26
27    pub fn push_sql(&mut self, sql: impl AsRef<str>) {
28        self.sql.push_str(sql.as_ref());
29    }
30
31    pub fn push_raw(&mut self, sql: impl AsRef<str>) {
32        self.push_sql(sql);
33    }
34
35    pub fn push_identifier(&mut self, identifier: &str) {
36        self.sql.push_str(&quote_identifier(identifier));
37    }
38
39    pub fn push_param(&mut self, value: Value) {
40        let placeholder = format!("${}", self.params.len() + 1);
41        self.sql.push_str(&placeholder);
42        self.params.push(value);
43    }
44
45    pub fn build(self) -> TemplatedQuery {
46        TemplatedQuery {
47            query: self.sql,
48            params: self.params,
49        }
50    }
51}
52
53static DOLLAR_RE: LazyLock<Regex> =
54    LazyLock::new(|| Regex::new(r"\$(\d+)").expect("invalid regex"));
55
56pub fn quote_identifier(ident: &str) -> String {
57    let escaped = ident.replace('"', "\"\"");
58    format!("\"{}\"", escaped)
59}
60
61pub fn format_query(pg: &mut Pglite, query: &str, params: &[Value]) -> Result<String> {
62    if params.is_empty() {
63        return Ok(query.to_string());
64    }
65
66    let described = pg.describe_query(query, None)?;
67    let data_type_ids = described
68        .query_params
69        .iter()
70        .map(|param| param.data_type_id)
71        .collect::<Vec<_>>();
72
73    let formatted = DOLLAR_RE
74        .replace_all(query, |caps: &regex::Captures| format!("%{}L", &caps[1]))
75        .to_string();
76
77    let mut sql = String::from("SELECT format($1");
78    for idx in 0..params.len() {
79        sql.push_str(", $");
80        sql.push_str(&(idx as i32 + 2).to_string());
81    }
82    sql.push_str(") AS query");
83
84    let mut arguments: Vec<Value> = Vec::with_capacity(params.len() + 1);
85    arguments.push(Value::String(formatted));
86    arguments.extend(params.iter().cloned());
87
88    let mut param_types = Vec::with_capacity(arguments.len());
89    param_types.push(TEXT);
90    param_types
91        .extend((0..params.len()).map(|idx| data_type_ids.get(idx).copied().unwrap_or(TEXT)));
92    let options = QueryOptions {
93        param_types,
94        ..QueryOptions::default()
95    };
96
97    let results = pg.query(&sql, &arguments, Some(&options))?;
98    let row = results
99        .rows
100        .first()
101        .ok_or_else(|| anyhow!("format query returned no rows"))?;
102    if let Value::Object(map) = row
103        && let Some(Value::String(formatted)) = map.get("query")
104    {
105        return Ok(formatted.clone());
106    }
107
108    Err(anyhow!("unexpected format query result"))
109}
110
111#[cfg(test)]
112mod tests {
113    use super::{QueryTemplate, quote_identifier};
114    use serde_json::json;
115
116    #[test]
117    fn template_builder_adds_params() {
118        let mut tpl = QueryTemplate::new();
119        tpl.push_sql("SELECT ");
120        tpl.push_identifier("foo");
121        tpl.push_sql(" WHERE id = ");
122        tpl.push_param(json!(42));
123        let built = tpl.build();
124        assert_eq!(built.query, "SELECT \"foo\" WHERE id = $1");
125        assert_eq!(built.params.len(), 1);
126    }
127
128    #[test]
129    fn quote_identifier_escapes_quotes() {
130        assert_eq!(quote_identifier("Foo"), "\"Foo\"");
131        assert_eq!(quote_identifier("a\"b"), "\"a\"\"b\"");
132    }
133}