datafusion_physical_expr/expressions/
not.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Not expression
19
20use std::any::Any;
21use std::fmt;
22use std::hash::Hash;
23use std::sync::Arc;
24
25use crate::PhysicalExpr;
26
27use arrow::datatypes::{DataType, FieldRef, Schema};
28use arrow::record_batch::RecordBatch;
29use datafusion_common::{cast::as_boolean_array, internal_err, Result, ScalarValue};
30use datafusion_expr::interval_arithmetic::Interval;
31use datafusion_expr::statistics::Distribution::{self, Bernoulli};
32use datafusion_expr::ColumnarValue;
33
34/// Not expression
35#[derive(Debug, Eq)]
36pub struct NotExpr {
37    /// Input expression
38    arg: Arc<dyn PhysicalExpr>,
39}
40
41// Manually derive PartialEq and Hash to work around https://github.com/rust-lang/rust/issues/78808
42impl PartialEq for NotExpr {
43    fn eq(&self, other: &Self) -> bool {
44        self.arg.eq(&other.arg)
45    }
46}
47
48impl Hash for NotExpr {
49    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
50        self.arg.hash(state);
51    }
52}
53
54impl NotExpr {
55    /// Create new not expression
56    pub fn new(arg: Arc<dyn PhysicalExpr>) -> Self {
57        Self { arg }
58    }
59
60    /// Get the input expression
61    pub fn arg(&self) -> &Arc<dyn PhysicalExpr> {
62        &self.arg
63    }
64}
65
66impl fmt::Display for NotExpr {
67    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
68        write!(f, "NOT {}", self.arg)
69    }
70}
71
72impl PhysicalExpr for NotExpr {
73    /// Return a reference to Any that can be used for downcasting
74    fn as_any(&self) -> &dyn Any {
75        self
76    }
77
78    fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
79        Ok(DataType::Boolean)
80    }
81
82    fn nullable(&self, input_schema: &Schema) -> Result<bool> {
83        self.arg.nullable(input_schema)
84    }
85
86    fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
87        match self.arg.evaluate(batch)? {
88            ColumnarValue::Array(array) => {
89                let array = as_boolean_array(&array)?;
90                Ok(ColumnarValue::Array(Arc::new(
91                    arrow::compute::kernels::boolean::not(array)?,
92                )))
93            }
94            ColumnarValue::Scalar(scalar) => {
95                if scalar.is_null() {
96                    return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None)));
97                }
98                let bool_value: bool = scalar.try_into()?;
99                Ok(ColumnarValue::Scalar(ScalarValue::from(!bool_value)))
100            }
101        }
102    }
103
104    fn return_field(&self, input_schema: &Schema) -> Result<FieldRef> {
105        self.arg.return_field(input_schema)
106    }
107
108    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
109        vec![&self.arg]
110    }
111
112    fn with_new_children(
113        self: Arc<Self>,
114        children: Vec<Arc<dyn PhysicalExpr>>,
115    ) -> Result<Arc<dyn PhysicalExpr>> {
116        Ok(Arc::new(NotExpr::new(Arc::clone(&children[0]))))
117    }
118
119    fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
120        children[0].not()
121    }
122
123    fn propagate_constraints(
124        &self,
125        interval: &Interval,
126        children: &[&Interval],
127    ) -> Result<Option<Vec<Interval>>> {
128        let complemented_interval = interval.not()?;
129
130        Ok(children[0]
131            .intersect(complemented_interval)?
132            .map(|result| vec![result]))
133    }
134
135    fn evaluate_statistics(&self, children: &[&Distribution]) -> Result<Distribution> {
136        match children[0] {
137            Bernoulli(b) => {
138                let p_value = b.p_value();
139                if p_value.is_null() {
140                    Ok(children[0].clone())
141                } else {
142                    let one = ScalarValue::new_one(&p_value.data_type())?;
143                    Distribution::new_bernoulli(one.sub_checked(p_value)?)
144                }
145            }
146            _ => internal_err!("NotExpr can only operate on Boolean datatypes"),
147        }
148    }
149
150    fn propagate_statistics(
151        &self,
152        parent: &Distribution,
153        children: &[&Distribution],
154    ) -> Result<Option<Vec<Distribution>>> {
155        match (parent, children[0]) {
156            (Bernoulli(parent), Bernoulli(child)) => {
157                let parent_range = parent.range();
158                let result = if parent_range == Interval::CERTAINLY_TRUE {
159                    if child.range() == Interval::CERTAINLY_TRUE {
160                        None
161                    } else {
162                        Some(vec![Distribution::new_bernoulli(ScalarValue::new_zero(
163                            &child.data_type(),
164                        )?)?])
165                    }
166                } else if parent_range == Interval::CERTAINLY_FALSE {
167                    if child.range() == Interval::CERTAINLY_FALSE {
168                        None
169                    } else {
170                        Some(vec![Distribution::new_bernoulli(ScalarValue::new_one(
171                            &child.data_type(),
172                        )?)?])
173                    }
174                } else {
175                    Some(vec![])
176                };
177                Ok(result)
178            }
179            _ => internal_err!("NotExpr can only operate on Boolean datatypes"),
180        }
181    }
182
183    fn fmt_sql(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
184        write!(f, "NOT ")?;
185        self.arg.fmt_sql(f)
186    }
187}
188
189/// Creates a unary expression NOT
190pub fn not(arg: Arc<dyn PhysicalExpr>) -> Result<Arc<dyn PhysicalExpr>> {
191    Ok(Arc::new(NotExpr::new(arg)))
192}
193
194#[cfg(test)]
195mod tests {
196    use std::sync::LazyLock;
197
198    use super::*;
199    use crate::expressions::{col, Column};
200
201    use arrow::{array::BooleanArray, datatypes::*};
202    use datafusion_physical_expr_common::physical_expr::fmt_sql;
203
204    #[test]
205    fn neg_op() -> Result<()> {
206        let schema = schema();
207
208        let expr = not(col("a", &schema)?)?;
209        assert_eq!(expr.data_type(&schema)?, DataType::Boolean);
210        assert!(expr.nullable(&schema)?);
211
212        let input = BooleanArray::from(vec![Some(true), None, Some(false)]);
213        let expected = &BooleanArray::from(vec![Some(false), None, Some(true)]);
214
215        let batch = RecordBatch::try_new(schema, vec![Arc::new(input)])?;
216
217        let result = expr
218            .evaluate(&batch)?
219            .into_array(batch.num_rows())
220            .expect("Failed to convert to array");
221        let result =
222            as_boolean_array(&result).expect("failed to downcast to BooleanArray");
223        assert_eq!(result, expected);
224
225        Ok(())
226    }
227
228    #[test]
229    fn test_evaluate_bounds() -> Result<()> {
230        // Note that `None` for boolean intervals is converted to `Some(false)`
231        // / `Some(true)` by `Interval::make`, so it is not explicitly tested
232        // here
233
234        // if the bounds are all booleans (false, true) so is the negation
235        assert_evaluate_bounds(
236            Interval::make(Some(false), Some(true))?,
237            Interval::make(Some(false), Some(true))?,
238        )?;
239        // (true, false) is not tested because it is not a valid interval (lower
240        // bound is greater than upper bound)
241        assert_evaluate_bounds(
242            Interval::make(Some(true), Some(true))?,
243            Interval::make(Some(false), Some(false))?,
244        )?;
245        assert_evaluate_bounds(
246            Interval::make(Some(false), Some(false))?,
247            Interval::make(Some(true), Some(true))?,
248        )?;
249        Ok(())
250    }
251
252    fn assert_evaluate_bounds(
253        interval: Interval,
254        expected_interval: Interval,
255    ) -> Result<()> {
256        let not_expr = not(col("a", &schema())?)?;
257        assert_eq!(not_expr.evaluate_bounds(&[&interval])?, expected_interval);
258        Ok(())
259    }
260
261    #[test]
262    fn test_evaluate_statistics() -> Result<()> {
263        let _schema = &Schema::new(vec![Field::new("a", DataType::Boolean, false)]);
264        let a = Arc::new(Column::new("a", 0)) as _;
265        let expr = not(a)?;
266
267        // Uniform with non-boolean bounds
268        assert!(expr
269            .evaluate_statistics(&[&Distribution::new_uniform(
270                Interval::make_unbounded(&DataType::Float64)?
271            )?])
272            .is_err());
273
274        // Exponential
275        assert!(expr
276            .evaluate_statistics(&[&Distribution::new_exponential(
277                ScalarValue::from(1.0),
278                ScalarValue::from(1.0),
279                true
280            )?])
281            .is_err());
282
283        // Gaussian
284        assert!(expr
285            .evaluate_statistics(&[&Distribution::new_gaussian(
286                ScalarValue::from(1.0),
287                ScalarValue::from(1.0),
288            )?])
289            .is_err());
290
291        // Bernoulli
292        assert_eq!(
293            expr.evaluate_statistics(&[&Distribution::new_bernoulli(
294                ScalarValue::from(0.0),
295            )?])?,
296            Distribution::new_bernoulli(ScalarValue::from(1.))?
297        );
298
299        assert_eq!(
300            expr.evaluate_statistics(&[&Distribution::new_bernoulli(
301                ScalarValue::from(1.0),
302            )?])?,
303            Distribution::new_bernoulli(ScalarValue::from(0.))?
304        );
305
306        assert_eq!(
307            expr.evaluate_statistics(&[&Distribution::new_bernoulli(
308                ScalarValue::from(0.25),
309            )?])?,
310            Distribution::new_bernoulli(ScalarValue::from(0.75))?
311        );
312
313        assert!(expr
314            .evaluate_statistics(&[&Distribution::new_generic(
315                ScalarValue::Null,
316                ScalarValue::Null,
317                ScalarValue::Null,
318                Interval::make_unbounded(&DataType::UInt8)?
319            )?])
320            .is_err());
321
322        // Unknown with non-boolean interval as range
323        assert!(expr
324            .evaluate_statistics(&[&Distribution::new_generic(
325                ScalarValue::Null,
326                ScalarValue::Null,
327                ScalarValue::Null,
328                Interval::make_unbounded(&DataType::Float64)?
329            )?])
330            .is_err());
331
332        Ok(())
333    }
334
335    #[test]
336    fn test_fmt_sql() -> Result<()> {
337        let schema = schema();
338
339        let expr = not(col("a", &schema)?)?;
340
341        let display_string = expr.to_string();
342        assert_eq!(display_string, "NOT a@0");
343
344        let sql_string = fmt_sql(expr.as_ref()).to_string();
345        assert_eq!(sql_string, "NOT a");
346
347        Ok(())
348    }
349
350    fn schema() -> SchemaRef {
351        static SCHEMA: LazyLock<SchemaRef> = LazyLock::new(|| {
352            Arc::new(Schema::new(vec![Field::new("a", DataType::Boolean, true)]))
353        });
354        Arc::clone(&SCHEMA)
355    }
356}