1use kyu_common::id::{PropertyId, TableId};
8use kyu_parser::ast::{BinaryOp, ComparisonOp, StringOp, UnaryOp};
9use kyu_types::{LogicalType, TypedValue};
10use smol_str::SmolStr;
11
12#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
14pub struct FunctionId(pub u32);
15
16#[derive(Clone, Debug)]
19pub enum BoundExpression {
20 Literal {
21 value: TypedValue,
22 result_type: LogicalType,
23 },
24 Variable {
25 index: u32,
26 result_type: LogicalType,
27 },
28 Property {
29 object: Box<BoundExpression>,
30 property_id: PropertyId,
31 property_name: SmolStr,
32 result_type: LogicalType,
33 },
34 Parameter {
35 name: SmolStr,
36 index: u32,
37 result_type: LogicalType,
38 },
39 UnaryOp {
40 op: UnaryOp,
41 operand: Box<BoundExpression>,
42 result_type: LogicalType,
43 },
44 BinaryOp {
45 op: BinaryOp,
46 left: Box<BoundExpression>,
47 right: Box<BoundExpression>,
48 result_type: LogicalType,
49 },
50 Comparison {
51 op: ComparisonOp,
52 left: Box<BoundExpression>,
53 right: Box<BoundExpression>,
54 },
55 IsNull {
56 expr: Box<BoundExpression>,
57 negated: bool,
58 },
59 InList {
60 expr: Box<BoundExpression>,
61 list: Vec<BoundExpression>,
62 negated: bool,
63 },
64 FunctionCall {
65 function_id: FunctionId,
66 function_name: SmolStr,
67 args: Vec<BoundExpression>,
68 distinct: bool,
69 result_type: LogicalType,
70 },
71 CountStar,
72 Case {
73 operand: Option<Box<BoundExpression>>,
74 whens: Vec<(BoundExpression, BoundExpression)>,
75 else_expr: Option<Box<BoundExpression>>,
76 result_type: LogicalType,
77 },
78 ListLiteral {
79 elements: Vec<BoundExpression>,
80 result_type: LogicalType,
81 },
82 MapLiteral {
83 entries: Vec<(BoundExpression, BoundExpression)>,
84 result_type: LogicalType,
85 },
86 Subscript {
87 expr: Box<BoundExpression>,
88 index: Box<BoundExpression>,
89 result_type: LogicalType,
90 },
91 Slice {
92 expr: Box<BoundExpression>,
93 from: Option<Box<BoundExpression>>,
94 to: Option<Box<BoundExpression>>,
95 result_type: LogicalType,
96 },
97 StringOp {
98 op: StringOp,
99 left: Box<BoundExpression>,
100 right: Box<BoundExpression>,
101 },
102 Cast {
103 expr: Box<BoundExpression>,
104 target_type: LogicalType,
105 },
106 HasLabel {
107 expr: Box<BoundExpression>,
108 table_ids: Vec<TableId>,
109 },
110}
111
112impl BoundExpression {
113 pub fn result_type(&self) -> &LogicalType {
115 match self {
116 Self::Literal { result_type, .. }
117 | Self::Variable { result_type, .. }
118 | Self::Property { result_type, .. }
119 | Self::Parameter { result_type, .. }
120 | Self::UnaryOp { result_type, .. }
121 | Self::BinaryOp { result_type, .. }
122 | Self::FunctionCall { result_type, .. }
123 | Self::Case { result_type, .. }
124 | Self::ListLiteral { result_type, .. }
125 | Self::MapLiteral { result_type, .. }
126 | Self::Subscript { result_type, .. }
127 | Self::Slice { result_type, .. } => result_type,
128
129 Self::Comparison { .. }
130 | Self::IsNull { .. }
131 | Self::InList { .. }
132 | Self::StringOp { .. }
133 | Self::HasLabel { .. } => &LogicalType::Bool,
134
135 Self::CountStar => &LogicalType::Int64,
136
137 Self::Cast { target_type, .. } => target_type,
138 }
139 }
140
141 pub fn is_constant(&self) -> bool {
143 matches!(self, Self::Literal { .. })
144 }
145
146 pub fn is_aggregate(&self) -> bool {
148 matches!(self, Self::CountStar | Self::FunctionCall { distinct: true, .. })
149 }
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155
156 fn lit_int(v: i64) -> BoundExpression {
157 BoundExpression::Literal {
158 value: TypedValue::Int64(v),
159 result_type: LogicalType::Int64,
160 }
161 }
162
163 fn lit_str(s: &str) -> BoundExpression {
164 BoundExpression::Literal {
165 value: TypedValue::String(SmolStr::new(s)),
166 result_type: LogicalType::String,
167 }
168 }
169
170 fn lit_bool(v: bool) -> BoundExpression {
171 BoundExpression::Literal {
172 value: TypedValue::Bool(v),
173 result_type: LogicalType::Bool,
174 }
175 }
176
177 #[test]
178 fn literal_result_type() {
179 assert_eq!(lit_int(42).result_type(), &LogicalType::Int64);
180 assert_eq!(lit_str("hello").result_type(), &LogicalType::String);
181 assert_eq!(lit_bool(true).result_type(), &LogicalType::Bool);
182 }
183
184 #[test]
185 fn variable_result_type() {
186 let var = BoundExpression::Variable {
187 index: 0,
188 result_type: LogicalType::Node,
189 };
190 assert_eq!(var.result_type(), &LogicalType::Node);
191 }
192
193 #[test]
194 fn comparison_always_bool() {
195 let cmp = BoundExpression::Comparison {
196 op: ComparisonOp::Gt,
197 left: Box::new(lit_int(1)),
198 right: Box::new(lit_int(2)),
199 };
200 assert_eq!(cmp.result_type(), &LogicalType::Bool);
201 }
202
203 #[test]
204 fn is_null_always_bool() {
205 let expr = BoundExpression::IsNull {
206 expr: Box::new(lit_int(1)),
207 negated: false,
208 };
209 assert_eq!(expr.result_type(), &LogicalType::Bool);
210 }
211
212 #[test]
213 fn string_op_always_bool() {
214 let expr = BoundExpression::StringOp {
215 op: StringOp::StartsWith,
216 left: Box::new(lit_str("hello")),
217 right: Box::new(lit_str("he")),
218 };
219 assert_eq!(expr.result_type(), &LogicalType::Bool);
220 }
221
222 #[test]
223 fn count_star_always_int64() {
224 assert_eq!(BoundExpression::CountStar.result_type(), &LogicalType::Int64);
225 }
226
227 #[test]
228 fn cast_result_type() {
229 let expr = BoundExpression::Cast {
230 expr: Box::new(lit_int(42)),
231 target_type: LogicalType::Double,
232 };
233 assert_eq!(expr.result_type(), &LogicalType::Double);
234 }
235
236 #[test]
237 fn is_constant() {
238 assert!(lit_int(1).is_constant());
239 assert!(!BoundExpression::Variable {
240 index: 0,
241 result_type: LogicalType::Int64,
242 }
243 .is_constant());
244 }
245
246 #[test]
247 fn has_label_always_bool() {
248 let expr = BoundExpression::HasLabel {
249 expr: Box::new(BoundExpression::Variable {
250 index: 0,
251 result_type: LogicalType::Node,
252 }),
253 table_ids: vec![TableId(1)],
254 };
255 assert_eq!(expr.result_type(), &LogicalType::Bool);
256 }
257
258 #[test]
259 fn function_call_result_type() {
260 let expr = BoundExpression::FunctionCall {
261 function_id: FunctionId(0),
262 function_name: SmolStr::new("abs"),
263 args: vec![lit_int(-5)],
264 distinct: false,
265 result_type: LogicalType::Int64,
266 };
267 assert_eq!(expr.result_type(), &LogicalType::Int64);
268 }
269}