Skip to main content

dibs_sql/render/
mod.rs

1//! Render SQL AST to string.
2
3use std::cell::RefCell;
4use std::fmt;
5
6use indexmap::IndexMap;
7
8use crate::expr::{ColumnRef, Expr};
9use crate::stmt::*;
10use crate::{Ident, ParamName, RenderedSql, escape_string};
11
12/// Mutable parameter tracking state.
13struct ParamState {
14    /// Named parameters mapped to their assigned positional index.
15    params: IndexMap<ParamName, usize>,
16    /// Next parameter index to assign (starts at 1 for `$1`).
17    next_param_idx: usize,
18}
19
20impl ParamState {
21    fn new() -> Self {
22        Self {
23            params: IndexMap::new(),
24            next_param_idx: 1,
25        }
26    }
27
28    /// Get or create a parameter index.
29    fn get_or_insert(&mut self, name: &ParamName) -> usize {
30        *self.params.entry(name.clone()).or_insert_with(|| {
31            let idx = self.next_param_idx;
32            self.next_param_idx += 1;
33            idx
34        })
35    }
36}
37
38/// Rendering context that tracks parameter assignment.
39///
40/// Uses interior mutability (`RefCell`) so that `Render::render` can take `&self`,
41/// enabling the `Fmt` wrapper to implement `Display`.
42pub struct RenderContext {
43    /// Parameter tracking state, wrapped for interior mutability.
44    params: RefCell<ParamState>,
45}
46
47impl RenderContext {
48    pub fn new() -> Self {
49        Self {
50            params: RefCell::new(ParamState::new()),
51        }
52    }
53
54    /// Get or create a parameter placeholder index.
55    fn param_idx(&self, name: &ParamName) -> usize {
56        self.params.borrow_mut().get_or_insert(name)
57    }
58
59    /// Finish rendering and return the collected params.
60    fn into_params(self) -> Vec<ParamName> {
61        self.params.into_inner().params.into_keys().collect()
62    }
63}
64
65impl Default for RenderContext {
66    fn default() -> Self {
67        Self::new()
68    }
69}
70
71/// Wrapper for rendering a `Render` type via `Display`.
72///
73/// Allows using `write!(f, "{}", Fmt(ctx, &expr))` in format strings.
74pub struct Fmt<'a, T: Render>(
75    /// The rendering context for parameter tracking.
76    &'a RenderContext,
77    /// The value to render.
78    &'a T,
79);
80
81impl<T: Render> fmt::Display for Fmt<'_, T> {
82    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83        self.1.render(self.0, f)
84    }
85}
86
87// ============================================================================
88// Render implementations
89// ============================================================================
90
91/// Trait for types that can be rendered to SQL.
92pub trait Render {
93    fn render(&self, ctx: &RenderContext, f: &mut fmt::Formatter<'_>) -> fmt::Result;
94}
95
96impl Render for Expr {
97    fn render(&self, ctx: &RenderContext, f: &mut fmt::Formatter<'_>) -> fmt::Result {
98        match self {
99            Expr::Param(name) => {
100                let idx = ctx.param_idx(name);
101                write!(f, "${idx}")
102            }
103            Expr::Column(col) => col.render(ctx, f),
104            Expr::String(s) => {
105                let escaped = escape_string(s);
106                write!(f, "{escaped}")
107            }
108            Expr::Int(n) => write!(f, "{n}"),
109            Expr::Bool(b) => write!(f, "{}", if *b { "TRUE" } else { "FALSE" }),
110            Expr::Null => write!(f, "NULL"),
111            Expr::Now => write!(f, "NOW()"),
112            Expr::Default => write!(f, "DEFAULT"),
113            Expr::BinOp { left, op, right } => {
114                let left = Fmt(ctx, left.as_ref());
115                let right = Fmt(ctx, right.as_ref());
116                let op = op.as_str();
117                write!(f, "{left} {op} {right}")
118            }
119            Expr::IsNull { expr, negated } => {
120                let expr = Fmt(ctx, expr.as_ref());
121                let suffix = if *negated { " IS NOT NULL" } else { " IS NULL" };
122                write!(f, "{expr}{suffix}")
123            }
124            Expr::Like { expr, pattern } => {
125                let expr = Fmt(ctx, expr.as_ref());
126                let pattern = Fmt(ctx, pattern.as_ref());
127                write!(f, "{expr} LIKE {pattern}")
128            }
129            Expr::ILike { expr, pattern } => {
130                let expr = Fmt(ctx, expr.as_ref());
131                let pattern = Fmt(ctx, pattern.as_ref());
132                write!(f, "{expr} ILIKE {pattern}")
133            }
134            Expr::Any { expr, array } => {
135                let expr = Fmt(ctx, expr.as_ref());
136                let array = Fmt(ctx, array.as_ref());
137                write!(f, "{expr} = ANY({array})")
138            }
139            Expr::JsonGet { expr, key } => {
140                let expr = Fmt(ctx, expr.as_ref());
141                let key = Fmt(ctx, key.as_ref());
142                write!(f, "{expr} -> {key}")
143            }
144            Expr::JsonGetText { expr, key } => {
145                let expr = Fmt(ctx, expr.as_ref());
146                let key = Fmt(ctx, key.as_ref());
147                write!(f, "{expr} ->> {key}")
148            }
149            Expr::Contains { expr, value } => {
150                let expr = Fmt(ctx, expr.as_ref());
151                let value = Fmt(ctx, value.as_ref());
152                write!(f, "{expr} @> {value}")
153            }
154            Expr::KeyExists { expr, key } => {
155                let expr = Fmt(ctx, expr.as_ref());
156                let key = Fmt(ctx, key.as_ref());
157                write!(f, "{expr} ? {key}")
158            }
159            Expr::Cast { expr, pg_type } => {
160                let expr = Fmt(ctx, expr.as_ref());
161                write!(f, "{expr}::{}", pg_type.as_str())
162            }
163            Expr::Excluded(column) => {
164                let column = Ident(column.as_str());
165                write!(f, "EXCLUDED.{column}")
166            }
167            Expr::FnCall { name, args } => {
168                write!(f, "{name}(")?;
169                for (i, arg) in args.iter().enumerate() {
170                    if i > 0 {
171                        write!(f, ", ")?;
172                    }
173                    write!(f, "{}", Fmt(ctx, arg))?;
174                }
175                write!(f, ")")
176            }
177            Expr::Count { table } => {
178                let table = Ident(table.as_str());
179                write!(f, "COUNT({table}.*)")
180            }
181            Expr::Raw(s) => write!(f, "{s}"),
182        }
183    }
184}
185
186impl Render for ColumnRef {
187    fn render(&self, _ctx: &RenderContext, f: &mut fmt::Formatter<'_>) -> fmt::Result {
188        if let Some(table) = &self.table {
189            let table = Ident(table.as_str());
190            write!(f, "{table}.")?;
191        }
192        let column = Ident(self.column.as_str());
193        write!(f, "{column}")
194    }
195}
196
197impl Render for SelectStmt {
198    fn render(&self, ctx: &RenderContext, f: &mut fmt::Formatter<'_>) -> fmt::Result {
199        write!(f, "SELECT")?;
200
201        // DISTINCT ON (takes precedence over DISTINCT)
202        if !self.distinct_on.is_empty() {
203            write!(f, " DISTINCT ON (")?;
204            for (i, expr) in self.distinct_on.iter().enumerate() {
205                if i > 0 {
206                    write!(f, ", ")?;
207                }
208                write!(f, "{}", Fmt(ctx, expr))?;
209            }
210            write!(f, ")")?;
211        } else if self.distinct {
212            write!(f, " DISTINCT")?;
213        }
214
215        // Columns
216        if self.columns.is_empty() {
217            write!(f, " *")?;
218        } else {
219            for (i, col) in self.columns.iter().enumerate() {
220                if i > 0 {
221                    write!(f, ",")?;
222                }
223                write!(f, " {}", Fmt(ctx, col))?;
224            }
225        }
226
227        // FROM
228        if let Some(from) = &self.from {
229            let table = Ident(from.table.as_str());
230            write!(f, "\nFROM {table}")?;
231            if let Some(alias) = &from.alias {
232                let alias = Ident(alias.as_str());
233                write!(f, " {alias}")?;
234            }
235        }
236
237        // JOINs
238        for join in &self.joins {
239            let kind = join.kind.as_str();
240            let table = Ident(join.table.as_str());
241            write!(f, "\n{kind} {table}")?;
242            if let Some(alias) = &join.alias {
243                let alias = Ident(alias.as_str());
244                write!(f, " {alias}")?;
245            }
246            let on = Fmt(ctx, &join.on);
247            write!(f, " ON {on}")?;
248        }
249
250        // WHERE
251        if let Some(where_) = &self.where_ {
252            let where_ = Fmt(ctx, where_);
253            write!(f, "\nWHERE {where_}")?;
254        }
255
256        // ORDER BY
257        if !self.order_by.is_empty() {
258            write!(f, "\nORDER BY ")?;
259            for (i, order) in self.order_by.iter().enumerate() {
260                if i > 0 {
261                    write!(f, ", ")?;
262                }
263                let expr = Fmt(ctx, &order.expr);
264                let dir = if order.desc { " DESC" } else { " ASC" };
265                write!(f, "{expr}{dir}")?;
266                if let Some(nulls) = &order.nulls {
267                    write!(
268                        f,
269                        "{}",
270                        match nulls {
271                            NullsOrder::First => " NULLS FIRST",
272                            NullsOrder::Last => " NULLS LAST",
273                        }
274                    )?;
275                }
276            }
277        }
278
279        // LIMIT
280        if let Some(limit) = &self.limit {
281            let limit = Fmt(ctx, limit);
282            write!(f, "\nLIMIT {limit}")?;
283        }
284
285        // OFFSET
286        if let Some(offset) = &self.offset {
287            let offset = Fmt(ctx, offset);
288            write!(f, "\nOFFSET {offset}")?;
289        }
290
291        Ok(())
292    }
293}
294
295impl Render for SelectColumn {
296    fn render(&self, ctx: &RenderContext, f: &mut fmt::Formatter<'_>) -> fmt::Result {
297        match self {
298            SelectColumn::Expr { expr, alias } => {
299                let expr = Fmt(ctx, expr);
300                write!(f, "{expr}")?;
301                if let Some(alias) = alias {
302                    let alias = Ident(alias.as_str());
303                    write!(f, " AS {alias}")?;
304                }
305                Ok(())
306            }
307            SelectColumn::AllFrom(table) => {
308                let table = Ident(table.as_str());
309                write!(f, "{table}.*")
310            }
311        }
312    }
313}
314
315impl Render for InsertStmt {
316    fn render(&self, ctx: &RenderContext, f: &mut fmt::Formatter<'_>) -> fmt::Result {
317        let table = Ident(self.table.as_str());
318        write!(f, "INSERT INTO {table} (")?;
319
320        // Columns
321        for (i, col) in self.columns.iter().enumerate() {
322            if i > 0 {
323                write!(f, ", ")?;
324            }
325            let col = Ident(col.as_str());
326            write!(f, "{col}")?;
327        }
328        write!(f, ")")?;
329
330        // VALUES
331        write!(f, "\nVALUES (")?;
332        for (i, val) in self.values.iter().enumerate() {
333            if i > 0 {
334                write!(f, ", ")?;
335            }
336            write!(f, "{}", Fmt(ctx, val))?;
337        }
338        write!(f, ")")?;
339
340        // ON CONFLICT
341        if let Some(conflict) = &self.on_conflict {
342            write!(f, "\nON CONFLICT (")?;
343            for (i, col) in conflict.columns.iter().enumerate() {
344                if i > 0 {
345                    write!(f, ", ")?;
346                }
347                let col = Ident(col.as_str());
348                write!(f, "{col}")?;
349            }
350            write!(f, ")")?;
351
352            match &conflict.action {
353                ConflictAction::DoNothing => {
354                    write!(f, " DO NOTHING")?;
355                }
356                ConflictAction::DoUpdate(assignments) => {
357                    write!(f, " DO UPDATE SET ")?;
358                    for (i, assign) in assignments.iter().enumerate() {
359                        if i > 0 {
360                            write!(f, ", ")?;
361                        }
362                        let col = Ident(assign.column.as_str());
363                        let val = Fmt(ctx, &assign.value);
364                        write!(f, "{col} = {val}")?;
365                    }
366                }
367            }
368        }
369
370        // RETURNING
371        if !self.returning.is_empty() {
372            write!(f, "\nRETURNING ")?;
373            for (i, col) in self.returning.iter().enumerate() {
374                if i > 0 {
375                    write!(f, ", ")?;
376                }
377                let col = Ident(col.as_str());
378                write!(f, "{col}")?;
379            }
380        }
381
382        Ok(())
383    }
384}
385
386impl Render for UpdateStmt {
387    fn render(&self, ctx: &RenderContext, f: &mut fmt::Formatter<'_>) -> fmt::Result {
388        let table = Ident(self.table.as_str());
389        write!(f, "UPDATE {table}")?;
390
391        // SET
392        write!(f, "\nSET ")?;
393        for (i, assign) in self.assignments.iter().enumerate() {
394            if i > 0 {
395                write!(f, ", ")?;
396            }
397            let col = Ident(assign.column.as_str());
398            let val = Fmt(ctx, &assign.value);
399            write!(f, "{col} = {val}")?;
400        }
401
402        // WHERE
403        if let Some(where_) = &self.where_ {
404            let where_ = Fmt(ctx, where_);
405            write!(f, "\nWHERE {where_}")?;
406        }
407
408        // RETURNING
409        if !self.returning.is_empty() {
410            write!(f, "\nRETURNING ")?;
411            for (i, col) in self.returning.iter().enumerate() {
412                if i > 0 {
413                    write!(f, ", ")?;
414                }
415                let col = Ident(col.as_str());
416                write!(f, "{col}")?;
417            }
418        }
419
420        Ok(())
421    }
422}
423
424impl Render for DeleteStmt {
425    fn render(&self, ctx: &RenderContext, f: &mut fmt::Formatter<'_>) -> fmt::Result {
426        let table = Ident(self.table.as_str());
427        write!(f, "DELETE FROM {table}")?;
428
429        // WHERE
430        if let Some(where_) = &self.where_ {
431            let where_ = Fmt(ctx, where_);
432            write!(f, "\nWHERE {where_}")?;
433        }
434
435        // RETURNING
436        if !self.returning.is_empty() {
437            write!(f, "\nRETURNING ")?;
438            for (i, col) in self.returning.iter().enumerate() {
439                if i > 0 {
440                    write!(f, ", ")?;
441                }
442                let col = Ident(col.as_str());
443                write!(f, "{col}")?;
444            }
445        }
446
447        Ok(())
448    }
449}
450
451impl Render for InsertSelectStmt {
452    fn render(&self, ctx: &RenderContext, f: &mut fmt::Formatter<'_>) -> fmt::Result {
453        let table = Ident(self.table.as_str());
454        write!(f, "INSERT INTO {table} (")?;
455
456        // Columns
457        for (i, col) in self.columns.iter().enumerate() {
458            if i > 0 {
459                write!(f, ", ")?;
460            }
461            let col = Ident(col.as_str());
462            write!(f, "{col}")?;
463        }
464        write!(f, ")")?;
465
466        // SELECT
467        write!(f, "\nSELECT ")?;
468        for (i, expr) in self.select_exprs.iter().enumerate() {
469            if i > 0 {
470                write!(f, ", ")?;
471            }
472            write!(f, "{}", Fmt(ctx, expr))?;
473        }
474
475        // FROM UNNEST
476        write!(f, "\nFROM UNNEST(")?;
477        for (i, param) in self.unnest.params.iter().enumerate() {
478            if i > 0 {
479                write!(f, ", ")?;
480            }
481            let idx = ctx.param_idx(&param.name.as_str().into());
482            write!(f, "${}::{}", idx, param.pg_type.as_str())?;
483        }
484        let alias = Ident(self.unnest.alias.as_str());
485        write!(f, ") AS {alias}(")?;
486        for (i, param) in self.unnest.params.iter().enumerate() {
487            if i > 0 {
488                write!(f, ", ")?;
489            }
490            write!(f, "{}", param.name.as_str())?;
491        }
492        write!(f, ")")?;
493
494        // ON CONFLICT
495        if let Some(conflict) = &self.on_conflict {
496            write!(f, "\nON CONFLICT (")?;
497            for (i, col) in conflict.columns.iter().enumerate() {
498                if i > 0 {
499                    write!(f, ", ")?;
500                }
501                let col = Ident(col.as_str());
502                write!(f, "{col}")?;
503            }
504            write!(f, ")")?;
505
506            match &conflict.action {
507                ConflictAction::DoNothing => {
508                    write!(f, " DO NOTHING")?;
509                }
510                ConflictAction::DoUpdate(assignments) => {
511                    write!(f, " DO UPDATE SET ")?;
512                    for (i, assign) in assignments.iter().enumerate() {
513                        if i > 0 {
514                            write!(f, ", ")?;
515                        }
516                        let col = Ident(assign.column.as_str());
517                        let val = Fmt(ctx, &assign.value);
518                        write!(f, "{col} = {val}")?;
519                    }
520                }
521            }
522        }
523
524        // RETURNING
525        if !self.returning.is_empty() {
526            write!(f, "\nRETURNING ")?;
527            for (i, col) in self.returning.iter().enumerate() {
528                if i > 0 {
529                    write!(f, ", ")?;
530                }
531                let col = Ident(col.as_str());
532                write!(f, "{col}")?;
533            }
534        }
535
536        Ok(())
537    }
538}
539
540impl Render for Stmt {
541    fn render(&self, ctx: &RenderContext, f: &mut fmt::Formatter<'_>) -> fmt::Result {
542        match self {
543            Stmt::Select(s) => s.render(ctx, f),
544            Stmt::Insert(s) => s.render(ctx, f),
545            Stmt::InsertSelect(s) => s.render(ctx, f),
546            Stmt::Update(s) => s.render(ctx, f),
547            Stmt::Delete(s) => s.render(ctx, f),
548        }
549    }
550}
551
552// ============================================================================
553// Convenience methods
554// ============================================================================
555
556/// Render a statement to SQL.
557pub fn render(stmt: &impl Render) -> RenderedSql {
558    let ctx = RenderContext::new();
559    let sql = format!("{}", Fmt(&ctx, stmt));
560    RenderedSql {
561        sql,
562        params: ctx.into_params(),
563    }
564}
565
566#[cfg(test)]
567mod tests;