1use std::fmt::Debug;
21use std::hash::Hash;
22use std::mem::align_of_val;
23use std::sync::Arc;
24
25use arrow::array::Float64Array;
26use arrow::datatypes::FieldRef;
27use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field};
28use datafusion_common::ScalarValue;
29use datafusion_common::{Result, internal_err, not_impl_err};
30use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
31use datafusion_expr::utils::format_state_name;
32use datafusion_expr::{
33 Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, Signature,
34 Volatility,
35};
36use datafusion_functions_aggregate_common::stats::StatsType;
37use datafusion_macros::user_doc;
38
39use crate::variance::{VarianceAccumulator, VarianceGroupsAccumulator};
40
41make_udaf_expr_and_func!(
42 Stddev,
43 stddev,
44 expression,
45 "Compute the standard deviation of a set of numbers",
46 stddev_udaf
47);
48
49#[user_doc(
50 doc_section(label = "Statistical Functions"),
51 description = "Returns the standard deviation of a set of numbers.",
52 syntax_example = "stddev(expression)",
53 sql_example = r#"```sql
54> SELECT stddev(column_name) FROM table_name;
55+----------------------+
56| stddev(column_name) |
57+----------------------+
58| 12.34 |
59+----------------------+
60```"#,
61 standard_argument(name = "expression",)
62)]
63#[derive(PartialEq, Eq, Hash, Debug)]
65pub struct Stddev {
66 signature: Signature,
67 alias: Vec<String>,
68}
69
70impl Default for Stddev {
71 fn default() -> Self {
72 Self::new()
73 }
74}
75
76impl Stddev {
77 pub fn new() -> Self {
79 Self {
80 signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable),
81 alias: vec!["stddev_samp".to_string()],
82 }
83 }
84}
85
86impl AggregateUDFImpl for Stddev {
87 fn name(&self) -> &str {
88 "stddev"
89 }
90
91 fn signature(&self) -> &Signature {
92 &self.signature
93 }
94
95 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
96 Ok(DataType::Float64)
97 }
98
99 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
100 Ok(vec![
101 Field::new(
102 format_state_name(args.name, "count"),
103 DataType::UInt64,
104 true,
105 ),
106 Field::new(
107 format_state_name(args.name, "mean"),
108 DataType::Float64,
109 true,
110 ),
111 Field::new(format_state_name(args.name, "m2"), DataType::Float64, true),
112 ]
113 .into_iter()
114 .map(Arc::new)
115 .collect())
116 }
117
118 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
119 if acc_args.is_distinct {
120 return not_impl_err!("STDDEV_POP(DISTINCT) aggregations are not available");
121 }
122 Ok(Box::new(StddevAccumulator::try_new(StatsType::Sample)?))
123 }
124
125 fn aliases(&self) -> &[String] {
126 &self.alias
127 }
128
129 fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool {
130 !acc_args.is_distinct
131 }
132
133 fn create_groups_accumulator(
134 &self,
135 _args: AccumulatorArgs,
136 ) -> Result<Box<dyn GroupsAccumulator>> {
137 Ok(Box::new(StddevGroupsAccumulator::new(StatsType::Sample)))
138 }
139
140 fn documentation(&self) -> Option<&Documentation> {
141 self.doc()
142 }
143}
144
145make_udaf_expr_and_func!(
146 StddevPop,
147 stddev_pop,
148 expression,
149 "Compute the population standard deviation of a set of numbers",
150 stddev_pop_udaf
151);
152
153#[user_doc(
154 doc_section(label = "Statistical Functions"),
155 description = "Returns the population standard deviation of a set of numbers.",
156 syntax_example = "stddev_pop(expression)",
157 sql_example = r#"```sql
158> SELECT stddev_pop(column_name) FROM table_name;
159+--------------------------+
160| stddev_pop(column_name) |
161+--------------------------+
162| 10.56 |
163+--------------------------+
164```"#,
165 standard_argument(name = "expression",)
166)]
167#[derive(PartialEq, Eq, Hash, Debug)]
169pub struct StddevPop {
170 signature: Signature,
171}
172
173impl Default for StddevPop {
174 fn default() -> Self {
175 Self::new()
176 }
177}
178
179impl StddevPop {
180 pub fn new() -> Self {
182 Self {
183 signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable),
184 }
185 }
186}
187
188impl AggregateUDFImpl for StddevPop {
189 fn name(&self) -> &str {
190 "stddev_pop"
191 }
192
193 fn signature(&self) -> &Signature {
194 &self.signature
195 }
196
197 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
198 Ok(vec![
199 Field::new(
200 format_state_name(args.name, "count"),
201 DataType::UInt64,
202 true,
203 ),
204 Field::new(
205 format_state_name(args.name, "mean"),
206 DataType::Float64,
207 true,
208 ),
209 Field::new(format_state_name(args.name, "m2"), DataType::Float64, true),
210 ]
211 .into_iter()
212 .map(Arc::new)
213 .collect())
214 }
215
216 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
217 if acc_args.is_distinct {
218 return not_impl_err!("STDDEV_POP(DISTINCT) aggregations are not available");
219 }
220 Ok(Box::new(StddevAccumulator::try_new(StatsType::Population)?))
221 }
222
223 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
224 Ok(DataType::Float64)
225 }
226
227 fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool {
228 !acc_args.is_distinct
229 }
230
231 fn create_groups_accumulator(
232 &self,
233 _args: AccumulatorArgs,
234 ) -> Result<Box<dyn GroupsAccumulator>> {
235 Ok(Box::new(StddevGroupsAccumulator::new(
236 StatsType::Population,
237 )))
238 }
239
240 fn documentation(&self) -> Option<&Documentation> {
241 self.doc()
242 }
243}
244
245#[derive(Debug)]
247pub struct StddevAccumulator {
248 variance: VarianceAccumulator,
249}
250
251impl StddevAccumulator {
252 pub fn try_new(s_type: StatsType) -> Result<Self> {
254 Ok(Self {
255 variance: VarianceAccumulator::try_new(s_type)?,
256 })
257 }
258
259 pub fn get_m2(&self) -> f64 {
260 self.variance.get_m2()
261 }
262}
263
264impl Accumulator for StddevAccumulator {
265 fn state(&mut self) -> Result<Vec<ScalarValue>> {
266 Ok(vec![
267 ScalarValue::from(self.variance.get_count()),
268 ScalarValue::from(self.variance.get_mean()),
269 ScalarValue::from(self.variance.get_m2()),
270 ])
271 }
272
273 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
274 self.variance.update_batch(values)
275 }
276
277 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
278 self.variance.retract_batch(values)
279 }
280
281 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
282 self.variance.merge_batch(states)
283 }
284
285 fn evaluate(&mut self) -> Result<ScalarValue> {
286 let variance = self.variance.evaluate()?;
287 match variance {
288 ScalarValue::Float64(None) => Ok(ScalarValue::Float64(None)),
289 ScalarValue::Float64(Some(f)) => Ok(ScalarValue::Float64(Some(f.sqrt()))),
290 _ => internal_err!("Variance should be f64"),
291 }
292 }
293
294 fn size(&self) -> usize {
295 align_of_val(self) - align_of_val(&self.variance) + self.variance.size()
296 }
297
298 fn supports_retract_batch(&self) -> bool {
299 self.variance.supports_retract_batch()
300 }
301}
302
303#[derive(Debug)]
304pub struct StddevGroupsAccumulator {
305 variance: VarianceGroupsAccumulator,
306}
307
308impl StddevGroupsAccumulator {
309 pub fn new(s_type: StatsType) -> Self {
310 Self {
311 variance: VarianceGroupsAccumulator::new(s_type),
312 }
313 }
314}
315
316impl GroupsAccumulator for StddevGroupsAccumulator {
317 fn update_batch(
318 &mut self,
319 values: &[ArrayRef],
320 group_indices: &[usize],
321 opt_filter: Option<&arrow::array::BooleanArray>,
322 total_num_groups: usize,
323 ) -> Result<()> {
324 self.variance
325 .update_batch(values, group_indices, opt_filter, total_num_groups)
326 }
327
328 fn merge_batch(
329 &mut self,
330 values: &[ArrayRef],
331 group_indices: &[usize],
332 opt_filter: Option<&arrow::array::BooleanArray>,
333 total_num_groups: usize,
334 ) -> Result<()> {
335 self.variance
336 .merge_batch(values, group_indices, opt_filter, total_num_groups)
337 }
338
339 fn evaluate(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<ArrayRef> {
340 let (mut variances, nulls) = self.variance.variance(emit_to);
341 variances.iter_mut().for_each(|v| *v = v.sqrt());
342 Ok(Arc::new(Float64Array::new(variances.into(), Some(nulls))))
343 }
344
345 fn state(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<Vec<ArrayRef>> {
346 self.variance.state(emit_to)
347 }
348
349 fn size(&self) -> usize {
350 self.variance.size()
351 }
352}
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357 use arrow::{array::*, datatypes::*};
358 use datafusion_expr::AggregateUDF;
359 use datafusion_functions_aggregate_common::utils::get_accum_scalar_values_as_arrays;
360 use datafusion_physical_expr::expressions::col;
361
362 #[test]
363 fn stddev_f64_merge_1() -> Result<()> {
364 let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64]));
365 let b = Arc::new(Float64Array::from(vec![4_f64, 5_f64]));
366
367 let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]);
368
369 let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?;
370 let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![b])?;
371
372 let agg1 = stddev_pop_udaf();
373 let agg2 = stddev_pop_udaf();
374
375 let actual = merge(&batch1, &batch2, agg1, agg2, &schema)?;
376 assert_eq!(actual, ScalarValue::from(std::f64::consts::SQRT_2));
377
378 Ok(())
379 }
380
381 #[test]
382 fn stddev_f64_merge_2() -> Result<()> {
383 let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64]));
384 let b = Arc::new(Float64Array::from(vec![None]));
385
386 let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]);
387
388 let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?;
389 let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![b])?;
390
391 let agg1 = stddev_pop_udaf();
392 let agg2 = stddev_pop_udaf();
393
394 let actual = merge(&batch1, &batch2, agg1, agg2, &schema)?;
395 assert_eq!(actual, ScalarValue::from(std::f64::consts::SQRT_2));
396
397 Ok(())
398 }
399
400 fn merge(
401 batch1: &RecordBatch,
402 batch2: &RecordBatch,
403 agg1: Arc<AggregateUDF>,
404 agg2: Arc<AggregateUDF>,
405 schema: &Schema,
406 ) -> Result<ScalarValue> {
407 let expr = col("a", schema)?;
408 let expr_field = expr.return_field(schema)?;
409
410 let args1 = AccumulatorArgs {
411 return_field: Field::new("f", DataType::Float64, true).into(),
412 schema,
413 expr_fields: &[Arc::clone(&expr_field)],
414 ignore_nulls: false,
415 order_bys: &[],
416 name: "a",
417 is_distinct: false,
418 is_reversed: false,
419 exprs: &[Arc::clone(&expr)],
420 };
421
422 let args2 = AccumulatorArgs {
423 return_field: Field::new("f", DataType::Float64, true).into(),
424 schema,
425 expr_fields: &[expr_field],
426 ignore_nulls: false,
427 order_bys: &[],
428 name: "a",
429 is_distinct: false,
430 is_reversed: false,
431 exprs: &[expr],
432 };
433
434 let mut accum1 = agg1.accumulator(args1)?;
435 let mut accum2 = agg2.accumulator(args2)?;
436
437 let value1 = vec![
438 col("a", schema)?
439 .evaluate(batch1)
440 .and_then(|v| v.into_array(batch1.num_rows()))?,
441 ];
442 let value2 = vec![
443 col("a", schema)?
444 .evaluate(batch2)
445 .and_then(|v| v.into_array(batch2.num_rows()))?,
446 ];
447
448 accum1.update_batch(&value1)?;
449 accum2.update_batch(&value2)?;
450 let state2 = get_accum_scalar_values_as_arrays(accum2.as_mut())?;
451 accum1.merge_batch(&state2)?;
452 let result = accum1.evaluate()?;
453 Ok(result)
454 }
455}