llkv_compute/
kernels.rs

1use arrow::array::{Array, ArrayRef, Scalar, new_null_array};
2use arrow::compute::kernels::cmp;
3use arrow::compute::{cast, kernels::numeric, nullif};
4use arrow::datatypes::DataType;
5use llkv_expr::expr::{BinaryOp, CompareOp};
6use llkv_result::Error;
7use std::sync::Arc;
8
9fn numeric_priority(dt: &DataType) -> Option<Numeric> {
10    match dt {
11        DataType::Int8 => Some(Numeric::Signed(8)),
12        DataType::Int16 => Some(Numeric::Signed(16)),
13        DataType::Int32 => Some(Numeric::Signed(32)),
14        DataType::Int64 => Some(Numeric::Signed(64)),
15        DataType::UInt8 => Some(Numeric::Unsigned(8)),
16        DataType::UInt16 => Some(Numeric::Unsigned(16)),
17        DataType::UInt32 => Some(Numeric::Unsigned(32)),
18        DataType::UInt64 => Some(Numeric::Unsigned(64)),
19        DataType::Float32 => Some(Numeric::F32),
20        DataType::Float64 => Some(Numeric::F64),
21        _ => None,
22    }
23}
24
25#[derive(Clone, Copy, Debug, PartialEq, Eq)]
26enum Numeric {
27    Signed(u8),
28    Unsigned(u8),
29    F32,
30    F64,
31}
32
33fn coerce_decimals(lhs: (u8, i8), rhs: (u8, i8)) -> DataType {
34    let scale = lhs.1.max(rhs.1);
35    let lhs_int = i32::from(lhs.0) - i32::from(lhs.1);
36    let rhs_int = i32::from(rhs.0) - i32::from(rhs.1);
37    let int_digits = lhs_int.max(rhs_int);
38    let precision = (int_digits + i32::from(scale)).clamp(1, 38) as u8;
39    DataType::Decimal128(precision, scale)
40}
41
42pub fn compute_binary(lhs: &ArrayRef, rhs: &ArrayRef, op: BinaryOp) -> Result<ArrayRef, Error> {
43    // Coerce inputs to common type
44    let (lhs_arr, rhs_arr) = coerce_types(lhs, rhs, op)?;
45
46    if lhs_arr.data_type() == &DataType::Null {
47        return Ok(new_null_array(&DataType::Null, lhs_arr.len()));
48    }
49
50    let result_arr: ArrayRef = match op {
51        BinaryOp::Add => {
52            numeric::add(&lhs_arr, &rhs_arr).map_err(|e| Error::Internal(e.to_string()))?
53        }
54        BinaryOp::Subtract => {
55            numeric::sub(&lhs_arr, &rhs_arr).map_err(|e| Error::Internal(e.to_string()))?
56        }
57        BinaryOp::Multiply => {
58            numeric::mul(&lhs_arr, &rhs_arr).map_err(|e| Error::Internal(e.to_string()))?
59        }
60        BinaryOp::Divide => {
61            // Handle division by zero by treating 0s as NULLs
62            let zero = arrow::array::Int64Array::from(vec![0]);
63            let zero = cast(&zero, rhs_arr.data_type())
64                .map_err(|e| Error::Internal(format!("Failed to cast 0: {}", e)))?;
65            let zero_scalar = Scalar::new(zero);
66
67            let is_zero = cmp::eq(&rhs_arr, &zero_scalar)
68                .map_err(|e| Error::Internal(format!("Failed to compare with 0: {}", e)))?;
69
70            let safe_rhs = nullif(&rhs_arr, &is_zero)
71                .map_err(|e| Error::Internal(format!("Failed to nullif zeros: {}", e)))?;
72
73            numeric::div(&lhs_arr, &safe_rhs).map_err(|e| Error::Internal(e.to_string()))?
74        }
75        BinaryOp::Modulo => {
76            numeric::rem(&lhs_arr, &rhs_arr).map_err(|e| Error::Internal(e.to_string()))?
77        }
78        BinaryOp::And => {
79            let lhs_bool =
80                cast(&lhs_arr, &DataType::Boolean).map_err(|e| Error::Internal(e.to_string()))?;
81            let rhs_bool =
82                cast(&rhs_arr, &DataType::Boolean).map_err(|e| Error::Internal(e.to_string()))?;
83            let lhs_bool = lhs_bool
84                .as_any()
85                .downcast_ref::<arrow::array::BooleanArray>()
86                .unwrap();
87            let rhs_bool = rhs_bool
88                .as_any()
89                .downcast_ref::<arrow::array::BooleanArray>()
90                .unwrap();
91            let result = arrow::compute::kernels::boolean::and(lhs_bool, rhs_bool)
92                .map_err(|e| Error::Internal(e.to_string()))?;
93            Arc::new(result)
94        }
95        BinaryOp::Or => {
96            let lhs_bool =
97                cast(&lhs_arr, &DataType::Boolean).map_err(|e| Error::Internal(e.to_string()))?;
98            let rhs_bool =
99                cast(&rhs_arr, &DataType::Boolean).map_err(|e| Error::Internal(e.to_string()))?;
100            let lhs_bool = lhs_bool
101                .as_any()
102                .downcast_ref::<arrow::array::BooleanArray>()
103                .unwrap();
104            let rhs_bool = rhs_bool
105                .as_any()
106                .downcast_ref::<arrow::array::BooleanArray>()
107                .unwrap();
108            let result = arrow::compute::kernels::boolean::or(lhs_bool, rhs_bool)
109                .map_err(|e| Error::Internal(e.to_string()))?;
110            Arc::new(result)
111        }
112        _ => return Err(Error::Internal(format!("Unsupported binary op: {:?}", op))),
113    };
114
115    Ok(result_arr)
116}
117
118pub fn get_common_type(lhs_type: &DataType, rhs_type: &DataType) -> DataType {
119    if lhs_type == rhs_type {
120        return lhs_type.clone();
121    }
122
123    match (lhs_type, rhs_type) {
124        (DataType::Null, other) | (other, DataType::Null) => other.clone(),
125        (DataType::Boolean, DataType::Boolean) => DataType::Boolean,
126        (DataType::Decimal128(lp, ls), DataType::Decimal128(rp, rs)) => {
127            coerce_decimals((*lp, *ls), (*rp, *rs))
128        }
129        (DataType::Decimal128(p, s), other) | (other, DataType::Decimal128(p, s)) => match other {
130            DataType::Float64 | DataType::Float32 => DataType::Float64,
131            DataType::Int8
132            | DataType::Int16
133            | DataType::Int32
134            | DataType::Int64
135            | DataType::UInt8
136            | DataType::UInt16
137            | DataType::UInt32
138            | DataType::UInt64 => coerce_decimals((*p, *s), (38u8, 0)),
139            _ => DataType::Float64,
140        },
141        _ => match (numeric_priority(lhs_type), numeric_priority(rhs_type)) {
142            (Some(Numeric::F64), _) | (_, Some(Numeric::F64)) => DataType::Float64,
143            (Some(Numeric::F32), _) | (_, Some(Numeric::F32)) => DataType::Float64,
144            (Some(Numeric::Signed(lhs)), Some(Numeric::Unsigned(rhs)))
145            | (Some(Numeric::Unsigned(lhs)), Some(Numeric::Signed(rhs))) => {
146                let max = std::cmp::max(lhs, rhs);
147                if max >= 64 {
148                    DataType::Float64
149                } else {
150                    DataType::Int64
151                }
152            }
153            (Some(Numeric::Signed(lhs)), Some(Numeric::Signed(rhs))) => {
154                if lhs >= 64 || rhs >= 64 {
155                    DataType::Int64
156                } else if lhs >= 32 || rhs >= 32 {
157                    DataType::Int32
158                } else if lhs >= 16 || rhs >= 16 {
159                    DataType::Int16
160                } else {
161                    DataType::Int8
162                }
163            }
164            (Some(Numeric::Unsigned(lhs)), Some(Numeric::Unsigned(rhs))) => {
165                if lhs >= 64 || rhs >= 64 {
166                    DataType::UInt64
167                } else if lhs >= 32 || rhs >= 32 {
168                    DataType::UInt32
169                } else if lhs >= 16 || rhs >= 16 {
170                    DataType::UInt16
171                } else {
172                    DataType::UInt8
173                }
174            }
175            _ => DataType::Float64,
176        },
177    }
178}
179
180/// Common type for a binary operator.
181pub fn common_type_for_op(lhs_type: &DataType, rhs_type: &DataType, _op: BinaryOp) -> DataType {
182    get_common_type(lhs_type, rhs_type)
183}
184
185pub fn coerce_types(
186    lhs: &ArrayRef,
187    rhs: &ArrayRef,
188    op: BinaryOp,
189) -> Result<(ArrayRef, ArrayRef), Error> {
190    let lhs_type = lhs.data_type();
191    let rhs_type = rhs.data_type();
192
193    let target_type = common_type_for_op(lhs_type, rhs_type, op);
194
195    if lhs_type == rhs_type && lhs_type == &target_type {
196        return Ok((lhs.clone(), rhs.clone()));
197    }
198
199    let lhs_casted = cast(lhs, &target_type).map_err(|e| Error::Internal(e.to_string()))?;
200    let rhs_casted = cast(rhs, &target_type).map_err(|e| Error::Internal(e.to_string()))?;
201
202    Ok((lhs_casted, rhs_casted))
203}
204
205pub fn compute_compare(lhs: &ArrayRef, op: CompareOp, rhs: &ArrayRef) -> Result<ArrayRef, Error> {
206    // Coerce inputs to common type for comparison
207    // We can reuse coerce_types logic or similar.
208    // For comparison, we usually want common type.
209    // We can pass a dummy BinaryOp to coerce_types.
210    let (lhs_arr, rhs_arr) = coerce_types(lhs, rhs, BinaryOp::Add)?;
211
212    let result_arr: ArrayRef = match op {
213        CompareOp::Eq => {
214            Arc::new(cmp::eq(&lhs_arr, &rhs_arr).map_err(|e| Error::Internal(e.to_string()))?)
215        }
216        CompareOp::NotEq => {
217            Arc::new(cmp::neq(&lhs_arr, &rhs_arr).map_err(|e| Error::Internal(e.to_string()))?)
218        }
219        CompareOp::Lt => {
220            Arc::new(cmp::lt(&lhs_arr, &rhs_arr).map_err(|e| Error::Internal(e.to_string()))?)
221        }
222        CompareOp::LtEq => {
223            Arc::new(cmp::lt_eq(&lhs_arr, &rhs_arr).map_err(|e| Error::Internal(e.to_string()))?)
224        }
225        CompareOp::Gt => {
226            Arc::new(cmp::gt(&lhs_arr, &rhs_arr).map_err(|e| Error::Internal(e.to_string()))?)
227        }
228        CompareOp::GtEq => {
229            Arc::new(cmp::gt_eq(&lhs_arr, &rhs_arr).map_err(|e| Error::Internal(e.to_string()))?)
230        }
231    };
232    Ok(result_arr)
233}