Skip to main content

kyu_expression/
bound_expr.rs

1//! Bound expression types — resolved, typed expressions ready for planning/execution.
2//!
3//! Every variant carries its pre-computed `result_type: LogicalType` so type
4//! queries are O(1) field access. Variables are referenced by u32 index (not
5//! string) for zero string comparison on the hot path.
6
7use kyu_common::id::{PropertyId, TableId};
8use kyu_parser::ast::{BinaryOp, ComparisonOp, StringOp, UnaryOp};
9use kyu_types::{LogicalType, TypedValue};
10use smol_str::SmolStr;
11
12/// Unique function identifier for O(1) registry lookup after resolution.
13#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
14pub struct FunctionId(pub u32);
15
16/// A resolved, typed expression. All names are resolved to IDs,
17/// all types are pre-computed. Consumed by the planner and executor.
18#[derive(Clone, Debug)]
19pub enum BoundExpression {
20    Literal {
21        value: TypedValue,
22        result_type: LogicalType,
23    },
24    Variable {
25        index: u32,
26        result_type: LogicalType,
27    },
28    Property {
29        object: Box<BoundExpression>,
30        property_id: PropertyId,
31        property_name: SmolStr,
32        result_type: LogicalType,
33    },
34    Parameter {
35        name: SmolStr,
36        index: u32,
37        result_type: LogicalType,
38    },
39    UnaryOp {
40        op: UnaryOp,
41        operand: Box<BoundExpression>,
42        result_type: LogicalType,
43    },
44    BinaryOp {
45        op: BinaryOp,
46        left: Box<BoundExpression>,
47        right: Box<BoundExpression>,
48        result_type: LogicalType,
49    },
50    Comparison {
51        op: ComparisonOp,
52        left: Box<BoundExpression>,
53        right: Box<BoundExpression>,
54    },
55    IsNull {
56        expr: Box<BoundExpression>,
57        negated: bool,
58    },
59    InList {
60        expr: Box<BoundExpression>,
61        list: Vec<BoundExpression>,
62        negated: bool,
63    },
64    FunctionCall {
65        function_id: FunctionId,
66        function_name: SmolStr,
67        args: Vec<BoundExpression>,
68        distinct: bool,
69        result_type: LogicalType,
70    },
71    CountStar,
72    Case {
73        operand: Option<Box<BoundExpression>>,
74        whens: Vec<(BoundExpression, BoundExpression)>,
75        else_expr: Option<Box<BoundExpression>>,
76        result_type: LogicalType,
77    },
78    ListLiteral {
79        elements: Vec<BoundExpression>,
80        result_type: LogicalType,
81    },
82    MapLiteral {
83        entries: Vec<(BoundExpression, BoundExpression)>,
84        result_type: LogicalType,
85    },
86    Subscript {
87        expr: Box<BoundExpression>,
88        index: Box<BoundExpression>,
89        result_type: LogicalType,
90    },
91    Slice {
92        expr: Box<BoundExpression>,
93        from: Option<Box<BoundExpression>>,
94        to: Option<Box<BoundExpression>>,
95        result_type: LogicalType,
96    },
97    StringOp {
98        op: StringOp,
99        left: Box<BoundExpression>,
100        right: Box<BoundExpression>,
101    },
102    Cast {
103        expr: Box<BoundExpression>,
104        target_type: LogicalType,
105    },
106    HasLabel {
107        expr: Box<BoundExpression>,
108        table_ids: Vec<TableId>,
109    },
110}
111
112impl BoundExpression {
113    /// Get the pre-computed result type of this expression.
114    pub fn result_type(&self) -> &LogicalType {
115        match self {
116            Self::Literal { result_type, .. }
117            | Self::Variable { result_type, .. }
118            | Self::Property { result_type, .. }
119            | Self::Parameter { result_type, .. }
120            | Self::UnaryOp { result_type, .. }
121            | Self::BinaryOp { result_type, .. }
122            | Self::FunctionCall { result_type, .. }
123            | Self::Case { result_type, .. }
124            | Self::ListLiteral { result_type, .. }
125            | Self::MapLiteral { result_type, .. }
126            | Self::Subscript { result_type, .. }
127            | Self::Slice { result_type, .. } => result_type,
128
129            Self::Comparison { .. }
130            | Self::IsNull { .. }
131            | Self::InList { .. }
132            | Self::StringOp { .. }
133            | Self::HasLabel { .. } => &LogicalType::Bool,
134
135            Self::CountStar => &LogicalType::Int64,
136
137            Self::Cast { target_type, .. } => target_type,
138        }
139    }
140
141    /// Whether this expression is a constant (literal or constant-folded).
142    pub fn is_constant(&self) -> bool {
143        matches!(self, Self::Literal { .. })
144    }
145
146    /// Whether this expression is an aggregate function call or COUNT(*).
147    pub fn is_aggregate(&self) -> bool {
148        matches!(self, Self::CountStar | Self::FunctionCall { distinct: true, .. })
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155
156    fn lit_int(v: i64) -> BoundExpression {
157        BoundExpression::Literal {
158            value: TypedValue::Int64(v),
159            result_type: LogicalType::Int64,
160        }
161    }
162
163    fn lit_str(s: &str) -> BoundExpression {
164        BoundExpression::Literal {
165            value: TypedValue::String(SmolStr::new(s)),
166            result_type: LogicalType::String,
167        }
168    }
169
170    fn lit_bool(v: bool) -> BoundExpression {
171        BoundExpression::Literal {
172            value: TypedValue::Bool(v),
173            result_type: LogicalType::Bool,
174        }
175    }
176
177    #[test]
178    fn literal_result_type() {
179        assert_eq!(lit_int(42).result_type(), &LogicalType::Int64);
180        assert_eq!(lit_str("hello").result_type(), &LogicalType::String);
181        assert_eq!(lit_bool(true).result_type(), &LogicalType::Bool);
182    }
183
184    #[test]
185    fn variable_result_type() {
186        let var = BoundExpression::Variable {
187            index: 0,
188            result_type: LogicalType::Node,
189        };
190        assert_eq!(var.result_type(), &LogicalType::Node);
191    }
192
193    #[test]
194    fn comparison_always_bool() {
195        let cmp = BoundExpression::Comparison {
196            op: ComparisonOp::Gt,
197            left: Box::new(lit_int(1)),
198            right: Box::new(lit_int(2)),
199        };
200        assert_eq!(cmp.result_type(), &LogicalType::Bool);
201    }
202
203    #[test]
204    fn is_null_always_bool() {
205        let expr = BoundExpression::IsNull {
206            expr: Box::new(lit_int(1)),
207            negated: false,
208        };
209        assert_eq!(expr.result_type(), &LogicalType::Bool);
210    }
211
212    #[test]
213    fn string_op_always_bool() {
214        let expr = BoundExpression::StringOp {
215            op: StringOp::StartsWith,
216            left: Box::new(lit_str("hello")),
217            right: Box::new(lit_str("he")),
218        };
219        assert_eq!(expr.result_type(), &LogicalType::Bool);
220    }
221
222    #[test]
223    fn count_star_always_int64() {
224        assert_eq!(BoundExpression::CountStar.result_type(), &LogicalType::Int64);
225    }
226
227    #[test]
228    fn cast_result_type() {
229        let expr = BoundExpression::Cast {
230            expr: Box::new(lit_int(42)),
231            target_type: LogicalType::Double,
232        };
233        assert_eq!(expr.result_type(), &LogicalType::Double);
234    }
235
236    #[test]
237    fn is_constant() {
238        assert!(lit_int(1).is_constant());
239        assert!(!BoundExpression::Variable {
240            index: 0,
241            result_type: LogicalType::Int64,
242        }
243        .is_constant());
244    }
245
246    #[test]
247    fn has_label_always_bool() {
248        let expr = BoundExpression::HasLabel {
249            expr: Box::new(BoundExpression::Variable {
250                index: 0,
251                result_type: LogicalType::Node,
252            }),
253            table_ids: vec![TableId(1)],
254        };
255        assert_eq!(expr.result_type(), &LogicalType::Bool);
256    }
257
258    #[test]
259    fn function_call_result_type() {
260        let expr = BoundExpression::FunctionCall {
261            function_id: FunctionId(0),
262            function_name: SmolStr::new("abs"),
263            args: vec![lit_int(-5)],
264            distinct: false,
265            result_type: LogicalType::Int64,
266        };
267        assert_eq!(expr.result_type(), &LogicalType::Int64);
268    }
269}