1use crate::param::SqlParam;
4use std::fmt::Write;
5
6#[derive(Clone, Debug, Default)]
12pub struct SqlFragment {
13 sql: String,
14 params: Vec<SqlParam>,
15}
16
17impl SqlFragment {
18 pub fn new() -> Self {
20 Self::default()
21 }
22
23 pub fn raw(sql: impl Into<String>) -> Self {
30 Self {
31 sql: sql.into(),
32 params: Vec::new(),
33 }
34 }
35
36 pub fn param(value: impl Into<SqlParam>) -> Self {
38 let mut frag = Self::new();
39 frag.push_param(value);
40 frag
41 }
42
43 pub fn sql(&self) -> &str {
45 &self.sql
46 }
47
48 pub fn params(&self) -> &[SqlParam] {
50 &self.params
51 }
52
53 pub fn param_count(&self) -> usize {
55 self.params.len()
56 }
57
58 pub fn is_empty(&self) -> bool {
60 self.sql.is_empty()
61 }
62
63 pub fn push(&mut self, sql: &str) -> &mut Self {
65 self.sql.push_str(sql);
66 self
67 }
68
69 pub fn push_char(&mut self, c: char) -> &mut Self {
71 self.sql.push(c);
72 self
73 }
74
75 pub fn push_param(&mut self, value: impl Into<SqlParam>) -> &mut Self {
77 let param_num = self.params.len() + 1;
78 write!(self.sql, "${}", param_num).unwrap();
79 self.params.push(value.into());
80 self
81 }
82
83 pub fn push_typed_param(&mut self, value: impl Into<SqlParam>, pg_type: &str) -> &mut Self {
85 let param_num = self.params.len() + 1;
86 write!(self.sql, "${}::{}", param_num, pg_type).unwrap();
87 self.params.push(value.into());
88 self
89 }
90
91 pub fn append(&mut self, other: SqlFragment) -> &mut Self {
96 let offset = self.params.len();
97
98 let renumbered_sql = renumber_params(&other.sql, offset);
100 self.sql.push_str(&renumbered_sql);
101 self.params.extend(other.params);
102 self
103 }
104
105 pub fn append_sep(&mut self, sep: &str, other: SqlFragment) -> &mut Self {
107 if !self.is_empty() && !other.is_empty() {
108 self.push(sep);
109 }
110 self.append(other)
111 }
112
113 pub fn join(sep: &str, fragments: impl IntoIterator<Item = SqlFragment>) -> Self {
115 let mut result = Self::new();
116 let mut first = true;
117
118 for frag in fragments {
119 if frag.is_empty() {
120 continue;
121 }
122 if !first {
123 result.push(sep);
124 }
125 result.append(frag);
126 first = false;
127 }
128
129 result
130 }
131
132 pub fn parens(mut self) -> Self {
134 self.sql = format!("({})", self.sql);
135 self
136 }
137
138 pub fn build(self) -> (String, Vec<SqlParam>) {
140 (self.sql, self.params)
141 }
142}
143
144fn renumber_params(sql: &str, offset: usize) -> String {
146 let mut result = String::with_capacity(sql.len());
147 let mut chars = sql.chars().peekable();
148
149 while let Some(c) = chars.next() {
150 if c == '$' {
151 let mut num_str = String::new();
153 while let Some(&next) = chars.peek() {
154 if next.is_ascii_digit() {
155 num_str.push(chars.next().unwrap());
156 } else {
157 break;
158 }
159 }
160
161 if let Ok(num) = num_str.parse::<usize>() {
162 write!(result, "${}", num + offset).unwrap();
163 } else {
164 result.push('$');
165 result.push_str(&num_str);
166 }
167 } else {
168 result.push(c);
169 }
170 }
171
172 result
173}
174
175pub trait SqlBuilder {
177 fn build_sql(&self) -> SqlFragment;
179}
180
181impl SqlBuilder for SqlFragment {
182 fn build_sql(&self) -> SqlFragment {
183 self.clone()
184 }
185}
186
187impl SqlBuilder for &str {
188 fn build_sql(&self) -> SqlFragment {
189 SqlFragment::raw(*self)
190 }
191}
192
193impl SqlBuilder for String {
194 fn build_sql(&self) -> SqlFragment {
195 SqlFragment::raw(self.clone())
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202
203 #[test]
204 fn test_sql_fragment_raw() {
205 let frag = SqlFragment::raw("SELECT * FROM users");
206 assert_eq!(frag.sql(), "SELECT * FROM users");
207 assert!(frag.params().is_empty());
208 }
209
210 #[test]
211 fn test_sql_fragment_param() {
212 let mut frag = SqlFragment::new();
213 frag.push("SELECT * FROM users WHERE id = ");
214 frag.push_param(42i64);
215
216 assert_eq!(frag.sql(), "SELECT * FROM users WHERE id = $1");
217 assert_eq!(frag.params().len(), 1);
218 }
219
220 #[test]
221 fn test_sql_fragment_append() {
222 let mut frag1 = SqlFragment::new();
223 frag1.push("SELECT * FROM users WHERE id = ");
224 frag1.push_param(42i64);
225
226 let mut frag2 = SqlFragment::new();
227 frag2.push(" AND name = ");
228 frag2.push_param("John");
229
230 frag1.append(frag2);
231
232 assert_eq!(
233 frag1.sql(),
234 "SELECT * FROM users WHERE id = $1 AND name = $2"
235 );
236 assert_eq!(frag1.params().len(), 2);
237 }
238
239 #[test]
240 fn test_sql_fragment_join() {
241 let frags = vec![
242 SqlFragment::raw("a = $1").push_param(1i64).clone(),
243 SqlFragment::raw("b = $1").push_param(2i64).clone(),
244 SqlFragment::raw("c = $1").push_param(3i64).clone(),
245 ];
246
247 let joined = SqlFragment::join(" AND ", frags.into_iter().map(|mut f| {
248 f.params.clear();
249 f.push_param(1i64);
250 f
251 }));
252
253 }
255
256 #[test]
257 fn test_renumber_params() {
258 assert_eq!(renumber_params("$1", 2), "$3");
259 assert_eq!(renumber_params("$1 AND $2", 5), "$6 AND $7");
260 assert_eq!(renumber_params("no params", 5), "no params");
261 }
262
263 #[test]
264 fn test_sql_fragment_parens() {
265 let frag = SqlFragment::raw("a OR b").parens();
266 assert_eq!(frag.sql(), "(a OR b)");
267 }
268}