1use kyu_common::id::{PropertyId, TableId};
8use kyu_parser::ast::{BinaryOp, ComparisonOp, StringOp, UnaryOp};
9use kyu_types::{LogicalType, TypedValue};
10use smol_str::SmolStr;
11
12#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
14pub struct FunctionId(pub u32);
15
16#[derive(Clone, Debug)]
19pub enum BoundExpression {
20 Literal {
21 value: TypedValue,
22 result_type: LogicalType,
23 },
24 Variable {
25 index: u32,
26 result_type: LogicalType,
27 },
28 Property {
29 object: Box<BoundExpression>,
30 property_id: PropertyId,
31 property_name: SmolStr,
32 result_type: LogicalType,
33 },
34 Parameter {
35 name: SmolStr,
36 index: u32,
37 result_type: LogicalType,
38 },
39 UnaryOp {
40 op: UnaryOp,
41 operand: Box<BoundExpression>,
42 result_type: LogicalType,
43 },
44 BinaryOp {
45 op: BinaryOp,
46 left: Box<BoundExpression>,
47 right: Box<BoundExpression>,
48 result_type: LogicalType,
49 },
50 Comparison {
51 op: ComparisonOp,
52 left: Box<BoundExpression>,
53 right: Box<BoundExpression>,
54 },
55 IsNull {
56 expr: Box<BoundExpression>,
57 negated: bool,
58 },
59 InList {
60 expr: Box<BoundExpression>,
61 list: Vec<BoundExpression>,
62 negated: bool,
63 },
64 FunctionCall {
65 function_id: FunctionId,
66 function_name: SmolStr,
67 args: Vec<BoundExpression>,
68 distinct: bool,
69 result_type: LogicalType,
70 },
71 CountStar,
72 Case {
73 operand: Option<Box<BoundExpression>>,
74 whens: Vec<(BoundExpression, BoundExpression)>,
75 else_expr: Option<Box<BoundExpression>>,
76 result_type: LogicalType,
77 },
78 ListLiteral {
79 elements: Vec<BoundExpression>,
80 result_type: LogicalType,
81 },
82 MapLiteral {
83 entries: Vec<(BoundExpression, BoundExpression)>,
84 result_type: LogicalType,
85 },
86 Subscript {
87 expr: Box<BoundExpression>,
88 index: Box<BoundExpression>,
89 result_type: LogicalType,
90 },
91 Slice {
92 expr: Box<BoundExpression>,
93 from: Option<Box<BoundExpression>>,
94 to: Option<Box<BoundExpression>>,
95 result_type: LogicalType,
96 },
97 StringOp {
98 op: StringOp,
99 left: Box<BoundExpression>,
100 right: Box<BoundExpression>,
101 },
102 Cast {
103 expr: Box<BoundExpression>,
104 target_type: LogicalType,
105 },
106 HasLabel {
107 expr: Box<BoundExpression>,
108 table_ids: Vec<TableId>,
109 },
110}
111
112impl BoundExpression {
113 pub fn result_type(&self) -> &LogicalType {
115 match self {
116 Self::Literal { result_type, .. }
117 | Self::Variable { result_type, .. }
118 | Self::Property { result_type, .. }
119 | Self::Parameter { result_type, .. }
120 | Self::UnaryOp { result_type, .. }
121 | Self::BinaryOp { result_type, .. }
122 | Self::FunctionCall { result_type, .. }
123 | Self::Case { result_type, .. }
124 | Self::ListLiteral { result_type, .. }
125 | Self::MapLiteral { result_type, .. }
126 | Self::Subscript { result_type, .. }
127 | Self::Slice { result_type, .. } => result_type,
128
129 Self::Comparison { .. }
130 | Self::IsNull { .. }
131 | Self::InList { .. }
132 | Self::StringOp { .. }
133 | Self::HasLabel { .. } => &LogicalType::Bool,
134
135 Self::CountStar => &LogicalType::Int64,
136
137 Self::Cast { target_type, .. } => target_type,
138 }
139 }
140
141 pub fn is_constant(&self) -> bool {
143 matches!(self, Self::Literal { .. })
144 }
145
146 pub fn is_aggregate(&self) -> bool {
148 matches!(
149 self,
150 Self::CountStar | Self::FunctionCall { distinct: true, .. }
151 )
152 }
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158
159 fn lit_int(v: i64) -> BoundExpression {
160 BoundExpression::Literal {
161 value: TypedValue::Int64(v),
162 result_type: LogicalType::Int64,
163 }
164 }
165
166 fn lit_str(s: &str) -> BoundExpression {
167 BoundExpression::Literal {
168 value: TypedValue::String(SmolStr::new(s)),
169 result_type: LogicalType::String,
170 }
171 }
172
173 fn lit_bool(v: bool) -> BoundExpression {
174 BoundExpression::Literal {
175 value: TypedValue::Bool(v),
176 result_type: LogicalType::Bool,
177 }
178 }
179
180 #[test]
181 fn literal_result_type() {
182 assert_eq!(lit_int(42).result_type(), &LogicalType::Int64);
183 assert_eq!(lit_str("hello").result_type(), &LogicalType::String);
184 assert_eq!(lit_bool(true).result_type(), &LogicalType::Bool);
185 }
186
187 #[test]
188 fn variable_result_type() {
189 let var = BoundExpression::Variable {
190 index: 0,
191 result_type: LogicalType::Node,
192 };
193 assert_eq!(var.result_type(), &LogicalType::Node);
194 }
195
196 #[test]
197 fn comparison_always_bool() {
198 let cmp = BoundExpression::Comparison {
199 op: ComparisonOp::Gt,
200 left: Box::new(lit_int(1)),
201 right: Box::new(lit_int(2)),
202 };
203 assert_eq!(cmp.result_type(), &LogicalType::Bool);
204 }
205
206 #[test]
207 fn is_null_always_bool() {
208 let expr = BoundExpression::IsNull {
209 expr: Box::new(lit_int(1)),
210 negated: false,
211 };
212 assert_eq!(expr.result_type(), &LogicalType::Bool);
213 }
214
215 #[test]
216 fn string_op_always_bool() {
217 let expr = BoundExpression::StringOp {
218 op: StringOp::StartsWith,
219 left: Box::new(lit_str("hello")),
220 right: Box::new(lit_str("he")),
221 };
222 assert_eq!(expr.result_type(), &LogicalType::Bool);
223 }
224
225 #[test]
226 fn count_star_always_int64() {
227 assert_eq!(
228 BoundExpression::CountStar.result_type(),
229 &LogicalType::Int64
230 );
231 }
232
233 #[test]
234 fn cast_result_type() {
235 let expr = BoundExpression::Cast {
236 expr: Box::new(lit_int(42)),
237 target_type: LogicalType::Double,
238 };
239 assert_eq!(expr.result_type(), &LogicalType::Double);
240 }
241
242 #[test]
243 fn is_constant() {
244 assert!(lit_int(1).is_constant());
245 assert!(
246 !BoundExpression::Variable {
247 index: 0,
248 result_type: LogicalType::Int64,
249 }
250 .is_constant()
251 );
252 }
253
254 #[test]
255 fn has_label_always_bool() {
256 let expr = BoundExpression::HasLabel {
257 expr: Box::new(BoundExpression::Variable {
258 index: 0,
259 result_type: LogicalType::Node,
260 }),
261 table_ids: vec![TableId(1)],
262 };
263 assert_eq!(expr.result_type(), &LogicalType::Bool);
264 }
265
266 #[test]
267 fn function_call_result_type() {
268 let expr = BoundExpression::FunctionCall {
269 function_id: FunctionId(0),
270 function_name: SmolStr::new("abs"),
271 args: vec![lit_int(-5)],
272 distinct: false,
273 result_type: LogicalType::Int64,
274 };
275 assert_eq!(expr.result_type(), &LogicalType::Int64);
276 }
277}