Skip to main content

datafusion_functions_aggregate/
stddev.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines physical expressions that can evaluated at runtime during query execution
19
20use std::any::Any;
21use std::fmt::Debug;
22use std::hash::Hash;
23use std::mem::align_of_val;
24use std::sync::Arc;
25
26use arrow::array::Float64Array;
27use arrow::datatypes::FieldRef;
28use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field};
29use datafusion_common::ScalarValue;
30use datafusion_common::{Result, internal_err, not_impl_err};
31use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
32use datafusion_expr::utils::format_state_name;
33use datafusion_expr::{
34    Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, Signature,
35    Volatility,
36};
37use datafusion_functions_aggregate_common::stats::StatsType;
38use datafusion_macros::user_doc;
39
40use crate::variance::{VarianceAccumulator, VarianceGroupsAccumulator};
41
42make_udaf_expr_and_func!(
43    Stddev,
44    stddev,
45    expression,
46    "Compute the standard deviation of a set of numbers",
47    stddev_udaf
48);
49
50#[user_doc(
51    doc_section(label = "Statistical Functions"),
52    description = "Returns the standard deviation of a set of numbers.",
53    syntax_example = "stddev(expression)",
54    sql_example = r#"```sql
55> SELECT stddev(column_name) FROM table_name;
56+----------------------+
57| stddev(column_name)   |
58+----------------------+
59| 12.34                |
60+----------------------+
61```"#,
62    standard_argument(name = "expression",)
63)]
64/// STDDEV and STDDEV_SAMP (standard deviation) aggregate expression
65#[derive(PartialEq, Eq, Hash, Debug)]
66pub struct Stddev {
67    signature: Signature,
68    alias: Vec<String>,
69}
70
71impl Default for Stddev {
72    fn default() -> Self {
73        Self::new()
74    }
75}
76
77impl Stddev {
78    /// Create a new STDDEV aggregate function
79    pub fn new() -> Self {
80        Self {
81            signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable),
82            alias: vec!["stddev_samp".to_string()],
83        }
84    }
85}
86
87impl AggregateUDFImpl for Stddev {
88    /// Return a reference to Any that can be used for downcasting
89    fn as_any(&self) -> &dyn Any {
90        self
91    }
92
93    fn name(&self) -> &str {
94        "stddev"
95    }
96
97    fn signature(&self) -> &Signature {
98        &self.signature
99    }
100
101    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
102        Ok(DataType::Float64)
103    }
104
105    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
106        Ok(vec![
107            Field::new(
108                format_state_name(args.name, "count"),
109                DataType::UInt64,
110                true,
111            ),
112            Field::new(
113                format_state_name(args.name, "mean"),
114                DataType::Float64,
115                true,
116            ),
117            Field::new(format_state_name(args.name, "m2"), DataType::Float64, true),
118        ]
119        .into_iter()
120        .map(Arc::new)
121        .collect())
122    }
123
124    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
125        if acc_args.is_distinct {
126            return not_impl_err!("STDDEV_POP(DISTINCT) aggregations are not available");
127        }
128        Ok(Box::new(StddevAccumulator::try_new(StatsType::Sample)?))
129    }
130
131    fn aliases(&self) -> &[String] {
132        &self.alias
133    }
134
135    fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool {
136        !acc_args.is_distinct
137    }
138
139    fn create_groups_accumulator(
140        &self,
141        _args: AccumulatorArgs,
142    ) -> Result<Box<dyn GroupsAccumulator>> {
143        Ok(Box::new(StddevGroupsAccumulator::new(StatsType::Sample)))
144    }
145
146    fn documentation(&self) -> Option<&Documentation> {
147        self.doc()
148    }
149}
150
151make_udaf_expr_and_func!(
152    StddevPop,
153    stddev_pop,
154    expression,
155    "Compute the population standard deviation of a set of numbers",
156    stddev_pop_udaf
157);
158
159#[user_doc(
160    doc_section(label = "Statistical Functions"),
161    description = "Returns the population standard deviation of a set of numbers.",
162    syntax_example = "stddev_pop(expression)",
163    sql_example = r#"```sql
164> SELECT stddev_pop(column_name) FROM table_name;
165+--------------------------+
166| stddev_pop(column_name)   |
167+--------------------------+
168| 10.56                    |
169+--------------------------+
170```"#,
171    standard_argument(name = "expression",)
172)]
173/// STDDEV_POP population aggregate expression
174#[derive(PartialEq, Eq, Hash, Debug)]
175pub struct StddevPop {
176    signature: Signature,
177}
178
179impl Default for StddevPop {
180    fn default() -> Self {
181        Self::new()
182    }
183}
184
185impl StddevPop {
186    /// Create a new STDDEV_POP aggregate function
187    pub fn new() -> Self {
188        Self {
189            signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable),
190        }
191    }
192}
193
194impl AggregateUDFImpl for StddevPop {
195    /// Return a reference to Any that can be used for downcasting
196    fn as_any(&self) -> &dyn Any {
197        self
198    }
199
200    fn name(&self) -> &str {
201        "stddev_pop"
202    }
203
204    fn signature(&self) -> &Signature {
205        &self.signature
206    }
207
208    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
209        Ok(vec![
210            Field::new(
211                format_state_name(args.name, "count"),
212                DataType::UInt64,
213                true,
214            ),
215            Field::new(
216                format_state_name(args.name, "mean"),
217                DataType::Float64,
218                true,
219            ),
220            Field::new(format_state_name(args.name, "m2"), DataType::Float64, true),
221        ]
222        .into_iter()
223        .map(Arc::new)
224        .collect())
225    }
226
227    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
228        if acc_args.is_distinct {
229            return not_impl_err!("STDDEV_POP(DISTINCT) aggregations are not available");
230        }
231        Ok(Box::new(StddevAccumulator::try_new(StatsType::Population)?))
232    }
233
234    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
235        Ok(DataType::Float64)
236    }
237
238    fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool {
239        !acc_args.is_distinct
240    }
241
242    fn create_groups_accumulator(
243        &self,
244        _args: AccumulatorArgs,
245    ) -> Result<Box<dyn GroupsAccumulator>> {
246        Ok(Box::new(StddevGroupsAccumulator::new(
247            StatsType::Population,
248        )))
249    }
250
251    fn documentation(&self) -> Option<&Documentation> {
252        self.doc()
253    }
254}
255
256/// An accumulator to compute the average
257#[derive(Debug)]
258pub struct StddevAccumulator {
259    variance: VarianceAccumulator,
260}
261
262impl StddevAccumulator {
263    /// Creates a new `StddevAccumulator`
264    pub fn try_new(s_type: StatsType) -> Result<Self> {
265        Ok(Self {
266            variance: VarianceAccumulator::try_new(s_type)?,
267        })
268    }
269
270    pub fn get_m2(&self) -> f64 {
271        self.variance.get_m2()
272    }
273}
274
275impl Accumulator for StddevAccumulator {
276    fn state(&mut self) -> Result<Vec<ScalarValue>> {
277        Ok(vec![
278            ScalarValue::from(self.variance.get_count()),
279            ScalarValue::from(self.variance.get_mean()),
280            ScalarValue::from(self.variance.get_m2()),
281        ])
282    }
283
284    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
285        self.variance.update_batch(values)
286    }
287
288    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
289        self.variance.retract_batch(values)
290    }
291
292    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
293        self.variance.merge_batch(states)
294    }
295
296    fn evaluate(&mut self) -> Result<ScalarValue> {
297        let variance = self.variance.evaluate()?;
298        match variance {
299            ScalarValue::Float64(None) => Ok(ScalarValue::Float64(None)),
300            ScalarValue::Float64(Some(f)) => Ok(ScalarValue::Float64(Some(f.sqrt()))),
301            _ => internal_err!("Variance should be f64"),
302        }
303    }
304
305    fn size(&self) -> usize {
306        align_of_val(self) - align_of_val(&self.variance) + self.variance.size()
307    }
308
309    fn supports_retract_batch(&self) -> bool {
310        self.variance.supports_retract_batch()
311    }
312}
313
314#[derive(Debug)]
315pub struct StddevGroupsAccumulator {
316    variance: VarianceGroupsAccumulator,
317}
318
319impl StddevGroupsAccumulator {
320    pub fn new(s_type: StatsType) -> Self {
321        Self {
322            variance: VarianceGroupsAccumulator::new(s_type),
323        }
324    }
325}
326
327impl GroupsAccumulator for StddevGroupsAccumulator {
328    fn update_batch(
329        &mut self,
330        values: &[ArrayRef],
331        group_indices: &[usize],
332        opt_filter: Option<&arrow::array::BooleanArray>,
333        total_num_groups: usize,
334    ) -> Result<()> {
335        self.variance
336            .update_batch(values, group_indices, opt_filter, total_num_groups)
337    }
338
339    fn merge_batch(
340        &mut self,
341        values: &[ArrayRef],
342        group_indices: &[usize],
343        opt_filter: Option<&arrow::array::BooleanArray>,
344        total_num_groups: usize,
345    ) -> Result<()> {
346        self.variance
347            .merge_batch(values, group_indices, opt_filter, total_num_groups)
348    }
349
350    fn evaluate(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<ArrayRef> {
351        let (mut variances, nulls) = self.variance.variance(emit_to);
352        variances.iter_mut().for_each(|v| *v = v.sqrt());
353        Ok(Arc::new(Float64Array::new(variances.into(), Some(nulls))))
354    }
355
356    fn state(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<Vec<ArrayRef>> {
357        self.variance.state(emit_to)
358    }
359
360    fn size(&self) -> usize {
361        self.variance.size()
362    }
363}
364
365#[cfg(test)]
366mod tests {
367    use super::*;
368    use arrow::{array::*, datatypes::*};
369    use datafusion_expr::AggregateUDF;
370    use datafusion_functions_aggregate_common::utils::get_accum_scalar_values_as_arrays;
371    use datafusion_physical_expr::expressions::col;
372    use std::sync::Arc;
373
374    #[test]
375    fn stddev_f64_merge_1() -> Result<()> {
376        let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64]));
377        let b = Arc::new(Float64Array::from(vec![4_f64, 5_f64]));
378
379        let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]);
380
381        let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?;
382        let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![b])?;
383
384        let agg1 = stddev_pop_udaf();
385        let agg2 = stddev_pop_udaf();
386
387        let actual = merge(&batch1, &batch2, agg1, agg2, &schema)?;
388        assert_eq!(actual, ScalarValue::from(std::f64::consts::SQRT_2));
389
390        Ok(())
391    }
392
393    #[test]
394    fn stddev_f64_merge_2() -> Result<()> {
395        let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64]));
396        let b = Arc::new(Float64Array::from(vec![None]));
397
398        let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]);
399
400        let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?;
401        let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![b])?;
402
403        let agg1 = stddev_pop_udaf();
404        let agg2 = stddev_pop_udaf();
405
406        let actual = merge(&batch1, &batch2, agg1, agg2, &schema)?;
407        assert_eq!(actual, ScalarValue::from(std::f64::consts::SQRT_2));
408
409        Ok(())
410    }
411
412    fn merge(
413        batch1: &RecordBatch,
414        batch2: &RecordBatch,
415        agg1: Arc<AggregateUDF>,
416        agg2: Arc<AggregateUDF>,
417        schema: &Schema,
418    ) -> Result<ScalarValue> {
419        let expr = col("a", schema)?;
420        let expr_field = expr.return_field(schema)?;
421
422        let args1 = AccumulatorArgs {
423            return_field: Field::new("f", DataType::Float64, true).into(),
424            schema,
425            expr_fields: &[Arc::clone(&expr_field)],
426            ignore_nulls: false,
427            order_bys: &[],
428            name: "a",
429            is_distinct: false,
430            is_reversed: false,
431            exprs: &[Arc::clone(&expr)],
432        };
433
434        let args2 = AccumulatorArgs {
435            return_field: Field::new("f", DataType::Float64, true).into(),
436            schema,
437            expr_fields: &[expr_field],
438            ignore_nulls: false,
439            order_bys: &[],
440            name: "a",
441            is_distinct: false,
442            is_reversed: false,
443            exprs: &[expr],
444        };
445
446        let mut accum1 = agg1.accumulator(args1)?;
447        let mut accum2 = agg2.accumulator(args2)?;
448
449        let value1 = vec![
450            col("a", schema)?
451                .evaluate(batch1)
452                .and_then(|v| v.into_array(batch1.num_rows()))?,
453        ];
454        let value2 = vec![
455            col("a", schema)?
456                .evaluate(batch2)
457                .and_then(|v| v.into_array(batch2.num_rows()))?,
458        ];
459
460        accum1.update_batch(&value1)?;
461        accum2.update_batch(&value2)?;
462        let state2 = get_accum_scalar_values_as_arrays(accum2.as_mut())?;
463        accum1.merge_batch(&state2)?;
464        let result = accum1.evaluate()?;
465        Ok(result)
466    }
467}