Skip to main content

datafusion_functions_aggregate/
stddev.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines physical expressions that can evaluated at runtime during query execution
19
20use std::fmt::Debug;
21use std::hash::Hash;
22use std::mem::align_of_val;
23use std::sync::Arc;
24
25use arrow::array::Float64Array;
26use arrow::datatypes::FieldRef;
27use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field};
28use datafusion_common::ScalarValue;
29use datafusion_common::{Result, internal_err, not_impl_err};
30use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
31use datafusion_expr::utils::format_state_name;
32use datafusion_expr::{
33    Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, Signature,
34    Volatility,
35};
36use datafusion_functions_aggregate_common::stats::StatsType;
37use datafusion_macros::user_doc;
38
39use crate::variance::{VarianceAccumulator, VarianceGroupsAccumulator};
40
41make_udaf_expr_and_func!(
42    Stddev,
43    stddev,
44    expression,
45    "Compute the standard deviation of a set of numbers",
46    stddev_udaf
47);
48
49#[user_doc(
50    doc_section(label = "Statistical Functions"),
51    description = "Returns the standard deviation of a set of numbers.",
52    syntax_example = "stddev(expression)",
53    sql_example = r#"```sql
54> SELECT stddev(column_name) FROM table_name;
55+----------------------+
56| stddev(column_name)   |
57+----------------------+
58| 12.34                |
59+----------------------+
60```"#,
61    standard_argument(name = "expression",)
62)]
63/// STDDEV and STDDEV_SAMP (standard deviation) aggregate expression
64#[derive(PartialEq, Eq, Hash, Debug)]
65pub struct Stddev {
66    signature: Signature,
67    alias: Vec<String>,
68}
69
70impl Default for Stddev {
71    fn default() -> Self {
72        Self::new()
73    }
74}
75
76impl Stddev {
77    /// Create a new STDDEV aggregate function
78    pub fn new() -> Self {
79        Self {
80            signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable),
81            alias: vec!["stddev_samp".to_string()],
82        }
83    }
84}
85
86impl AggregateUDFImpl for Stddev {
87    fn name(&self) -> &str {
88        "stddev"
89    }
90
91    fn signature(&self) -> &Signature {
92        &self.signature
93    }
94
95    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
96        Ok(DataType::Float64)
97    }
98
99    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
100        Ok(vec![
101            Field::new(
102                format_state_name(args.name, "count"),
103                DataType::UInt64,
104                true,
105            ),
106            Field::new(
107                format_state_name(args.name, "mean"),
108                DataType::Float64,
109                true,
110            ),
111            Field::new(format_state_name(args.name, "m2"), DataType::Float64, true),
112        ]
113        .into_iter()
114        .map(Arc::new)
115        .collect())
116    }
117
118    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
119        if acc_args.is_distinct {
120            return not_impl_err!("STDDEV_POP(DISTINCT) aggregations are not available");
121        }
122        Ok(Box::new(StddevAccumulator::try_new(StatsType::Sample)?))
123    }
124
125    fn aliases(&self) -> &[String] {
126        &self.alias
127    }
128
129    fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool {
130        !acc_args.is_distinct
131    }
132
133    fn create_groups_accumulator(
134        &self,
135        _args: AccumulatorArgs,
136    ) -> Result<Box<dyn GroupsAccumulator>> {
137        Ok(Box::new(StddevGroupsAccumulator::new(StatsType::Sample)))
138    }
139
140    fn documentation(&self) -> Option<&Documentation> {
141        self.doc()
142    }
143}
144
145make_udaf_expr_and_func!(
146    StddevPop,
147    stddev_pop,
148    expression,
149    "Compute the population standard deviation of a set of numbers",
150    stddev_pop_udaf
151);
152
153#[user_doc(
154    doc_section(label = "Statistical Functions"),
155    description = "Returns the population standard deviation of a set of numbers.",
156    syntax_example = "stddev_pop(expression)",
157    sql_example = r#"```sql
158> SELECT stddev_pop(column_name) FROM table_name;
159+--------------------------+
160| stddev_pop(column_name)   |
161+--------------------------+
162| 10.56                    |
163+--------------------------+
164```"#,
165    standard_argument(name = "expression",)
166)]
167/// STDDEV_POP population aggregate expression
168#[derive(PartialEq, Eq, Hash, Debug)]
169pub struct StddevPop {
170    signature: Signature,
171}
172
173impl Default for StddevPop {
174    fn default() -> Self {
175        Self::new()
176    }
177}
178
179impl StddevPop {
180    /// Create a new STDDEV_POP aggregate function
181    pub fn new() -> Self {
182        Self {
183            signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable),
184        }
185    }
186}
187
188impl AggregateUDFImpl for StddevPop {
189    fn name(&self) -> &str {
190        "stddev_pop"
191    }
192
193    fn signature(&self) -> &Signature {
194        &self.signature
195    }
196
197    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
198        Ok(vec![
199            Field::new(
200                format_state_name(args.name, "count"),
201                DataType::UInt64,
202                true,
203            ),
204            Field::new(
205                format_state_name(args.name, "mean"),
206                DataType::Float64,
207                true,
208            ),
209            Field::new(format_state_name(args.name, "m2"), DataType::Float64, true),
210        ]
211        .into_iter()
212        .map(Arc::new)
213        .collect())
214    }
215
216    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
217        if acc_args.is_distinct {
218            return not_impl_err!("STDDEV_POP(DISTINCT) aggregations are not available");
219        }
220        Ok(Box::new(StddevAccumulator::try_new(StatsType::Population)?))
221    }
222
223    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
224        Ok(DataType::Float64)
225    }
226
227    fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool {
228        !acc_args.is_distinct
229    }
230
231    fn create_groups_accumulator(
232        &self,
233        _args: AccumulatorArgs,
234    ) -> Result<Box<dyn GroupsAccumulator>> {
235        Ok(Box::new(StddevGroupsAccumulator::new(
236            StatsType::Population,
237        )))
238    }
239
240    fn documentation(&self) -> Option<&Documentation> {
241        self.doc()
242    }
243}
244
245/// An accumulator to compute the average
246#[derive(Debug)]
247pub struct StddevAccumulator {
248    variance: VarianceAccumulator,
249}
250
251impl StddevAccumulator {
252    /// Creates a new `StddevAccumulator`
253    pub fn try_new(s_type: StatsType) -> Result<Self> {
254        Ok(Self {
255            variance: VarianceAccumulator::try_new(s_type)?,
256        })
257    }
258
259    pub fn get_m2(&self) -> f64 {
260        self.variance.get_m2()
261    }
262}
263
264impl Accumulator for StddevAccumulator {
265    fn state(&mut self) -> Result<Vec<ScalarValue>> {
266        Ok(vec![
267            ScalarValue::from(self.variance.get_count()),
268            ScalarValue::from(self.variance.get_mean()),
269            ScalarValue::from(self.variance.get_m2()),
270        ])
271    }
272
273    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
274        self.variance.update_batch(values)
275    }
276
277    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
278        self.variance.retract_batch(values)
279    }
280
281    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
282        self.variance.merge_batch(states)
283    }
284
285    fn evaluate(&mut self) -> Result<ScalarValue> {
286        let variance = self.variance.evaluate()?;
287        match variance {
288            ScalarValue::Float64(None) => Ok(ScalarValue::Float64(None)),
289            ScalarValue::Float64(Some(f)) => Ok(ScalarValue::Float64(Some(f.sqrt()))),
290            _ => internal_err!("Variance should be f64"),
291        }
292    }
293
294    fn size(&self) -> usize {
295        align_of_val(self) - align_of_val(&self.variance) + self.variance.size()
296    }
297
298    fn supports_retract_batch(&self) -> bool {
299        self.variance.supports_retract_batch()
300    }
301}
302
303#[derive(Debug)]
304pub struct StddevGroupsAccumulator {
305    variance: VarianceGroupsAccumulator,
306}
307
308impl StddevGroupsAccumulator {
309    pub fn new(s_type: StatsType) -> Self {
310        Self {
311            variance: VarianceGroupsAccumulator::new(s_type),
312        }
313    }
314}
315
316impl GroupsAccumulator for StddevGroupsAccumulator {
317    fn update_batch(
318        &mut self,
319        values: &[ArrayRef],
320        group_indices: &[usize],
321        opt_filter: Option<&arrow::array::BooleanArray>,
322        total_num_groups: usize,
323    ) -> Result<()> {
324        self.variance
325            .update_batch(values, group_indices, opt_filter, total_num_groups)
326    }
327
328    fn merge_batch(
329        &mut self,
330        values: &[ArrayRef],
331        group_indices: &[usize],
332        opt_filter: Option<&arrow::array::BooleanArray>,
333        total_num_groups: usize,
334    ) -> Result<()> {
335        self.variance
336            .merge_batch(values, group_indices, opt_filter, total_num_groups)
337    }
338
339    fn evaluate(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<ArrayRef> {
340        let (mut variances, nulls) = self.variance.variance(emit_to);
341        variances.iter_mut().for_each(|v| *v = v.sqrt());
342        Ok(Arc::new(Float64Array::new(variances.into(), Some(nulls))))
343    }
344
345    fn state(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<Vec<ArrayRef>> {
346        self.variance.state(emit_to)
347    }
348
349    fn size(&self) -> usize {
350        self.variance.size()
351    }
352}
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357    use arrow::{array::*, datatypes::*};
358    use datafusion_expr::AggregateUDF;
359    use datafusion_functions_aggregate_common::utils::get_accum_scalar_values_as_arrays;
360    use datafusion_physical_expr::expressions::col;
361
362    #[test]
363    fn stddev_f64_merge_1() -> Result<()> {
364        let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64]));
365        let b = Arc::new(Float64Array::from(vec![4_f64, 5_f64]));
366
367        let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]);
368
369        let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?;
370        let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![b])?;
371
372        let agg1 = stddev_pop_udaf();
373        let agg2 = stddev_pop_udaf();
374
375        let actual = merge(&batch1, &batch2, agg1, agg2, &schema)?;
376        assert_eq!(actual, ScalarValue::from(std::f64::consts::SQRT_2));
377
378        Ok(())
379    }
380
381    #[test]
382    fn stddev_f64_merge_2() -> Result<()> {
383        let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64]));
384        let b = Arc::new(Float64Array::from(vec![None]));
385
386        let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]);
387
388        let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?;
389        let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![b])?;
390
391        let agg1 = stddev_pop_udaf();
392        let agg2 = stddev_pop_udaf();
393
394        let actual = merge(&batch1, &batch2, agg1, agg2, &schema)?;
395        assert_eq!(actual, ScalarValue::from(std::f64::consts::SQRT_2));
396
397        Ok(())
398    }
399
400    fn merge(
401        batch1: &RecordBatch,
402        batch2: &RecordBatch,
403        agg1: Arc<AggregateUDF>,
404        agg2: Arc<AggregateUDF>,
405        schema: &Schema,
406    ) -> Result<ScalarValue> {
407        let expr = col("a", schema)?;
408        let expr_field = expr.return_field(schema)?;
409
410        let args1 = AccumulatorArgs {
411            return_field: Field::new("f", DataType::Float64, true).into(),
412            schema,
413            expr_fields: &[Arc::clone(&expr_field)],
414            ignore_nulls: false,
415            order_bys: &[],
416            name: "a",
417            is_distinct: false,
418            is_reversed: false,
419            exprs: &[Arc::clone(&expr)],
420        };
421
422        let args2 = AccumulatorArgs {
423            return_field: Field::new("f", DataType::Float64, true).into(),
424            schema,
425            expr_fields: &[expr_field],
426            ignore_nulls: false,
427            order_bys: &[],
428            name: "a",
429            is_distinct: false,
430            is_reversed: false,
431            exprs: &[expr],
432        };
433
434        let mut accum1 = agg1.accumulator(args1)?;
435        let mut accum2 = agg2.accumulator(args2)?;
436
437        let value1 = vec![
438            col("a", schema)?
439                .evaluate(batch1)
440                .and_then(|v| v.into_array(batch1.num_rows()))?,
441        ];
442        let value2 = vec![
443            col("a", schema)?
444                .evaluate(batch2)
445                .and_then(|v| v.into_array(batch2.num_rows()))?,
446        ];
447
448        accum1.update_batch(&value1)?;
449        accum2.update_batch(&value2)?;
450        let state2 = get_accum_scalar_values_as_arrays(accum2.as_mut())?;
451        accum1.merge_batch(&state2)?;
452        let result = accum1.evaluate()?;
453        Ok(result)
454    }
455}