datafusion_physical_expr/expressions/
negative.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Negation (-) expression
19
20use std::any::Any;
21use std::hash::Hash;
22use std::sync::Arc;
23
24use crate::PhysicalExpr;
25
26use arrow::datatypes::FieldRef;
27use arrow::{
28    compute::kernels::numeric::neg_wrapping,
29    datatypes::{DataType, Schema},
30    record_batch::RecordBatch,
31};
32use datafusion_common::{internal_err, plan_err, Result};
33use datafusion_expr::interval_arithmetic::Interval;
34use datafusion_expr::sort_properties::ExprProperties;
35use datafusion_expr::statistics::Distribution::{
36    self, Bernoulli, Exponential, Gaussian, Generic, Uniform,
37};
38use datafusion_expr::{
39    type_coercion::{is_interval, is_null, is_signed_numeric, is_timestamp},
40    ColumnarValue,
41};
42
43/// Negative expression
44#[derive(Debug, Eq)]
45pub struct NegativeExpr {
46    /// Input expression
47    arg: Arc<dyn PhysicalExpr>,
48}
49
50// Manually derive PartialEq and Hash to work around https://github.com/rust-lang/rust/issues/78808
51impl PartialEq for NegativeExpr {
52    fn eq(&self, other: &Self) -> bool {
53        self.arg.eq(&other.arg)
54    }
55}
56
57impl Hash for NegativeExpr {
58    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
59        self.arg.hash(state);
60    }
61}
62
63impl NegativeExpr {
64    /// Create new not expression
65    pub fn new(arg: Arc<dyn PhysicalExpr>) -> Self {
66        Self { arg }
67    }
68
69    /// Get the input expression
70    pub fn arg(&self) -> &Arc<dyn PhysicalExpr> {
71        &self.arg
72    }
73}
74
75impl std::fmt::Display for NegativeExpr {
76    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
77        write!(f, "(- {})", self.arg)
78    }
79}
80
81impl PhysicalExpr for NegativeExpr {
82    /// Return a reference to Any that can be used for downcasting
83    fn as_any(&self) -> &dyn Any {
84        self
85    }
86
87    fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
88        self.arg.data_type(input_schema)
89    }
90
91    fn nullable(&self, input_schema: &Schema) -> Result<bool> {
92        self.arg.nullable(input_schema)
93    }
94
95    fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
96        match self.arg.evaluate(batch)? {
97            ColumnarValue::Array(array) => {
98                let result = neg_wrapping(array.as_ref())?;
99                Ok(ColumnarValue::Array(result))
100            }
101            ColumnarValue::Scalar(scalar) => {
102                Ok(ColumnarValue::Scalar(scalar.arithmetic_negate()?))
103            }
104        }
105    }
106
107    fn return_field(&self, input_schema: &Schema) -> Result<FieldRef> {
108        self.arg.return_field(input_schema)
109    }
110
111    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
112        vec![&self.arg]
113    }
114
115    fn with_new_children(
116        self: Arc<Self>,
117        children: Vec<Arc<dyn PhysicalExpr>>,
118    ) -> Result<Arc<dyn PhysicalExpr>> {
119        Ok(Arc::new(NegativeExpr::new(Arc::clone(&children[0]))))
120    }
121
122    /// Given the child interval of a NegativeExpr, it calculates the NegativeExpr's interval.
123    /// It replaces the upper and lower bounds after multiplying them with -1.
124    /// Ex: `(a, b]` => `[-b, -a)`
125    fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
126        children[0].arithmetic_negate()
127    }
128
129    /// Returns a new [`Interval`] of a NegativeExpr  that has the existing `interval` given that
130    /// given the input interval is known to be `children`.
131    fn propagate_constraints(
132        &self,
133        interval: &Interval,
134        children: &[&Interval],
135    ) -> Result<Option<Vec<Interval>>> {
136        let negated_interval = interval.arithmetic_negate()?;
137
138        Ok(children[0]
139            .intersect(negated_interval)?
140            .map(|result| vec![result]))
141    }
142
143    fn evaluate_statistics(&self, children: &[&Distribution]) -> Result<Distribution> {
144        match children[0] {
145            Uniform(u) => Distribution::new_uniform(u.range().arithmetic_negate()?),
146            Exponential(e) => Distribution::new_exponential(
147                e.rate().clone(),
148                e.offset().arithmetic_negate()?,
149                !e.positive_tail(),
150            ),
151            Gaussian(g) => Distribution::new_gaussian(
152                g.mean().arithmetic_negate()?,
153                g.variance().clone(),
154            ),
155            Bernoulli(_) => {
156                internal_err!("NegativeExpr cannot operate on Boolean datatypes")
157            }
158            Generic(u) => Distribution::new_generic(
159                u.mean().arithmetic_negate()?,
160                u.median().arithmetic_negate()?,
161                u.variance().clone(),
162                u.range().arithmetic_negate()?,
163            ),
164        }
165    }
166
167    /// The ordering of a [`NegativeExpr`] is simply the reverse of its child.
168    fn get_properties(&self, children: &[ExprProperties]) -> Result<ExprProperties> {
169        Ok(ExprProperties {
170            sort_properties: -children[0].sort_properties,
171            range: children[0].range.clone().arithmetic_negate()?,
172            preserves_lex_ordering: false,
173        })
174    }
175
176    fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
177        write!(f, "(- ")?;
178        self.arg.fmt_sql(f)?;
179        write!(f, ")")
180    }
181}
182
183/// Creates a unary expression NEGATIVE
184///
185/// # Errors
186///
187/// This function errors when the argument's type is not signed numeric
188pub fn negative(
189    arg: Arc<dyn PhysicalExpr>,
190    input_schema: &Schema,
191) -> Result<Arc<dyn PhysicalExpr>> {
192    let data_type = arg.data_type(input_schema)?;
193    if is_null(&data_type) {
194        Ok(arg)
195    } else if !is_signed_numeric(&data_type)
196        && !is_interval(&data_type)
197        && !is_timestamp(&data_type)
198    {
199        plan_err!("Negation only supports numeric, interval and timestamp types")
200    } else {
201        Ok(Arc::new(NegativeExpr::new(arg)))
202    }
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208    use crate::expressions::{col, Column};
209
210    use arrow::array::*;
211    use arrow::datatypes::DataType::{Float32, Float64, Int16, Int32, Int64, Int8};
212    use arrow::datatypes::*;
213    use datafusion_common::cast::as_primitive_array;
214    use datafusion_common::{DataFusionError, ScalarValue};
215
216    use datafusion_physical_expr_common::physical_expr::fmt_sql;
217    use paste::paste;
218
219    macro_rules! test_array_negative_op {
220        ($DATA_TY:tt, $($VALUE:expr),*   ) => {
221            let schema = Schema::new(vec![Field::new("a", DataType::$DATA_TY, true)]);
222            let expr = negative(col("a", &schema)?, &schema)?;
223            assert_eq!(expr.data_type(&schema)?, DataType::$DATA_TY);
224            assert!(expr.nullable(&schema)?);
225            let mut arr = Vec::new();
226            let mut arr_expected = Vec::new();
227            $(
228                arr.push(Some($VALUE));
229                arr_expected.push(Some(-$VALUE));
230            )+
231            arr.push(None);
232            arr_expected.push(None);
233            let input = paste!{[<$DATA_TY Array>]::from(arr)};
234            let expected = &paste!{[<$DATA_TY Array>]::from(arr_expected)};
235            let batch =
236                RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(input)])?;
237            let result = expr.evaluate(&batch)?.into_array(batch.num_rows()).expect("Failed to convert to array");
238            let result =
239                as_primitive_array(&result).expect(format!("failed to downcast to {:?}Array", $DATA_TY).as_str());
240            assert_eq!(result, expected);
241        };
242    }
243
244    #[test]
245    fn array_negative_op() -> Result<()> {
246        test_array_negative_op!(Int8, 2i8, 1i8);
247        test_array_negative_op!(Int16, 234i16, 123i16);
248        test_array_negative_op!(Int32, 2345i32, 1234i32);
249        test_array_negative_op!(Int64, 23456i64, 12345i64);
250        test_array_negative_op!(Float32, 2345.0f32, 1234.0f32);
251        test_array_negative_op!(Float64, 23456.0f64, 12345.0f64);
252        Ok(())
253    }
254
255    #[test]
256    fn test_evaluate_bounds() -> Result<()> {
257        let negative_expr = NegativeExpr::new(Arc::new(Column::new("a", 0)));
258        let child_interval = Interval::make(Some(-2), Some(1))?;
259        let negative_expr_interval = Interval::make(Some(-1), Some(2))?;
260        assert_eq!(
261            negative_expr.evaluate_bounds(&[&child_interval])?,
262            negative_expr_interval
263        );
264        Ok(())
265    }
266
267    #[test]
268    fn test_evaluate_statistics() -> Result<()> {
269        let negative_expr = NegativeExpr::new(Arc::new(Column::new("a", 0)));
270
271        // Uniform
272        assert_eq!(
273            negative_expr.evaluate_statistics(&[&Distribution::new_uniform(
274                Interval::make(Some(-2.), Some(3.))?
275            )?])?,
276            Distribution::new_uniform(Interval::make(Some(-3.), Some(2.))?)?
277        );
278
279        // Bernoulli
280        assert!(negative_expr
281            .evaluate_statistics(&[&Distribution::new_bernoulli(ScalarValue::from(
282                0.75
283            ))?])
284            .is_err());
285
286        // Exponential
287        assert_eq!(
288            negative_expr.evaluate_statistics(&[&Distribution::new_exponential(
289                ScalarValue::from(1.),
290                ScalarValue::from(1.),
291                true
292            )?])?,
293            Distribution::new_exponential(
294                ScalarValue::from(1.),
295                ScalarValue::from(-1.),
296                false
297            )?
298        );
299
300        // Gaussian
301        assert_eq!(
302            negative_expr.evaluate_statistics(&[&Distribution::new_gaussian(
303                ScalarValue::from(15),
304                ScalarValue::from(225),
305            )?])?,
306            Distribution::new_gaussian(ScalarValue::from(-15), ScalarValue::from(225),)?
307        );
308
309        // Unknown
310        assert_eq!(
311            negative_expr.evaluate_statistics(&[&Distribution::new_generic(
312                ScalarValue::from(15),
313                ScalarValue::from(15),
314                ScalarValue::from(10),
315                Interval::make(Some(10), Some(20))?
316            )?])?,
317            Distribution::new_generic(
318                ScalarValue::from(-15),
319                ScalarValue::from(-15),
320                ScalarValue::from(10),
321                Interval::make(Some(-20), Some(-10))?
322            )?
323        );
324
325        Ok(())
326    }
327
328    #[test]
329    fn test_propagate_constraints() -> Result<()> {
330        let negative_expr = NegativeExpr::new(Arc::new(Column::new("a", 0)));
331        let original_child_interval = Interval::make(Some(-2), Some(3))?;
332        let negative_expr_interval = Interval::make(Some(0), Some(4))?;
333        let after_propagation = Some(vec![Interval::make(Some(-2), Some(0))?]);
334        assert_eq!(
335            negative_expr.propagate_constraints(
336                &negative_expr_interval,
337                &[&original_child_interval]
338            )?,
339            after_propagation
340        );
341        Ok(())
342    }
343
344    #[test]
345    fn test_propagate_statistics_range_holders() -> Result<()> {
346        let negative_expr = NegativeExpr::new(Arc::new(Column::new("a", 0)));
347        let original_child_interval = Interval::make(Some(-2), Some(3))?;
348        let after_propagation = Interval::make(Some(-2), Some(0))?;
349
350        let parent = Distribution::new_uniform(Interval::make(Some(0), Some(4))?)?;
351        let children: Vec<Vec<Distribution>> = vec![
352            vec![Distribution::new_uniform(original_child_interval.clone())?],
353            vec![Distribution::new_generic(
354                ScalarValue::from(0),
355                ScalarValue::from(0),
356                ScalarValue::Int32(None),
357                original_child_interval.clone(),
358            )?],
359        ];
360
361        for child_view in children {
362            let child_refs: Vec<_> = child_view.iter().collect();
363            let actual = negative_expr.propagate_statistics(&parent, &child_refs)?;
364            let expected = Some(vec![Distribution::new_from_interval(
365                after_propagation.clone(),
366            )?]);
367            assert_eq!(actual, expected);
368        }
369
370        Ok(())
371    }
372
373    #[test]
374    fn test_negation_valid_types() -> Result<()> {
375        let negatable_types = [
376            Int8,
377            DataType::Timestamp(TimeUnit::Second, None),
378            DataType::Interval(IntervalUnit::YearMonth),
379        ];
380        for negatable_type in negatable_types {
381            let schema = Schema::new(vec![Field::new("a", negatable_type, true)]);
382            let _expr = negative(col("a", &schema)?, &schema)?;
383        }
384        Ok(())
385    }
386
387    #[test]
388    fn test_negation_invalid_types() -> Result<()> {
389        let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
390        let expr = negative(col("a", &schema)?, &schema).unwrap_err();
391        matches!(expr, DataFusionError::Plan(_));
392        Ok(())
393    }
394
395    #[test]
396    fn test_fmt_sql() -> Result<()> {
397        let expr = NegativeExpr::new(Arc::new(Column::new("a", 0)));
398        let display_string = expr.to_string();
399        assert_eq!(display_string, "(- a@0)");
400        let sql_string = fmt_sql(&expr).to_string();
401        assert_eq!(sql_string, "(- a)");
402
403        Ok(())
404    }
405}