Skip to main content

datafusion_physical_expr/expressions/
negative.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Negation (-) expression
19
20use std::hash::Hash;
21use std::sync::Arc;
22
23use crate::PhysicalExpr;
24
25use arrow::datatypes::FieldRef;
26use arrow::{
27    compute::kernels::numeric::neg_wrapping,
28    datatypes::{DataType, Schema},
29    record_batch::RecordBatch,
30};
31use datafusion_common::{Result, internal_err, plan_err};
32use datafusion_expr::interval_arithmetic::Interval;
33use datafusion_expr::sort_properties::ExprProperties;
34#[expect(deprecated)]
35use datafusion_expr::statistics::Distribution::{
36    self, Bernoulli, Exponential, Gaussian, Generic, Uniform,
37};
38use datafusion_expr::{
39    ColumnarValue,
40    type_coercion::{is_interval, is_signed_numeric, is_timestamp},
41};
42
43/// Negative expression
44#[derive(Debug, Eq)]
45pub struct NegativeExpr {
46    /// Input expression
47    arg: Arc<dyn PhysicalExpr>,
48}
49
50// Manually derive PartialEq and Hash to work around https://github.com/rust-lang/rust/issues/78808
51impl PartialEq for NegativeExpr {
52    fn eq(&self, other: &Self) -> bool {
53        self.arg.eq(&other.arg)
54    }
55}
56
57impl Hash for NegativeExpr {
58    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
59        self.arg.hash(state);
60    }
61}
62
63impl NegativeExpr {
64    /// Create new not expression
65    pub fn new(arg: Arc<dyn PhysicalExpr>) -> Self {
66        Self { arg }
67    }
68
69    /// Get the input expression
70    pub fn arg(&self) -> &Arc<dyn PhysicalExpr> {
71        &self.arg
72    }
73}
74
75impl std::fmt::Display for NegativeExpr {
76    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
77        write!(f, "(- {})", self.arg)
78    }
79}
80
81impl PhysicalExpr for NegativeExpr {
82    fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
83        self.arg.data_type(input_schema)
84    }
85
86    fn nullable(&self, input_schema: &Schema) -> Result<bool> {
87        self.arg.nullable(input_schema)
88    }
89
90    fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
91        match self.arg.evaluate(batch)? {
92            ColumnarValue::Array(array) => {
93                let result = neg_wrapping(array.as_ref())?;
94                Ok(ColumnarValue::Array(result))
95            }
96            ColumnarValue::Scalar(scalar) => {
97                Ok(ColumnarValue::Scalar(scalar.arithmetic_negate()?))
98            }
99        }
100    }
101
102    fn return_field(&self, input_schema: &Schema) -> Result<FieldRef> {
103        self.arg.return_field(input_schema)
104    }
105
106    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
107        vec![&self.arg]
108    }
109
110    fn with_new_children(
111        self: Arc<Self>,
112        children: Vec<Arc<dyn PhysicalExpr>>,
113    ) -> Result<Arc<dyn PhysicalExpr>> {
114        Ok(Arc::new(NegativeExpr::new(Arc::clone(&children[0]))))
115    }
116
117    /// Given the child interval of a NegativeExpr, it calculates the NegativeExpr's interval.
118    /// It replaces the upper and lower bounds after multiplying them with -1.
119    /// Ex: `(a, b]` => `[-b, -a)`
120    fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
121        children[0].arithmetic_negate()
122    }
123
124    /// Returns a new [`Interval`] of a NegativeExpr  that has the existing `interval` given that
125    /// given the input interval is known to be `children`.
126    fn propagate_constraints(
127        &self,
128        interval: &Interval,
129        children: &[&Interval],
130    ) -> Result<Option<Vec<Interval>>> {
131        let negated_interval = interval.arithmetic_negate()?;
132
133        Ok(children[0]
134            .intersect(negated_interval)?
135            .map(|result| vec![result]))
136    }
137
138    #[expect(deprecated)]
139    fn evaluate_statistics(&self, children: &[&Distribution]) -> Result<Distribution> {
140        match children[0] {
141            Uniform(u) => Distribution::new_uniform(u.range().arithmetic_negate()?),
142            Exponential(e) => Distribution::new_exponential(
143                e.rate().clone(),
144                e.offset().arithmetic_negate()?,
145                !e.positive_tail(),
146            ),
147            Gaussian(g) => Distribution::new_gaussian(
148                g.mean().arithmetic_negate()?,
149                g.variance().clone(),
150            ),
151            Bernoulli(_) => {
152                internal_err!("NegativeExpr cannot operate on Boolean datatypes")
153            }
154            Generic(u) => Distribution::new_generic(
155                u.mean().arithmetic_negate()?,
156                u.median().arithmetic_negate()?,
157                u.variance().clone(),
158                u.range().arithmetic_negate()?,
159            ),
160        }
161    }
162
163    /// The ordering of a [`NegativeExpr`] is simply the reverse of its child.
164    fn get_properties(&self, children: &[ExprProperties]) -> Result<ExprProperties> {
165        Ok(ExprProperties {
166            sort_properties: -children[0].sort_properties,
167            range: children[0].range.clone().arithmetic_negate()?,
168            preserves_lex_ordering: false,
169        })
170    }
171
172    fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
173        write!(f, "(- ")?;
174        self.arg.fmt_sql(f)?;
175        write!(f, ")")
176    }
177}
178
179/// Creates a unary expression NEGATIVE
180///
181/// # Errors
182///
183/// This function errors when the argument's type is not signed numeric
184pub fn negative(
185    arg: Arc<dyn PhysicalExpr>,
186    input_schema: &Schema,
187) -> Result<Arc<dyn PhysicalExpr>> {
188    let data_type = arg.data_type(input_schema)?;
189    if data_type.is_null() {
190        Ok(arg)
191    } else if !is_signed_numeric(&data_type)
192        && !is_interval(&data_type)
193        && !is_timestamp(&data_type)
194    {
195        plan_err!("Negation only supports numeric, interval and timestamp types")
196    } else {
197        Ok(Arc::new(NegativeExpr::new(arg)))
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204    use crate::expressions::{Column, col};
205
206    use arrow::array::*;
207    use arrow::datatypes::DataType::{Float32, Float64, Int8, Int16, Int32, Int64};
208    use arrow::datatypes::*;
209    use datafusion_common::cast::as_primitive_array;
210    use datafusion_common::{DataFusionError, ScalarValue};
211
212    use datafusion_physical_expr_common::physical_expr::fmt_sql;
213
214    macro_rules! test_array_negative_op {
215        ($DATA_TY:tt, $ARRAY_TY:ty, $($VALUE:expr),*   ) => {
216            let schema = Schema::new(vec![Field::new("a", DataType::$DATA_TY, true)]);
217            let expr = negative(col("a", &schema)?, &schema)?;
218            assert_eq!(expr.data_type(&schema)?, DataType::$DATA_TY);
219            assert!(expr.nullable(&schema)?);
220            let mut arr = Vec::new();
221            let mut arr_expected = Vec::new();
222            $(
223                arr.push(Some($VALUE));
224                arr_expected.push(Some(-$VALUE));
225            )+
226            arr.push(None);
227            arr_expected.push(None);
228            let input = <$ARRAY_TY>::from(arr);
229            let expected = &<$ARRAY_TY>::from(arr_expected);
230            let batch =
231                RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(input)])?;
232            let result = expr.evaluate(&batch)?.into_array(batch.num_rows()).expect("Failed to convert to array");
233            let result =
234                as_primitive_array(&result).expect(format!("failed to downcast to {:?}Array", $DATA_TY).as_str());
235            assert_eq!(result, expected);
236        };
237    }
238
239    #[test]
240    fn array_negative_op() -> Result<()> {
241        test_array_negative_op!(Int8, Int8Array, 2i8, 1i8);
242        test_array_negative_op!(Int16, Int16Array, 234i16, 123i16);
243        test_array_negative_op!(Int32, Int32Array, 2345i32, 1234i32);
244        test_array_negative_op!(Int64, Int64Array, 23456i64, 12345i64);
245        test_array_negative_op!(Float32, Float32Array, 2345.0f32, 1234.0f32);
246        test_array_negative_op!(Float64, Float64Array, 23456.0f64, 12345.0f64);
247        Ok(())
248    }
249
250    #[test]
251    fn test_evaluate_bounds() -> Result<()> {
252        let negative_expr = NegativeExpr::new(Arc::new(Column::new("a", 0)));
253        let child_interval = Interval::make(Some(-2), Some(1))?;
254        let negative_expr_interval = Interval::make(Some(-1), Some(2))?;
255        assert_eq!(
256            negative_expr.evaluate_bounds(&[&child_interval])?,
257            negative_expr_interval
258        );
259        Ok(())
260    }
261
262    #[test]
263    #[expect(deprecated)]
264    fn test_evaluate_statistics() -> Result<()> {
265        let negative_expr = NegativeExpr::new(Arc::new(Column::new("a", 0)));
266
267        // Uniform
268        assert_eq!(
269            negative_expr.evaluate_statistics(&[&Distribution::new_uniform(
270                Interval::make(Some(-2.), Some(3.))?
271            )?])?,
272            Distribution::new_uniform(Interval::make(Some(-3.), Some(2.))?)?
273        );
274
275        // Bernoulli
276        assert!(
277            negative_expr
278                .evaluate_statistics(&[&Distribution::new_bernoulli(ScalarValue::from(
279                    0.75
280                ))?])
281                .is_err()
282        );
283
284        // Exponential
285        assert_eq!(
286            negative_expr.evaluate_statistics(&[&Distribution::new_exponential(
287                ScalarValue::from(1.),
288                ScalarValue::from(1.),
289                true
290            )?])?,
291            Distribution::new_exponential(
292                ScalarValue::from(1.),
293                ScalarValue::from(-1.),
294                false
295            )?
296        );
297
298        // Gaussian
299        assert_eq!(
300            negative_expr.evaluate_statistics(&[&Distribution::new_gaussian(
301                ScalarValue::from(15),
302                ScalarValue::from(225),
303            )?])?,
304            Distribution::new_gaussian(ScalarValue::from(-15), ScalarValue::from(225),)?
305        );
306
307        // Unknown
308        assert_eq!(
309            negative_expr.evaluate_statistics(&[&Distribution::new_generic(
310                ScalarValue::from(15),
311                ScalarValue::from(15),
312                ScalarValue::from(10),
313                Interval::make(Some(10), Some(20))?
314            )?])?,
315            Distribution::new_generic(
316                ScalarValue::from(-15),
317                ScalarValue::from(-15),
318                ScalarValue::from(10),
319                Interval::make(Some(-20), Some(-10))?
320            )?
321        );
322
323        Ok(())
324    }
325
326    #[test]
327    fn test_propagate_constraints() -> Result<()> {
328        let negative_expr = NegativeExpr::new(Arc::new(Column::new("a", 0)));
329        let original_child_interval = Interval::make(Some(-2), Some(3))?;
330        let negative_expr_interval = Interval::make(Some(0), Some(4))?;
331        let after_propagation = Some(vec![Interval::make(Some(-2), Some(0))?]);
332        assert_eq!(
333            negative_expr.propagate_constraints(
334                &negative_expr_interval,
335                &[&original_child_interval]
336            )?,
337            after_propagation
338        );
339        Ok(())
340    }
341
342    #[test]
343    #[expect(deprecated)]
344    fn test_propagate_statistics_range_holders() -> Result<()> {
345        let negative_expr = NegativeExpr::new(Arc::new(Column::new("a", 0)));
346        let original_child_interval = Interval::make(Some(-2), Some(3))?;
347        let after_propagation = Interval::make(Some(-2), Some(0))?;
348
349        let parent = Distribution::new_uniform(Interval::make(Some(0), Some(4))?)?;
350        let children: Vec<Vec<Distribution>> = vec![
351            vec![Distribution::new_uniform(original_child_interval.clone())?],
352            vec![Distribution::new_generic(
353                ScalarValue::from(0),
354                ScalarValue::from(0),
355                ScalarValue::Int32(None),
356                original_child_interval.clone(),
357            )?],
358        ];
359
360        for child_view in children {
361            let child_refs: Vec<_> = child_view.iter().collect();
362            let actual = negative_expr.propagate_statistics(&parent, &child_refs)?;
363            let expected = Some(vec![Distribution::new_from_interval(
364                after_propagation.clone(),
365            )?]);
366            assert_eq!(actual, expected);
367        }
368
369        Ok(())
370    }
371
372    #[test]
373    fn test_negation_valid_types() -> Result<()> {
374        let negatable_types = [
375            Int8,
376            DataType::Timestamp(TimeUnit::Second, None),
377            DataType::Interval(IntervalUnit::YearMonth),
378        ];
379        for negatable_type in negatable_types {
380            let schema = Schema::new(vec![Field::new("a", negatable_type, true)]);
381            let _expr = negative(col("a", &schema)?, &schema)?;
382        }
383        Ok(())
384    }
385
386    #[test]
387    fn test_negation_invalid_types() -> Result<()> {
388        let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
389        let expr = negative(col("a", &schema)?, &schema).unwrap_err();
390        matches!(expr, DataFusionError::Plan(_));
391        Ok(())
392    }
393
394    #[test]
395    fn test_fmt_sql() -> Result<()> {
396        let expr = NegativeExpr::new(Arc::new(Column::new("a", 0)));
397        let display_string = expr.to_string();
398        assert_eq!(display_string, "(- a@0)");
399        let sql_string = fmt_sql(&expr).to_string();
400        assert_eq!(sql_string, "(- a)");
401
402        Ok(())
403    }
404}