1use arrow::array::{Array, ArrayRef, Scalar, new_null_array};
2use arrow::compute::kernels::cmp;
3use arrow::compute::{cast, kernels::numeric, nullif};
4use arrow::datatypes::DataType;
5use llkv_expr::expr::{BinaryOp, CompareOp};
6use llkv_result::Error;
7use std::sync::Arc;
8
9fn numeric_priority(dt: &DataType) -> Option<Numeric> {
10 match dt {
11 DataType::Int8 => Some(Numeric::Signed(8)),
12 DataType::Int16 => Some(Numeric::Signed(16)),
13 DataType::Int32 => Some(Numeric::Signed(32)),
14 DataType::Int64 => Some(Numeric::Signed(64)),
15 DataType::UInt8 => Some(Numeric::Unsigned(8)),
16 DataType::UInt16 => Some(Numeric::Unsigned(16)),
17 DataType::UInt32 => Some(Numeric::Unsigned(32)),
18 DataType::UInt64 => Some(Numeric::Unsigned(64)),
19 DataType::Float32 => Some(Numeric::F32),
20 DataType::Float64 => Some(Numeric::F64),
21 _ => None,
22 }
23}
24
25#[derive(Clone, Copy, Debug, PartialEq, Eq)]
26enum Numeric {
27 Signed(u8),
28 Unsigned(u8),
29 F32,
30 F64,
31}
32
33fn coerce_decimals(lhs: (u8, i8), rhs: (u8, i8)) -> DataType {
34 let scale = lhs.1.max(rhs.1);
35 let lhs_int = i32::from(lhs.0) - i32::from(lhs.1);
36 let rhs_int = i32::from(rhs.0) - i32::from(rhs.1);
37 let int_digits = lhs_int.max(rhs_int);
38 let precision = (int_digits + i32::from(scale)).clamp(1, 38) as u8;
39 DataType::Decimal128(precision, scale)
40}
41
42pub fn compute_binary(lhs: &ArrayRef, rhs: &ArrayRef, op: BinaryOp) -> Result<ArrayRef, Error> {
43 let (lhs_arr, rhs_arr) = coerce_types(lhs, rhs, op)?;
45
46 if lhs_arr.data_type() == &DataType::Null {
47 return Ok(new_null_array(&DataType::Null, lhs_arr.len()));
48 }
49
50 let result_arr: ArrayRef = match op {
51 BinaryOp::Add => {
52 numeric::add(&lhs_arr, &rhs_arr).map_err(|e| Error::Internal(e.to_string()))?
53 }
54 BinaryOp::Subtract => {
55 numeric::sub(&lhs_arr, &rhs_arr).map_err(|e| Error::Internal(e.to_string()))?
56 }
57 BinaryOp::Multiply => {
58 numeric::mul(&lhs_arr, &rhs_arr).map_err(|e| Error::Internal(e.to_string()))?
59 }
60 BinaryOp::Divide => {
61 let zero = arrow::array::Int64Array::from(vec![0]);
63 let zero = cast(&zero, rhs_arr.data_type())
64 .map_err(|e| Error::Internal(format!("Failed to cast 0: {}", e)))?;
65 let zero_scalar = Scalar::new(zero);
66
67 let is_zero = cmp::eq(&rhs_arr, &zero_scalar)
68 .map_err(|e| Error::Internal(format!("Failed to compare with 0: {}", e)))?;
69
70 let safe_rhs = nullif(&rhs_arr, &is_zero)
71 .map_err(|e| Error::Internal(format!("Failed to nullif zeros: {}", e)))?;
72
73 numeric::div(&lhs_arr, &safe_rhs).map_err(|e| Error::Internal(e.to_string()))?
74 }
75 BinaryOp::Modulo => {
76 numeric::rem(&lhs_arr, &rhs_arr).map_err(|e| Error::Internal(e.to_string()))?
77 }
78 BinaryOp::And => {
79 let lhs_bool =
80 cast(&lhs_arr, &DataType::Boolean).map_err(|e| Error::Internal(e.to_string()))?;
81 let rhs_bool =
82 cast(&rhs_arr, &DataType::Boolean).map_err(|e| Error::Internal(e.to_string()))?;
83 let lhs_bool = lhs_bool
84 .as_any()
85 .downcast_ref::<arrow::array::BooleanArray>()
86 .unwrap();
87 let rhs_bool = rhs_bool
88 .as_any()
89 .downcast_ref::<arrow::array::BooleanArray>()
90 .unwrap();
91 let result = arrow::compute::kernels::boolean::and(lhs_bool, rhs_bool)
92 .map_err(|e| Error::Internal(e.to_string()))?;
93 Arc::new(result)
94 }
95 BinaryOp::Or => {
96 let lhs_bool =
97 cast(&lhs_arr, &DataType::Boolean).map_err(|e| Error::Internal(e.to_string()))?;
98 let rhs_bool =
99 cast(&rhs_arr, &DataType::Boolean).map_err(|e| Error::Internal(e.to_string()))?;
100 let lhs_bool = lhs_bool
101 .as_any()
102 .downcast_ref::<arrow::array::BooleanArray>()
103 .unwrap();
104 let rhs_bool = rhs_bool
105 .as_any()
106 .downcast_ref::<arrow::array::BooleanArray>()
107 .unwrap();
108 let result = arrow::compute::kernels::boolean::or(lhs_bool, rhs_bool)
109 .map_err(|e| Error::Internal(e.to_string()))?;
110 Arc::new(result)
111 }
112 _ => return Err(Error::Internal(format!("Unsupported binary op: {:?}", op))),
113 };
114
115 Ok(result_arr)
116}
117
118pub fn get_common_type(lhs_type: &DataType, rhs_type: &DataType) -> DataType {
119 if lhs_type == rhs_type {
120 return lhs_type.clone();
121 }
122
123 match (lhs_type, rhs_type) {
124 (DataType::Null, other) | (other, DataType::Null) => other.clone(),
125 (DataType::Boolean, DataType::Boolean) => DataType::Boolean,
126 (DataType::Decimal128(lp, ls), DataType::Decimal128(rp, rs)) => {
127 coerce_decimals((*lp, *ls), (*rp, *rs))
128 }
129 (DataType::Decimal128(p, s), other) | (other, DataType::Decimal128(p, s)) => match other {
130 DataType::Float64 | DataType::Float32 => DataType::Float64,
131 DataType::Int8
132 | DataType::Int16
133 | DataType::Int32
134 | DataType::Int64
135 | DataType::UInt8
136 | DataType::UInt16
137 | DataType::UInt32
138 | DataType::UInt64 => coerce_decimals((*p, *s), (38u8, 0)),
139 _ => DataType::Float64,
140 },
141 _ => match (numeric_priority(lhs_type), numeric_priority(rhs_type)) {
142 (Some(Numeric::F64), _) | (_, Some(Numeric::F64)) => DataType::Float64,
143 (Some(Numeric::F32), _) | (_, Some(Numeric::F32)) => DataType::Float64,
144 (Some(Numeric::Signed(lhs)), Some(Numeric::Unsigned(rhs)))
145 | (Some(Numeric::Unsigned(lhs)), Some(Numeric::Signed(rhs))) => {
146 let max = std::cmp::max(lhs, rhs);
147 if max >= 64 {
148 DataType::Float64
149 } else {
150 DataType::Int64
151 }
152 }
153 (Some(Numeric::Signed(lhs)), Some(Numeric::Signed(rhs))) => {
154 if lhs >= 64 || rhs >= 64 {
155 DataType::Int64
156 } else if lhs >= 32 || rhs >= 32 {
157 DataType::Int32
158 } else if lhs >= 16 || rhs >= 16 {
159 DataType::Int16
160 } else {
161 DataType::Int8
162 }
163 }
164 (Some(Numeric::Unsigned(lhs)), Some(Numeric::Unsigned(rhs))) => {
165 if lhs >= 64 || rhs >= 64 {
166 DataType::UInt64
167 } else if lhs >= 32 || rhs >= 32 {
168 DataType::UInt32
169 } else if lhs >= 16 || rhs >= 16 {
170 DataType::UInt16
171 } else {
172 DataType::UInt8
173 }
174 }
175 _ => DataType::Float64,
176 },
177 }
178}
179
180pub fn common_type_for_op(lhs_type: &DataType, rhs_type: &DataType, _op: BinaryOp) -> DataType {
182 get_common_type(lhs_type, rhs_type)
183}
184
185pub fn coerce_types(
186 lhs: &ArrayRef,
187 rhs: &ArrayRef,
188 op: BinaryOp,
189) -> Result<(ArrayRef, ArrayRef), Error> {
190 let lhs_type = lhs.data_type();
191 let rhs_type = rhs.data_type();
192
193 let target_type = common_type_for_op(lhs_type, rhs_type, op);
194
195 if lhs_type == rhs_type && lhs_type == &target_type {
196 return Ok((lhs.clone(), rhs.clone()));
197 }
198
199 let lhs_casted = cast(lhs, &target_type).map_err(|e| Error::Internal(e.to_string()))?;
200 let rhs_casted = cast(rhs, &target_type).map_err(|e| Error::Internal(e.to_string()))?;
201
202 Ok((lhs_casted, rhs_casted))
203}
204
205pub fn compute_compare(lhs: &ArrayRef, op: CompareOp, rhs: &ArrayRef) -> Result<ArrayRef, Error> {
206 let (lhs_arr, rhs_arr) = coerce_types(lhs, rhs, BinaryOp::Add)?;
211
212 let result_arr: ArrayRef = match op {
213 CompareOp::Eq => {
214 Arc::new(cmp::eq(&lhs_arr, &rhs_arr).map_err(|e| Error::Internal(e.to_string()))?)
215 }
216 CompareOp::NotEq => {
217 Arc::new(cmp::neq(&lhs_arr, &rhs_arr).map_err(|e| Error::Internal(e.to_string()))?)
218 }
219 CompareOp::Lt => {
220 Arc::new(cmp::lt(&lhs_arr, &rhs_arr).map_err(|e| Error::Internal(e.to_string()))?)
221 }
222 CompareOp::LtEq => {
223 Arc::new(cmp::lt_eq(&lhs_arr, &rhs_arr).map_err(|e| Error::Internal(e.to_string()))?)
224 }
225 CompareOp::Gt => {
226 Arc::new(cmp::gt(&lhs_arr, &rhs_arr).map_err(|e| Error::Internal(e.to_string()))?)
227 }
228 CompareOp::GtEq => {
229 Arc::new(cmp::gt_eq(&lhs_arr, &rhs_arr).map_err(|e| Error::Internal(e.to_string()))?)
230 }
231 };
232 Ok(result_arr)
233}