datafusion_physical_expr/expressions/
not.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Not expression
19
20use std::any::Any;
21use std::fmt;
22use std::hash::Hash;
23use std::sync::Arc;
24
25use crate::PhysicalExpr;
26
27use arrow::datatypes::{DataType, Schema};
28use arrow::record_batch::RecordBatch;
29use datafusion_common::{cast::as_boolean_array, internal_err, Result, ScalarValue};
30use datafusion_expr::interval_arithmetic::Interval;
31use datafusion_expr::statistics::Distribution::{self, Bernoulli};
32use datafusion_expr::ColumnarValue;
33
34/// Not expression
35#[derive(Debug, Eq)]
36pub struct NotExpr {
37    /// Input expression
38    arg: Arc<dyn PhysicalExpr>,
39}
40
41// Manually derive PartialEq and Hash to work around https://github.com/rust-lang/rust/issues/78808
42impl PartialEq for NotExpr {
43    fn eq(&self, other: &Self) -> bool {
44        self.arg.eq(&other.arg)
45    }
46}
47
48impl Hash for NotExpr {
49    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
50        self.arg.hash(state);
51    }
52}
53
54impl NotExpr {
55    /// Create new not expression
56    pub fn new(arg: Arc<dyn PhysicalExpr>) -> Self {
57        Self { arg }
58    }
59
60    /// Get the input expression
61    pub fn arg(&self) -> &Arc<dyn PhysicalExpr> {
62        &self.arg
63    }
64}
65
66impl fmt::Display for NotExpr {
67    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
68        write!(f, "NOT {}", self.arg)
69    }
70}
71
72impl PhysicalExpr for NotExpr {
73    /// Return a reference to Any that can be used for downcasting
74    fn as_any(&self) -> &dyn Any {
75        self
76    }
77
78    fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
79        Ok(DataType::Boolean)
80    }
81
82    fn nullable(&self, input_schema: &Schema) -> Result<bool> {
83        self.arg.nullable(input_schema)
84    }
85
86    fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
87        match self.arg.evaluate(batch)? {
88            ColumnarValue::Array(array) => {
89                let array = as_boolean_array(&array)?;
90                Ok(ColumnarValue::Array(Arc::new(
91                    arrow::compute::kernels::boolean::not(array)?,
92                )))
93            }
94            ColumnarValue::Scalar(scalar) => {
95                if scalar.is_null() {
96                    return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None)));
97                }
98                let bool_value: bool = scalar.try_into()?;
99                Ok(ColumnarValue::Scalar(ScalarValue::from(!bool_value)))
100            }
101        }
102    }
103
104    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
105        vec![&self.arg]
106    }
107
108    fn with_new_children(
109        self: Arc<Self>,
110        children: Vec<Arc<dyn PhysicalExpr>>,
111    ) -> Result<Arc<dyn PhysicalExpr>> {
112        Ok(Arc::new(NotExpr::new(Arc::clone(&children[0]))))
113    }
114
115    fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
116        children[0].not()
117    }
118
119    fn propagate_constraints(
120        &self,
121        interval: &Interval,
122        children: &[&Interval],
123    ) -> Result<Option<Vec<Interval>>> {
124        let complemented_interval = interval.not()?;
125
126        Ok(children[0]
127            .intersect(complemented_interval)?
128            .map(|result| vec![result]))
129    }
130
131    fn evaluate_statistics(&self, children: &[&Distribution]) -> Result<Distribution> {
132        match children[0] {
133            Bernoulli(b) => {
134                let p_value = b.p_value();
135                if p_value.is_null() {
136                    Ok(children[0].clone())
137                } else {
138                    let one = ScalarValue::new_one(&p_value.data_type())?;
139                    Distribution::new_bernoulli(one.sub_checked(p_value)?)
140                }
141            }
142            _ => internal_err!("NotExpr can only operate on Boolean datatypes"),
143        }
144    }
145
146    fn propagate_statistics(
147        &self,
148        parent: &Distribution,
149        children: &[&Distribution],
150    ) -> Result<Option<Vec<Distribution>>> {
151        match (parent, children[0]) {
152            (Bernoulli(parent), Bernoulli(child)) => {
153                let parent_range = parent.range();
154                let result = if parent_range == Interval::CERTAINLY_TRUE {
155                    if child.range() == Interval::CERTAINLY_TRUE {
156                        None
157                    } else {
158                        Some(vec![Distribution::new_bernoulli(ScalarValue::new_zero(
159                            &child.data_type(),
160                        )?)?])
161                    }
162                } else if parent_range == Interval::CERTAINLY_FALSE {
163                    if child.range() == Interval::CERTAINLY_FALSE {
164                        None
165                    } else {
166                        Some(vec![Distribution::new_bernoulli(ScalarValue::new_one(
167                            &child.data_type(),
168                        )?)?])
169                    }
170                } else {
171                    Some(vec![])
172                };
173                Ok(result)
174            }
175            _ => internal_err!("NotExpr can only operate on Boolean datatypes"),
176        }
177    }
178
179    fn fmt_sql(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
180        write!(f, "NOT ")?;
181        self.arg.fmt_sql(f)
182    }
183}
184
185/// Creates a unary expression NOT
186pub fn not(arg: Arc<dyn PhysicalExpr>) -> Result<Arc<dyn PhysicalExpr>> {
187    Ok(Arc::new(NotExpr::new(arg)))
188}
189
190#[cfg(test)]
191mod tests {
192    use std::sync::LazyLock;
193
194    use super::*;
195    use crate::expressions::{col, Column};
196
197    use arrow::{array::BooleanArray, datatypes::*};
198    use datafusion_physical_expr_common::physical_expr::fmt_sql;
199
200    #[test]
201    fn neg_op() -> Result<()> {
202        let schema = schema();
203
204        let expr = not(col("a", &schema)?)?;
205        assert_eq!(expr.data_type(&schema)?, DataType::Boolean);
206        assert!(expr.nullable(&schema)?);
207
208        let input = BooleanArray::from(vec![Some(true), None, Some(false)]);
209        let expected = &BooleanArray::from(vec![Some(false), None, Some(true)]);
210
211        let batch = RecordBatch::try_new(schema, vec![Arc::new(input)])?;
212
213        let result = expr
214            .evaluate(&batch)?
215            .into_array(batch.num_rows())
216            .expect("Failed to convert to array");
217        let result =
218            as_boolean_array(&result).expect("failed to downcast to BooleanArray");
219        assert_eq!(result, expected);
220
221        Ok(())
222    }
223
224    #[test]
225    fn test_evaluate_bounds() -> Result<()> {
226        // Note that `None` for boolean intervals is converted to `Some(false)`
227        // / `Some(true)` by `Interval::make`, so it is not explicitly tested
228        // here
229
230        // if the bounds are all booleans (false, true) so is the negation
231        assert_evaluate_bounds(
232            Interval::make(Some(false), Some(true))?,
233            Interval::make(Some(false), Some(true))?,
234        )?;
235        // (true, false) is not tested because it is not a valid interval (lower
236        // bound is greater than upper bound)
237        assert_evaluate_bounds(
238            Interval::make(Some(true), Some(true))?,
239            Interval::make(Some(false), Some(false))?,
240        )?;
241        assert_evaluate_bounds(
242            Interval::make(Some(false), Some(false))?,
243            Interval::make(Some(true), Some(true))?,
244        )?;
245        Ok(())
246    }
247
248    fn assert_evaluate_bounds(
249        interval: Interval,
250        expected_interval: Interval,
251    ) -> Result<()> {
252        let not_expr = not(col("a", &schema())?)?;
253        assert_eq!(not_expr.evaluate_bounds(&[&interval])?, expected_interval);
254        Ok(())
255    }
256
257    #[test]
258    fn test_evaluate_statistics() -> Result<()> {
259        let _schema = &Schema::new(vec![Field::new("a", DataType::Boolean, false)]);
260        let a = Arc::new(Column::new("a", 0)) as _;
261        let expr = not(a)?;
262
263        // Uniform with non-boolean bounds
264        assert!(expr
265            .evaluate_statistics(&[&Distribution::new_uniform(
266                Interval::make_unbounded(&DataType::Float64)?
267            )?])
268            .is_err());
269
270        // Exponential
271        assert!(expr
272            .evaluate_statistics(&[&Distribution::new_exponential(
273                ScalarValue::from(1.0),
274                ScalarValue::from(1.0),
275                true
276            )?])
277            .is_err());
278
279        // Gaussian
280        assert!(expr
281            .evaluate_statistics(&[&Distribution::new_gaussian(
282                ScalarValue::from(1.0),
283                ScalarValue::from(1.0),
284            )?])
285            .is_err());
286
287        // Bernoulli
288        assert_eq!(
289            expr.evaluate_statistics(&[&Distribution::new_bernoulli(
290                ScalarValue::from(0.0),
291            )?])?,
292            Distribution::new_bernoulli(ScalarValue::from(1.))?
293        );
294
295        assert_eq!(
296            expr.evaluate_statistics(&[&Distribution::new_bernoulli(
297                ScalarValue::from(1.0),
298            )?])?,
299            Distribution::new_bernoulli(ScalarValue::from(0.))?
300        );
301
302        assert_eq!(
303            expr.evaluate_statistics(&[&Distribution::new_bernoulli(
304                ScalarValue::from(0.25),
305            )?])?,
306            Distribution::new_bernoulli(ScalarValue::from(0.75))?
307        );
308
309        assert!(expr
310            .evaluate_statistics(&[&Distribution::new_generic(
311                ScalarValue::Null,
312                ScalarValue::Null,
313                ScalarValue::Null,
314                Interval::make_unbounded(&DataType::UInt8)?
315            )?])
316            .is_err());
317
318        // Unknown with non-boolean interval as range
319        assert!(expr
320            .evaluate_statistics(&[&Distribution::new_generic(
321                ScalarValue::Null,
322                ScalarValue::Null,
323                ScalarValue::Null,
324                Interval::make_unbounded(&DataType::Float64)?
325            )?])
326            .is_err());
327
328        Ok(())
329    }
330
331    #[test]
332    fn test_fmt_sql() -> Result<()> {
333        let schema = schema();
334
335        let expr = not(col("a", &schema)?)?;
336
337        let display_string = expr.to_string();
338        assert_eq!(display_string, "NOT a@0");
339
340        let sql_string = fmt_sql(expr.as_ref()).to_string();
341        assert_eq!(sql_string, "NOT a");
342
343        Ok(())
344    }
345
346    fn schema() -> SchemaRef {
347        static SCHEMA: LazyLock<SchemaRef> = LazyLock::new(|| {
348            Arc::new(Schema::new(vec![Field::new("a", DataType::Boolean, true)]))
349        });
350        Arc::clone(&SCHEMA)
351    }
352}