datafusion_functions_aggregate/
stddev.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines physical expressions that can evaluated at runtime during query execution
19
20use std::any::Any;
21use std::fmt::{Debug, Formatter};
22use std::mem::align_of_val;
23use std::sync::Arc;
24
25use arrow::array::Float64Array;
26use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field};
27
28use datafusion_common::{internal_err, not_impl_err, Result};
29use datafusion_common::{plan_err, ScalarValue};
30use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
31use datafusion_expr::utils::format_state_name;
32use datafusion_expr::{
33    Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, Signature,
34    Volatility,
35};
36use datafusion_functions_aggregate_common::stats::StatsType;
37use datafusion_macros::user_doc;
38
39use crate::variance::{VarianceAccumulator, VarianceGroupsAccumulator};
40
41make_udaf_expr_and_func!(
42    Stddev,
43    stddev,
44    expression,
45    "Compute the standard deviation of a set of numbers",
46    stddev_udaf
47);
48
49#[user_doc(
50    doc_section(label = "Statistical Functions"),
51    description = "Returns the standard deviation of a set of numbers.",
52    syntax_example = "stddev(expression)",
53    sql_example = r#"```sql
54> SELECT stddev(column_name) FROM table_name;
55+----------------------+
56| stddev(column_name)   |
57+----------------------+
58| 12.34                |
59+----------------------+
60```"#,
61    standard_argument(name = "expression",)
62)]
63/// STDDEV and STDDEV_SAMP (standard deviation) aggregate expression
64pub struct Stddev {
65    signature: Signature,
66    alias: Vec<String>,
67}
68
69impl Debug for Stddev {
70    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
71        f.debug_struct("Stddev")
72            .field("name", &self.name())
73            .field("signature", &self.signature)
74            .finish()
75    }
76}
77
78impl Default for Stddev {
79    fn default() -> Self {
80        Self::new()
81    }
82}
83
84impl Stddev {
85    /// Create a new STDDEV aggregate function
86    pub fn new() -> Self {
87        Self {
88            signature: Signature::numeric(1, Volatility::Immutable),
89            alias: vec!["stddev_samp".to_string()],
90        }
91    }
92}
93
94impl AggregateUDFImpl for Stddev {
95    /// Return a reference to Any that can be used for downcasting
96    fn as_any(&self) -> &dyn Any {
97        self
98    }
99
100    fn name(&self) -> &str {
101        "stddev"
102    }
103
104    fn signature(&self) -> &Signature {
105        &self.signature
106    }
107
108    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
109        Ok(DataType::Float64)
110    }
111
112    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
113        Ok(vec![
114            Field::new(
115                format_state_name(args.name, "count"),
116                DataType::UInt64,
117                true,
118            ),
119            Field::new(
120                format_state_name(args.name, "mean"),
121                DataType::Float64,
122                true,
123            ),
124            Field::new(format_state_name(args.name, "m2"), DataType::Float64, true),
125        ])
126    }
127
128    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
129        if acc_args.is_distinct {
130            return not_impl_err!("STDDEV_POP(DISTINCT) aggregations are not available");
131        }
132        Ok(Box::new(StddevAccumulator::try_new(StatsType::Sample)?))
133    }
134
135    fn aliases(&self) -> &[String] {
136        &self.alias
137    }
138
139    fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool {
140        !acc_args.is_distinct
141    }
142
143    fn create_groups_accumulator(
144        &self,
145        _args: AccumulatorArgs,
146    ) -> Result<Box<dyn GroupsAccumulator>> {
147        Ok(Box::new(StddevGroupsAccumulator::new(StatsType::Sample)))
148    }
149
150    fn documentation(&self) -> Option<&Documentation> {
151        self.doc()
152    }
153}
154
155make_udaf_expr_and_func!(
156    StddevPop,
157    stddev_pop,
158    expression,
159    "Compute the population standard deviation of a set of numbers",
160    stddev_pop_udaf
161);
162
163#[user_doc(
164    doc_section(label = "Statistical Functions"),
165    description = "Returns the population standard deviation of a set of numbers.",
166    syntax_example = "stddev_pop(expression)",
167    sql_example = r#"```sql
168> SELECT stddev_pop(column_name) FROM table_name;
169+--------------------------+
170| stddev_pop(column_name)   |
171+--------------------------+
172| 10.56                    |
173+--------------------------+
174```"#,
175    standard_argument(name = "expression",)
176)]
177/// STDDEV_POP population aggregate expression
178pub struct StddevPop {
179    signature: Signature,
180}
181
182impl Debug for StddevPop {
183    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
184        f.debug_struct("StddevPop")
185            .field("name", &self.name())
186            .field("signature", &self.signature)
187            .finish()
188    }
189}
190
191impl Default for StddevPop {
192    fn default() -> Self {
193        Self::new()
194    }
195}
196
197impl StddevPop {
198    /// Create a new STDDEV_POP aggregate function
199    pub fn new() -> Self {
200        Self {
201            signature: Signature::numeric(1, Volatility::Immutable),
202        }
203    }
204}
205
206impl AggregateUDFImpl for StddevPop {
207    /// Return a reference to Any that can be used for downcasting
208    fn as_any(&self) -> &dyn Any {
209        self
210    }
211
212    fn name(&self) -> &str {
213        "stddev_pop"
214    }
215
216    fn signature(&self) -> &Signature {
217        &self.signature
218    }
219
220    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
221        Ok(vec![
222            Field::new(
223                format_state_name(args.name, "count"),
224                DataType::UInt64,
225                true,
226            ),
227            Field::new(
228                format_state_name(args.name, "mean"),
229                DataType::Float64,
230                true,
231            ),
232            Field::new(format_state_name(args.name, "m2"), DataType::Float64, true),
233        ])
234    }
235
236    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
237        if acc_args.is_distinct {
238            return not_impl_err!("STDDEV_POP(DISTINCT) aggregations are not available");
239        }
240        Ok(Box::new(StddevAccumulator::try_new(StatsType::Population)?))
241    }
242
243    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
244        if !arg_types[0].is_numeric() {
245            return plan_err!("StddevPop requires numeric input types");
246        }
247
248        Ok(DataType::Float64)
249    }
250
251    fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool {
252        !acc_args.is_distinct
253    }
254
255    fn create_groups_accumulator(
256        &self,
257        _args: AccumulatorArgs,
258    ) -> Result<Box<dyn GroupsAccumulator>> {
259        Ok(Box::new(StddevGroupsAccumulator::new(
260            StatsType::Population,
261        )))
262    }
263
264    fn documentation(&self) -> Option<&Documentation> {
265        self.doc()
266    }
267}
268
269/// An accumulator to compute the average
270#[derive(Debug)]
271pub struct StddevAccumulator {
272    variance: VarianceAccumulator,
273}
274
275impl StddevAccumulator {
276    /// Creates a new `StddevAccumulator`
277    pub fn try_new(s_type: StatsType) -> Result<Self> {
278        Ok(Self {
279            variance: VarianceAccumulator::try_new(s_type)?,
280        })
281    }
282
283    pub fn get_m2(&self) -> f64 {
284        self.variance.get_m2()
285    }
286}
287
288impl Accumulator for StddevAccumulator {
289    fn state(&mut self) -> Result<Vec<ScalarValue>> {
290        Ok(vec![
291            ScalarValue::from(self.variance.get_count()),
292            ScalarValue::from(self.variance.get_mean()),
293            ScalarValue::from(self.variance.get_m2()),
294        ])
295    }
296
297    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
298        self.variance.update_batch(values)
299    }
300
301    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
302        self.variance.retract_batch(values)
303    }
304
305    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
306        self.variance.merge_batch(states)
307    }
308
309    fn evaluate(&mut self) -> Result<ScalarValue> {
310        let variance = self.variance.evaluate()?;
311        match variance {
312            ScalarValue::Float64(e) => {
313                if e.is_none() {
314                    Ok(ScalarValue::Float64(None))
315                } else {
316                    Ok(ScalarValue::Float64(e.map(|f| f.sqrt())))
317                }
318            }
319            _ => internal_err!("Variance should be f64"),
320        }
321    }
322
323    fn size(&self) -> usize {
324        align_of_val(self) - align_of_val(&self.variance) + self.variance.size()
325    }
326
327    fn supports_retract_batch(&self) -> bool {
328        self.variance.supports_retract_batch()
329    }
330}
331
332#[derive(Debug)]
333pub struct StddevGroupsAccumulator {
334    variance: VarianceGroupsAccumulator,
335}
336
337impl StddevGroupsAccumulator {
338    pub fn new(s_type: StatsType) -> Self {
339        Self {
340            variance: VarianceGroupsAccumulator::new(s_type),
341        }
342    }
343}
344
345impl GroupsAccumulator for StddevGroupsAccumulator {
346    fn update_batch(
347        &mut self,
348        values: &[ArrayRef],
349        group_indices: &[usize],
350        opt_filter: Option<&arrow::array::BooleanArray>,
351        total_num_groups: usize,
352    ) -> Result<()> {
353        self.variance
354            .update_batch(values, group_indices, opt_filter, total_num_groups)
355    }
356
357    fn merge_batch(
358        &mut self,
359        values: &[ArrayRef],
360        group_indices: &[usize],
361        opt_filter: Option<&arrow::array::BooleanArray>,
362        total_num_groups: usize,
363    ) -> Result<()> {
364        self.variance
365            .merge_batch(values, group_indices, opt_filter, total_num_groups)
366    }
367
368    fn evaluate(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<ArrayRef> {
369        let (mut variances, nulls) = self.variance.variance(emit_to);
370        variances.iter_mut().for_each(|v| *v = v.sqrt());
371        Ok(Arc::new(Float64Array::new(variances.into(), Some(nulls))))
372    }
373
374    fn state(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<Vec<ArrayRef>> {
375        self.variance.state(emit_to)
376    }
377
378    fn size(&self) -> usize {
379        self.variance.size()
380    }
381}
382
383#[cfg(test)]
384mod tests {
385    use super::*;
386    use arrow::{array::*, datatypes::*};
387    use datafusion_expr::AggregateUDF;
388    use datafusion_functions_aggregate_common::utils::get_accum_scalar_values_as_arrays;
389    use datafusion_physical_expr::expressions::col;
390    use datafusion_physical_expr_common::sort_expr::LexOrdering;
391    use std::sync::Arc;
392
393    #[test]
394    fn stddev_f64_merge_1() -> Result<()> {
395        let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64]));
396        let b = Arc::new(Float64Array::from(vec![4_f64, 5_f64]));
397
398        let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]);
399
400        let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?;
401        let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![b])?;
402
403        let agg1 = stddev_pop_udaf();
404        let agg2 = stddev_pop_udaf();
405
406        let actual = merge(&batch1, &batch2, agg1, agg2, &schema)?;
407        assert_eq!(actual, ScalarValue::from(std::f64::consts::SQRT_2));
408
409        Ok(())
410    }
411
412    #[test]
413    fn stddev_f64_merge_2() -> Result<()> {
414        let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64]));
415        let b = Arc::new(Float64Array::from(vec![None]));
416
417        let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]);
418
419        let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?;
420        let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![b])?;
421
422        let agg1 = stddev_pop_udaf();
423        let agg2 = stddev_pop_udaf();
424
425        let actual = merge(&batch1, &batch2, agg1, agg2, &schema)?;
426        assert_eq!(actual, ScalarValue::from(std::f64::consts::SQRT_2));
427
428        Ok(())
429    }
430
431    fn merge(
432        batch1: &RecordBatch,
433        batch2: &RecordBatch,
434        agg1: Arc<AggregateUDF>,
435        agg2: Arc<AggregateUDF>,
436        schema: &Schema,
437    ) -> Result<ScalarValue> {
438        let args1 = AccumulatorArgs {
439            return_type: &DataType::Float64,
440            schema,
441            ignore_nulls: false,
442            ordering_req: &LexOrdering::default(),
443            name: "a",
444            is_distinct: false,
445            is_reversed: false,
446            exprs: &[col("a", schema)?],
447        };
448
449        let args2 = AccumulatorArgs {
450            return_type: &DataType::Float64,
451            schema,
452            ignore_nulls: false,
453            ordering_req: &LexOrdering::default(),
454            name: "a",
455            is_distinct: false,
456            is_reversed: false,
457            exprs: &[col("a", schema)?],
458        };
459
460        let mut accum1 = agg1.accumulator(args1)?;
461        let mut accum2 = agg2.accumulator(args2)?;
462
463        let value1 = vec![col("a", schema)?
464            .evaluate(batch1)
465            .and_then(|v| v.into_array(batch1.num_rows()))?];
466        let value2 = vec![col("a", schema)?
467            .evaluate(batch2)
468            .and_then(|v| v.into_array(batch2.num_rows()))?];
469
470        accum1.update_batch(&value1)?;
471        accum2.update_batch(&value2)?;
472        let state2 = get_accum_scalar_values_as_arrays(accum2.as_mut())?;
473        accum1.merge_batch(&state2)?;
474        let result = accum1.evaluate()?;
475        Ok(result)
476    }
477}