datafusion_functions_aggregate/
stddev.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines physical expressions that can evaluated at runtime during query execution
19
20use std::any::Any;
21use std::fmt::{Debug, Formatter};
22use std::mem::align_of_val;
23use std::sync::Arc;
24
25use arrow::array::Float64Array;
26use arrow::datatypes::FieldRef;
27use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field};
28use datafusion_common::{internal_err, not_impl_err, Result};
29use datafusion_common::{plan_err, ScalarValue};
30use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
31use datafusion_expr::utils::format_state_name;
32use datafusion_expr::{
33    Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, Signature,
34    Volatility,
35};
36use datafusion_functions_aggregate_common::stats::StatsType;
37use datafusion_macros::user_doc;
38
39use crate::variance::{VarianceAccumulator, VarianceGroupsAccumulator};
40
41make_udaf_expr_and_func!(
42    Stddev,
43    stddev,
44    expression,
45    "Compute the standard deviation of a set of numbers",
46    stddev_udaf
47);
48
49#[user_doc(
50    doc_section(label = "Statistical Functions"),
51    description = "Returns the standard deviation of a set of numbers.",
52    syntax_example = "stddev(expression)",
53    sql_example = r#"```sql
54> SELECT stddev(column_name) FROM table_name;
55+----------------------+
56| stddev(column_name)   |
57+----------------------+
58| 12.34                |
59+----------------------+
60```"#,
61    standard_argument(name = "expression",)
62)]
63/// STDDEV and STDDEV_SAMP (standard deviation) aggregate expression
64pub struct Stddev {
65    signature: Signature,
66    alias: Vec<String>,
67}
68
69impl Debug for Stddev {
70    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
71        f.debug_struct("Stddev")
72            .field("name", &self.name())
73            .field("signature", &self.signature)
74            .finish()
75    }
76}
77
78impl Default for Stddev {
79    fn default() -> Self {
80        Self::new()
81    }
82}
83
84impl Stddev {
85    /// Create a new STDDEV aggregate function
86    pub fn new() -> Self {
87        Self {
88            signature: Signature::numeric(1, Volatility::Immutable),
89            alias: vec!["stddev_samp".to_string()],
90        }
91    }
92}
93
94impl AggregateUDFImpl for Stddev {
95    /// Return a reference to Any that can be used for downcasting
96    fn as_any(&self) -> &dyn Any {
97        self
98    }
99
100    fn name(&self) -> &str {
101        "stddev"
102    }
103
104    fn signature(&self) -> &Signature {
105        &self.signature
106    }
107
108    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
109        Ok(DataType::Float64)
110    }
111
112    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
113        Ok(vec![
114            Field::new(
115                format_state_name(args.name, "count"),
116                DataType::UInt64,
117                true,
118            ),
119            Field::new(
120                format_state_name(args.name, "mean"),
121                DataType::Float64,
122                true,
123            ),
124            Field::new(format_state_name(args.name, "m2"), DataType::Float64, true),
125        ]
126        .into_iter()
127        .map(Arc::new)
128        .collect())
129    }
130
131    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
132        if acc_args.is_distinct {
133            return not_impl_err!("STDDEV_POP(DISTINCT) aggregations are not available");
134        }
135        Ok(Box::new(StddevAccumulator::try_new(StatsType::Sample)?))
136    }
137
138    fn aliases(&self) -> &[String] {
139        &self.alias
140    }
141
142    fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool {
143        !acc_args.is_distinct
144    }
145
146    fn create_groups_accumulator(
147        &self,
148        _args: AccumulatorArgs,
149    ) -> Result<Box<dyn GroupsAccumulator>> {
150        Ok(Box::new(StddevGroupsAccumulator::new(StatsType::Sample)))
151    }
152
153    fn documentation(&self) -> Option<&Documentation> {
154        self.doc()
155    }
156}
157
158make_udaf_expr_and_func!(
159    StddevPop,
160    stddev_pop,
161    expression,
162    "Compute the population standard deviation of a set of numbers",
163    stddev_pop_udaf
164);
165
166#[user_doc(
167    doc_section(label = "Statistical Functions"),
168    description = "Returns the population standard deviation of a set of numbers.",
169    syntax_example = "stddev_pop(expression)",
170    sql_example = r#"```sql
171> SELECT stddev_pop(column_name) FROM table_name;
172+--------------------------+
173| stddev_pop(column_name)   |
174+--------------------------+
175| 10.56                    |
176+--------------------------+
177```"#,
178    standard_argument(name = "expression",)
179)]
180/// STDDEV_POP population aggregate expression
181pub struct StddevPop {
182    signature: Signature,
183}
184
185impl Debug for StddevPop {
186    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
187        f.debug_struct("StddevPop")
188            .field("name", &self.name())
189            .field("signature", &self.signature)
190            .finish()
191    }
192}
193
194impl Default for StddevPop {
195    fn default() -> Self {
196        Self::new()
197    }
198}
199
200impl StddevPop {
201    /// Create a new STDDEV_POP aggregate function
202    pub fn new() -> Self {
203        Self {
204            signature: Signature::numeric(1, Volatility::Immutable),
205        }
206    }
207}
208
209impl AggregateUDFImpl for StddevPop {
210    /// Return a reference to Any that can be used for downcasting
211    fn as_any(&self) -> &dyn Any {
212        self
213    }
214
215    fn name(&self) -> &str {
216        "stddev_pop"
217    }
218
219    fn signature(&self) -> &Signature {
220        &self.signature
221    }
222
223    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
224        Ok(vec![
225            Field::new(
226                format_state_name(args.name, "count"),
227                DataType::UInt64,
228                true,
229            ),
230            Field::new(
231                format_state_name(args.name, "mean"),
232                DataType::Float64,
233                true,
234            ),
235            Field::new(format_state_name(args.name, "m2"), DataType::Float64, true),
236        ]
237        .into_iter()
238        .map(Arc::new)
239        .collect())
240    }
241
242    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
243        if acc_args.is_distinct {
244            return not_impl_err!("STDDEV_POP(DISTINCT) aggregations are not available");
245        }
246        Ok(Box::new(StddevAccumulator::try_new(StatsType::Population)?))
247    }
248
249    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
250        if !arg_types[0].is_numeric() {
251            return plan_err!("StddevPop requires numeric input types");
252        }
253
254        Ok(DataType::Float64)
255    }
256
257    fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool {
258        !acc_args.is_distinct
259    }
260
261    fn create_groups_accumulator(
262        &self,
263        _args: AccumulatorArgs,
264    ) -> Result<Box<dyn GroupsAccumulator>> {
265        Ok(Box::new(StddevGroupsAccumulator::new(
266            StatsType::Population,
267        )))
268    }
269
270    fn documentation(&self) -> Option<&Documentation> {
271        self.doc()
272    }
273}
274
275/// An accumulator to compute the average
276#[derive(Debug)]
277pub struct StddevAccumulator {
278    variance: VarianceAccumulator,
279}
280
281impl StddevAccumulator {
282    /// Creates a new `StddevAccumulator`
283    pub fn try_new(s_type: StatsType) -> Result<Self> {
284        Ok(Self {
285            variance: VarianceAccumulator::try_new(s_type)?,
286        })
287    }
288
289    pub fn get_m2(&self) -> f64 {
290        self.variance.get_m2()
291    }
292}
293
294impl Accumulator for StddevAccumulator {
295    fn state(&mut self) -> Result<Vec<ScalarValue>> {
296        Ok(vec![
297            ScalarValue::from(self.variance.get_count()),
298            ScalarValue::from(self.variance.get_mean()),
299            ScalarValue::from(self.variance.get_m2()),
300        ])
301    }
302
303    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
304        self.variance.update_batch(values)
305    }
306
307    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
308        self.variance.retract_batch(values)
309    }
310
311    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
312        self.variance.merge_batch(states)
313    }
314
315    fn evaluate(&mut self) -> Result<ScalarValue> {
316        let variance = self.variance.evaluate()?;
317        match variance {
318            ScalarValue::Float64(e) => {
319                if e.is_none() {
320                    Ok(ScalarValue::Float64(None))
321                } else {
322                    Ok(ScalarValue::Float64(e.map(|f| f.sqrt())))
323                }
324            }
325            _ => internal_err!("Variance should be f64"),
326        }
327    }
328
329    fn size(&self) -> usize {
330        align_of_val(self) - align_of_val(&self.variance) + self.variance.size()
331    }
332
333    fn supports_retract_batch(&self) -> bool {
334        self.variance.supports_retract_batch()
335    }
336}
337
338#[derive(Debug)]
339pub struct StddevGroupsAccumulator {
340    variance: VarianceGroupsAccumulator,
341}
342
343impl StddevGroupsAccumulator {
344    pub fn new(s_type: StatsType) -> Self {
345        Self {
346            variance: VarianceGroupsAccumulator::new(s_type),
347        }
348    }
349}
350
351impl GroupsAccumulator for StddevGroupsAccumulator {
352    fn update_batch(
353        &mut self,
354        values: &[ArrayRef],
355        group_indices: &[usize],
356        opt_filter: Option<&arrow::array::BooleanArray>,
357        total_num_groups: usize,
358    ) -> Result<()> {
359        self.variance
360            .update_batch(values, group_indices, opt_filter, total_num_groups)
361    }
362
363    fn merge_batch(
364        &mut self,
365        values: &[ArrayRef],
366        group_indices: &[usize],
367        opt_filter: Option<&arrow::array::BooleanArray>,
368        total_num_groups: usize,
369    ) -> Result<()> {
370        self.variance
371            .merge_batch(values, group_indices, opt_filter, total_num_groups)
372    }
373
374    fn evaluate(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<ArrayRef> {
375        let (mut variances, nulls) = self.variance.variance(emit_to);
376        variances.iter_mut().for_each(|v| *v = v.sqrt());
377        Ok(Arc::new(Float64Array::new(variances.into(), Some(nulls))))
378    }
379
380    fn state(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<Vec<ArrayRef>> {
381        self.variance.state(emit_to)
382    }
383
384    fn size(&self) -> usize {
385        self.variance.size()
386    }
387}
388
389#[cfg(test)]
390mod tests {
391    use super::*;
392    use arrow::{array::*, datatypes::*};
393    use datafusion_expr::AggregateUDF;
394    use datafusion_functions_aggregate_common::utils::get_accum_scalar_values_as_arrays;
395    use datafusion_physical_expr::expressions::col;
396    use datafusion_physical_expr_common::sort_expr::LexOrdering;
397    use std::sync::Arc;
398
399    #[test]
400    fn stddev_f64_merge_1() -> Result<()> {
401        let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64]));
402        let b = Arc::new(Float64Array::from(vec![4_f64, 5_f64]));
403
404        let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]);
405
406        let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?;
407        let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![b])?;
408
409        let agg1 = stddev_pop_udaf();
410        let agg2 = stddev_pop_udaf();
411
412        let actual = merge(&batch1, &batch2, agg1, agg2, &schema)?;
413        assert_eq!(actual, ScalarValue::from(std::f64::consts::SQRT_2));
414
415        Ok(())
416    }
417
418    #[test]
419    fn stddev_f64_merge_2() -> Result<()> {
420        let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64]));
421        let b = Arc::new(Float64Array::from(vec![None]));
422
423        let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]);
424
425        let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?;
426        let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![b])?;
427
428        let agg1 = stddev_pop_udaf();
429        let agg2 = stddev_pop_udaf();
430
431        let actual = merge(&batch1, &batch2, agg1, agg2, &schema)?;
432        assert_eq!(actual, ScalarValue::from(std::f64::consts::SQRT_2));
433
434        Ok(())
435    }
436
437    fn merge(
438        batch1: &RecordBatch,
439        batch2: &RecordBatch,
440        agg1: Arc<AggregateUDF>,
441        agg2: Arc<AggregateUDF>,
442        schema: &Schema,
443    ) -> Result<ScalarValue> {
444        let args1 = AccumulatorArgs {
445            return_field: Field::new("f", DataType::Float64, true).into(),
446            schema,
447            ignore_nulls: false,
448            ordering_req: &LexOrdering::default(),
449            name: "a",
450            is_distinct: false,
451            is_reversed: false,
452            exprs: &[col("a", schema)?],
453        };
454
455        let args2 = AccumulatorArgs {
456            return_field: Field::new("f", DataType::Float64, true).into(),
457            schema,
458            ignore_nulls: false,
459            ordering_req: &LexOrdering::default(),
460            name: "a",
461            is_distinct: false,
462            is_reversed: false,
463            exprs: &[col("a", schema)?],
464        };
465
466        let mut accum1 = agg1.accumulator(args1)?;
467        let mut accum2 = agg2.accumulator(args2)?;
468
469        let value1 = vec![col("a", schema)?
470            .evaluate(batch1)
471            .and_then(|v| v.into_array(batch1.num_rows()))?];
472        let value2 = vec![col("a", schema)?
473            .evaluate(batch2)
474            .and_then(|v| v.into_array(batch2.num_rows()))?];
475
476        accum1.update_batch(&value1)?;
477        accum2.update_batch(&value2)?;
478        let state2 = get_accum_scalar_values_as_arrays(accum2.as_mut())?;
479        accum1.merge_batch(&state2)?;
480        let result = accum1.evaluate()?;
481        Ok(result)
482    }
483}