datafusion_physical_expr/expressions/
negative.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Negation (-) expression
19
20use std::any::Any;
21use std::hash::Hash;
22use std::sync::Arc;
23
24use crate::PhysicalExpr;
25
26use arrow::{
27    compute::kernels::numeric::neg_wrapping,
28    datatypes::{DataType, Schema},
29    record_batch::RecordBatch,
30};
31use datafusion_common::{internal_err, plan_err, Result};
32use datafusion_expr::interval_arithmetic::Interval;
33use datafusion_expr::sort_properties::ExprProperties;
34use datafusion_expr::statistics::Distribution::{
35    self, Bernoulli, Exponential, Gaussian, Generic, Uniform,
36};
37use datafusion_expr::{
38    type_coercion::{is_interval, is_null, is_signed_numeric, is_timestamp},
39    ColumnarValue,
40};
41
42/// Negative expression
43#[derive(Debug, Eq)]
44pub struct NegativeExpr {
45    /// Input expression
46    arg: Arc<dyn PhysicalExpr>,
47}
48
49// Manually derive PartialEq and Hash to work around https://github.com/rust-lang/rust/issues/78808
50impl PartialEq for NegativeExpr {
51    fn eq(&self, other: &Self) -> bool {
52        self.arg.eq(&other.arg)
53    }
54}
55
56impl Hash for NegativeExpr {
57    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
58        self.arg.hash(state);
59    }
60}
61
62impl NegativeExpr {
63    /// Create new not expression
64    pub fn new(arg: Arc<dyn PhysicalExpr>) -> Self {
65        Self { arg }
66    }
67
68    /// Get the input expression
69    pub fn arg(&self) -> &Arc<dyn PhysicalExpr> {
70        &self.arg
71    }
72}
73
74impl std::fmt::Display for NegativeExpr {
75    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
76        write!(f, "(- {})", self.arg)
77    }
78}
79
80impl PhysicalExpr for NegativeExpr {
81    /// Return a reference to Any that can be used for downcasting
82    fn as_any(&self) -> &dyn Any {
83        self
84    }
85
86    fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
87        self.arg.data_type(input_schema)
88    }
89
90    fn nullable(&self, input_schema: &Schema) -> Result<bool> {
91        self.arg.nullable(input_schema)
92    }
93
94    fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
95        match self.arg.evaluate(batch)? {
96            ColumnarValue::Array(array) => {
97                let result = neg_wrapping(array.as_ref())?;
98                Ok(ColumnarValue::Array(result))
99            }
100            ColumnarValue::Scalar(scalar) => {
101                Ok(ColumnarValue::Scalar(scalar.arithmetic_negate()?))
102            }
103        }
104    }
105
106    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
107        vec![&self.arg]
108    }
109
110    fn with_new_children(
111        self: Arc<Self>,
112        children: Vec<Arc<dyn PhysicalExpr>>,
113    ) -> Result<Arc<dyn PhysicalExpr>> {
114        Ok(Arc::new(NegativeExpr::new(Arc::clone(&children[0]))))
115    }
116
117    /// Given the child interval of a NegativeExpr, it calculates the NegativeExpr's interval.
118    /// It replaces the upper and lower bounds after multiplying them with -1.
119    /// Ex: `(a, b]` => `[-b, -a)`
120    fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
121        children[0].arithmetic_negate()
122    }
123
124    /// Returns a new [`Interval`] of a NegativeExpr  that has the existing `interval` given that
125    /// given the input interval is known to be `children`.
126    fn propagate_constraints(
127        &self,
128        interval: &Interval,
129        children: &[&Interval],
130    ) -> Result<Option<Vec<Interval>>> {
131        let negated_interval = interval.arithmetic_negate()?;
132
133        Ok(children[0]
134            .intersect(negated_interval)?
135            .map(|result| vec![result]))
136    }
137
138    fn evaluate_statistics(&self, children: &[&Distribution]) -> Result<Distribution> {
139        match children[0] {
140            Uniform(u) => Distribution::new_uniform(u.range().arithmetic_negate()?),
141            Exponential(e) => Distribution::new_exponential(
142                e.rate().clone(),
143                e.offset().arithmetic_negate()?,
144                !e.positive_tail(),
145            ),
146            Gaussian(g) => Distribution::new_gaussian(
147                g.mean().arithmetic_negate()?,
148                g.variance().clone(),
149            ),
150            Bernoulli(_) => {
151                internal_err!("NegativeExpr cannot operate on Boolean datatypes")
152            }
153            Generic(u) => Distribution::new_generic(
154                u.mean().arithmetic_negate()?,
155                u.median().arithmetic_negate()?,
156                u.variance().clone(),
157                u.range().arithmetic_negate()?,
158            ),
159        }
160    }
161
162    /// The ordering of a [`NegativeExpr`] is simply the reverse of its child.
163    fn get_properties(&self, children: &[ExprProperties]) -> Result<ExprProperties> {
164        Ok(ExprProperties {
165            sort_properties: -children[0].sort_properties,
166            range: children[0].range.clone().arithmetic_negate()?,
167            preserves_lex_ordering: false,
168        })
169    }
170
171    fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
172        write!(f, "(- ")?;
173        self.arg.fmt_sql(f)?;
174        write!(f, ")")
175    }
176}
177
178/// Creates a unary expression NEGATIVE
179///
180/// # Errors
181///
182/// This function errors when the argument's type is not signed numeric
183pub fn negative(
184    arg: Arc<dyn PhysicalExpr>,
185    input_schema: &Schema,
186) -> Result<Arc<dyn PhysicalExpr>> {
187    let data_type = arg.data_type(input_schema)?;
188    if is_null(&data_type) {
189        Ok(arg)
190    } else if !is_signed_numeric(&data_type)
191        && !is_interval(&data_type)
192        && !is_timestamp(&data_type)
193    {
194        plan_err!("Negation only supports numeric, interval and timestamp types")
195    } else {
196        Ok(Arc::new(NegativeExpr::new(arg)))
197    }
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203    use crate::expressions::{col, Column};
204
205    use arrow::array::*;
206    use arrow::datatypes::DataType::{Float32, Float64, Int16, Int32, Int64, Int8};
207    use arrow::datatypes::*;
208    use datafusion_common::cast::as_primitive_array;
209    use datafusion_common::{DataFusionError, ScalarValue};
210
211    use datafusion_physical_expr_common::physical_expr::fmt_sql;
212    use paste::paste;
213
214    macro_rules! test_array_negative_op {
215        ($DATA_TY:tt, $($VALUE:expr),*   ) => {
216            let schema = Schema::new(vec![Field::new("a", DataType::$DATA_TY, true)]);
217            let expr = negative(col("a", &schema)?, &schema)?;
218            assert_eq!(expr.data_type(&schema)?, DataType::$DATA_TY);
219            assert!(expr.nullable(&schema)?);
220            let mut arr = Vec::new();
221            let mut arr_expected = Vec::new();
222            $(
223                arr.push(Some($VALUE));
224                arr_expected.push(Some(-$VALUE));
225            )+
226            arr.push(None);
227            arr_expected.push(None);
228            let input = paste!{[<$DATA_TY Array>]::from(arr)};
229            let expected = &paste!{[<$DATA_TY Array>]::from(arr_expected)};
230            let batch =
231                RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(input)])?;
232            let result = expr.evaluate(&batch)?.into_array(batch.num_rows()).expect("Failed to convert to array");
233            let result =
234                as_primitive_array(&result).expect(format!("failed to downcast to {:?}Array", $DATA_TY).as_str());
235            assert_eq!(result, expected);
236        };
237    }
238
239    #[test]
240    fn array_negative_op() -> Result<()> {
241        test_array_negative_op!(Int8, 2i8, 1i8);
242        test_array_negative_op!(Int16, 234i16, 123i16);
243        test_array_negative_op!(Int32, 2345i32, 1234i32);
244        test_array_negative_op!(Int64, 23456i64, 12345i64);
245        test_array_negative_op!(Float32, 2345.0f32, 1234.0f32);
246        test_array_negative_op!(Float64, 23456.0f64, 12345.0f64);
247        Ok(())
248    }
249
250    #[test]
251    fn test_evaluate_bounds() -> Result<()> {
252        let negative_expr = NegativeExpr::new(Arc::new(Column::new("a", 0)));
253        let child_interval = Interval::make(Some(-2), Some(1))?;
254        let negative_expr_interval = Interval::make(Some(-1), Some(2))?;
255        assert_eq!(
256            negative_expr.evaluate_bounds(&[&child_interval])?,
257            negative_expr_interval
258        );
259        Ok(())
260    }
261
262    #[test]
263    fn test_evaluate_statistics() -> Result<()> {
264        let negative_expr = NegativeExpr::new(Arc::new(Column::new("a", 0)));
265
266        // Uniform
267        assert_eq!(
268            negative_expr.evaluate_statistics(&[&Distribution::new_uniform(
269                Interval::make(Some(-2.), Some(3.))?
270            )?])?,
271            Distribution::new_uniform(Interval::make(Some(-3.), Some(2.))?)?
272        );
273
274        // Bernoulli
275        assert!(negative_expr
276            .evaluate_statistics(&[&Distribution::new_bernoulli(ScalarValue::from(
277                0.75
278            ))?])
279            .is_err());
280
281        // Exponential
282        assert_eq!(
283            negative_expr.evaluate_statistics(&[&Distribution::new_exponential(
284                ScalarValue::from(1.),
285                ScalarValue::from(1.),
286                true
287            )?])?,
288            Distribution::new_exponential(
289                ScalarValue::from(1.),
290                ScalarValue::from(-1.),
291                false
292            )?
293        );
294
295        // Gaussian
296        assert_eq!(
297            negative_expr.evaluate_statistics(&[&Distribution::new_gaussian(
298                ScalarValue::from(15),
299                ScalarValue::from(225),
300            )?])?,
301            Distribution::new_gaussian(ScalarValue::from(-15), ScalarValue::from(225),)?
302        );
303
304        // Unknown
305        assert_eq!(
306            negative_expr.evaluate_statistics(&[&Distribution::new_generic(
307                ScalarValue::from(15),
308                ScalarValue::from(15),
309                ScalarValue::from(10),
310                Interval::make(Some(10), Some(20))?
311            )?])?,
312            Distribution::new_generic(
313                ScalarValue::from(-15),
314                ScalarValue::from(-15),
315                ScalarValue::from(10),
316                Interval::make(Some(-20), Some(-10))?
317            )?
318        );
319
320        Ok(())
321    }
322
323    #[test]
324    fn test_propagate_constraints() -> Result<()> {
325        let negative_expr = NegativeExpr::new(Arc::new(Column::new("a", 0)));
326        let original_child_interval = Interval::make(Some(-2), Some(3))?;
327        let negative_expr_interval = Interval::make(Some(0), Some(4))?;
328        let after_propagation = Some(vec![Interval::make(Some(-2), Some(0))?]);
329        assert_eq!(
330            negative_expr.propagate_constraints(
331                &negative_expr_interval,
332                &[&original_child_interval]
333            )?,
334            after_propagation
335        );
336        Ok(())
337    }
338
339    #[test]
340    fn test_propagate_statistics_range_holders() -> Result<()> {
341        let negative_expr = NegativeExpr::new(Arc::new(Column::new("a", 0)));
342        let original_child_interval = Interval::make(Some(-2), Some(3))?;
343        let after_propagation = Interval::make(Some(-2), Some(0))?;
344
345        let parent = Distribution::new_uniform(Interval::make(Some(0), Some(4))?)?;
346        let children: Vec<Vec<Distribution>> = vec![
347            vec![Distribution::new_uniform(original_child_interval.clone())?],
348            vec![Distribution::new_generic(
349                ScalarValue::from(0),
350                ScalarValue::from(0),
351                ScalarValue::Int32(None),
352                original_child_interval.clone(),
353            )?],
354        ];
355
356        for child_view in children {
357            let child_refs: Vec<_> = child_view.iter().collect();
358            let actual = negative_expr.propagate_statistics(&parent, &child_refs)?;
359            let expected = Some(vec![Distribution::new_from_interval(
360                after_propagation.clone(),
361            )?]);
362            assert_eq!(actual, expected);
363        }
364
365        Ok(())
366    }
367
368    #[test]
369    fn test_negation_valid_types() -> Result<()> {
370        let negatable_types = [
371            Int8,
372            DataType::Timestamp(TimeUnit::Second, None),
373            DataType::Interval(IntervalUnit::YearMonth),
374        ];
375        for negatable_type in negatable_types {
376            let schema = Schema::new(vec![Field::new("a", negatable_type, true)]);
377            let _expr = negative(col("a", &schema)?, &schema)?;
378        }
379        Ok(())
380    }
381
382    #[test]
383    fn test_negation_invalid_types() -> Result<()> {
384        let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
385        let expr = negative(col("a", &schema)?, &schema).unwrap_err();
386        matches!(expr, DataFusionError::Plan(_));
387        Ok(())
388    }
389
390    #[test]
391    fn test_fmt_sql() -> Result<()> {
392        let expr = NegativeExpr::new(Arc::new(Column::new("a", 0)));
393        let display_string = expr.to_string();
394        assert_eq!(display_string, "(- a@0)");
395        let sql_string = fmt_sql(&expr).to_string();
396        assert_eq!(sql_string, "(- a)");
397
398        Ok(())
399    }
400}