1use crate::{
2 base::{database::ColumnType, math::decimal::Precision},
3 sql::{
4 proof_exprs::{DynProofExpr, ProofExpr},
5 AnalyzeError, AnalyzeResult,
6 },
7};
8use alloc::string::ToString;
9use core::cmp::Ordering;
10
11#[expect(clippy::missing_panics_doc, reason = "Precision can not be invalid")]
14fn decimal_scale_cast_expr(
15 from_proof_expr: DynProofExpr,
16 from_scale: i8,
17 to_scale: i8,
18) -> AnalyzeResult<DynProofExpr> {
19 if !from_proof_expr.data_type().is_numeric() {
20 return Err(AnalyzeError::DataTypeMismatch {
21 left_type: from_proof_expr.data_type().to_string(),
22 right_type: "Some numeric type".to_string(),
23 });
24 }
25 let from_precision_value = from_proof_expr.data_type().precision_value().unwrap_or(0);
26 let to_precision_value = u8::try_from(
27 i16::from(from_precision_value) + i16::from(to_scale - from_scale).min(75_i16),
28 )
29 .expect("Precision is definitely valid");
30 DynProofExpr::try_new_scaling_cast(
31 from_proof_expr,
32 ColumnType::Decimal75(
33 Precision::new(to_precision_value).expect("Precision is definitely valid"),
34 to_scale,
35 ),
36 )
37}
38
39pub fn scale_cast_binary_op(
44 left_proof_expr: DynProofExpr,
45 right_proof_expr: DynProofExpr,
46) -> AnalyzeResult<(DynProofExpr, DynProofExpr)> {
47 let left_type = left_proof_expr.data_type();
48 let right_type = right_proof_expr.data_type();
49 let left_scale = left_type.scale().unwrap_or(0);
50 let right_scale = right_type.scale().unwrap_or(0);
51 let scale = left_scale.max(right_scale);
52 match left_scale.cmp(&right_scale) {
53 Ordering::Less => Ok((
54 if matches!(left_type, ColumnType::TimestampTZ(_, _)) {
55 DynProofExpr::try_new_scaling_cast(left_proof_expr, right_type)?
56 } else {
57 decimal_scale_cast_expr(left_proof_expr, left_scale, scale)?
58 },
59 right_proof_expr,
60 )),
61 Ordering::Greater => Ok((
62 left_proof_expr,
63 if matches!(right_type, ColumnType::TimestampTZ(_, _)) {
64 DynProofExpr::try_new_scaling_cast(right_proof_expr, left_type)?
65 } else {
66 decimal_scale_cast_expr(right_proof_expr, right_scale, scale)?
67 },
68 )),
69 Ordering::Equal => Ok((left_proof_expr, right_proof_expr)),
70 }
71}
72
73#[cfg(test)]
74mod tests {
75 use super::*;
76 use crate::base::database::{ColumnRef, TableRef};
77
78 #[expect(non_snake_case)]
79 fn COLUMN1_BOOLEAN() -> DynProofExpr {
80 DynProofExpr::new_column(ColumnRef::new(
81 TableRef::from_names(Some("namespace"), "table_name"),
82 "column1".into(),
83 ColumnType::Boolean,
84 ))
85 }
86
87 #[expect(non_snake_case)]
88 fn COLUMN1_SMALLINT() -> DynProofExpr {
89 DynProofExpr::new_column(ColumnRef::new(
90 TableRef::from_names(Some("namespace"), "table_name"),
91 "column1".into(),
92 ColumnType::SmallInt,
93 ))
94 }
95
96 #[expect(non_snake_case)]
97 fn COLUMN1_DECIMAL_3_MINUS2() -> DynProofExpr {
98 DynProofExpr::new_column(ColumnRef::new(
99 TableRef::from_names(Some("namespace"), "table_name"),
100 "column1".into(),
101 ColumnType::Decimal75(
102 Precision::new(3).expect("Precision is definitely valid"),
103 -2,
104 ),
105 ))
106 }
107
108 #[expect(non_snake_case)]
109 fn COLUMN1_DECIMAL_10_5() -> DynProofExpr {
110 DynProofExpr::new_column(ColumnRef::new(
111 TableRef::from_names(Some("namespace"), "table_name"),
112 "column1".into(),
113 ColumnType::Decimal75(
114 Precision::new(10).expect("Precision is definitely valid"),
115 5,
116 ),
117 ))
118 }
119
120 #[expect(non_snake_case)]
121 fn COLUMN3_DECIMAL_75_10() -> DynProofExpr {
122 DynProofExpr::new_column(ColumnRef::new(
123 TableRef::from_names(Some("namespace"), "table_name"),
124 "column3".into(),
125 ColumnType::Decimal75(
126 Precision::new(75).expect("Precision is definitely valid"),
127 10,
128 ),
129 ))
130 }
131
132 #[expect(non_snake_case)]
133 fn COLUMN2_DECIMAL_25_5() -> DynProofExpr {
134 DynProofExpr::new_column(ColumnRef::new(
135 TableRef::from_names(Some("namespace"), "table_name"),
136 "column2".into(),
137 ColumnType::Decimal75(
138 Precision::new(25).expect("Precision is definitely valid"),
139 5,
140 ),
141 ))
142 }
143
144 #[test]
146 fn we_can_convert_decimal_scale_cast_expr() {
147 let expr = COLUMN1_SMALLINT();
148 let scale = 0;
149 let to_scale = 5;
150 let proof_expr = decimal_scale_cast_expr(expr, scale, to_scale).unwrap();
151 assert_eq!(
152 proof_expr,
153 DynProofExpr::try_new_scaling_cast(
154 COLUMN1_SMALLINT(),
155 ColumnType::Decimal75(
156 Precision::new(10).expect("Precision is definitely valid"),
157 5
158 )
159 )
160 .unwrap()
161 );
162 }
163
164 #[test]
165 fn we_cannot_convert_nonnumeric_types_using_decimal_scale_cast_expr() {
166 let expr = COLUMN1_BOOLEAN();
167 let scale = 0;
168 let to_scale = 5;
169 let proof_expr = decimal_scale_cast_expr(expr, scale, to_scale);
170 assert!(matches!(
171 proof_expr,
172 Err(AnalyzeError::DataTypeMismatch { .. })
173 ));
174 }
175
176 #[test]
178 fn we_can_convert_scale_cast_binary_op_upcasting_left() {
179 let left_array = [
180 COLUMN1_SMALLINT(),
181 COLUMN1_DECIMAL_10_5(),
182 COLUMN1_DECIMAL_3_MINUS2(),
183 ];
184 let right = COLUMN3_DECIMAL_75_10();
185 for left in left_array {
186 let proof_exprs = scale_cast_binary_op(left.clone(), right.clone()).unwrap();
187 assert_eq!(
188 proof_exprs,
189 (
190 DynProofExpr::try_new_scaling_cast(
191 left,
192 ColumnType::Decimal75(
193 Precision::new(15).expect("Precision is definitely valid"),
194 10
195 )
196 )
197 .unwrap(),
198 COLUMN3_DECIMAL_75_10()
199 )
200 );
201 }
202 }
203
204 #[test]
205 fn we_can_convert_scale_cast_binary_op_upcasting_right() {
206 let left = COLUMN3_DECIMAL_75_10();
207 let right_array = [
208 COLUMN1_SMALLINT(),
209 COLUMN1_DECIMAL_10_5(),
210 COLUMN1_DECIMAL_3_MINUS2(),
211 ];
212 for right in right_array {
213 let proof_exprs = scale_cast_binary_op(left.clone(), right.clone()).unwrap();
214 assert_eq!(
215 proof_exprs,
216 (
217 COLUMN3_DECIMAL_75_10(),
218 DynProofExpr::try_new_scaling_cast(
219 right,
220 ColumnType::Decimal75(
221 Precision::new(15).expect("Precision is definitely valid"),
222 10
223 )
224 )
225 .unwrap()
226 )
227 );
228 }
229 }
230
231 #[test]
232 fn we_can_convert_scale_cast_binary_op_equal() {
233 let left = COLUMN1_DECIMAL_10_5();
234 let right = COLUMN2_DECIMAL_25_5();
235 let proof_exprs = scale_cast_binary_op(left.clone(), right.clone()).unwrap();
236 assert_eq!(proof_exprs, (left, right));
237 }
238}