Skip to main content

datafusion_physical_expr/expressions/
not.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Not expression
19
20use std::fmt;
21use std::hash::Hash;
22use std::sync::Arc;
23
24use crate::PhysicalExpr;
25
26use arrow::datatypes::{DataType, FieldRef, Schema};
27use arrow::record_batch::RecordBatch;
28use datafusion_common::{Result, ScalarValue, cast::as_boolean_array, internal_err};
29use datafusion_expr::ColumnarValue;
30use datafusion_expr::interval_arithmetic::Interval;
31#[expect(deprecated)]
32use datafusion_expr::statistics::Distribution::{self, Bernoulli};
33
34/// Not expression
35#[derive(Debug, Eq)]
36pub struct NotExpr {
37    /// Input expression
38    arg: Arc<dyn PhysicalExpr>,
39}
40
41// Manually derive PartialEq and Hash to work around https://github.com/rust-lang/rust/issues/78808
42impl PartialEq for NotExpr {
43    fn eq(&self, other: &Self) -> bool {
44        self.arg.eq(&other.arg)
45    }
46}
47
48impl Hash for NotExpr {
49    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
50        self.arg.hash(state);
51    }
52}
53
54impl NotExpr {
55    /// Create new not expression
56    pub fn new(arg: Arc<dyn PhysicalExpr>) -> Self {
57        Self { arg }
58    }
59
60    /// Get the input expression
61    pub fn arg(&self) -> &Arc<dyn PhysicalExpr> {
62        &self.arg
63    }
64}
65
66impl fmt::Display for NotExpr {
67    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
68        write!(f, "NOT {}", self.arg)
69    }
70}
71
72impl PhysicalExpr for NotExpr {
73    fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
74        Ok(DataType::Boolean)
75    }
76
77    fn nullable(&self, input_schema: &Schema) -> Result<bool> {
78        self.arg.nullable(input_schema)
79    }
80
81    fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
82        match self.arg.evaluate(batch)? {
83            ColumnarValue::Array(array) => {
84                let array = as_boolean_array(&array)?;
85                Ok(ColumnarValue::Array(Arc::new(
86                    arrow::compute::kernels::boolean::not(array)?,
87                )))
88            }
89            ColumnarValue::Scalar(scalar) => {
90                if scalar.is_null() {
91                    return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None)));
92                }
93                let bool_value: bool = scalar.try_into()?;
94                Ok(ColumnarValue::Scalar(ScalarValue::from(!bool_value)))
95            }
96        }
97    }
98
99    fn return_field(&self, input_schema: &Schema) -> Result<FieldRef> {
100        self.arg.return_field(input_schema)
101    }
102
103    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
104        vec![&self.arg]
105    }
106
107    fn with_new_children(
108        self: Arc<Self>,
109        children: Vec<Arc<dyn PhysicalExpr>>,
110    ) -> Result<Arc<dyn PhysicalExpr>> {
111        Ok(Arc::new(NotExpr::new(Arc::clone(&children[0]))))
112    }
113
114    fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
115        children[0].not()
116    }
117
118    fn propagate_constraints(
119        &self,
120        interval: &Interval,
121        children: &[&Interval],
122    ) -> Result<Option<Vec<Interval>>> {
123        let complemented_interval = interval.not()?;
124
125        Ok(children[0]
126            .intersect(complemented_interval)?
127            .map(|result| vec![result]))
128    }
129
130    #[expect(deprecated)]
131    fn evaluate_statistics(&self, children: &[&Distribution]) -> Result<Distribution> {
132        match children[0] {
133            Bernoulli(b) => {
134                let p_value = b.p_value();
135                if p_value.is_null() {
136                    Ok(children[0].clone())
137                } else {
138                    let one = ScalarValue::new_one(&p_value.data_type())?;
139                    Distribution::new_bernoulli(one.sub_checked(p_value)?)
140                }
141            }
142            _ => internal_err!("NotExpr can only operate on Boolean datatypes"),
143        }
144    }
145
146    #[expect(deprecated)]
147    fn propagate_statistics(
148        &self,
149        parent: &Distribution,
150        children: &[&Distribution],
151    ) -> Result<Option<Vec<Distribution>>> {
152        match (parent, children[0]) {
153            (Bernoulli(parent), Bernoulli(child)) => {
154                let parent_range = parent.range();
155                let result = if parent_range == Interval::TRUE {
156                    if child.range() == Interval::TRUE {
157                        None
158                    } else {
159                        Some(vec![Distribution::new_bernoulli(ScalarValue::new_zero(
160                            &child.data_type(),
161                        )?)?])
162                    }
163                } else if parent_range == Interval::FALSE {
164                    if child.range() == Interval::FALSE {
165                        None
166                    } else {
167                        Some(vec![Distribution::new_bernoulli(ScalarValue::new_one(
168                            &child.data_type(),
169                        )?)?])
170                    }
171                } else {
172                    Some(vec![])
173                };
174                Ok(result)
175            }
176            _ => internal_err!("NotExpr can only operate on Boolean datatypes"),
177        }
178    }
179
180    fn fmt_sql(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
181        write!(f, "NOT ")?;
182        self.arg.fmt_sql(f)
183    }
184}
185
186/// Creates a unary expression NOT
187pub fn not(arg: Arc<dyn PhysicalExpr>) -> Result<Arc<dyn PhysicalExpr>> {
188    Ok(Arc::new(NotExpr::new(arg)))
189}
190
191#[cfg(test)]
192mod tests {
193    use std::sync::LazyLock;
194
195    use super::*;
196    use crate::expressions::{Column, col};
197
198    use arrow::{array::BooleanArray, datatypes::*};
199    use datafusion_physical_expr_common::physical_expr::fmt_sql;
200
201    #[test]
202    fn neg_op() -> Result<()> {
203        let schema = schema();
204
205        let expr = not(col("a", &schema)?)?;
206        assert_eq!(expr.data_type(&schema)?, DataType::Boolean);
207        assert!(expr.nullable(&schema)?);
208
209        let input = BooleanArray::from(vec![Some(true), None, Some(false)]);
210        let expected = &BooleanArray::from(vec![Some(false), None, Some(true)]);
211
212        let batch = RecordBatch::try_new(schema, vec![Arc::new(input)])?;
213
214        let result = expr
215            .evaluate(&batch)?
216            .into_array(batch.num_rows())
217            .expect("Failed to convert to array");
218        let result =
219            as_boolean_array(&result).expect("failed to downcast to BooleanArray");
220        assert_eq!(result, expected);
221
222        Ok(())
223    }
224
225    #[test]
226    fn test_evaluate_bounds() -> Result<()> {
227        // Note that `None` for boolean intervals is converted to `Some(false)`
228        // / `Some(true)` by `Interval::make`, so it is not explicitly tested
229        // here
230
231        // if the bounds are all booleans (false, true) so is the negation
232        assert_evaluate_bounds(
233            Interval::make(Some(false), Some(true))?,
234            Interval::make(Some(false), Some(true))?,
235        )?;
236        // (true, false) is not tested because it is not a valid interval (lower
237        // bound is greater than upper bound)
238        assert_evaluate_bounds(
239            Interval::make(Some(true), Some(true))?,
240            Interval::make(Some(false), Some(false))?,
241        )?;
242        assert_evaluate_bounds(
243            Interval::make(Some(false), Some(false))?,
244            Interval::make(Some(true), Some(true))?,
245        )?;
246        Ok(())
247    }
248
249    fn assert_evaluate_bounds(
250        interval: Interval,
251        expected_interval: Interval,
252    ) -> Result<()> {
253        let not_expr = not(col("a", &schema())?)?;
254        assert_eq!(not_expr.evaluate_bounds(&[&interval])?, expected_interval);
255        Ok(())
256    }
257
258    #[test]
259    #[expect(deprecated)]
260    fn test_evaluate_statistics() -> Result<()> {
261        let _schema = &Schema::new(vec![Field::new("a", DataType::Boolean, false)]);
262        let a = Arc::new(Column::new("a", 0)) as _;
263        let expr = not(a)?;
264
265        // Uniform with non-boolean bounds
266        assert!(
267            expr.evaluate_statistics(&[&Distribution::new_uniform(
268                Interval::make_unbounded(&DataType::Float64)?
269            )?])
270            .is_err()
271        );
272
273        // Exponential
274        assert!(
275            expr.evaluate_statistics(&[&Distribution::new_exponential(
276                ScalarValue::from(1.0),
277                ScalarValue::from(1.0),
278                true
279            )?])
280            .is_err()
281        );
282
283        // Gaussian
284        assert!(
285            expr.evaluate_statistics(&[&Distribution::new_gaussian(
286                ScalarValue::from(1.0),
287                ScalarValue::from(1.0),
288            )?])
289            .is_err()
290        );
291
292        // Bernoulli
293        assert_eq!(
294            expr.evaluate_statistics(&[&Distribution::new_bernoulli(
295                ScalarValue::from(0.0),
296            )?])?,
297            Distribution::new_bernoulli(ScalarValue::from(1.))?
298        );
299
300        assert_eq!(
301            expr.evaluate_statistics(&[&Distribution::new_bernoulli(
302                ScalarValue::from(1.0),
303            )?])?,
304            Distribution::new_bernoulli(ScalarValue::from(0.))?
305        );
306
307        assert_eq!(
308            expr.evaluate_statistics(&[&Distribution::new_bernoulli(
309                ScalarValue::from(0.25),
310            )?])?,
311            Distribution::new_bernoulli(ScalarValue::from(0.75))?
312        );
313
314        assert!(
315            expr.evaluate_statistics(&[&Distribution::new_generic(
316                ScalarValue::Null,
317                ScalarValue::Null,
318                ScalarValue::Null,
319                Interval::make_unbounded(&DataType::UInt8)?
320            )?])
321            .is_err()
322        );
323
324        // Unknown with non-boolean interval as range
325        assert!(
326            expr.evaluate_statistics(&[&Distribution::new_generic(
327                ScalarValue::Null,
328                ScalarValue::Null,
329                ScalarValue::Null,
330                Interval::make_unbounded(&DataType::Float64)?
331            )?])
332            .is_err()
333        );
334
335        Ok(())
336    }
337
338    #[test]
339    fn test_fmt_sql() -> Result<()> {
340        let schema = schema();
341
342        let expr = not(col("a", &schema)?)?;
343
344        let display_string = expr.to_string();
345        assert_eq!(display_string, "NOT a@0");
346
347        let sql_string = fmt_sql(expr.as_ref()).to_string();
348        assert_eq!(sql_string, "NOT a");
349
350        Ok(())
351    }
352
353    fn schema() -> SchemaRef {
354        static SCHEMA: LazyLock<SchemaRef> = LazyLock::new(|| {
355            Arc::new(Schema::new(vec![Field::new("a", DataType::Boolean, true)]))
356        });
357        Arc::clone(&SCHEMA)
358    }
359}