datafusion_functions_aggregate/
stddev.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines physical expressions that can evaluated at runtime during query execution
19
20use std::any::Any;
21use std::fmt::{Debug, Formatter};
22use std::hash::Hash;
23use std::mem::align_of_val;
24use std::sync::Arc;
25
26use arrow::array::Float64Array;
27use arrow::datatypes::FieldRef;
28use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field};
29use datafusion_common::{internal_err, not_impl_err, Result};
30use datafusion_common::{plan_err, ScalarValue};
31use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
32use datafusion_expr::utils::format_state_name;
33use datafusion_expr::{
34    Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, Signature,
35    Volatility,
36};
37use datafusion_functions_aggregate_common::stats::StatsType;
38use datafusion_macros::user_doc;
39
40use crate::variance::{VarianceAccumulator, VarianceGroupsAccumulator};
41
42make_udaf_expr_and_func!(
43    Stddev,
44    stddev,
45    expression,
46    "Compute the standard deviation of a set of numbers",
47    stddev_udaf
48);
49
50#[user_doc(
51    doc_section(label = "Statistical Functions"),
52    description = "Returns the standard deviation of a set of numbers.",
53    syntax_example = "stddev(expression)",
54    sql_example = r#"```sql
55> SELECT stddev(column_name) FROM table_name;
56+----------------------+
57| stddev(column_name)   |
58+----------------------+
59| 12.34                |
60+----------------------+
61```"#,
62    standard_argument(name = "expression",)
63)]
64/// STDDEV and STDDEV_SAMP (standard deviation) aggregate expression
65#[derive(PartialEq, Eq, Hash)]
66pub struct Stddev {
67    signature: Signature,
68    alias: Vec<String>,
69}
70
71impl Debug for Stddev {
72    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
73        f.debug_struct("Stddev")
74            .field("name", &self.name())
75            .field("signature", &self.signature)
76            .finish()
77    }
78}
79
80impl Default for Stddev {
81    fn default() -> Self {
82        Self::new()
83    }
84}
85
86impl Stddev {
87    /// Create a new STDDEV aggregate function
88    pub fn new() -> Self {
89        Self {
90            signature: Signature::numeric(1, Volatility::Immutable),
91            alias: vec!["stddev_samp".to_string()],
92        }
93    }
94}
95
96impl AggregateUDFImpl for Stddev {
97    /// Return a reference to Any that can be used for downcasting
98    fn as_any(&self) -> &dyn Any {
99        self
100    }
101
102    fn name(&self) -> &str {
103        "stddev"
104    }
105
106    fn signature(&self) -> &Signature {
107        &self.signature
108    }
109
110    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
111        Ok(DataType::Float64)
112    }
113
114    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
115        Ok(vec![
116            Field::new(
117                format_state_name(args.name, "count"),
118                DataType::UInt64,
119                true,
120            ),
121            Field::new(
122                format_state_name(args.name, "mean"),
123                DataType::Float64,
124                true,
125            ),
126            Field::new(format_state_name(args.name, "m2"), DataType::Float64, true),
127        ]
128        .into_iter()
129        .map(Arc::new)
130        .collect())
131    }
132
133    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
134        if acc_args.is_distinct {
135            return not_impl_err!("STDDEV_POP(DISTINCT) aggregations are not available");
136        }
137        Ok(Box::new(StddevAccumulator::try_new(StatsType::Sample)?))
138    }
139
140    fn aliases(&self) -> &[String] {
141        &self.alias
142    }
143
144    fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool {
145        !acc_args.is_distinct
146    }
147
148    fn create_groups_accumulator(
149        &self,
150        _args: AccumulatorArgs,
151    ) -> Result<Box<dyn GroupsAccumulator>> {
152        Ok(Box::new(StddevGroupsAccumulator::new(StatsType::Sample)))
153    }
154
155    fn documentation(&self) -> Option<&Documentation> {
156        self.doc()
157    }
158}
159
160make_udaf_expr_and_func!(
161    StddevPop,
162    stddev_pop,
163    expression,
164    "Compute the population standard deviation of a set of numbers",
165    stddev_pop_udaf
166);
167
168#[user_doc(
169    doc_section(label = "Statistical Functions"),
170    description = "Returns the population standard deviation of a set of numbers.",
171    syntax_example = "stddev_pop(expression)",
172    sql_example = r#"```sql
173> SELECT stddev_pop(column_name) FROM table_name;
174+--------------------------+
175| stddev_pop(column_name)   |
176+--------------------------+
177| 10.56                    |
178+--------------------------+
179```"#,
180    standard_argument(name = "expression",)
181)]
182/// STDDEV_POP population aggregate expression
183#[derive(PartialEq, Eq, Hash)]
184pub struct StddevPop {
185    signature: Signature,
186}
187
188impl Debug for StddevPop {
189    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
190        f.debug_struct("StddevPop")
191            .field("name", &self.name())
192            .field("signature", &self.signature)
193            .finish()
194    }
195}
196
197impl Default for StddevPop {
198    fn default() -> Self {
199        Self::new()
200    }
201}
202
203impl StddevPop {
204    /// Create a new STDDEV_POP aggregate function
205    pub fn new() -> Self {
206        Self {
207            signature: Signature::numeric(1, Volatility::Immutable),
208        }
209    }
210}
211
212impl AggregateUDFImpl for StddevPop {
213    /// Return a reference to Any that can be used for downcasting
214    fn as_any(&self) -> &dyn Any {
215        self
216    }
217
218    fn name(&self) -> &str {
219        "stddev_pop"
220    }
221
222    fn signature(&self) -> &Signature {
223        &self.signature
224    }
225
226    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
227        Ok(vec![
228            Field::new(
229                format_state_name(args.name, "count"),
230                DataType::UInt64,
231                true,
232            ),
233            Field::new(
234                format_state_name(args.name, "mean"),
235                DataType::Float64,
236                true,
237            ),
238            Field::new(format_state_name(args.name, "m2"), DataType::Float64, true),
239        ]
240        .into_iter()
241        .map(Arc::new)
242        .collect())
243    }
244
245    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
246        if acc_args.is_distinct {
247            return not_impl_err!("STDDEV_POP(DISTINCT) aggregations are not available");
248        }
249        Ok(Box::new(StddevAccumulator::try_new(StatsType::Population)?))
250    }
251
252    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
253        if !arg_types[0].is_numeric() {
254            return plan_err!("StddevPop requires numeric input types");
255        }
256
257        Ok(DataType::Float64)
258    }
259
260    fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool {
261        !acc_args.is_distinct
262    }
263
264    fn create_groups_accumulator(
265        &self,
266        _args: AccumulatorArgs,
267    ) -> Result<Box<dyn GroupsAccumulator>> {
268        Ok(Box::new(StddevGroupsAccumulator::new(
269            StatsType::Population,
270        )))
271    }
272
273    fn documentation(&self) -> Option<&Documentation> {
274        self.doc()
275    }
276}
277
278/// An accumulator to compute the average
279#[derive(Debug)]
280pub struct StddevAccumulator {
281    variance: VarianceAccumulator,
282}
283
284impl StddevAccumulator {
285    /// Creates a new `StddevAccumulator`
286    pub fn try_new(s_type: StatsType) -> Result<Self> {
287        Ok(Self {
288            variance: VarianceAccumulator::try_new(s_type)?,
289        })
290    }
291
292    pub fn get_m2(&self) -> f64 {
293        self.variance.get_m2()
294    }
295}
296
297impl Accumulator for StddevAccumulator {
298    fn state(&mut self) -> Result<Vec<ScalarValue>> {
299        Ok(vec![
300            ScalarValue::from(self.variance.get_count()),
301            ScalarValue::from(self.variance.get_mean()),
302            ScalarValue::from(self.variance.get_m2()),
303        ])
304    }
305
306    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
307        self.variance.update_batch(values)
308    }
309
310    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
311        self.variance.retract_batch(values)
312    }
313
314    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
315        self.variance.merge_batch(states)
316    }
317
318    fn evaluate(&mut self) -> Result<ScalarValue> {
319        let variance = self.variance.evaluate()?;
320        match variance {
321            ScalarValue::Float64(e) => {
322                if e.is_none() {
323                    Ok(ScalarValue::Float64(None))
324                } else {
325                    Ok(ScalarValue::Float64(e.map(|f| f.sqrt())))
326                }
327            }
328            _ => internal_err!("Variance should be f64"),
329        }
330    }
331
332    fn size(&self) -> usize {
333        align_of_val(self) - align_of_val(&self.variance) + self.variance.size()
334    }
335
336    fn supports_retract_batch(&self) -> bool {
337        self.variance.supports_retract_batch()
338    }
339}
340
341#[derive(Debug)]
342pub struct StddevGroupsAccumulator {
343    variance: VarianceGroupsAccumulator,
344}
345
346impl StddevGroupsAccumulator {
347    pub fn new(s_type: StatsType) -> Self {
348        Self {
349            variance: VarianceGroupsAccumulator::new(s_type),
350        }
351    }
352}
353
354impl GroupsAccumulator for StddevGroupsAccumulator {
355    fn update_batch(
356        &mut self,
357        values: &[ArrayRef],
358        group_indices: &[usize],
359        opt_filter: Option<&arrow::array::BooleanArray>,
360        total_num_groups: usize,
361    ) -> Result<()> {
362        self.variance
363            .update_batch(values, group_indices, opt_filter, total_num_groups)
364    }
365
366    fn merge_batch(
367        &mut self,
368        values: &[ArrayRef],
369        group_indices: &[usize],
370        opt_filter: Option<&arrow::array::BooleanArray>,
371        total_num_groups: usize,
372    ) -> Result<()> {
373        self.variance
374            .merge_batch(values, group_indices, opt_filter, total_num_groups)
375    }
376
377    fn evaluate(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<ArrayRef> {
378        let (mut variances, nulls) = self.variance.variance(emit_to);
379        variances.iter_mut().for_each(|v| *v = v.sqrt());
380        Ok(Arc::new(Float64Array::new(variances.into(), Some(nulls))))
381    }
382
383    fn state(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<Vec<ArrayRef>> {
384        self.variance.state(emit_to)
385    }
386
387    fn size(&self) -> usize {
388        self.variance.size()
389    }
390}
391
392#[cfg(test)]
393mod tests {
394    use super::*;
395    use arrow::{array::*, datatypes::*};
396    use datafusion_expr::AggregateUDF;
397    use datafusion_functions_aggregate_common::utils::get_accum_scalar_values_as_arrays;
398    use datafusion_physical_expr::expressions::col;
399    use std::sync::Arc;
400
401    #[test]
402    fn stddev_f64_merge_1() -> Result<()> {
403        let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64]));
404        let b = Arc::new(Float64Array::from(vec![4_f64, 5_f64]));
405
406        let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]);
407
408        let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?;
409        let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![b])?;
410
411        let agg1 = stddev_pop_udaf();
412        let agg2 = stddev_pop_udaf();
413
414        let actual = merge(&batch1, &batch2, agg1, agg2, &schema)?;
415        assert_eq!(actual, ScalarValue::from(std::f64::consts::SQRT_2));
416
417        Ok(())
418    }
419
420    #[test]
421    fn stddev_f64_merge_2() -> Result<()> {
422        let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64]));
423        let b = Arc::new(Float64Array::from(vec![None]));
424
425        let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]);
426
427        let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?;
428        let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![b])?;
429
430        let agg1 = stddev_pop_udaf();
431        let agg2 = stddev_pop_udaf();
432
433        let actual = merge(&batch1, &batch2, agg1, agg2, &schema)?;
434        assert_eq!(actual, ScalarValue::from(std::f64::consts::SQRT_2));
435
436        Ok(())
437    }
438
439    fn merge(
440        batch1: &RecordBatch,
441        batch2: &RecordBatch,
442        agg1: Arc<AggregateUDF>,
443        agg2: Arc<AggregateUDF>,
444        schema: &Schema,
445    ) -> Result<ScalarValue> {
446        let args1 = AccumulatorArgs {
447            return_field: Field::new("f", DataType::Float64, true).into(),
448            schema,
449            ignore_nulls: false,
450            order_bys: &[],
451            name: "a",
452            is_distinct: false,
453            is_reversed: false,
454            exprs: &[col("a", schema)?],
455        };
456
457        let args2 = AccumulatorArgs {
458            return_field: Field::new("f", DataType::Float64, true).into(),
459            schema,
460            ignore_nulls: false,
461            order_bys: &[],
462            name: "a",
463            is_distinct: false,
464            is_reversed: false,
465            exprs: &[col("a", schema)?],
466        };
467
468        let mut accum1 = agg1.accumulator(args1)?;
469        let mut accum2 = agg2.accumulator(args2)?;
470
471        let value1 = vec![col("a", schema)?
472            .evaluate(batch1)
473            .and_then(|v| v.into_array(batch1.num_rows()))?];
474        let value2 = vec![col("a", schema)?
475            .evaluate(batch2)
476            .and_then(|v| v.into_array(batch2.num_rows()))?];
477
478        accum1.update_batch(&value1)?;
479        accum2.update_batch(&value2)?;
480        let state2 = get_accum_scalar_values_as_arrays(accum2.as_mut())?;
481        accum1.merge_batch(&state2)?;
482        let result = accum1.evaluate()?;
483        Ok(result)
484    }
485}