1use std::any::Any;
21use std::fmt::Debug;
22use std::hash::Hash;
23use std::mem::align_of_val;
24use std::sync::Arc;
25
26use arrow::array::Float64Array;
27use arrow::datatypes::FieldRef;
28use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field};
29use datafusion_common::ScalarValue;
30use datafusion_common::{Result, internal_err, not_impl_err};
31use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
32use datafusion_expr::utils::format_state_name;
33use datafusion_expr::{
34 Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, Signature,
35 Volatility,
36};
37use datafusion_functions_aggregate_common::stats::StatsType;
38use datafusion_macros::user_doc;
39
40use crate::variance::{VarianceAccumulator, VarianceGroupsAccumulator};
41
42make_udaf_expr_and_func!(
43 Stddev,
44 stddev,
45 expression,
46 "Compute the standard deviation of a set of numbers",
47 stddev_udaf
48);
49
50#[user_doc(
51 doc_section(label = "Statistical Functions"),
52 description = "Returns the standard deviation of a set of numbers.",
53 syntax_example = "stddev(expression)",
54 sql_example = r#"```sql
55> SELECT stddev(column_name) FROM table_name;
56+----------------------+
57| stddev(column_name) |
58+----------------------+
59| 12.34 |
60+----------------------+
61```"#,
62 standard_argument(name = "expression",)
63)]
64#[derive(PartialEq, Eq, Hash, Debug)]
66pub struct Stddev {
67 signature: Signature,
68 alias: Vec<String>,
69}
70
71impl Default for Stddev {
72 fn default() -> Self {
73 Self::new()
74 }
75}
76
77impl Stddev {
78 pub fn new() -> Self {
80 Self {
81 signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable),
82 alias: vec!["stddev_samp".to_string()],
83 }
84 }
85}
86
87impl AggregateUDFImpl for Stddev {
88 fn as_any(&self) -> &dyn Any {
90 self
91 }
92
93 fn name(&self) -> &str {
94 "stddev"
95 }
96
97 fn signature(&self) -> &Signature {
98 &self.signature
99 }
100
101 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
102 Ok(DataType::Float64)
103 }
104
105 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
106 Ok(vec![
107 Field::new(
108 format_state_name(args.name, "count"),
109 DataType::UInt64,
110 true,
111 ),
112 Field::new(
113 format_state_name(args.name, "mean"),
114 DataType::Float64,
115 true,
116 ),
117 Field::new(format_state_name(args.name, "m2"), DataType::Float64, true),
118 ]
119 .into_iter()
120 .map(Arc::new)
121 .collect())
122 }
123
124 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
125 if acc_args.is_distinct {
126 return not_impl_err!("STDDEV_POP(DISTINCT) aggregations are not available");
127 }
128 Ok(Box::new(StddevAccumulator::try_new(StatsType::Sample)?))
129 }
130
131 fn aliases(&self) -> &[String] {
132 &self.alias
133 }
134
135 fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool {
136 !acc_args.is_distinct
137 }
138
139 fn create_groups_accumulator(
140 &self,
141 _args: AccumulatorArgs,
142 ) -> Result<Box<dyn GroupsAccumulator>> {
143 Ok(Box::new(StddevGroupsAccumulator::new(StatsType::Sample)))
144 }
145
146 fn documentation(&self) -> Option<&Documentation> {
147 self.doc()
148 }
149}
150
151make_udaf_expr_and_func!(
152 StddevPop,
153 stddev_pop,
154 expression,
155 "Compute the population standard deviation of a set of numbers",
156 stddev_pop_udaf
157);
158
159#[user_doc(
160 doc_section(label = "Statistical Functions"),
161 description = "Returns the population standard deviation of a set of numbers.",
162 syntax_example = "stddev_pop(expression)",
163 sql_example = r#"```sql
164> SELECT stddev_pop(column_name) FROM table_name;
165+--------------------------+
166| stddev_pop(column_name) |
167+--------------------------+
168| 10.56 |
169+--------------------------+
170```"#,
171 standard_argument(name = "expression",)
172)]
173#[derive(PartialEq, Eq, Hash, Debug)]
175pub struct StddevPop {
176 signature: Signature,
177}
178
179impl Default for StddevPop {
180 fn default() -> Self {
181 Self::new()
182 }
183}
184
185impl StddevPop {
186 pub fn new() -> Self {
188 Self {
189 signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable),
190 }
191 }
192}
193
194impl AggregateUDFImpl for StddevPop {
195 fn as_any(&self) -> &dyn Any {
197 self
198 }
199
200 fn name(&self) -> &str {
201 "stddev_pop"
202 }
203
204 fn signature(&self) -> &Signature {
205 &self.signature
206 }
207
208 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
209 Ok(vec![
210 Field::new(
211 format_state_name(args.name, "count"),
212 DataType::UInt64,
213 true,
214 ),
215 Field::new(
216 format_state_name(args.name, "mean"),
217 DataType::Float64,
218 true,
219 ),
220 Field::new(format_state_name(args.name, "m2"), DataType::Float64, true),
221 ]
222 .into_iter()
223 .map(Arc::new)
224 .collect())
225 }
226
227 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
228 if acc_args.is_distinct {
229 return not_impl_err!("STDDEV_POP(DISTINCT) aggregations are not available");
230 }
231 Ok(Box::new(StddevAccumulator::try_new(StatsType::Population)?))
232 }
233
234 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
235 Ok(DataType::Float64)
236 }
237
238 fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool {
239 !acc_args.is_distinct
240 }
241
242 fn create_groups_accumulator(
243 &self,
244 _args: AccumulatorArgs,
245 ) -> Result<Box<dyn GroupsAccumulator>> {
246 Ok(Box::new(StddevGroupsAccumulator::new(
247 StatsType::Population,
248 )))
249 }
250
251 fn documentation(&self) -> Option<&Documentation> {
252 self.doc()
253 }
254}
255
256#[derive(Debug)]
258pub struct StddevAccumulator {
259 variance: VarianceAccumulator,
260}
261
262impl StddevAccumulator {
263 pub fn try_new(s_type: StatsType) -> Result<Self> {
265 Ok(Self {
266 variance: VarianceAccumulator::try_new(s_type)?,
267 })
268 }
269
270 pub fn get_m2(&self) -> f64 {
271 self.variance.get_m2()
272 }
273}
274
275impl Accumulator for StddevAccumulator {
276 fn state(&mut self) -> Result<Vec<ScalarValue>> {
277 Ok(vec![
278 ScalarValue::from(self.variance.get_count()),
279 ScalarValue::from(self.variance.get_mean()),
280 ScalarValue::from(self.variance.get_m2()),
281 ])
282 }
283
284 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
285 self.variance.update_batch(values)
286 }
287
288 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
289 self.variance.retract_batch(values)
290 }
291
292 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
293 self.variance.merge_batch(states)
294 }
295
296 fn evaluate(&mut self) -> Result<ScalarValue> {
297 let variance = self.variance.evaluate()?;
298 match variance {
299 ScalarValue::Float64(None) => Ok(ScalarValue::Float64(None)),
300 ScalarValue::Float64(Some(f)) => Ok(ScalarValue::Float64(Some(f.sqrt()))),
301 _ => internal_err!("Variance should be f64"),
302 }
303 }
304
305 fn size(&self) -> usize {
306 align_of_val(self) - align_of_val(&self.variance) + self.variance.size()
307 }
308
309 fn supports_retract_batch(&self) -> bool {
310 self.variance.supports_retract_batch()
311 }
312}
313
314#[derive(Debug)]
315pub struct StddevGroupsAccumulator {
316 variance: VarianceGroupsAccumulator,
317}
318
319impl StddevGroupsAccumulator {
320 pub fn new(s_type: StatsType) -> Self {
321 Self {
322 variance: VarianceGroupsAccumulator::new(s_type),
323 }
324 }
325}
326
327impl GroupsAccumulator for StddevGroupsAccumulator {
328 fn update_batch(
329 &mut self,
330 values: &[ArrayRef],
331 group_indices: &[usize],
332 opt_filter: Option<&arrow::array::BooleanArray>,
333 total_num_groups: usize,
334 ) -> Result<()> {
335 self.variance
336 .update_batch(values, group_indices, opt_filter, total_num_groups)
337 }
338
339 fn merge_batch(
340 &mut self,
341 values: &[ArrayRef],
342 group_indices: &[usize],
343 opt_filter: Option<&arrow::array::BooleanArray>,
344 total_num_groups: usize,
345 ) -> Result<()> {
346 self.variance
347 .merge_batch(values, group_indices, opt_filter, total_num_groups)
348 }
349
350 fn evaluate(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<ArrayRef> {
351 let (mut variances, nulls) = self.variance.variance(emit_to);
352 variances.iter_mut().for_each(|v| *v = v.sqrt());
353 Ok(Arc::new(Float64Array::new(variances.into(), Some(nulls))))
354 }
355
356 fn state(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<Vec<ArrayRef>> {
357 self.variance.state(emit_to)
358 }
359
360 fn size(&self) -> usize {
361 self.variance.size()
362 }
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368 use arrow::{array::*, datatypes::*};
369 use datafusion_expr::AggregateUDF;
370 use datafusion_functions_aggregate_common::utils::get_accum_scalar_values_as_arrays;
371 use datafusion_physical_expr::expressions::col;
372 use std::sync::Arc;
373
374 #[test]
375 fn stddev_f64_merge_1() -> Result<()> {
376 let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64]));
377 let b = Arc::new(Float64Array::from(vec![4_f64, 5_f64]));
378
379 let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]);
380
381 let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?;
382 let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![b])?;
383
384 let agg1 = stddev_pop_udaf();
385 let agg2 = stddev_pop_udaf();
386
387 let actual = merge(&batch1, &batch2, agg1, agg2, &schema)?;
388 assert_eq!(actual, ScalarValue::from(std::f64::consts::SQRT_2));
389
390 Ok(())
391 }
392
393 #[test]
394 fn stddev_f64_merge_2() -> Result<()> {
395 let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64]));
396 let b = Arc::new(Float64Array::from(vec![None]));
397
398 let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]);
399
400 let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?;
401 let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![b])?;
402
403 let agg1 = stddev_pop_udaf();
404 let agg2 = stddev_pop_udaf();
405
406 let actual = merge(&batch1, &batch2, agg1, agg2, &schema)?;
407 assert_eq!(actual, ScalarValue::from(std::f64::consts::SQRT_2));
408
409 Ok(())
410 }
411
412 fn merge(
413 batch1: &RecordBatch,
414 batch2: &RecordBatch,
415 agg1: Arc<AggregateUDF>,
416 agg2: Arc<AggregateUDF>,
417 schema: &Schema,
418 ) -> Result<ScalarValue> {
419 let expr = col("a", schema)?;
420 let expr_field = expr.return_field(schema)?;
421
422 let args1 = AccumulatorArgs {
423 return_field: Field::new("f", DataType::Float64, true).into(),
424 schema,
425 expr_fields: &[Arc::clone(&expr_field)],
426 ignore_nulls: false,
427 order_bys: &[],
428 name: "a",
429 is_distinct: false,
430 is_reversed: false,
431 exprs: &[Arc::clone(&expr)],
432 };
433
434 let args2 = AccumulatorArgs {
435 return_field: Field::new("f", DataType::Float64, true).into(),
436 schema,
437 expr_fields: &[expr_field],
438 ignore_nulls: false,
439 order_bys: &[],
440 name: "a",
441 is_distinct: false,
442 is_reversed: false,
443 exprs: &[expr],
444 };
445
446 let mut accum1 = agg1.accumulator(args1)?;
447 let mut accum2 = agg2.accumulator(args2)?;
448
449 let value1 = vec![
450 col("a", schema)?
451 .evaluate(batch1)
452 .and_then(|v| v.into_array(batch1.num_rows()))?,
453 ];
454 let value2 = vec![
455 col("a", schema)?
456 .evaluate(batch2)
457 .and_then(|v| v.into_array(batch2.num_rows()))?,
458 ];
459
460 accum1.update_batch(&value1)?;
461 accum2.update_batch(&value2)?;
462 let state2 = get_accum_scalar_values_as_arrays(accum2.as_mut())?;
463 accum1.merge_batch(&state2)?;
464 let result = accum1.evaluate()?;
465 Ok(result)
466 }
467}