1use std::any::Any;
21use std::fmt::{Debug, Formatter};
22use std::hash::{DefaultHasher, Hash, Hasher};
23use std::mem::align_of_val;
24use std::sync::Arc;
25
26use arrow::array::Float64Array;
27use arrow::datatypes::FieldRef;
28use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field};
29use datafusion_common::{internal_err, not_impl_err, Result};
30use datafusion_common::{plan_err, ScalarValue};
31use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
32use datafusion_expr::utils::format_state_name;
33use datafusion_expr::{
34 Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, Signature,
35 Volatility,
36};
37use datafusion_functions_aggregate_common::stats::StatsType;
38use datafusion_macros::user_doc;
39
40use crate::variance::{VarianceAccumulator, VarianceGroupsAccumulator};
41
42make_udaf_expr_and_func!(
43 Stddev,
44 stddev,
45 expression,
46 "Compute the standard deviation of a set of numbers",
47 stddev_udaf
48);
49
50#[user_doc(
51 doc_section(label = "Statistical Functions"),
52 description = "Returns the standard deviation of a set of numbers.",
53 syntax_example = "stddev(expression)",
54 sql_example = r#"```sql
55> SELECT stddev(column_name) FROM table_name;
56+----------------------+
57| stddev(column_name) |
58+----------------------+
59| 12.34 |
60+----------------------+
61```"#,
62 standard_argument(name = "expression",)
63)]
64pub struct Stddev {
66 signature: Signature,
67 alias: Vec<String>,
68}
69
70impl Debug for Stddev {
71 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
72 f.debug_struct("Stddev")
73 .field("name", &self.name())
74 .field("signature", &self.signature)
75 .finish()
76 }
77}
78
79impl Default for Stddev {
80 fn default() -> Self {
81 Self::new()
82 }
83}
84
85impl Stddev {
86 pub fn new() -> Self {
88 Self {
89 signature: Signature::numeric(1, Volatility::Immutable),
90 alias: vec!["stddev_samp".to_string()],
91 }
92 }
93}
94
95impl AggregateUDFImpl for Stddev {
96 fn as_any(&self) -> &dyn Any {
98 self
99 }
100
101 fn name(&self) -> &str {
102 "stddev"
103 }
104
105 fn signature(&self) -> &Signature {
106 &self.signature
107 }
108
109 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
110 Ok(DataType::Float64)
111 }
112
113 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
114 Ok(vec![
115 Field::new(
116 format_state_name(args.name, "count"),
117 DataType::UInt64,
118 true,
119 ),
120 Field::new(
121 format_state_name(args.name, "mean"),
122 DataType::Float64,
123 true,
124 ),
125 Field::new(format_state_name(args.name, "m2"), DataType::Float64, true),
126 ]
127 .into_iter()
128 .map(Arc::new)
129 .collect())
130 }
131
132 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
133 if acc_args.is_distinct {
134 return not_impl_err!("STDDEV_POP(DISTINCT) aggregations are not available");
135 }
136 Ok(Box::new(StddevAccumulator::try_new(StatsType::Sample)?))
137 }
138
139 fn aliases(&self) -> &[String] {
140 &self.alias
141 }
142
143 fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool {
144 !acc_args.is_distinct
145 }
146
147 fn create_groups_accumulator(
148 &self,
149 _args: AccumulatorArgs,
150 ) -> Result<Box<dyn GroupsAccumulator>> {
151 Ok(Box::new(StddevGroupsAccumulator::new(StatsType::Sample)))
152 }
153
154 fn documentation(&self) -> Option<&Documentation> {
155 self.doc()
156 }
157
158 fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
159 let Some(other) = other.as_any().downcast_ref::<Self>() else {
160 return false;
161 };
162 let Self { signature, alias } = self;
163 signature == &other.signature && alias == &other.alias
164 }
165
166 fn hash_value(&self) -> u64 {
167 let Self { signature, alias } = self;
168 let mut hasher = DefaultHasher::new();
169 std::any::type_name::<Self>().hash(&mut hasher);
170 signature.hash(&mut hasher);
171 alias.hash(&mut hasher);
172 hasher.finish()
173 }
174}
175
176make_udaf_expr_and_func!(
177 StddevPop,
178 stddev_pop,
179 expression,
180 "Compute the population standard deviation of a set of numbers",
181 stddev_pop_udaf
182);
183
184#[user_doc(
185 doc_section(label = "Statistical Functions"),
186 description = "Returns the population standard deviation of a set of numbers.",
187 syntax_example = "stddev_pop(expression)",
188 sql_example = r#"```sql
189> SELECT stddev_pop(column_name) FROM table_name;
190+--------------------------+
191| stddev_pop(column_name) |
192+--------------------------+
193| 10.56 |
194+--------------------------+
195```"#,
196 standard_argument(name = "expression",)
197)]
198pub struct StddevPop {
200 signature: Signature,
201}
202
203impl Debug for StddevPop {
204 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
205 f.debug_struct("StddevPop")
206 .field("name", &self.name())
207 .field("signature", &self.signature)
208 .finish()
209 }
210}
211
212impl Default for StddevPop {
213 fn default() -> Self {
214 Self::new()
215 }
216}
217
218impl StddevPop {
219 pub fn new() -> Self {
221 Self {
222 signature: Signature::numeric(1, Volatility::Immutable),
223 }
224 }
225}
226
227impl AggregateUDFImpl for StddevPop {
228 fn as_any(&self) -> &dyn Any {
230 self
231 }
232
233 fn name(&self) -> &str {
234 "stddev_pop"
235 }
236
237 fn signature(&self) -> &Signature {
238 &self.signature
239 }
240
241 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
242 Ok(vec![
243 Field::new(
244 format_state_name(args.name, "count"),
245 DataType::UInt64,
246 true,
247 ),
248 Field::new(
249 format_state_name(args.name, "mean"),
250 DataType::Float64,
251 true,
252 ),
253 Field::new(format_state_name(args.name, "m2"), DataType::Float64, true),
254 ]
255 .into_iter()
256 .map(Arc::new)
257 .collect())
258 }
259
260 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
261 if acc_args.is_distinct {
262 return not_impl_err!("STDDEV_POP(DISTINCT) aggregations are not available");
263 }
264 Ok(Box::new(StddevAccumulator::try_new(StatsType::Population)?))
265 }
266
267 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
268 if !arg_types[0].is_numeric() {
269 return plan_err!("StddevPop requires numeric input types");
270 }
271
272 Ok(DataType::Float64)
273 }
274
275 fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool {
276 !acc_args.is_distinct
277 }
278
279 fn create_groups_accumulator(
280 &self,
281 _args: AccumulatorArgs,
282 ) -> Result<Box<dyn GroupsAccumulator>> {
283 Ok(Box::new(StddevGroupsAccumulator::new(
284 StatsType::Population,
285 )))
286 }
287
288 fn documentation(&self) -> Option<&Documentation> {
289 self.doc()
290 }
291}
292
293#[derive(Debug)]
295pub struct StddevAccumulator {
296 variance: VarianceAccumulator,
297}
298
299impl StddevAccumulator {
300 pub fn try_new(s_type: StatsType) -> Result<Self> {
302 Ok(Self {
303 variance: VarianceAccumulator::try_new(s_type)?,
304 })
305 }
306
307 pub fn get_m2(&self) -> f64 {
308 self.variance.get_m2()
309 }
310}
311
312impl Accumulator for StddevAccumulator {
313 fn state(&mut self) -> Result<Vec<ScalarValue>> {
314 Ok(vec![
315 ScalarValue::from(self.variance.get_count()),
316 ScalarValue::from(self.variance.get_mean()),
317 ScalarValue::from(self.variance.get_m2()),
318 ])
319 }
320
321 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
322 self.variance.update_batch(values)
323 }
324
325 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
326 self.variance.retract_batch(values)
327 }
328
329 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
330 self.variance.merge_batch(states)
331 }
332
333 fn evaluate(&mut self) -> Result<ScalarValue> {
334 let variance = self.variance.evaluate()?;
335 match variance {
336 ScalarValue::Float64(e) => {
337 if e.is_none() {
338 Ok(ScalarValue::Float64(None))
339 } else {
340 Ok(ScalarValue::Float64(e.map(|f| f.sqrt())))
341 }
342 }
343 _ => internal_err!("Variance should be f64"),
344 }
345 }
346
347 fn size(&self) -> usize {
348 align_of_val(self) - align_of_val(&self.variance) + self.variance.size()
349 }
350
351 fn supports_retract_batch(&self) -> bool {
352 self.variance.supports_retract_batch()
353 }
354}
355
356#[derive(Debug)]
357pub struct StddevGroupsAccumulator {
358 variance: VarianceGroupsAccumulator,
359}
360
361impl StddevGroupsAccumulator {
362 pub fn new(s_type: StatsType) -> Self {
363 Self {
364 variance: VarianceGroupsAccumulator::new(s_type),
365 }
366 }
367}
368
369impl GroupsAccumulator for StddevGroupsAccumulator {
370 fn update_batch(
371 &mut self,
372 values: &[ArrayRef],
373 group_indices: &[usize],
374 opt_filter: Option<&arrow::array::BooleanArray>,
375 total_num_groups: usize,
376 ) -> Result<()> {
377 self.variance
378 .update_batch(values, group_indices, opt_filter, total_num_groups)
379 }
380
381 fn merge_batch(
382 &mut self,
383 values: &[ArrayRef],
384 group_indices: &[usize],
385 opt_filter: Option<&arrow::array::BooleanArray>,
386 total_num_groups: usize,
387 ) -> Result<()> {
388 self.variance
389 .merge_batch(values, group_indices, opt_filter, total_num_groups)
390 }
391
392 fn evaluate(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<ArrayRef> {
393 let (mut variances, nulls) = self.variance.variance(emit_to);
394 variances.iter_mut().for_each(|v| *v = v.sqrt());
395 Ok(Arc::new(Float64Array::new(variances.into(), Some(nulls))))
396 }
397
398 fn state(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<Vec<ArrayRef>> {
399 self.variance.state(emit_to)
400 }
401
402 fn size(&self) -> usize {
403 self.variance.size()
404 }
405}
406
407#[cfg(test)]
408mod tests {
409 use super::*;
410 use arrow::{array::*, datatypes::*};
411 use datafusion_expr::AggregateUDF;
412 use datafusion_functions_aggregate_common::utils::get_accum_scalar_values_as_arrays;
413 use datafusion_physical_expr::expressions::col;
414 use std::sync::Arc;
415
416 #[test]
417 fn stddev_f64_merge_1() -> Result<()> {
418 let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64]));
419 let b = Arc::new(Float64Array::from(vec![4_f64, 5_f64]));
420
421 let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]);
422
423 let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?;
424 let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![b])?;
425
426 let agg1 = stddev_pop_udaf();
427 let agg2 = stddev_pop_udaf();
428
429 let actual = merge(&batch1, &batch2, agg1, agg2, &schema)?;
430 assert_eq!(actual, ScalarValue::from(std::f64::consts::SQRT_2));
431
432 Ok(())
433 }
434
435 #[test]
436 fn stddev_f64_merge_2() -> Result<()> {
437 let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64]));
438 let b = Arc::new(Float64Array::from(vec![None]));
439
440 let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]);
441
442 let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?;
443 let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![b])?;
444
445 let agg1 = stddev_pop_udaf();
446 let agg2 = stddev_pop_udaf();
447
448 let actual = merge(&batch1, &batch2, agg1, agg2, &schema)?;
449 assert_eq!(actual, ScalarValue::from(std::f64::consts::SQRT_2));
450
451 Ok(())
452 }
453
454 fn merge(
455 batch1: &RecordBatch,
456 batch2: &RecordBatch,
457 agg1: Arc<AggregateUDF>,
458 agg2: Arc<AggregateUDF>,
459 schema: &Schema,
460 ) -> Result<ScalarValue> {
461 let args1 = AccumulatorArgs {
462 return_field: Field::new("f", DataType::Float64, true).into(),
463 schema,
464 ignore_nulls: false,
465 order_bys: &[],
466 name: "a",
467 is_distinct: false,
468 is_reversed: false,
469 exprs: &[col("a", schema)?],
470 };
471
472 let args2 = AccumulatorArgs {
473 return_field: Field::new("f", DataType::Float64, true).into(),
474 schema,
475 ignore_nulls: false,
476 order_bys: &[],
477 name: "a",
478 is_distinct: false,
479 is_reversed: false,
480 exprs: &[col("a", schema)?],
481 };
482
483 let mut accum1 = agg1.accumulator(args1)?;
484 let mut accum2 = agg2.accumulator(args2)?;
485
486 let value1 = vec![col("a", schema)?
487 .evaluate(batch1)
488 .and_then(|v| v.into_array(batch1.num_rows()))?];
489 let value2 = vec![col("a", schema)?
490 .evaluate(batch2)
491 .and_then(|v| v.into_array(batch2.num_rows()))?];
492
493 accum1.update_batch(&value1)?;
494 accum2.update_batch(&value2)?;
495 let state2 = get_accum_scalar_values_as_arrays(accum2.as_mut())?;
496 accum1.merge_batch(&state2)?;
497 let result = accum1.evaluate()?;
498 Ok(result)
499 }
500}