datafusion_comet_spark_expr/math_funcs/internal/
checkoverflow.rs1use arrow::{
19 array::{as_primitive_array, Array, ArrayRef, Decimal128Array},
20 datatypes::{Decimal128Type, DecimalType},
21 record_batch::RecordBatch,
22};
23use arrow_schema::{DataType, Schema};
24use datafusion::logical_expr::ColumnarValue;
25use datafusion_common::{DataFusionError, ScalarValue};
26use datafusion_physical_expr::PhysicalExpr;
27use std::hash::Hash;
28use std::{
29 any::Any,
30 fmt::{Display, Formatter},
31 sync::Arc,
32};
33
34#[derive(Debug, Eq)]
39pub struct CheckOverflow {
40 pub child: Arc<dyn PhysicalExpr>,
41 pub data_type: DataType,
42 pub fail_on_error: bool,
43}
44
45impl Hash for CheckOverflow {
46 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
47 self.child.hash(state);
48 self.data_type.hash(state);
49 self.fail_on_error.hash(state);
50 }
51}
52
53impl PartialEq for CheckOverflow {
54 fn eq(&self, other: &Self) -> bool {
55 self.child.eq(&other.child)
56 && self.data_type.eq(&other.data_type)
57 && self.fail_on_error.eq(&other.fail_on_error)
58 }
59}
60
61impl CheckOverflow {
62 pub fn new(child: Arc<dyn PhysicalExpr>, data_type: DataType, fail_on_error: bool) -> Self {
63 Self {
64 child,
65 data_type,
66 fail_on_error,
67 }
68 }
69}
70
71impl Display for CheckOverflow {
72 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
73 write!(
74 f,
75 "CheckOverflow [datatype: {}, fail_on_error: {}, child: {}]",
76 self.data_type, self.fail_on_error, self.child
77 )
78 }
79}
80
81impl PhysicalExpr for CheckOverflow {
82 fn as_any(&self) -> &dyn Any {
83 self
84 }
85
86 fn data_type(&self, _: &Schema) -> datafusion_common::Result<DataType> {
87 Ok(self.data_type.clone())
88 }
89
90 fn nullable(&self, _: &Schema) -> datafusion_common::Result<bool> {
91 Ok(true)
92 }
93
94 fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result<ColumnarValue> {
95 let arg = self.child.evaluate(batch)?;
96 match arg {
97 ColumnarValue::Array(array)
98 if matches!(array.data_type(), DataType::Decimal128(_, _)) =>
99 {
100 let (precision, scale) = match &self.data_type {
101 DataType::Decimal128(p, s) => (p, s),
102 dt => {
103 return Err(DataFusionError::Execution(format!(
104 "CheckOverflow expects only Decimal128, but got {:?}",
105 dt
106 )))
107 }
108 };
109
110 let decimal_array = as_primitive_array::<Decimal128Type>(&array);
111
112 let casted_array = if self.fail_on_error {
113 decimal_array.validate_decimal_precision(*precision)?;
115 decimal_array
116 } else {
117 &decimal_array.null_if_overflow_precision(*precision)
119 };
120
121 let new_array = Decimal128Array::from(casted_array.into_data())
122 .with_precision_and_scale(*precision, *scale)
123 .map(|a| Arc::new(a) as ArrayRef)?;
124
125 Ok(ColumnarValue::Array(new_array))
126 }
127 ColumnarValue::Scalar(ScalarValue::Decimal128(v, precision, scale)) => {
128 assert!(
131 !self.fail_on_error,
132 "fail_on_error (ANSI mode) is not supported yet"
133 );
134
135 let new_v: Option<i128> = v.and_then(|v| {
136 Decimal128Type::validate_decimal_precision(v, precision)
137 .map(|_| v)
138 .ok()
139 });
140
141 Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(
142 new_v, precision, scale,
143 )))
144 }
145 v => Err(DataFusionError::Execution(format!(
146 "CheckOverflow's child expression should be decimal array, but found {:?}",
147 v
148 ))),
149 }
150 }
151
152 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
153 vec![&self.child]
154 }
155
156 fn with_new_children(
157 self: Arc<Self>,
158 children: Vec<Arc<dyn PhysicalExpr>>,
159 ) -> datafusion_common::Result<Arc<dyn PhysicalExpr>> {
160 Ok(Arc::new(CheckOverflow::new(
161 Arc::clone(&children[0]),
162 self.data_type.clone(),
163 self.fail_on_error,
164 )))
165 }
166}