datafusion_physical_expr/expressions/
negative.rs1use std::any::Any;
21use std::hash::Hash;
22use std::sync::Arc;
23
24use crate::PhysicalExpr;
25
26use arrow::datatypes::FieldRef;
27use arrow::{
28 compute::kernels::numeric::neg_wrapping,
29 datatypes::{DataType, Schema},
30 record_batch::RecordBatch,
31};
32use datafusion_common::{Result, internal_err, plan_err};
33use datafusion_expr::interval_arithmetic::Interval;
34use datafusion_expr::sort_properties::ExprProperties;
35use datafusion_expr::statistics::Distribution::{
36 self, Bernoulli, Exponential, Gaussian, Generic, Uniform,
37};
38use datafusion_expr::{
39 ColumnarValue,
40 type_coercion::{is_interval, is_null, is_signed_numeric, is_timestamp},
41};
42
43#[derive(Debug, Eq)]
45pub struct NegativeExpr {
46 arg: Arc<dyn PhysicalExpr>,
48}
49
50impl PartialEq for NegativeExpr {
52 fn eq(&self, other: &Self) -> bool {
53 self.arg.eq(&other.arg)
54 }
55}
56
57impl Hash for NegativeExpr {
58 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
59 self.arg.hash(state);
60 }
61}
62
63impl NegativeExpr {
64 pub fn new(arg: Arc<dyn PhysicalExpr>) -> Self {
66 Self { arg }
67 }
68
69 pub fn arg(&self) -> &Arc<dyn PhysicalExpr> {
71 &self.arg
72 }
73}
74
75impl std::fmt::Display for NegativeExpr {
76 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
77 write!(f, "(- {})", self.arg)
78 }
79}
80
81impl PhysicalExpr for NegativeExpr {
82 fn as_any(&self) -> &dyn Any {
84 self
85 }
86
87 fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
88 self.arg.data_type(input_schema)
89 }
90
91 fn nullable(&self, input_schema: &Schema) -> Result<bool> {
92 self.arg.nullable(input_schema)
93 }
94
95 fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
96 match self.arg.evaluate(batch)? {
97 ColumnarValue::Array(array) => {
98 let result = neg_wrapping(array.as_ref())?;
99 Ok(ColumnarValue::Array(result))
100 }
101 ColumnarValue::Scalar(scalar) => {
102 Ok(ColumnarValue::Scalar(scalar.arithmetic_negate()?))
103 }
104 }
105 }
106
107 fn return_field(&self, input_schema: &Schema) -> Result<FieldRef> {
108 self.arg.return_field(input_schema)
109 }
110
111 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
112 vec![&self.arg]
113 }
114
115 fn with_new_children(
116 self: Arc<Self>,
117 children: Vec<Arc<dyn PhysicalExpr>>,
118 ) -> Result<Arc<dyn PhysicalExpr>> {
119 Ok(Arc::new(NegativeExpr::new(Arc::clone(&children[0]))))
120 }
121
122 fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
126 children[0].arithmetic_negate()
127 }
128
129 fn propagate_constraints(
132 &self,
133 interval: &Interval,
134 children: &[&Interval],
135 ) -> Result<Option<Vec<Interval>>> {
136 let negated_interval = interval.arithmetic_negate()?;
137
138 Ok(children[0]
139 .intersect(negated_interval)?
140 .map(|result| vec![result]))
141 }
142
143 fn evaluate_statistics(&self, children: &[&Distribution]) -> Result<Distribution> {
144 match children[0] {
145 Uniform(u) => Distribution::new_uniform(u.range().arithmetic_negate()?),
146 Exponential(e) => Distribution::new_exponential(
147 e.rate().clone(),
148 e.offset().arithmetic_negate()?,
149 !e.positive_tail(),
150 ),
151 Gaussian(g) => Distribution::new_gaussian(
152 g.mean().arithmetic_negate()?,
153 g.variance().clone(),
154 ),
155 Bernoulli(_) => {
156 internal_err!("NegativeExpr cannot operate on Boolean datatypes")
157 }
158 Generic(u) => Distribution::new_generic(
159 u.mean().arithmetic_negate()?,
160 u.median().arithmetic_negate()?,
161 u.variance().clone(),
162 u.range().arithmetic_negate()?,
163 ),
164 }
165 }
166
167 fn get_properties(&self, children: &[ExprProperties]) -> Result<ExprProperties> {
169 Ok(ExprProperties {
170 sort_properties: -children[0].sort_properties,
171 range: children[0].range.clone().arithmetic_negate()?,
172 preserves_lex_ordering: false,
173 })
174 }
175
176 fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
177 write!(f, "(- ")?;
178 self.arg.fmt_sql(f)?;
179 write!(f, ")")
180 }
181}
182
183pub fn negative(
189 arg: Arc<dyn PhysicalExpr>,
190 input_schema: &Schema,
191) -> Result<Arc<dyn PhysicalExpr>> {
192 let data_type = arg.data_type(input_schema)?;
193 if is_null(&data_type) {
194 Ok(arg)
195 } else if !is_signed_numeric(&data_type)
196 && !is_interval(&data_type)
197 && !is_timestamp(&data_type)
198 {
199 plan_err!("Negation only supports numeric, interval and timestamp types")
200 } else {
201 Ok(Arc::new(NegativeExpr::new(arg)))
202 }
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208 use crate::expressions::{Column, col};
209
210 use arrow::array::*;
211 use arrow::datatypes::DataType::{Float32, Float64, Int8, Int16, Int32, Int64};
212 use arrow::datatypes::*;
213 use datafusion_common::cast::as_primitive_array;
214 use datafusion_common::{DataFusionError, ScalarValue};
215
216 use datafusion_physical_expr_common::physical_expr::fmt_sql;
217 use paste::paste;
218
219 macro_rules! test_array_negative_op {
220 ($DATA_TY:tt, $($VALUE:expr),* ) => {
221 let schema = Schema::new(vec![Field::new("a", DataType::$DATA_TY, true)]);
222 let expr = negative(col("a", &schema)?, &schema)?;
223 assert_eq!(expr.data_type(&schema)?, DataType::$DATA_TY);
224 assert!(expr.nullable(&schema)?);
225 let mut arr = Vec::new();
226 let mut arr_expected = Vec::new();
227 $(
228 arr.push(Some($VALUE));
229 arr_expected.push(Some(-$VALUE));
230 )+
231 arr.push(None);
232 arr_expected.push(None);
233 let input = paste!{[<$DATA_TY Array>]::from(arr)};
234 let expected = &paste!{[<$DATA_TY Array>]::from(arr_expected)};
235 let batch =
236 RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(input)])?;
237 let result = expr.evaluate(&batch)?.into_array(batch.num_rows()).expect("Failed to convert to array");
238 let result =
239 as_primitive_array(&result).expect(format!("failed to downcast to {:?}Array", $DATA_TY).as_str());
240 assert_eq!(result, expected);
241 };
242 }
243
244 #[test]
245 fn array_negative_op() -> Result<()> {
246 test_array_negative_op!(Int8, 2i8, 1i8);
247 test_array_negative_op!(Int16, 234i16, 123i16);
248 test_array_negative_op!(Int32, 2345i32, 1234i32);
249 test_array_negative_op!(Int64, 23456i64, 12345i64);
250 test_array_negative_op!(Float32, 2345.0f32, 1234.0f32);
251 test_array_negative_op!(Float64, 23456.0f64, 12345.0f64);
252 Ok(())
253 }
254
255 #[test]
256 fn test_evaluate_bounds() -> Result<()> {
257 let negative_expr = NegativeExpr::new(Arc::new(Column::new("a", 0)));
258 let child_interval = Interval::make(Some(-2), Some(1))?;
259 let negative_expr_interval = Interval::make(Some(-1), Some(2))?;
260 assert_eq!(
261 negative_expr.evaluate_bounds(&[&child_interval])?,
262 negative_expr_interval
263 );
264 Ok(())
265 }
266
267 #[test]
268 fn test_evaluate_statistics() -> Result<()> {
269 let negative_expr = NegativeExpr::new(Arc::new(Column::new("a", 0)));
270
271 assert_eq!(
273 negative_expr.evaluate_statistics(&[&Distribution::new_uniform(
274 Interval::make(Some(-2.), Some(3.))?
275 )?])?,
276 Distribution::new_uniform(Interval::make(Some(-3.), Some(2.))?)?
277 );
278
279 assert!(
281 negative_expr
282 .evaluate_statistics(&[&Distribution::new_bernoulli(ScalarValue::from(
283 0.75
284 ))?])
285 .is_err()
286 );
287
288 assert_eq!(
290 negative_expr.evaluate_statistics(&[&Distribution::new_exponential(
291 ScalarValue::from(1.),
292 ScalarValue::from(1.),
293 true
294 )?])?,
295 Distribution::new_exponential(
296 ScalarValue::from(1.),
297 ScalarValue::from(-1.),
298 false
299 )?
300 );
301
302 assert_eq!(
304 negative_expr.evaluate_statistics(&[&Distribution::new_gaussian(
305 ScalarValue::from(15),
306 ScalarValue::from(225),
307 )?])?,
308 Distribution::new_gaussian(ScalarValue::from(-15), ScalarValue::from(225),)?
309 );
310
311 assert_eq!(
313 negative_expr.evaluate_statistics(&[&Distribution::new_generic(
314 ScalarValue::from(15),
315 ScalarValue::from(15),
316 ScalarValue::from(10),
317 Interval::make(Some(10), Some(20))?
318 )?])?,
319 Distribution::new_generic(
320 ScalarValue::from(-15),
321 ScalarValue::from(-15),
322 ScalarValue::from(10),
323 Interval::make(Some(-20), Some(-10))?
324 )?
325 );
326
327 Ok(())
328 }
329
330 #[test]
331 fn test_propagate_constraints() -> Result<()> {
332 let negative_expr = NegativeExpr::new(Arc::new(Column::new("a", 0)));
333 let original_child_interval = Interval::make(Some(-2), Some(3))?;
334 let negative_expr_interval = Interval::make(Some(0), Some(4))?;
335 let after_propagation = Some(vec![Interval::make(Some(-2), Some(0))?]);
336 assert_eq!(
337 negative_expr.propagate_constraints(
338 &negative_expr_interval,
339 &[&original_child_interval]
340 )?,
341 after_propagation
342 );
343 Ok(())
344 }
345
346 #[test]
347 fn test_propagate_statistics_range_holders() -> Result<()> {
348 let negative_expr = NegativeExpr::new(Arc::new(Column::new("a", 0)));
349 let original_child_interval = Interval::make(Some(-2), Some(3))?;
350 let after_propagation = Interval::make(Some(-2), Some(0))?;
351
352 let parent = Distribution::new_uniform(Interval::make(Some(0), Some(4))?)?;
353 let children: Vec<Vec<Distribution>> = vec![
354 vec![Distribution::new_uniform(original_child_interval.clone())?],
355 vec![Distribution::new_generic(
356 ScalarValue::from(0),
357 ScalarValue::from(0),
358 ScalarValue::Int32(None),
359 original_child_interval.clone(),
360 )?],
361 ];
362
363 for child_view in children {
364 let child_refs: Vec<_> = child_view.iter().collect();
365 let actual = negative_expr.propagate_statistics(&parent, &child_refs)?;
366 let expected = Some(vec![Distribution::new_from_interval(
367 after_propagation.clone(),
368 )?]);
369 assert_eq!(actual, expected);
370 }
371
372 Ok(())
373 }
374
375 #[test]
376 fn test_negation_valid_types() -> Result<()> {
377 let negatable_types = [
378 Int8,
379 DataType::Timestamp(TimeUnit::Second, None),
380 DataType::Interval(IntervalUnit::YearMonth),
381 ];
382 for negatable_type in negatable_types {
383 let schema = Schema::new(vec![Field::new("a", negatable_type, true)]);
384 let _expr = negative(col("a", &schema)?, &schema)?;
385 }
386 Ok(())
387 }
388
389 #[test]
390 fn test_negation_invalid_types() -> Result<()> {
391 let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
392 let expr = negative(col("a", &schema)?, &schema).unwrap_err();
393 matches!(expr, DataFusionError::Plan(_));
394 Ok(())
395 }
396
397 #[test]
398 fn test_fmt_sql() -> Result<()> {
399 let expr = NegativeExpr::new(Arc::new(Column::new("a", 0)));
400 let display_string = expr.to_string();
401 assert_eq!(display_string, "(- a@0)");
402 let sql_string = fmt_sql(&expr).to_string();
403 assert_eq!(sql_string, "(- a)");
404
405 Ok(())
406 }
407}