datafusion_comet_spark_expr/math_funcs/internal/
checkoverflow.rs1use arrow::datatypes::{DataType, Schema};
19use arrow::{
20 array::{as_primitive_array, Array, ArrayRef, Decimal128Array},
21 datatypes::{Decimal128Type, DecimalType},
22 record_batch::RecordBatch,
23};
24use datafusion::common::{DataFusionError, ScalarValue};
25use datafusion::logical_expr::ColumnarValue;
26use datafusion::physical_expr::PhysicalExpr;
27use std::hash::Hash;
28use std::{
29 any::Any,
30 fmt::{Display, Formatter},
31 sync::Arc,
32};
33
34#[derive(Debug, Eq)]
39pub struct CheckOverflow {
40 pub child: Arc<dyn PhysicalExpr>,
41 pub data_type: DataType,
42 pub fail_on_error: bool,
43}
44
45impl Hash for CheckOverflow {
46 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
47 self.child.hash(state);
48 self.data_type.hash(state);
49 self.fail_on_error.hash(state);
50 }
51}
52
53impl PartialEq for CheckOverflow {
54 fn eq(&self, other: &Self) -> bool {
55 self.child.eq(&other.child)
56 && self.data_type.eq(&other.data_type)
57 && self.fail_on_error.eq(&other.fail_on_error)
58 }
59}
60
61impl CheckOverflow {
62 pub fn new(child: Arc<dyn PhysicalExpr>, data_type: DataType, fail_on_error: bool) -> Self {
63 Self {
64 child,
65 data_type,
66 fail_on_error,
67 }
68 }
69}
70
71impl Display for CheckOverflow {
72 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
73 write!(
74 f,
75 "CheckOverflow [datatype: {}, fail_on_error: {}, child: {}]",
76 self.data_type, self.fail_on_error, self.child
77 )
78 }
79}
80
81impl PhysicalExpr for CheckOverflow {
82 fn as_any(&self) -> &dyn Any {
83 self
84 }
85
86 fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result {
87 unimplemented!()
88 }
89
90 fn data_type(&self, _: &Schema) -> datafusion::common::Result<DataType> {
91 Ok(self.data_type.clone())
92 }
93
94 fn nullable(&self, _: &Schema) -> datafusion::common::Result<bool> {
95 Ok(true)
96 }
97
98 fn evaluate(&self, batch: &RecordBatch) -> datafusion::common::Result<ColumnarValue> {
99 let arg = self.child.evaluate(batch)?;
100 match arg {
101 ColumnarValue::Array(array)
102 if matches!(array.data_type(), DataType::Decimal128(_, _)) =>
103 {
104 let (precision, scale) = match &self.data_type {
105 DataType::Decimal128(p, s) => (p, s),
106 dt => {
107 return Err(DataFusionError::Execution(format!(
108 "CheckOverflow expects only Decimal128, but got {dt:?}"
109 )))
110 }
111 };
112
113 let decimal_array = as_primitive_array::<Decimal128Type>(&array);
114
115 let casted_array = if self.fail_on_error {
116 decimal_array.validate_decimal_precision(*precision)?;
118 decimal_array
119 } else {
120 &decimal_array.null_if_overflow_precision(*precision)
122 };
123
124 let new_array = Decimal128Array::from(casted_array.into_data())
125 .with_precision_and_scale(*precision, *scale)
126 .map(|a| Arc::new(a) as ArrayRef)?;
127
128 Ok(ColumnarValue::Array(new_array))
129 }
130 ColumnarValue::Scalar(ScalarValue::Decimal128(v, precision, scale)) => {
131 assert!(
134 !self.fail_on_error,
135 "fail_on_error (ANSI mode) is not supported yet"
136 );
137
138 let new_v: Option<i128> = v.and_then(|v| {
139 Decimal128Type::validate_decimal_precision(v, precision)
140 .map(|_| v)
141 .ok()
142 });
143
144 Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(
145 new_v, precision, scale,
146 )))
147 }
148 v => Err(DataFusionError::Execution(format!(
149 "CheckOverflow's child expression should be decimal array, but found {v:?}"
150 ))),
151 }
152 }
153
154 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
155 vec![&self.child]
156 }
157
158 fn with_new_children(
159 self: Arc<Self>,
160 children: Vec<Arc<dyn PhysicalExpr>>,
161 ) -> datafusion::common::Result<Arc<dyn PhysicalExpr>> {
162 Ok(Arc::new(CheckOverflow::new(
163 Arc::clone(&children[0]),
164 self.data_type.clone(),
165 self.fail_on_error,
166 )))
167 }
168}