datafusion_functions_aggregate/
stddev.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines physical expressions that can evaluated at runtime during query execution
19
20use std::any::Any;
21use std::fmt::{Debug, Formatter};
22use std::hash::{DefaultHasher, Hash, Hasher};
23use std::mem::align_of_val;
24use std::sync::Arc;
25
26use arrow::array::Float64Array;
27use arrow::datatypes::FieldRef;
28use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field};
29use datafusion_common::{internal_err, not_impl_err, Result};
30use datafusion_common::{plan_err, ScalarValue};
31use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
32use datafusion_expr::utils::format_state_name;
33use datafusion_expr::{
34    Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, Signature,
35    Volatility,
36};
37use datafusion_functions_aggregate_common::stats::StatsType;
38use datafusion_macros::user_doc;
39
40use crate::variance::{VarianceAccumulator, VarianceGroupsAccumulator};
41
42make_udaf_expr_and_func!(
43    Stddev,
44    stddev,
45    expression,
46    "Compute the standard deviation of a set of numbers",
47    stddev_udaf
48);
49
50#[user_doc(
51    doc_section(label = "Statistical Functions"),
52    description = "Returns the standard deviation of a set of numbers.",
53    syntax_example = "stddev(expression)",
54    sql_example = r#"```sql
55> SELECT stddev(column_name) FROM table_name;
56+----------------------+
57| stddev(column_name)   |
58+----------------------+
59| 12.34                |
60+----------------------+
61```"#,
62    standard_argument(name = "expression",)
63)]
64/// STDDEV and STDDEV_SAMP (standard deviation) aggregate expression
65pub struct Stddev {
66    signature: Signature,
67    alias: Vec<String>,
68}
69
70impl Debug for Stddev {
71    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
72        f.debug_struct("Stddev")
73            .field("name", &self.name())
74            .field("signature", &self.signature)
75            .finish()
76    }
77}
78
79impl Default for Stddev {
80    fn default() -> Self {
81        Self::new()
82    }
83}
84
85impl Stddev {
86    /// Create a new STDDEV aggregate function
87    pub fn new() -> Self {
88        Self {
89            signature: Signature::numeric(1, Volatility::Immutable),
90            alias: vec!["stddev_samp".to_string()],
91        }
92    }
93}
94
95impl AggregateUDFImpl for Stddev {
96    /// Return a reference to Any that can be used for downcasting
97    fn as_any(&self) -> &dyn Any {
98        self
99    }
100
101    fn name(&self) -> &str {
102        "stddev"
103    }
104
105    fn signature(&self) -> &Signature {
106        &self.signature
107    }
108
109    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
110        Ok(DataType::Float64)
111    }
112
113    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
114        Ok(vec![
115            Field::new(
116                format_state_name(args.name, "count"),
117                DataType::UInt64,
118                true,
119            ),
120            Field::new(
121                format_state_name(args.name, "mean"),
122                DataType::Float64,
123                true,
124            ),
125            Field::new(format_state_name(args.name, "m2"), DataType::Float64, true),
126        ]
127        .into_iter()
128        .map(Arc::new)
129        .collect())
130    }
131
132    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
133        if acc_args.is_distinct {
134            return not_impl_err!("STDDEV_POP(DISTINCT) aggregations are not available");
135        }
136        Ok(Box::new(StddevAccumulator::try_new(StatsType::Sample)?))
137    }
138
139    fn aliases(&self) -> &[String] {
140        &self.alias
141    }
142
143    fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool {
144        !acc_args.is_distinct
145    }
146
147    fn create_groups_accumulator(
148        &self,
149        _args: AccumulatorArgs,
150    ) -> Result<Box<dyn GroupsAccumulator>> {
151        Ok(Box::new(StddevGroupsAccumulator::new(StatsType::Sample)))
152    }
153
154    fn documentation(&self) -> Option<&Documentation> {
155        self.doc()
156    }
157
158    fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
159        let Some(other) = other.as_any().downcast_ref::<Self>() else {
160            return false;
161        };
162        let Self { signature, alias } = self;
163        signature == &other.signature && alias == &other.alias
164    }
165
166    fn hash_value(&self) -> u64 {
167        let Self { signature, alias } = self;
168        let mut hasher = DefaultHasher::new();
169        std::any::type_name::<Self>().hash(&mut hasher);
170        signature.hash(&mut hasher);
171        alias.hash(&mut hasher);
172        hasher.finish()
173    }
174}
175
176make_udaf_expr_and_func!(
177    StddevPop,
178    stddev_pop,
179    expression,
180    "Compute the population standard deviation of a set of numbers",
181    stddev_pop_udaf
182);
183
184#[user_doc(
185    doc_section(label = "Statistical Functions"),
186    description = "Returns the population standard deviation of a set of numbers.",
187    syntax_example = "stddev_pop(expression)",
188    sql_example = r#"```sql
189> SELECT stddev_pop(column_name) FROM table_name;
190+--------------------------+
191| stddev_pop(column_name)   |
192+--------------------------+
193| 10.56                    |
194+--------------------------+
195```"#,
196    standard_argument(name = "expression",)
197)]
198/// STDDEV_POP population aggregate expression
199pub struct StddevPop {
200    signature: Signature,
201}
202
203impl Debug for StddevPop {
204    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
205        f.debug_struct("StddevPop")
206            .field("name", &self.name())
207            .field("signature", &self.signature)
208            .finish()
209    }
210}
211
212impl Default for StddevPop {
213    fn default() -> Self {
214        Self::new()
215    }
216}
217
218impl StddevPop {
219    /// Create a new STDDEV_POP aggregate function
220    pub fn new() -> Self {
221        Self {
222            signature: Signature::numeric(1, Volatility::Immutable),
223        }
224    }
225}
226
227impl AggregateUDFImpl for StddevPop {
228    /// Return a reference to Any that can be used for downcasting
229    fn as_any(&self) -> &dyn Any {
230        self
231    }
232
233    fn name(&self) -> &str {
234        "stddev_pop"
235    }
236
237    fn signature(&self) -> &Signature {
238        &self.signature
239    }
240
241    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
242        Ok(vec![
243            Field::new(
244                format_state_name(args.name, "count"),
245                DataType::UInt64,
246                true,
247            ),
248            Field::new(
249                format_state_name(args.name, "mean"),
250                DataType::Float64,
251                true,
252            ),
253            Field::new(format_state_name(args.name, "m2"), DataType::Float64, true),
254        ]
255        .into_iter()
256        .map(Arc::new)
257        .collect())
258    }
259
260    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
261        if acc_args.is_distinct {
262            return not_impl_err!("STDDEV_POP(DISTINCT) aggregations are not available");
263        }
264        Ok(Box::new(StddevAccumulator::try_new(StatsType::Population)?))
265    }
266
267    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
268        if !arg_types[0].is_numeric() {
269            return plan_err!("StddevPop requires numeric input types");
270        }
271
272        Ok(DataType::Float64)
273    }
274
275    fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool {
276        !acc_args.is_distinct
277    }
278
279    fn create_groups_accumulator(
280        &self,
281        _args: AccumulatorArgs,
282    ) -> Result<Box<dyn GroupsAccumulator>> {
283        Ok(Box::new(StddevGroupsAccumulator::new(
284            StatsType::Population,
285        )))
286    }
287
288    fn documentation(&self) -> Option<&Documentation> {
289        self.doc()
290    }
291}
292
293/// An accumulator to compute the average
294#[derive(Debug)]
295pub struct StddevAccumulator {
296    variance: VarianceAccumulator,
297}
298
299impl StddevAccumulator {
300    /// Creates a new `StddevAccumulator`
301    pub fn try_new(s_type: StatsType) -> Result<Self> {
302        Ok(Self {
303            variance: VarianceAccumulator::try_new(s_type)?,
304        })
305    }
306
307    pub fn get_m2(&self) -> f64 {
308        self.variance.get_m2()
309    }
310}
311
312impl Accumulator for StddevAccumulator {
313    fn state(&mut self) -> Result<Vec<ScalarValue>> {
314        Ok(vec![
315            ScalarValue::from(self.variance.get_count()),
316            ScalarValue::from(self.variance.get_mean()),
317            ScalarValue::from(self.variance.get_m2()),
318        ])
319    }
320
321    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
322        self.variance.update_batch(values)
323    }
324
325    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
326        self.variance.retract_batch(values)
327    }
328
329    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
330        self.variance.merge_batch(states)
331    }
332
333    fn evaluate(&mut self) -> Result<ScalarValue> {
334        let variance = self.variance.evaluate()?;
335        match variance {
336            ScalarValue::Float64(e) => {
337                if e.is_none() {
338                    Ok(ScalarValue::Float64(None))
339                } else {
340                    Ok(ScalarValue::Float64(e.map(|f| f.sqrt())))
341                }
342            }
343            _ => internal_err!("Variance should be f64"),
344        }
345    }
346
347    fn size(&self) -> usize {
348        align_of_val(self) - align_of_val(&self.variance) + self.variance.size()
349    }
350
351    fn supports_retract_batch(&self) -> bool {
352        self.variance.supports_retract_batch()
353    }
354}
355
356#[derive(Debug)]
357pub struct StddevGroupsAccumulator {
358    variance: VarianceGroupsAccumulator,
359}
360
361impl StddevGroupsAccumulator {
362    pub fn new(s_type: StatsType) -> Self {
363        Self {
364            variance: VarianceGroupsAccumulator::new(s_type),
365        }
366    }
367}
368
369impl GroupsAccumulator for StddevGroupsAccumulator {
370    fn update_batch(
371        &mut self,
372        values: &[ArrayRef],
373        group_indices: &[usize],
374        opt_filter: Option<&arrow::array::BooleanArray>,
375        total_num_groups: usize,
376    ) -> Result<()> {
377        self.variance
378            .update_batch(values, group_indices, opt_filter, total_num_groups)
379    }
380
381    fn merge_batch(
382        &mut self,
383        values: &[ArrayRef],
384        group_indices: &[usize],
385        opt_filter: Option<&arrow::array::BooleanArray>,
386        total_num_groups: usize,
387    ) -> Result<()> {
388        self.variance
389            .merge_batch(values, group_indices, opt_filter, total_num_groups)
390    }
391
392    fn evaluate(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<ArrayRef> {
393        let (mut variances, nulls) = self.variance.variance(emit_to);
394        variances.iter_mut().for_each(|v| *v = v.sqrt());
395        Ok(Arc::new(Float64Array::new(variances.into(), Some(nulls))))
396    }
397
398    fn state(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<Vec<ArrayRef>> {
399        self.variance.state(emit_to)
400    }
401
402    fn size(&self) -> usize {
403        self.variance.size()
404    }
405}
406
407#[cfg(test)]
408mod tests {
409    use super::*;
410    use arrow::{array::*, datatypes::*};
411    use datafusion_expr::AggregateUDF;
412    use datafusion_functions_aggregate_common::utils::get_accum_scalar_values_as_arrays;
413    use datafusion_physical_expr::expressions::col;
414    use std::sync::Arc;
415
416    #[test]
417    fn stddev_f64_merge_1() -> Result<()> {
418        let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64]));
419        let b = Arc::new(Float64Array::from(vec![4_f64, 5_f64]));
420
421        let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]);
422
423        let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?;
424        let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![b])?;
425
426        let agg1 = stddev_pop_udaf();
427        let agg2 = stddev_pop_udaf();
428
429        let actual = merge(&batch1, &batch2, agg1, agg2, &schema)?;
430        assert_eq!(actual, ScalarValue::from(std::f64::consts::SQRT_2));
431
432        Ok(())
433    }
434
435    #[test]
436    fn stddev_f64_merge_2() -> Result<()> {
437        let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64]));
438        let b = Arc::new(Float64Array::from(vec![None]));
439
440        let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]);
441
442        let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?;
443        let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![b])?;
444
445        let agg1 = stddev_pop_udaf();
446        let agg2 = stddev_pop_udaf();
447
448        let actual = merge(&batch1, &batch2, agg1, agg2, &schema)?;
449        assert_eq!(actual, ScalarValue::from(std::f64::consts::SQRT_2));
450
451        Ok(())
452    }
453
454    fn merge(
455        batch1: &RecordBatch,
456        batch2: &RecordBatch,
457        agg1: Arc<AggregateUDF>,
458        agg2: Arc<AggregateUDF>,
459        schema: &Schema,
460    ) -> Result<ScalarValue> {
461        let args1 = AccumulatorArgs {
462            return_field: Field::new("f", DataType::Float64, true).into(),
463            schema,
464            ignore_nulls: false,
465            order_bys: &[],
466            name: "a",
467            is_distinct: false,
468            is_reversed: false,
469            exprs: &[col("a", schema)?],
470        };
471
472        let args2 = AccumulatorArgs {
473            return_field: Field::new("f", DataType::Float64, true).into(),
474            schema,
475            ignore_nulls: false,
476            order_bys: &[],
477            name: "a",
478            is_distinct: false,
479            is_reversed: false,
480            exprs: &[col("a", schema)?],
481        };
482
483        let mut accum1 = agg1.accumulator(args1)?;
484        let mut accum2 = agg2.accumulator(args2)?;
485
486        let value1 = vec![col("a", schema)?
487            .evaluate(batch1)
488            .and_then(|v| v.into_array(batch1.num_rows()))?];
489        let value2 = vec![col("a", schema)?
490            .evaluate(batch2)
491            .and_then(|v| v.into_array(batch2.num_rows()))?];
492
493        accum1.update_batch(&value1)?;
494        accum2.update_batch(&value2)?;
495        let state2 = get_accum_scalar_values_as_arrays(accum2.as_mut())?;
496        accum1.merge_batch(&state2)?;
497        let result = accum1.evaluate()?;
498        Ok(result)
499    }
500}