Skip to main content

dbkit_core/
compile.rs

1use crate::expr::{BinaryOp, BoolOp, ExprNode, IntervalField, UnaryOp, Value, VectorBinaryOp};
2use crate::schema::ColumnRef;
3
4#[derive(Debug, Clone, PartialEq)]
5pub struct CompiledSql {
6    pub sql: String,
7    pub binds: Vec<Value>,
8}
9
10#[derive(Debug, Default)]
11pub struct SqlBuilder {
12    sql: String,
13    binds: Vec<Value>,
14}
15
16impl SqlBuilder {
17    pub fn new() -> Self {
18        Self::default()
19    }
20
21    pub fn push_sql(&mut self, fragment: &str) {
22        self.sql.push_str(fragment);
23    }
24
25    pub fn push_value(&mut self, value: Value) {
26        if value == Value::Null {
27            self.sql.push_str("NULL");
28            return;
29        }
30        let cast_as_vector = matches!(&value, Value::Vector(_));
31        let cast_as_interval = matches!(&value, Value::Interval(_));
32        let cast_as_enum = match &value {
33            Value::Enum { type_name, .. } => Some(*type_name),
34            _ => None,
35        };
36        let idx = if let Some(existing) = self.binds.iter().position(|item| item == &value) {
37            existing + 1
38        } else {
39            self.binds.push(value);
40            self.binds.len()
41        };
42        self.sql.push('$');
43        self.sql.push_str(&idx.to_string());
44        if cast_as_vector {
45            self.sql.push_str("::vector");
46        } else if cast_as_interval {
47            self.sql.push_str("::interval");
48        } else if let Some(type_name) = cast_as_enum {
49            self.sql.push_str("::");
50            self.sql.push_str(type_name);
51        }
52    }
53
54    pub fn push_column(&mut self, col: ColumnRef) {
55        self.sql.push_str(&col.qualified_name());
56    }
57
58    pub fn finish(self) -> CompiledSql {
59        CompiledSql {
60            sql: self.sql,
61            binds: self.binds,
62        }
63    }
64}
65
66pub trait ToSql {
67    fn to_sql(&self, builder: &mut SqlBuilder);
68}
69
70impl ToSql for ExprNode {
71    fn to_sql(&self, builder: &mut SqlBuilder) {
72        match self {
73            ExprNode::Column(col) => builder.push_column(*col),
74            ExprNode::Value(value) => builder.push_value(value.clone()),
75            ExprNode::Func { name, args } => {
76                builder.push_sql(name);
77                builder.push_sql("(");
78                for (idx, arg) in args.iter().enumerate() {
79                    if idx > 0 {
80                        builder.push_sql(", ");
81                    }
82                    arg.to_sql(builder);
83                }
84                builder.push_sql(")");
85            }
86            ExprNode::VectorBinary { left, op, right } => {
87                builder.push_sql("(");
88                left.to_sql(builder);
89                builder.push_sql(match op {
90                    VectorBinaryOp::L2Distance => " <-> ",
91                    VectorBinaryOp::CosineDistance => " <=> ",
92                    VectorBinaryOp::InnerProductDistance => " <#> ",
93                    VectorBinaryOp::L1Distance => " <+> ",
94                });
95                right.to_sql(builder);
96                builder.push_sql(")");
97            }
98            ExprNode::MakeInterval { field, value } => {
99                builder.push_sql("MAKE_INTERVAL(");
100                builder.push_sql(match field {
101                    IntervalField::Days => "days => ",
102                    IntervalField::Hours => "hours => ",
103                    IntervalField::Minutes => "mins => ",
104                    IntervalField::Seconds => "secs => ",
105                });
106                value.to_sql(builder);
107                builder.push_sql(")");
108            }
109            ExprNode::Binary { left, op, right } => {
110                builder.push_sql("(");
111                left.to_sql(builder);
112                builder.push_sql(match op {
113                    BinaryOp::Add => " + ",
114                    BinaryOp::Sub => " - ",
115                    BinaryOp::Eq => " = ",
116                    BinaryOp::Ne => " <> ",
117                    BinaryOp::IsDistinctFrom => " IS DISTINCT FROM ",
118                    BinaryOp::IsNotDistinctFrom => " IS NOT DISTINCT FROM ",
119                    BinaryOp::Lt => " < ",
120                    BinaryOp::Le => " <= ",
121                    BinaryOp::Gt => " > ",
122                    BinaryOp::Ge => " >= ",
123                });
124                right.to_sql(builder);
125                builder.push_sql(")");
126            }
127            ExprNode::Bool { left, op, right } => {
128                builder.push_sql("(");
129                left.to_sql(builder);
130                builder.push_sql(match op {
131                    BoolOp::And => " AND ",
132                    BoolOp::Or => " OR ",
133                });
134                right.to_sql(builder);
135                builder.push_sql(")");
136            }
137            ExprNode::Unary { op, expr } => {
138                builder.push_sql(match op {
139                    UnaryOp::Not => "NOT ",
140                });
141                builder.push_sql("(");
142                expr.to_sql(builder);
143                builder.push_sql(")");
144            }
145            ExprNode::In { expr, values } => {
146                if values.is_empty() {
147                    builder.push_sql("(FALSE)");
148                    return;
149                }
150                builder.push_sql("(");
151                expr.to_sql(builder);
152                builder.push_sql(" IN (");
153                for (idx, value) in values.iter().enumerate() {
154                    if idx > 0 {
155                        builder.push_sql(", ");
156                    }
157                    builder.push_value(value.clone());
158                }
159                builder.push_sql("))");
160            }
161            ExprNode::IsNull { expr, negated } => {
162                builder.push_sql("(");
163                expr.to_sql(builder);
164                if *negated {
165                    builder.push_sql(" IS NOT NULL)");
166                } else {
167                    builder.push_sql(" IS NULL)");
168                }
169            }
170            ExprNode::Like {
171                expr,
172                pattern,
173                case_insensitive,
174            } => {
175                builder.push_sql("(");
176                expr.to_sql(builder);
177                builder.push_sql(if *case_insensitive { " ILIKE " } else { " LIKE " });
178                builder.push_value(pattern.clone());
179                builder.push_sql(")");
180            }
181        }
182    }
183}