datafusion_physical_expr/expressions/
negative.rs1use std::any::Any;
21use std::hash::Hash;
22use std::sync::Arc;
23
24use crate::PhysicalExpr;
25
26use arrow::{
27 compute::kernels::numeric::neg_wrapping,
28 datatypes::{DataType, Schema},
29 record_batch::RecordBatch,
30};
31use datafusion_common::{internal_err, plan_err, Result};
32use datafusion_expr::interval_arithmetic::Interval;
33use datafusion_expr::sort_properties::ExprProperties;
34use datafusion_expr::statistics::Distribution::{
35 self, Bernoulli, Exponential, Gaussian, Generic, Uniform,
36};
37use datafusion_expr::{
38 type_coercion::{is_interval, is_null, is_signed_numeric, is_timestamp},
39 ColumnarValue,
40};
41
42#[derive(Debug, Eq)]
44pub struct NegativeExpr {
45 arg: Arc<dyn PhysicalExpr>,
47}
48
49impl PartialEq for NegativeExpr {
51 fn eq(&self, other: &Self) -> bool {
52 self.arg.eq(&other.arg)
53 }
54}
55
56impl Hash for NegativeExpr {
57 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
58 self.arg.hash(state);
59 }
60}
61
62impl NegativeExpr {
63 pub fn new(arg: Arc<dyn PhysicalExpr>) -> Self {
65 Self { arg }
66 }
67
68 pub fn arg(&self) -> &Arc<dyn PhysicalExpr> {
70 &self.arg
71 }
72}
73
74impl std::fmt::Display for NegativeExpr {
75 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
76 write!(f, "(- {})", self.arg)
77 }
78}
79
80impl PhysicalExpr for NegativeExpr {
81 fn as_any(&self) -> &dyn Any {
83 self
84 }
85
86 fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
87 self.arg.data_type(input_schema)
88 }
89
90 fn nullable(&self, input_schema: &Schema) -> Result<bool> {
91 self.arg.nullable(input_schema)
92 }
93
94 fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
95 match self.arg.evaluate(batch)? {
96 ColumnarValue::Array(array) => {
97 let result = neg_wrapping(array.as_ref())?;
98 Ok(ColumnarValue::Array(result))
99 }
100 ColumnarValue::Scalar(scalar) => {
101 Ok(ColumnarValue::Scalar(scalar.arithmetic_negate()?))
102 }
103 }
104 }
105
106 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
107 vec![&self.arg]
108 }
109
110 fn with_new_children(
111 self: Arc<Self>,
112 children: Vec<Arc<dyn PhysicalExpr>>,
113 ) -> Result<Arc<dyn PhysicalExpr>> {
114 Ok(Arc::new(NegativeExpr::new(Arc::clone(&children[0]))))
115 }
116
117 fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
121 children[0].arithmetic_negate()
122 }
123
124 fn propagate_constraints(
127 &self,
128 interval: &Interval,
129 children: &[&Interval],
130 ) -> Result<Option<Vec<Interval>>> {
131 let negated_interval = interval.arithmetic_negate()?;
132
133 Ok(children[0]
134 .intersect(negated_interval)?
135 .map(|result| vec![result]))
136 }
137
138 fn evaluate_statistics(&self, children: &[&Distribution]) -> Result<Distribution> {
139 match children[0] {
140 Uniform(u) => Distribution::new_uniform(u.range().arithmetic_negate()?),
141 Exponential(e) => Distribution::new_exponential(
142 e.rate().clone(),
143 e.offset().arithmetic_negate()?,
144 !e.positive_tail(),
145 ),
146 Gaussian(g) => Distribution::new_gaussian(
147 g.mean().arithmetic_negate()?,
148 g.variance().clone(),
149 ),
150 Bernoulli(_) => {
151 internal_err!("NegativeExpr cannot operate on Boolean datatypes")
152 }
153 Generic(u) => Distribution::new_generic(
154 u.mean().arithmetic_negate()?,
155 u.median().arithmetic_negate()?,
156 u.variance().clone(),
157 u.range().arithmetic_negate()?,
158 ),
159 }
160 }
161
162 fn get_properties(&self, children: &[ExprProperties]) -> Result<ExprProperties> {
164 Ok(ExprProperties {
165 sort_properties: -children[0].sort_properties,
166 range: children[0].range.clone().arithmetic_negate()?,
167 preserves_lex_ordering: false,
168 })
169 }
170
171 fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
172 write!(f, "(- ")?;
173 self.arg.fmt_sql(f)?;
174 write!(f, ")")
175 }
176}
177
178pub fn negative(
184 arg: Arc<dyn PhysicalExpr>,
185 input_schema: &Schema,
186) -> Result<Arc<dyn PhysicalExpr>> {
187 let data_type = arg.data_type(input_schema)?;
188 if is_null(&data_type) {
189 Ok(arg)
190 } else if !is_signed_numeric(&data_type)
191 && !is_interval(&data_type)
192 && !is_timestamp(&data_type)
193 {
194 plan_err!("Negation only supports numeric, interval and timestamp types")
195 } else {
196 Ok(Arc::new(NegativeExpr::new(arg)))
197 }
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203 use crate::expressions::{col, Column};
204
205 use arrow::array::*;
206 use arrow::datatypes::DataType::{Float32, Float64, Int16, Int32, Int64, Int8};
207 use arrow::datatypes::*;
208 use datafusion_common::cast::as_primitive_array;
209 use datafusion_common::{DataFusionError, ScalarValue};
210
211 use datafusion_physical_expr_common::physical_expr::fmt_sql;
212 use paste::paste;
213
214 macro_rules! test_array_negative_op {
215 ($DATA_TY:tt, $($VALUE:expr),* ) => {
216 let schema = Schema::new(vec![Field::new("a", DataType::$DATA_TY, true)]);
217 let expr = negative(col("a", &schema)?, &schema)?;
218 assert_eq!(expr.data_type(&schema)?, DataType::$DATA_TY);
219 assert!(expr.nullable(&schema)?);
220 let mut arr = Vec::new();
221 let mut arr_expected = Vec::new();
222 $(
223 arr.push(Some($VALUE));
224 arr_expected.push(Some(-$VALUE));
225 )+
226 arr.push(None);
227 arr_expected.push(None);
228 let input = paste!{[<$DATA_TY Array>]::from(arr)};
229 let expected = &paste!{[<$DATA_TY Array>]::from(arr_expected)};
230 let batch =
231 RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(input)])?;
232 let result = expr.evaluate(&batch)?.into_array(batch.num_rows()).expect("Failed to convert to array");
233 let result =
234 as_primitive_array(&result).expect(format!("failed to downcast to {:?}Array", $DATA_TY).as_str());
235 assert_eq!(result, expected);
236 };
237 }
238
239 #[test]
240 fn array_negative_op() -> Result<()> {
241 test_array_negative_op!(Int8, 2i8, 1i8);
242 test_array_negative_op!(Int16, 234i16, 123i16);
243 test_array_negative_op!(Int32, 2345i32, 1234i32);
244 test_array_negative_op!(Int64, 23456i64, 12345i64);
245 test_array_negative_op!(Float32, 2345.0f32, 1234.0f32);
246 test_array_negative_op!(Float64, 23456.0f64, 12345.0f64);
247 Ok(())
248 }
249
250 #[test]
251 fn test_evaluate_bounds() -> Result<()> {
252 let negative_expr = NegativeExpr::new(Arc::new(Column::new("a", 0)));
253 let child_interval = Interval::make(Some(-2), Some(1))?;
254 let negative_expr_interval = Interval::make(Some(-1), Some(2))?;
255 assert_eq!(
256 negative_expr.evaluate_bounds(&[&child_interval])?,
257 negative_expr_interval
258 );
259 Ok(())
260 }
261
262 #[test]
263 fn test_evaluate_statistics() -> Result<()> {
264 let negative_expr = NegativeExpr::new(Arc::new(Column::new("a", 0)));
265
266 assert_eq!(
268 negative_expr.evaluate_statistics(&[&Distribution::new_uniform(
269 Interval::make(Some(-2.), Some(3.))?
270 )?])?,
271 Distribution::new_uniform(Interval::make(Some(-3.), Some(2.))?)?
272 );
273
274 assert!(negative_expr
276 .evaluate_statistics(&[&Distribution::new_bernoulli(ScalarValue::from(
277 0.75
278 ))?])
279 .is_err());
280
281 assert_eq!(
283 negative_expr.evaluate_statistics(&[&Distribution::new_exponential(
284 ScalarValue::from(1.),
285 ScalarValue::from(1.),
286 true
287 )?])?,
288 Distribution::new_exponential(
289 ScalarValue::from(1.),
290 ScalarValue::from(-1.),
291 false
292 )?
293 );
294
295 assert_eq!(
297 negative_expr.evaluate_statistics(&[&Distribution::new_gaussian(
298 ScalarValue::from(15),
299 ScalarValue::from(225),
300 )?])?,
301 Distribution::new_gaussian(ScalarValue::from(-15), ScalarValue::from(225),)?
302 );
303
304 assert_eq!(
306 negative_expr.evaluate_statistics(&[&Distribution::new_generic(
307 ScalarValue::from(15),
308 ScalarValue::from(15),
309 ScalarValue::from(10),
310 Interval::make(Some(10), Some(20))?
311 )?])?,
312 Distribution::new_generic(
313 ScalarValue::from(-15),
314 ScalarValue::from(-15),
315 ScalarValue::from(10),
316 Interval::make(Some(-20), Some(-10))?
317 )?
318 );
319
320 Ok(())
321 }
322
323 #[test]
324 fn test_propagate_constraints() -> Result<()> {
325 let negative_expr = NegativeExpr::new(Arc::new(Column::new("a", 0)));
326 let original_child_interval = Interval::make(Some(-2), Some(3))?;
327 let negative_expr_interval = Interval::make(Some(0), Some(4))?;
328 let after_propagation = Some(vec![Interval::make(Some(-2), Some(0))?]);
329 assert_eq!(
330 negative_expr.propagate_constraints(
331 &negative_expr_interval,
332 &[&original_child_interval]
333 )?,
334 after_propagation
335 );
336 Ok(())
337 }
338
339 #[test]
340 fn test_propagate_statistics_range_holders() -> Result<()> {
341 let negative_expr = NegativeExpr::new(Arc::new(Column::new("a", 0)));
342 let original_child_interval = Interval::make(Some(-2), Some(3))?;
343 let after_propagation = Interval::make(Some(-2), Some(0))?;
344
345 let parent = Distribution::new_uniform(Interval::make(Some(0), Some(4))?)?;
346 let children: Vec<Vec<Distribution>> = vec![
347 vec![Distribution::new_uniform(original_child_interval.clone())?],
348 vec![Distribution::new_generic(
349 ScalarValue::from(0),
350 ScalarValue::from(0),
351 ScalarValue::Int32(None),
352 original_child_interval.clone(),
353 )?],
354 ];
355
356 for child_view in children {
357 let child_refs: Vec<_> = child_view.iter().collect();
358 let actual = negative_expr.propagate_statistics(&parent, &child_refs)?;
359 let expected = Some(vec![Distribution::new_from_interval(
360 after_propagation.clone(),
361 )?]);
362 assert_eq!(actual, expected);
363 }
364
365 Ok(())
366 }
367
368 #[test]
369 fn test_negation_valid_types() -> Result<()> {
370 let negatable_types = [
371 Int8,
372 DataType::Timestamp(TimeUnit::Second, None),
373 DataType::Interval(IntervalUnit::YearMonth),
374 ];
375 for negatable_type in negatable_types {
376 let schema = Schema::new(vec![Field::new("a", negatable_type, true)]);
377 let _expr = negative(col("a", &schema)?, &schema)?;
378 }
379 Ok(())
380 }
381
382 #[test]
383 fn test_negation_invalid_types() -> Result<()> {
384 let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
385 let expr = negative(col("a", &schema)?, &schema).unwrap_err();
386 matches!(expr, DataFusionError::Plan(_));
387 Ok(())
388 }
389
390 #[test]
391 fn test_fmt_sql() -> Result<()> {
392 let expr = NegativeExpr::new(Arc::new(Column::new("a", 0)));
393 let display_string = expr.to_string();
394 assert_eq!(display_string, "(- a@0)");
395 let sql_string = fmt_sql(&expr).to_string();
396 assert_eq!(sql_string, "(- a)");
397
398 Ok(())
399 }
400}