datafusion_comet_spark_expr/math_funcs/internal/
checkoverflow.rs1use arrow::datatypes::{DataType, Schema};
19use arrow::{
20 array::{as_primitive_array, Array, ArrayRef, Decimal128Array},
21 datatypes::{Decimal128Type, DecimalType},
22 record_batch::RecordBatch,
23};
24use datafusion::common::{DataFusionError, ScalarValue};
25use datafusion::logical_expr::ColumnarValue;
26use datafusion::physical_expr::PhysicalExpr;
27use std::hash::Hash;
28use std::{
29 any::Any,
30 fmt::{Display, Formatter},
31 sync::Arc,
32};
33
34#[derive(Debug, Eq)]
39pub struct CheckOverflow {
40 pub child: Arc<dyn PhysicalExpr>,
41 pub data_type: DataType,
42 pub fail_on_error: bool,
43}
44
45impl Hash for CheckOverflow {
46 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
47 self.child.hash(state);
48 self.data_type.hash(state);
49 self.fail_on_error.hash(state);
50 }
51}
52
53impl PartialEq for CheckOverflow {
54 fn eq(&self, other: &Self) -> bool {
55 self.child.eq(&other.child)
56 && self.data_type.eq(&other.data_type)
57 && self.fail_on_error.eq(&other.fail_on_error)
58 }
59}
60
61impl CheckOverflow {
62 pub fn new(child: Arc<dyn PhysicalExpr>, data_type: DataType, fail_on_error: bool) -> Self {
63 Self {
64 child,
65 data_type,
66 fail_on_error,
67 }
68 }
69}
70
71impl Display for CheckOverflow {
72 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
73 write!(
74 f,
75 "CheckOverflow [datatype: {}, fail_on_error: {}, child: {}]",
76 self.data_type, self.fail_on_error, self.child
77 )
78 }
79}
80
81impl PhysicalExpr for CheckOverflow {
82 fn as_any(&self) -> &dyn Any {
83 self
84 }
85
86 fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result {
87 unimplemented!()
88 }
89
90 fn data_type(&self, _: &Schema) -> datafusion::common::Result<DataType> {
91 Ok(self.data_type.clone())
92 }
93
94 fn nullable(&self, _: &Schema) -> datafusion::common::Result<bool> {
95 Ok(true)
96 }
97
98 fn evaluate(&self, batch: &RecordBatch) -> datafusion::common::Result<ColumnarValue> {
99 let arg = self.child.evaluate(batch)?;
100 match arg {
101 ColumnarValue::Array(array)
102 if matches!(array.data_type(), DataType::Decimal128(_, _)) =>
103 {
104 let (precision, scale) = match &self.data_type {
105 DataType::Decimal128(p, s) => (p, s),
106 dt => {
107 return Err(DataFusionError::Execution(format!(
108 "CheckOverflow expects only Decimal128, but got {:?}",
109 dt
110 )))
111 }
112 };
113
114 let decimal_array = as_primitive_array::<Decimal128Type>(&array);
115
116 let casted_array = if self.fail_on_error {
117 decimal_array.validate_decimal_precision(*precision)?;
119 decimal_array
120 } else {
121 &decimal_array.null_if_overflow_precision(*precision)
123 };
124
125 let new_array = Decimal128Array::from(casted_array.into_data())
126 .with_precision_and_scale(*precision, *scale)
127 .map(|a| Arc::new(a) as ArrayRef)?;
128
129 Ok(ColumnarValue::Array(new_array))
130 }
131 ColumnarValue::Scalar(ScalarValue::Decimal128(v, precision, scale)) => {
132 assert!(
135 !self.fail_on_error,
136 "fail_on_error (ANSI mode) is not supported yet"
137 );
138
139 let new_v: Option<i128> = v.and_then(|v| {
140 Decimal128Type::validate_decimal_precision(v, precision)
141 .map(|_| v)
142 .ok()
143 });
144
145 Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(
146 new_v, precision, scale,
147 )))
148 }
149 v => Err(DataFusionError::Execution(format!(
150 "CheckOverflow's child expression should be decimal array, but found {:?}",
151 v
152 ))),
153 }
154 }
155
156 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
157 vec![&self.child]
158 }
159
160 fn with_new_children(
161 self: Arc<Self>,
162 children: Vec<Arc<dyn PhysicalExpr>>,
163 ) -> datafusion::common::Result<Arc<dyn PhysicalExpr>> {
164 Ok(Arc::new(CheckOverflow::new(
165 Arc::clone(&children[0]),
166 self.data_type.clone(),
167 self.fail_on_error,
168 )))
169 }
170}