datafusion_physical_expr/expressions/
negative.rs1use std::hash::Hash;
21use std::sync::Arc;
22
23use crate::PhysicalExpr;
24
25use arrow::datatypes::FieldRef;
26use arrow::{
27 compute::kernels::numeric::neg_wrapping,
28 datatypes::{DataType, Schema},
29 record_batch::RecordBatch,
30};
31use datafusion_common::{Result, internal_err, plan_err};
32use datafusion_expr::interval_arithmetic::Interval;
33use datafusion_expr::sort_properties::ExprProperties;
34#[expect(deprecated)]
35use datafusion_expr::statistics::Distribution::{
36 self, Bernoulli, Exponential, Gaussian, Generic, Uniform,
37};
38use datafusion_expr::{
39 ColumnarValue,
40 type_coercion::{is_interval, is_signed_numeric, is_timestamp},
41};
42
43#[derive(Debug, Eq)]
45pub struct NegativeExpr {
46 arg: Arc<dyn PhysicalExpr>,
48}
49
50impl PartialEq for NegativeExpr {
52 fn eq(&self, other: &Self) -> bool {
53 self.arg.eq(&other.arg)
54 }
55}
56
57impl Hash for NegativeExpr {
58 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
59 self.arg.hash(state);
60 }
61}
62
63impl NegativeExpr {
64 pub fn new(arg: Arc<dyn PhysicalExpr>) -> Self {
66 Self { arg }
67 }
68
69 pub fn arg(&self) -> &Arc<dyn PhysicalExpr> {
71 &self.arg
72 }
73}
74
75impl std::fmt::Display for NegativeExpr {
76 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
77 write!(f, "(- {})", self.arg)
78 }
79}
80
81impl PhysicalExpr for NegativeExpr {
82 fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
83 self.arg.data_type(input_schema)
84 }
85
86 fn nullable(&self, input_schema: &Schema) -> Result<bool> {
87 self.arg.nullable(input_schema)
88 }
89
90 fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
91 match self.arg.evaluate(batch)? {
92 ColumnarValue::Array(array) => {
93 let result = neg_wrapping(array.as_ref())?;
94 Ok(ColumnarValue::Array(result))
95 }
96 ColumnarValue::Scalar(scalar) => {
97 Ok(ColumnarValue::Scalar(scalar.arithmetic_negate()?))
98 }
99 }
100 }
101
102 fn return_field(&self, input_schema: &Schema) -> Result<FieldRef> {
103 self.arg.return_field(input_schema)
104 }
105
106 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
107 vec![&self.arg]
108 }
109
110 fn with_new_children(
111 self: Arc<Self>,
112 children: Vec<Arc<dyn PhysicalExpr>>,
113 ) -> Result<Arc<dyn PhysicalExpr>> {
114 Ok(Arc::new(NegativeExpr::new(Arc::clone(&children[0]))))
115 }
116
117 fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
121 children[0].arithmetic_negate()
122 }
123
124 fn propagate_constraints(
127 &self,
128 interval: &Interval,
129 children: &[&Interval],
130 ) -> Result<Option<Vec<Interval>>> {
131 let negated_interval = interval.arithmetic_negate()?;
132
133 Ok(children[0]
134 .intersect(negated_interval)?
135 .map(|result| vec![result]))
136 }
137
138 #[expect(deprecated)]
139 fn evaluate_statistics(&self, children: &[&Distribution]) -> Result<Distribution> {
140 match children[0] {
141 Uniform(u) => Distribution::new_uniform(u.range().arithmetic_negate()?),
142 Exponential(e) => Distribution::new_exponential(
143 e.rate().clone(),
144 e.offset().arithmetic_negate()?,
145 !e.positive_tail(),
146 ),
147 Gaussian(g) => Distribution::new_gaussian(
148 g.mean().arithmetic_negate()?,
149 g.variance().clone(),
150 ),
151 Bernoulli(_) => {
152 internal_err!("NegativeExpr cannot operate on Boolean datatypes")
153 }
154 Generic(u) => Distribution::new_generic(
155 u.mean().arithmetic_negate()?,
156 u.median().arithmetic_negate()?,
157 u.variance().clone(),
158 u.range().arithmetic_negate()?,
159 ),
160 }
161 }
162
163 fn get_properties(&self, children: &[ExprProperties]) -> Result<ExprProperties> {
165 Ok(ExprProperties {
166 sort_properties: -children[0].sort_properties,
167 range: children[0].range.clone().arithmetic_negate()?,
168 preserves_lex_ordering: false,
169 })
170 }
171
172 fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
173 write!(f, "(- ")?;
174 self.arg.fmt_sql(f)?;
175 write!(f, ")")
176 }
177}
178
179pub fn negative(
185 arg: Arc<dyn PhysicalExpr>,
186 input_schema: &Schema,
187) -> Result<Arc<dyn PhysicalExpr>> {
188 let data_type = arg.data_type(input_schema)?;
189 if data_type.is_null() {
190 Ok(arg)
191 } else if !is_signed_numeric(&data_type)
192 && !is_interval(&data_type)
193 && !is_timestamp(&data_type)
194 {
195 plan_err!("Negation only supports numeric, interval and timestamp types")
196 } else {
197 Ok(Arc::new(NegativeExpr::new(arg)))
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204 use crate::expressions::{Column, col};
205
206 use arrow::array::*;
207 use arrow::datatypes::DataType::{Float32, Float64, Int8, Int16, Int32, Int64};
208 use arrow::datatypes::*;
209 use datafusion_common::cast::as_primitive_array;
210 use datafusion_common::{DataFusionError, ScalarValue};
211
212 use datafusion_physical_expr_common::physical_expr::fmt_sql;
213
214 macro_rules! test_array_negative_op {
215 ($DATA_TY:tt, $ARRAY_TY:ty, $($VALUE:expr),* ) => {
216 let schema = Schema::new(vec![Field::new("a", DataType::$DATA_TY, true)]);
217 let expr = negative(col("a", &schema)?, &schema)?;
218 assert_eq!(expr.data_type(&schema)?, DataType::$DATA_TY);
219 assert!(expr.nullable(&schema)?);
220 let mut arr = Vec::new();
221 let mut arr_expected = Vec::new();
222 $(
223 arr.push(Some($VALUE));
224 arr_expected.push(Some(-$VALUE));
225 )+
226 arr.push(None);
227 arr_expected.push(None);
228 let input = <$ARRAY_TY>::from(arr);
229 let expected = &<$ARRAY_TY>::from(arr_expected);
230 let batch =
231 RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(input)])?;
232 let result = expr.evaluate(&batch)?.into_array(batch.num_rows()).expect("Failed to convert to array");
233 let result =
234 as_primitive_array(&result).expect(format!("failed to downcast to {:?}Array", $DATA_TY).as_str());
235 assert_eq!(result, expected);
236 };
237 }
238
239 #[test]
240 fn array_negative_op() -> Result<()> {
241 test_array_negative_op!(Int8, Int8Array, 2i8, 1i8);
242 test_array_negative_op!(Int16, Int16Array, 234i16, 123i16);
243 test_array_negative_op!(Int32, Int32Array, 2345i32, 1234i32);
244 test_array_negative_op!(Int64, Int64Array, 23456i64, 12345i64);
245 test_array_negative_op!(Float32, Float32Array, 2345.0f32, 1234.0f32);
246 test_array_negative_op!(Float64, Float64Array, 23456.0f64, 12345.0f64);
247 Ok(())
248 }
249
250 #[test]
251 fn test_evaluate_bounds() -> Result<()> {
252 let negative_expr = NegativeExpr::new(Arc::new(Column::new("a", 0)));
253 let child_interval = Interval::make(Some(-2), Some(1))?;
254 let negative_expr_interval = Interval::make(Some(-1), Some(2))?;
255 assert_eq!(
256 negative_expr.evaluate_bounds(&[&child_interval])?,
257 negative_expr_interval
258 );
259 Ok(())
260 }
261
262 #[test]
263 #[expect(deprecated)]
264 fn test_evaluate_statistics() -> Result<()> {
265 let negative_expr = NegativeExpr::new(Arc::new(Column::new("a", 0)));
266
267 assert_eq!(
269 negative_expr.evaluate_statistics(&[&Distribution::new_uniform(
270 Interval::make(Some(-2.), Some(3.))?
271 )?])?,
272 Distribution::new_uniform(Interval::make(Some(-3.), Some(2.))?)?
273 );
274
275 assert!(
277 negative_expr
278 .evaluate_statistics(&[&Distribution::new_bernoulli(ScalarValue::from(
279 0.75
280 ))?])
281 .is_err()
282 );
283
284 assert_eq!(
286 negative_expr.evaluate_statistics(&[&Distribution::new_exponential(
287 ScalarValue::from(1.),
288 ScalarValue::from(1.),
289 true
290 )?])?,
291 Distribution::new_exponential(
292 ScalarValue::from(1.),
293 ScalarValue::from(-1.),
294 false
295 )?
296 );
297
298 assert_eq!(
300 negative_expr.evaluate_statistics(&[&Distribution::new_gaussian(
301 ScalarValue::from(15),
302 ScalarValue::from(225),
303 )?])?,
304 Distribution::new_gaussian(ScalarValue::from(-15), ScalarValue::from(225),)?
305 );
306
307 assert_eq!(
309 negative_expr.evaluate_statistics(&[&Distribution::new_generic(
310 ScalarValue::from(15),
311 ScalarValue::from(15),
312 ScalarValue::from(10),
313 Interval::make(Some(10), Some(20))?
314 )?])?,
315 Distribution::new_generic(
316 ScalarValue::from(-15),
317 ScalarValue::from(-15),
318 ScalarValue::from(10),
319 Interval::make(Some(-20), Some(-10))?
320 )?
321 );
322
323 Ok(())
324 }
325
326 #[test]
327 fn test_propagate_constraints() -> Result<()> {
328 let negative_expr = NegativeExpr::new(Arc::new(Column::new("a", 0)));
329 let original_child_interval = Interval::make(Some(-2), Some(3))?;
330 let negative_expr_interval = Interval::make(Some(0), Some(4))?;
331 let after_propagation = Some(vec![Interval::make(Some(-2), Some(0))?]);
332 assert_eq!(
333 negative_expr.propagate_constraints(
334 &negative_expr_interval,
335 &[&original_child_interval]
336 )?,
337 after_propagation
338 );
339 Ok(())
340 }
341
342 #[test]
343 #[expect(deprecated)]
344 fn test_propagate_statistics_range_holders() -> Result<()> {
345 let negative_expr = NegativeExpr::new(Arc::new(Column::new("a", 0)));
346 let original_child_interval = Interval::make(Some(-2), Some(3))?;
347 let after_propagation = Interval::make(Some(-2), Some(0))?;
348
349 let parent = Distribution::new_uniform(Interval::make(Some(0), Some(4))?)?;
350 let children: Vec<Vec<Distribution>> = vec![
351 vec![Distribution::new_uniform(original_child_interval.clone())?],
352 vec![Distribution::new_generic(
353 ScalarValue::from(0),
354 ScalarValue::from(0),
355 ScalarValue::Int32(None),
356 original_child_interval.clone(),
357 )?],
358 ];
359
360 for child_view in children {
361 let child_refs: Vec<_> = child_view.iter().collect();
362 let actual = negative_expr.propagate_statistics(&parent, &child_refs)?;
363 let expected = Some(vec![Distribution::new_from_interval(
364 after_propagation.clone(),
365 )?]);
366 assert_eq!(actual, expected);
367 }
368
369 Ok(())
370 }
371
372 #[test]
373 fn test_negation_valid_types() -> Result<()> {
374 let negatable_types = [
375 Int8,
376 DataType::Timestamp(TimeUnit::Second, None),
377 DataType::Interval(IntervalUnit::YearMonth),
378 ];
379 for negatable_type in negatable_types {
380 let schema = Schema::new(vec![Field::new("a", negatable_type, true)]);
381 let _expr = negative(col("a", &schema)?, &schema)?;
382 }
383 Ok(())
384 }
385
386 #[test]
387 fn test_negation_invalid_types() -> Result<()> {
388 let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
389 let expr = negative(col("a", &schema)?, &schema).unwrap_err();
390 matches!(expr, DataFusionError::Plan(_));
391 Ok(())
392 }
393
394 #[test]
395 fn test_fmt_sql() -> Result<()> {
396 let expr = NegativeExpr::new(Arc::new(Column::new("a", 0)));
397 let display_string = expr.to_string();
398 assert_eq!(display_string, "(- a@0)");
399 let sql_string = fmt_sql(&expr).to_string();
400 assert_eq!(sql_string, "(- a)");
401
402 Ok(())
403 }
404}