proof_of_sql/sql/
scale.rs

1use crate::{
2    base::{database::ColumnType, math::decimal::Precision},
3    sql::{
4        proof_exprs::{DynProofExpr, ProofExpr},
5        AnalyzeError, AnalyzeResult,
6    },
7};
8use alloc::string::ToString;
9use core::cmp::Ordering;
10
11/// Add a layer of decimal scaling cast to the expression
12/// so that we can do binary operations on it
13#[expect(clippy::missing_panics_doc, reason = "Precision can not be invalid")]
14fn decimal_scale_cast_expr(
15    from_proof_expr: DynProofExpr,
16    from_scale: i8,
17    to_scale: i8,
18) -> AnalyzeResult<DynProofExpr> {
19    if !from_proof_expr.data_type().is_numeric() {
20        return Err(AnalyzeError::DataTypeMismatch {
21            left_type: from_proof_expr.data_type().to_string(),
22            right_type: "Some numeric type".to_string(),
23        });
24    }
25    let from_precision_value = from_proof_expr.data_type().precision_value().unwrap_or(0);
26    let to_precision_value = u8::try_from(
27        i16::from(from_precision_value) + i16::from(to_scale - from_scale).min(75_i16),
28    )
29    .expect("Precision is definitely valid");
30    DynProofExpr::try_new_scaling_cast(
31        from_proof_expr,
32        ColumnType::Decimal75(
33            Precision::new(to_precision_value).expect("Precision is definitely valid"),
34            to_scale,
35        ),
36    )
37}
38
39/// Scale cast one side so that both sides have the same scale
40///
41/// We use this function so that binary ops for numeric types no longer
42/// need to keep track of scale
43pub fn scale_cast_binary_op(
44    left_proof_expr: DynProofExpr,
45    right_proof_expr: DynProofExpr,
46) -> AnalyzeResult<(DynProofExpr, DynProofExpr)> {
47    let left_type = left_proof_expr.data_type();
48    let right_type = right_proof_expr.data_type();
49    let left_scale = left_type.scale().unwrap_or(0);
50    let right_scale = right_type.scale().unwrap_or(0);
51    let scale = left_scale.max(right_scale);
52    match left_scale.cmp(&right_scale) {
53        Ordering::Less => Ok((
54            if matches!(left_type, ColumnType::TimestampTZ(_, _)) {
55                DynProofExpr::try_new_scaling_cast(left_proof_expr, right_type)?
56            } else {
57                decimal_scale_cast_expr(left_proof_expr, left_scale, scale)?
58            },
59            right_proof_expr,
60        )),
61        Ordering::Greater => Ok((
62            left_proof_expr,
63            if matches!(right_type, ColumnType::TimestampTZ(_, _)) {
64                DynProofExpr::try_new_scaling_cast(right_proof_expr, left_type)?
65            } else {
66                decimal_scale_cast_expr(right_proof_expr, right_scale, scale)?
67            },
68        )),
69        Ordering::Equal => Ok((left_proof_expr, right_proof_expr)),
70    }
71}
72
73#[cfg(test)]
74mod tests {
75    use super::*;
76    use crate::base::database::{ColumnRef, TableRef};
77
78    #[expect(non_snake_case)]
79    fn COLUMN1_BOOLEAN() -> DynProofExpr {
80        DynProofExpr::new_column(ColumnRef::new(
81            TableRef::from_names(Some("namespace"), "table_name"),
82            "column1".into(),
83            ColumnType::Boolean,
84        ))
85    }
86
87    #[expect(non_snake_case)]
88    fn COLUMN1_SMALLINT() -> DynProofExpr {
89        DynProofExpr::new_column(ColumnRef::new(
90            TableRef::from_names(Some("namespace"), "table_name"),
91            "column1".into(),
92            ColumnType::SmallInt,
93        ))
94    }
95
96    #[expect(non_snake_case)]
97    fn COLUMN1_DECIMAL_3_MINUS2() -> DynProofExpr {
98        DynProofExpr::new_column(ColumnRef::new(
99            TableRef::from_names(Some("namespace"), "table_name"),
100            "column1".into(),
101            ColumnType::Decimal75(
102                Precision::new(3).expect("Precision is definitely valid"),
103                -2,
104            ),
105        ))
106    }
107
108    #[expect(non_snake_case)]
109    fn COLUMN1_DECIMAL_10_5() -> DynProofExpr {
110        DynProofExpr::new_column(ColumnRef::new(
111            TableRef::from_names(Some("namespace"), "table_name"),
112            "column1".into(),
113            ColumnType::Decimal75(
114                Precision::new(10).expect("Precision is definitely valid"),
115                5,
116            ),
117        ))
118    }
119
120    #[expect(non_snake_case)]
121    fn COLUMN3_DECIMAL_75_10() -> DynProofExpr {
122        DynProofExpr::new_column(ColumnRef::new(
123            TableRef::from_names(Some("namespace"), "table_name"),
124            "column3".into(),
125            ColumnType::Decimal75(
126                Precision::new(75).expect("Precision is definitely valid"),
127                10,
128            ),
129        ))
130    }
131
132    #[expect(non_snake_case)]
133    fn COLUMN2_DECIMAL_25_5() -> DynProofExpr {
134        DynProofExpr::new_column(ColumnRef::new(
135            TableRef::from_names(Some("namespace"), "table_name"),
136            "column2".into(),
137            ColumnType::Decimal75(
138                Precision::new(25).expect("Precision is definitely valid"),
139                5,
140            ),
141        ))
142    }
143
144    // decimal_scale_cast_expr
145    #[test]
146    fn we_can_convert_decimal_scale_cast_expr() {
147        let expr = COLUMN1_SMALLINT();
148        let scale = 0;
149        let to_scale = 5;
150        let proof_expr = decimal_scale_cast_expr(expr, scale, to_scale).unwrap();
151        assert_eq!(
152            proof_expr,
153            DynProofExpr::try_new_scaling_cast(
154                COLUMN1_SMALLINT(),
155                ColumnType::Decimal75(
156                    Precision::new(10).expect("Precision is definitely valid"),
157                    5
158                )
159            )
160            .unwrap()
161        );
162    }
163
164    #[test]
165    fn we_cannot_convert_nonnumeric_types_using_decimal_scale_cast_expr() {
166        let expr = COLUMN1_BOOLEAN();
167        let scale = 0;
168        let to_scale = 5;
169        let proof_expr = decimal_scale_cast_expr(expr, scale, to_scale);
170        assert!(matches!(
171            proof_expr,
172            Err(AnalyzeError::DataTypeMismatch { .. })
173        ));
174    }
175
176    // scale_cast_binary_op
177    #[test]
178    fn we_can_convert_scale_cast_binary_op_upcasting_left() {
179        let left_array = [
180            COLUMN1_SMALLINT(),
181            COLUMN1_DECIMAL_10_5(),
182            COLUMN1_DECIMAL_3_MINUS2(),
183        ];
184        let right = COLUMN3_DECIMAL_75_10();
185        for left in left_array {
186            let proof_exprs = scale_cast_binary_op(left.clone(), right.clone()).unwrap();
187            assert_eq!(
188                proof_exprs,
189                (
190                    DynProofExpr::try_new_scaling_cast(
191                        left,
192                        ColumnType::Decimal75(
193                            Precision::new(15).expect("Precision is definitely valid"),
194                            10
195                        )
196                    )
197                    .unwrap(),
198                    COLUMN3_DECIMAL_75_10()
199                )
200            );
201        }
202    }
203
204    #[test]
205    fn we_can_convert_scale_cast_binary_op_upcasting_right() {
206        let left = COLUMN3_DECIMAL_75_10();
207        let right_array = [
208            COLUMN1_SMALLINT(),
209            COLUMN1_DECIMAL_10_5(),
210            COLUMN1_DECIMAL_3_MINUS2(),
211        ];
212        for right in right_array {
213            let proof_exprs = scale_cast_binary_op(left.clone(), right.clone()).unwrap();
214            assert_eq!(
215                proof_exprs,
216                (
217                    COLUMN3_DECIMAL_75_10(),
218                    DynProofExpr::try_new_scaling_cast(
219                        right,
220                        ColumnType::Decimal75(
221                            Precision::new(15).expect("Precision is definitely valid"),
222                            10
223                        )
224                    )
225                    .unwrap()
226                )
227            );
228        }
229    }
230
231    #[test]
232    fn we_can_convert_scale_cast_binary_op_equal() {
233        let left = COLUMN1_DECIMAL_10_5();
234        let right = COLUMN2_DECIMAL_25_5();
235        let proof_exprs = scale_cast_binary_op(left.clone(), right.clone()).unwrap();
236        assert_eq!(proof_exprs, (left, right));
237    }
238}