postrust_sql/
builder.rs

1//! Core SQL builder types.
2
3use crate::param::SqlParam;
4use std::fmt::Write;
5
6/// A SQL fragment with its associated parameters.
7///
8/// This is the core type for building SQL queries safely. It maintains
9/// a SQL string with parameter placeholders ($1, $2, etc.) and a vector
10/// of parameter values.
11#[derive(Clone, Debug, Default)]
12pub struct SqlFragment {
13    sql: String,
14    params: Vec<SqlParam>,
15}
16
17impl SqlFragment {
18    /// Create a new empty SQL fragment.
19    pub fn new() -> Self {
20        Self::default()
21    }
22
23    /// Create a SQL fragment from raw SQL (no parameters).
24    ///
25    /// # Warning
26    ///
27    /// Only use this for known-safe SQL strings (e.g., keywords, operators).
28    /// Never use this with user input.
29    pub fn raw(sql: impl Into<String>) -> Self {
30        Self {
31            sql: sql.into(),
32            params: Vec::new(),
33        }
34    }
35
36    /// Create a SQL fragment with a single parameter.
37    pub fn param(value: impl Into<SqlParam>) -> Self {
38        let mut frag = Self::new();
39        frag.push_param(value);
40        frag
41    }
42
43    /// Get the SQL string.
44    pub fn sql(&self) -> &str {
45        &self.sql
46    }
47
48    /// Get the parameters.
49    pub fn params(&self) -> &[SqlParam] {
50        &self.params
51    }
52
53    /// Get the current parameter count.
54    pub fn param_count(&self) -> usize {
55        self.params.len()
56    }
57
58    /// Check if the fragment is empty.
59    pub fn is_empty(&self) -> bool {
60        self.sql.is_empty()
61    }
62
63    /// Push raw SQL (no parameters).
64    pub fn push(&mut self, sql: &str) -> &mut Self {
65        self.sql.push_str(sql);
66        self
67    }
68
69    /// Push a character.
70    pub fn push_char(&mut self, c: char) -> &mut Self {
71        self.sql.push(c);
72        self
73    }
74
75    /// Push a parameter and its placeholder.
76    pub fn push_param(&mut self, value: impl Into<SqlParam>) -> &mut Self {
77        let param_num = self.params.len() + 1;
78        write!(self.sql, "${}", param_num).unwrap();
79        self.params.push(value.into());
80        self
81    }
82
83    /// Push a typed parameter with explicit cast.
84    pub fn push_typed_param(&mut self, value: impl Into<SqlParam>, pg_type: &str) -> &mut Self {
85        let param_num = self.params.len() + 1;
86        write!(self.sql, "${}::{}", param_num, pg_type).unwrap();
87        self.params.push(value.into());
88        self
89    }
90
91    /// Append another SQL fragment.
92    ///
93    /// This renumbers the parameters in the appended fragment to continue
94    /// from the current count.
95    pub fn append(&mut self, other: SqlFragment) -> &mut Self {
96        let offset = self.params.len();
97
98        // Renumber parameters in the other fragment
99        let renumbered_sql = renumber_params(&other.sql, offset);
100        self.sql.push_str(&renumbered_sql);
101        self.params.extend(other.params);
102        self
103    }
104
105    /// Append with a separator if not empty.
106    pub fn append_sep(&mut self, sep: &str, other: SqlFragment) -> &mut Self {
107        if !self.is_empty() && !other.is_empty() {
108            self.push(sep);
109        }
110        self.append(other)
111    }
112
113    /// Join multiple fragments with a separator.
114    pub fn join(sep: &str, fragments: impl IntoIterator<Item = SqlFragment>) -> Self {
115        let mut result = Self::new();
116        let mut first = true;
117
118        for frag in fragments {
119            if frag.is_empty() {
120                continue;
121            }
122            if !first {
123                result.push(sep);
124            }
125            result.append(frag);
126            first = false;
127        }
128
129        result
130    }
131
132    /// Wrap in parentheses.
133    pub fn parens(mut self) -> Self {
134        self.sql = format!("({})", self.sql);
135        self
136    }
137
138    /// Build the final SQL and parameters.
139    pub fn build(self) -> (String, Vec<SqlParam>) {
140        (self.sql, self.params)
141    }
142}
143
144/// Renumber parameter placeholders in a SQL string.
145fn renumber_params(sql: &str, offset: usize) -> String {
146    let mut result = String::with_capacity(sql.len());
147    let mut chars = sql.chars().peekable();
148
149    while let Some(c) = chars.next() {
150        if c == '$' {
151            // Parse the parameter number
152            let mut num_str = String::new();
153            while let Some(&next) = chars.peek() {
154                if next.is_ascii_digit() {
155                    num_str.push(chars.next().unwrap());
156                } else {
157                    break;
158                }
159            }
160
161            if let Ok(num) = num_str.parse::<usize>() {
162                write!(result, "${}", num + offset).unwrap();
163            } else {
164                result.push('$');
165                result.push_str(&num_str);
166            }
167        } else {
168            result.push(c);
169        }
170    }
171
172    result
173}
174
175/// Trait for types that can be converted to SQL fragments.
176pub trait SqlBuilder {
177    /// Build the SQL fragment for this type.
178    fn build_sql(&self) -> SqlFragment;
179}
180
181impl SqlBuilder for SqlFragment {
182    fn build_sql(&self) -> SqlFragment {
183        self.clone()
184    }
185}
186
187impl SqlBuilder for &str {
188    fn build_sql(&self) -> SqlFragment {
189        SqlFragment::raw(*self)
190    }
191}
192
193impl SqlBuilder for String {
194    fn build_sql(&self) -> SqlFragment {
195        SqlFragment::raw(self.clone())
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202
203    #[test]
204    fn test_sql_fragment_raw() {
205        let frag = SqlFragment::raw("SELECT * FROM users");
206        assert_eq!(frag.sql(), "SELECT * FROM users");
207        assert!(frag.params().is_empty());
208    }
209
210    #[test]
211    fn test_sql_fragment_param() {
212        let mut frag = SqlFragment::new();
213        frag.push("SELECT * FROM users WHERE id = ");
214        frag.push_param(42i64);
215
216        assert_eq!(frag.sql(), "SELECT * FROM users WHERE id = $1");
217        assert_eq!(frag.params().len(), 1);
218    }
219
220    #[test]
221    fn test_sql_fragment_append() {
222        let mut frag1 = SqlFragment::new();
223        frag1.push("SELECT * FROM users WHERE id = ");
224        frag1.push_param(42i64);
225
226        let mut frag2 = SqlFragment::new();
227        frag2.push(" AND name = ");
228        frag2.push_param("John");
229
230        frag1.append(frag2);
231
232        assert_eq!(
233            frag1.sql(),
234            "SELECT * FROM users WHERE id = $1 AND name = $2"
235        );
236        assert_eq!(frag1.params().len(), 2);
237    }
238
239    #[test]
240    fn test_sql_fragment_join() {
241        let frags = vec![
242            SqlFragment::raw("a = $1").push_param(1i64).clone(),
243            SqlFragment::raw("b = $1").push_param(2i64).clone(),
244            SqlFragment::raw("c = $1").push_param(3i64).clone(),
245        ];
246
247        let joined = SqlFragment::join(" AND ", frags.into_iter().map(|mut f| {
248            f.params.clear();
249            f.push_param(1i64);
250            f
251        }));
252
253        // Note: This test shows the renumbering behavior
254    }
255
256    #[test]
257    fn test_renumber_params() {
258        assert_eq!(renumber_params("$1", 2), "$3");
259        assert_eq!(renumber_params("$1 AND $2", 5), "$6 AND $7");
260        assert_eq!(renumber_params("no params", 5), "no params");
261    }
262
263    #[test]
264    fn test_sql_fragment_parens() {
265        let frag = SqlFragment::raw("a OR b").parens();
266        assert_eq!(frag.sql(), "(a OR b)");
267    }
268}