1use std::cell::RefCell;
4use std::fmt;
5
6use indexmap::IndexMap;
7
8use crate::expr::{ColumnRef, Expr};
9use crate::stmt::*;
10use crate::{Ident, ParamName, RenderedSql, escape_string};
11
12struct ParamState {
14 params: IndexMap<ParamName, usize>,
16 next_param_idx: usize,
18}
19
20impl ParamState {
21 fn new() -> Self {
22 Self {
23 params: IndexMap::new(),
24 next_param_idx: 1,
25 }
26 }
27
28 fn get_or_insert(&mut self, name: &ParamName) -> usize {
30 *self.params.entry(name.clone()).or_insert_with(|| {
31 let idx = self.next_param_idx;
32 self.next_param_idx += 1;
33 idx
34 })
35 }
36}
37
38pub struct RenderContext {
43 params: RefCell<ParamState>,
45}
46
47impl RenderContext {
48 pub fn new() -> Self {
49 Self {
50 params: RefCell::new(ParamState::new()),
51 }
52 }
53
54 fn param_idx(&self, name: &ParamName) -> usize {
56 self.params.borrow_mut().get_or_insert(name)
57 }
58
59 fn into_params(self) -> Vec<ParamName> {
61 self.params.into_inner().params.into_keys().collect()
62 }
63}
64
65impl Default for RenderContext {
66 fn default() -> Self {
67 Self::new()
68 }
69}
70
71pub struct Fmt<'a, T: Render>(
75 &'a RenderContext,
77 &'a T,
79);
80
81impl<T: Render> fmt::Display for Fmt<'_, T> {
82 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83 self.1.render(self.0, f)
84 }
85}
86
87pub trait Render {
93 fn render(&self, ctx: &RenderContext, f: &mut fmt::Formatter<'_>) -> fmt::Result;
94}
95
96impl Render for Expr {
97 fn render(&self, ctx: &RenderContext, f: &mut fmt::Formatter<'_>) -> fmt::Result {
98 match self {
99 Expr::Param(name) => {
100 let idx = ctx.param_idx(name);
101 write!(f, "${idx}")
102 }
103 Expr::Column(col) => col.render(ctx, f),
104 Expr::String(s) => {
105 let escaped = escape_string(s);
106 write!(f, "{escaped}")
107 }
108 Expr::Int(n) => write!(f, "{n}"),
109 Expr::Bool(b) => write!(f, "{}", if *b { "TRUE" } else { "FALSE" }),
110 Expr::Null => write!(f, "NULL"),
111 Expr::Now => write!(f, "NOW()"),
112 Expr::Default => write!(f, "DEFAULT"),
113 Expr::BinOp { left, op, right } => {
114 let left = Fmt(ctx, left.as_ref());
115 let right = Fmt(ctx, right.as_ref());
116 let op = op.as_str();
117 write!(f, "{left} {op} {right}")
118 }
119 Expr::IsNull { expr, negated } => {
120 let expr = Fmt(ctx, expr.as_ref());
121 let suffix = if *negated { " IS NOT NULL" } else { " IS NULL" };
122 write!(f, "{expr}{suffix}")
123 }
124 Expr::Like { expr, pattern } => {
125 let expr = Fmt(ctx, expr.as_ref());
126 let pattern = Fmt(ctx, pattern.as_ref());
127 write!(f, "{expr} LIKE {pattern}")
128 }
129 Expr::ILike { expr, pattern } => {
130 let expr = Fmt(ctx, expr.as_ref());
131 let pattern = Fmt(ctx, pattern.as_ref());
132 write!(f, "{expr} ILIKE {pattern}")
133 }
134 Expr::Any { expr, array } => {
135 let expr = Fmt(ctx, expr.as_ref());
136 let array = Fmt(ctx, array.as_ref());
137 write!(f, "{expr} = ANY({array})")
138 }
139 Expr::JsonGet { expr, key } => {
140 let expr = Fmt(ctx, expr.as_ref());
141 let key = Fmt(ctx, key.as_ref());
142 write!(f, "{expr} -> {key}")
143 }
144 Expr::JsonGetText { expr, key } => {
145 let expr = Fmt(ctx, expr.as_ref());
146 let key = Fmt(ctx, key.as_ref());
147 write!(f, "{expr} ->> {key}")
148 }
149 Expr::Contains { expr, value } => {
150 let expr = Fmt(ctx, expr.as_ref());
151 let value = Fmt(ctx, value.as_ref());
152 write!(f, "{expr} @> {value}")
153 }
154 Expr::KeyExists { expr, key } => {
155 let expr = Fmt(ctx, expr.as_ref());
156 let key = Fmt(ctx, key.as_ref());
157 write!(f, "{expr} ? {key}")
158 }
159 Expr::Cast { expr, pg_type } => {
160 let expr = Fmt(ctx, expr.as_ref());
161 write!(f, "{expr}::{}", pg_type.as_str())
162 }
163 Expr::Excluded(column) => {
164 let column = Ident(column.as_str());
165 write!(f, "EXCLUDED.{column}")
166 }
167 Expr::FnCall { name, args } => {
168 write!(f, "{name}(")?;
169 for (i, arg) in args.iter().enumerate() {
170 if i > 0 {
171 write!(f, ", ")?;
172 }
173 write!(f, "{}", Fmt(ctx, arg))?;
174 }
175 write!(f, ")")
176 }
177 Expr::Count { table } => {
178 let table = Ident(table.as_str());
179 write!(f, "COUNT({table}.*)")
180 }
181 Expr::Raw(s) => write!(f, "{s}"),
182 }
183 }
184}
185
186impl Render for ColumnRef {
187 fn render(&self, _ctx: &RenderContext, f: &mut fmt::Formatter<'_>) -> fmt::Result {
188 if let Some(table) = &self.table {
189 let table = Ident(table.as_str());
190 write!(f, "{table}.")?;
191 }
192 let column = Ident(self.column.as_str());
193 write!(f, "{column}")
194 }
195}
196
197impl Render for SelectStmt {
198 fn render(&self, ctx: &RenderContext, f: &mut fmt::Formatter<'_>) -> fmt::Result {
199 write!(f, "SELECT")?;
200
201 if !self.distinct_on.is_empty() {
203 write!(f, " DISTINCT ON (")?;
204 for (i, expr) in self.distinct_on.iter().enumerate() {
205 if i > 0 {
206 write!(f, ", ")?;
207 }
208 write!(f, "{}", Fmt(ctx, expr))?;
209 }
210 write!(f, ")")?;
211 } else if self.distinct {
212 write!(f, " DISTINCT")?;
213 }
214
215 if self.columns.is_empty() {
217 write!(f, " *")?;
218 } else {
219 for (i, col) in self.columns.iter().enumerate() {
220 if i > 0 {
221 write!(f, ",")?;
222 }
223 write!(f, " {}", Fmt(ctx, col))?;
224 }
225 }
226
227 if let Some(from) = &self.from {
229 let table = Ident(from.table.as_str());
230 write!(f, "\nFROM {table}")?;
231 if let Some(alias) = &from.alias {
232 let alias = Ident(alias.as_str());
233 write!(f, " {alias}")?;
234 }
235 }
236
237 for join in &self.joins {
239 let kind = join.kind.as_str();
240 let table = Ident(join.table.as_str());
241 write!(f, "\n{kind} {table}")?;
242 if let Some(alias) = &join.alias {
243 let alias = Ident(alias.as_str());
244 write!(f, " {alias}")?;
245 }
246 let on = Fmt(ctx, &join.on);
247 write!(f, " ON {on}")?;
248 }
249
250 if let Some(where_) = &self.where_ {
252 let where_ = Fmt(ctx, where_);
253 write!(f, "\nWHERE {where_}")?;
254 }
255
256 if !self.order_by.is_empty() {
258 write!(f, "\nORDER BY ")?;
259 for (i, order) in self.order_by.iter().enumerate() {
260 if i > 0 {
261 write!(f, ", ")?;
262 }
263 let expr = Fmt(ctx, &order.expr);
264 let dir = if order.desc { " DESC" } else { " ASC" };
265 write!(f, "{expr}{dir}")?;
266 if let Some(nulls) = &order.nulls {
267 write!(
268 f,
269 "{}",
270 match nulls {
271 NullsOrder::First => " NULLS FIRST",
272 NullsOrder::Last => " NULLS LAST",
273 }
274 )?;
275 }
276 }
277 }
278
279 if let Some(limit) = &self.limit {
281 let limit = Fmt(ctx, limit);
282 write!(f, "\nLIMIT {limit}")?;
283 }
284
285 if let Some(offset) = &self.offset {
287 let offset = Fmt(ctx, offset);
288 write!(f, "\nOFFSET {offset}")?;
289 }
290
291 Ok(())
292 }
293}
294
295impl Render for SelectColumn {
296 fn render(&self, ctx: &RenderContext, f: &mut fmt::Formatter<'_>) -> fmt::Result {
297 match self {
298 SelectColumn::Expr { expr, alias } => {
299 let expr = Fmt(ctx, expr);
300 write!(f, "{expr}")?;
301 if let Some(alias) = alias {
302 let alias = Ident(alias.as_str());
303 write!(f, " AS {alias}")?;
304 }
305 Ok(())
306 }
307 SelectColumn::AllFrom(table) => {
308 let table = Ident(table.as_str());
309 write!(f, "{table}.*")
310 }
311 }
312 }
313}
314
315impl Render for InsertStmt {
316 fn render(&self, ctx: &RenderContext, f: &mut fmt::Formatter<'_>) -> fmt::Result {
317 let table = Ident(self.table.as_str());
318 write!(f, "INSERT INTO {table} (")?;
319
320 for (i, col) in self.columns.iter().enumerate() {
322 if i > 0 {
323 write!(f, ", ")?;
324 }
325 let col = Ident(col.as_str());
326 write!(f, "{col}")?;
327 }
328 write!(f, ")")?;
329
330 write!(f, "\nVALUES (")?;
332 for (i, val) in self.values.iter().enumerate() {
333 if i > 0 {
334 write!(f, ", ")?;
335 }
336 write!(f, "{}", Fmt(ctx, val))?;
337 }
338 write!(f, ")")?;
339
340 if let Some(conflict) = &self.on_conflict {
342 write!(f, "\nON CONFLICT (")?;
343 for (i, col) in conflict.columns.iter().enumerate() {
344 if i > 0 {
345 write!(f, ", ")?;
346 }
347 let col = Ident(col.as_str());
348 write!(f, "{col}")?;
349 }
350 write!(f, ")")?;
351
352 match &conflict.action {
353 ConflictAction::DoNothing => {
354 write!(f, " DO NOTHING")?;
355 }
356 ConflictAction::DoUpdate(assignments) => {
357 write!(f, " DO UPDATE SET ")?;
358 for (i, assign) in assignments.iter().enumerate() {
359 if i > 0 {
360 write!(f, ", ")?;
361 }
362 let col = Ident(assign.column.as_str());
363 let val = Fmt(ctx, &assign.value);
364 write!(f, "{col} = {val}")?;
365 }
366 }
367 }
368 }
369
370 if !self.returning.is_empty() {
372 write!(f, "\nRETURNING ")?;
373 for (i, col) in self.returning.iter().enumerate() {
374 if i > 0 {
375 write!(f, ", ")?;
376 }
377 let col = Ident(col.as_str());
378 write!(f, "{col}")?;
379 }
380 }
381
382 Ok(())
383 }
384}
385
386impl Render for UpdateStmt {
387 fn render(&self, ctx: &RenderContext, f: &mut fmt::Formatter<'_>) -> fmt::Result {
388 let table = Ident(self.table.as_str());
389 write!(f, "UPDATE {table}")?;
390
391 write!(f, "\nSET ")?;
393 for (i, assign) in self.assignments.iter().enumerate() {
394 if i > 0 {
395 write!(f, ", ")?;
396 }
397 let col = Ident(assign.column.as_str());
398 let val = Fmt(ctx, &assign.value);
399 write!(f, "{col} = {val}")?;
400 }
401
402 if let Some(where_) = &self.where_ {
404 let where_ = Fmt(ctx, where_);
405 write!(f, "\nWHERE {where_}")?;
406 }
407
408 if !self.returning.is_empty() {
410 write!(f, "\nRETURNING ")?;
411 for (i, col) in self.returning.iter().enumerate() {
412 if i > 0 {
413 write!(f, ", ")?;
414 }
415 let col = Ident(col.as_str());
416 write!(f, "{col}")?;
417 }
418 }
419
420 Ok(())
421 }
422}
423
424impl Render for DeleteStmt {
425 fn render(&self, ctx: &RenderContext, f: &mut fmt::Formatter<'_>) -> fmt::Result {
426 let table = Ident(self.table.as_str());
427 write!(f, "DELETE FROM {table}")?;
428
429 if let Some(where_) = &self.where_ {
431 let where_ = Fmt(ctx, where_);
432 write!(f, "\nWHERE {where_}")?;
433 }
434
435 if !self.returning.is_empty() {
437 write!(f, "\nRETURNING ")?;
438 for (i, col) in self.returning.iter().enumerate() {
439 if i > 0 {
440 write!(f, ", ")?;
441 }
442 let col = Ident(col.as_str());
443 write!(f, "{col}")?;
444 }
445 }
446
447 Ok(())
448 }
449}
450
451impl Render for InsertSelectStmt {
452 fn render(&self, ctx: &RenderContext, f: &mut fmt::Formatter<'_>) -> fmt::Result {
453 let table = Ident(self.table.as_str());
454 write!(f, "INSERT INTO {table} (")?;
455
456 for (i, col) in self.columns.iter().enumerate() {
458 if i > 0 {
459 write!(f, ", ")?;
460 }
461 let col = Ident(col.as_str());
462 write!(f, "{col}")?;
463 }
464 write!(f, ")")?;
465
466 write!(f, "\nSELECT ")?;
468 for (i, expr) in self.select_exprs.iter().enumerate() {
469 if i > 0 {
470 write!(f, ", ")?;
471 }
472 write!(f, "{}", Fmt(ctx, expr))?;
473 }
474
475 write!(f, "\nFROM UNNEST(")?;
477 for (i, param) in self.unnest.params.iter().enumerate() {
478 if i > 0 {
479 write!(f, ", ")?;
480 }
481 let idx = ctx.param_idx(¶m.name.as_str().into());
482 write!(f, "${}::{}", idx, param.pg_type.as_str())?;
483 }
484 let alias = Ident(self.unnest.alias.as_str());
485 write!(f, ") AS {alias}(")?;
486 for (i, param) in self.unnest.params.iter().enumerate() {
487 if i > 0 {
488 write!(f, ", ")?;
489 }
490 write!(f, "{}", param.name.as_str())?;
491 }
492 write!(f, ")")?;
493
494 if let Some(conflict) = &self.on_conflict {
496 write!(f, "\nON CONFLICT (")?;
497 for (i, col) in conflict.columns.iter().enumerate() {
498 if i > 0 {
499 write!(f, ", ")?;
500 }
501 let col = Ident(col.as_str());
502 write!(f, "{col}")?;
503 }
504 write!(f, ")")?;
505
506 match &conflict.action {
507 ConflictAction::DoNothing => {
508 write!(f, " DO NOTHING")?;
509 }
510 ConflictAction::DoUpdate(assignments) => {
511 write!(f, " DO UPDATE SET ")?;
512 for (i, assign) in assignments.iter().enumerate() {
513 if i > 0 {
514 write!(f, ", ")?;
515 }
516 let col = Ident(assign.column.as_str());
517 let val = Fmt(ctx, &assign.value);
518 write!(f, "{col} = {val}")?;
519 }
520 }
521 }
522 }
523
524 if !self.returning.is_empty() {
526 write!(f, "\nRETURNING ")?;
527 for (i, col) in self.returning.iter().enumerate() {
528 if i > 0 {
529 write!(f, ", ")?;
530 }
531 let col = Ident(col.as_str());
532 write!(f, "{col}")?;
533 }
534 }
535
536 Ok(())
537 }
538}
539
540impl Render for Stmt {
541 fn render(&self, ctx: &RenderContext, f: &mut fmt::Formatter<'_>) -> fmt::Result {
542 match self {
543 Stmt::Select(s) => s.render(ctx, f),
544 Stmt::Insert(s) => s.render(ctx, f),
545 Stmt::InsertSelect(s) => s.render(ctx, f),
546 Stmt::Update(s) => s.render(ctx, f),
547 Stmt::Delete(s) => s.render(ctx, f),
548 }
549 }
550}
551
552pub fn render(stmt: &impl Render) -> RenderedSql {
558 let ctx = RenderContext::new();
559 let sql = format!("{}", Fmt(&ctx, stmt));
560 RenderedSql {
561 sql,
562 params: ctx.into_params(),
563 }
564}
565
566#[cfg(test)]
567mod tests;