Skip to main content

kyu_expression/
bound_expr.rs

1//! Bound expression types — resolved, typed expressions ready for planning/execution.
2//!
3//! Every variant carries its pre-computed `result_type: LogicalType` so type
4//! queries are O(1) field access. Variables are referenced by u32 index (not
5//! string) for zero string comparison on the hot path.
6
7use kyu_common::id::{PropertyId, TableId};
8use kyu_parser::ast::{BinaryOp, ComparisonOp, StringOp, UnaryOp};
9use kyu_types::{LogicalType, TypedValue};
10use smol_str::SmolStr;
11
12/// Unique function identifier for O(1) registry lookup after resolution.
13#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
14pub struct FunctionId(pub u32);
15
16/// A resolved, typed expression. All names are resolved to IDs,
17/// all types are pre-computed. Consumed by the planner and executor.
18#[derive(Clone, Debug)]
19pub enum BoundExpression {
20    Literal {
21        value: TypedValue,
22        result_type: LogicalType,
23    },
24    Variable {
25        index: u32,
26        result_type: LogicalType,
27    },
28    Property {
29        object: Box<BoundExpression>,
30        property_id: PropertyId,
31        property_name: SmolStr,
32        result_type: LogicalType,
33    },
34    Parameter {
35        name: SmolStr,
36        index: u32,
37        result_type: LogicalType,
38    },
39    UnaryOp {
40        op: UnaryOp,
41        operand: Box<BoundExpression>,
42        result_type: LogicalType,
43    },
44    BinaryOp {
45        op: BinaryOp,
46        left: Box<BoundExpression>,
47        right: Box<BoundExpression>,
48        result_type: LogicalType,
49    },
50    Comparison {
51        op: ComparisonOp,
52        left: Box<BoundExpression>,
53        right: Box<BoundExpression>,
54    },
55    IsNull {
56        expr: Box<BoundExpression>,
57        negated: bool,
58    },
59    InList {
60        expr: Box<BoundExpression>,
61        list: Vec<BoundExpression>,
62        negated: bool,
63    },
64    FunctionCall {
65        function_id: FunctionId,
66        function_name: SmolStr,
67        args: Vec<BoundExpression>,
68        distinct: bool,
69        result_type: LogicalType,
70    },
71    CountStar,
72    Case {
73        operand: Option<Box<BoundExpression>>,
74        whens: Vec<(BoundExpression, BoundExpression)>,
75        else_expr: Option<Box<BoundExpression>>,
76        result_type: LogicalType,
77    },
78    ListLiteral {
79        elements: Vec<BoundExpression>,
80        result_type: LogicalType,
81    },
82    MapLiteral {
83        entries: Vec<(BoundExpression, BoundExpression)>,
84        result_type: LogicalType,
85    },
86    Subscript {
87        expr: Box<BoundExpression>,
88        index: Box<BoundExpression>,
89        result_type: LogicalType,
90    },
91    Slice {
92        expr: Box<BoundExpression>,
93        from: Option<Box<BoundExpression>>,
94        to: Option<Box<BoundExpression>>,
95        result_type: LogicalType,
96    },
97    StringOp {
98        op: StringOp,
99        left: Box<BoundExpression>,
100        right: Box<BoundExpression>,
101    },
102    Cast {
103        expr: Box<BoundExpression>,
104        target_type: LogicalType,
105    },
106    HasLabel {
107        expr: Box<BoundExpression>,
108        table_ids: Vec<TableId>,
109    },
110}
111
112impl BoundExpression {
113    /// Get the pre-computed result type of this expression.
114    pub fn result_type(&self) -> &LogicalType {
115        match self {
116            Self::Literal { result_type, .. }
117            | Self::Variable { result_type, .. }
118            | Self::Property { result_type, .. }
119            | Self::Parameter { result_type, .. }
120            | Self::UnaryOp { result_type, .. }
121            | Self::BinaryOp { result_type, .. }
122            | Self::FunctionCall { result_type, .. }
123            | Self::Case { result_type, .. }
124            | Self::ListLiteral { result_type, .. }
125            | Self::MapLiteral { result_type, .. }
126            | Self::Subscript { result_type, .. }
127            | Self::Slice { result_type, .. } => result_type,
128
129            Self::Comparison { .. }
130            | Self::IsNull { .. }
131            | Self::InList { .. }
132            | Self::StringOp { .. }
133            | Self::HasLabel { .. } => &LogicalType::Bool,
134
135            Self::CountStar => &LogicalType::Int64,
136
137            Self::Cast { target_type, .. } => target_type,
138        }
139    }
140
141    /// Whether this expression is a constant (literal or constant-folded).
142    pub fn is_constant(&self) -> bool {
143        matches!(self, Self::Literal { .. })
144    }
145
146    /// Whether this expression is an aggregate function call or COUNT(*).
147    pub fn is_aggregate(&self) -> bool {
148        matches!(
149            self,
150            Self::CountStar | Self::FunctionCall { distinct: true, .. }
151        )
152    }
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158
159    fn lit_int(v: i64) -> BoundExpression {
160        BoundExpression::Literal {
161            value: TypedValue::Int64(v),
162            result_type: LogicalType::Int64,
163        }
164    }
165
166    fn lit_str(s: &str) -> BoundExpression {
167        BoundExpression::Literal {
168            value: TypedValue::String(SmolStr::new(s)),
169            result_type: LogicalType::String,
170        }
171    }
172
173    fn lit_bool(v: bool) -> BoundExpression {
174        BoundExpression::Literal {
175            value: TypedValue::Bool(v),
176            result_type: LogicalType::Bool,
177        }
178    }
179
180    #[test]
181    fn literal_result_type() {
182        assert_eq!(lit_int(42).result_type(), &LogicalType::Int64);
183        assert_eq!(lit_str("hello").result_type(), &LogicalType::String);
184        assert_eq!(lit_bool(true).result_type(), &LogicalType::Bool);
185    }
186
187    #[test]
188    fn variable_result_type() {
189        let var = BoundExpression::Variable {
190            index: 0,
191            result_type: LogicalType::Node,
192        };
193        assert_eq!(var.result_type(), &LogicalType::Node);
194    }
195
196    #[test]
197    fn comparison_always_bool() {
198        let cmp = BoundExpression::Comparison {
199            op: ComparisonOp::Gt,
200            left: Box::new(lit_int(1)),
201            right: Box::new(lit_int(2)),
202        };
203        assert_eq!(cmp.result_type(), &LogicalType::Bool);
204    }
205
206    #[test]
207    fn is_null_always_bool() {
208        let expr = BoundExpression::IsNull {
209            expr: Box::new(lit_int(1)),
210            negated: false,
211        };
212        assert_eq!(expr.result_type(), &LogicalType::Bool);
213    }
214
215    #[test]
216    fn string_op_always_bool() {
217        let expr = BoundExpression::StringOp {
218            op: StringOp::StartsWith,
219            left: Box::new(lit_str("hello")),
220            right: Box::new(lit_str("he")),
221        };
222        assert_eq!(expr.result_type(), &LogicalType::Bool);
223    }
224
225    #[test]
226    fn count_star_always_int64() {
227        assert_eq!(
228            BoundExpression::CountStar.result_type(),
229            &LogicalType::Int64
230        );
231    }
232
233    #[test]
234    fn cast_result_type() {
235        let expr = BoundExpression::Cast {
236            expr: Box::new(lit_int(42)),
237            target_type: LogicalType::Double,
238        };
239        assert_eq!(expr.result_type(), &LogicalType::Double);
240    }
241
242    #[test]
243    fn is_constant() {
244        assert!(lit_int(1).is_constant());
245        assert!(
246            !BoundExpression::Variable {
247                index: 0,
248                result_type: LogicalType::Int64,
249            }
250            .is_constant()
251        );
252    }
253
254    #[test]
255    fn has_label_always_bool() {
256        let expr = BoundExpression::HasLabel {
257            expr: Box::new(BoundExpression::Variable {
258                index: 0,
259                result_type: LogicalType::Node,
260            }),
261            table_ids: vec![TableId(1)],
262        };
263        assert_eq!(expr.result_type(), &LogicalType::Bool);
264    }
265
266    #[test]
267    fn function_call_result_type() {
268        let expr = BoundExpression::FunctionCall {
269            function_id: FunctionId(0),
270            function_name: SmolStr::new("abs"),
271            args: vec![lit_int(-5)],
272            distinct: false,
273            result_type: LogicalType::Int64,
274        };
275        assert_eq!(expr.result_type(), &LogicalType::Int64);
276    }
277}