1use std::any::Any;
21use std::fmt::{Debug, Formatter};
22use std::mem::align_of_val;
23use std::sync::Arc;
24
25use arrow::array::Float64Array;
26use arrow::datatypes::FieldRef;
27use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field};
28use datafusion_common::{internal_err, not_impl_err, Result};
29use datafusion_common::{plan_err, ScalarValue};
30use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
31use datafusion_expr::utils::format_state_name;
32use datafusion_expr::{
33 Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, Signature,
34 Volatility,
35};
36use datafusion_functions_aggregate_common::stats::StatsType;
37use datafusion_macros::user_doc;
38
39use crate::variance::{VarianceAccumulator, VarianceGroupsAccumulator};
40
41make_udaf_expr_and_func!(
42 Stddev,
43 stddev,
44 expression,
45 "Compute the standard deviation of a set of numbers",
46 stddev_udaf
47);
48
49#[user_doc(
50 doc_section(label = "Statistical Functions"),
51 description = "Returns the standard deviation of a set of numbers.",
52 syntax_example = "stddev(expression)",
53 sql_example = r#"```sql
54> SELECT stddev(column_name) FROM table_name;
55+----------------------+
56| stddev(column_name) |
57+----------------------+
58| 12.34 |
59+----------------------+
60```"#,
61 standard_argument(name = "expression",)
62)]
63pub struct Stddev {
65 signature: Signature,
66 alias: Vec<String>,
67}
68
69impl Debug for Stddev {
70 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
71 f.debug_struct("Stddev")
72 .field("name", &self.name())
73 .field("signature", &self.signature)
74 .finish()
75 }
76}
77
78impl Default for Stddev {
79 fn default() -> Self {
80 Self::new()
81 }
82}
83
84impl Stddev {
85 pub fn new() -> Self {
87 Self {
88 signature: Signature::numeric(1, Volatility::Immutable),
89 alias: vec!["stddev_samp".to_string()],
90 }
91 }
92}
93
94impl AggregateUDFImpl for Stddev {
95 fn as_any(&self) -> &dyn Any {
97 self
98 }
99
100 fn name(&self) -> &str {
101 "stddev"
102 }
103
104 fn signature(&self) -> &Signature {
105 &self.signature
106 }
107
108 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
109 Ok(DataType::Float64)
110 }
111
112 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
113 Ok(vec![
114 Field::new(
115 format_state_name(args.name, "count"),
116 DataType::UInt64,
117 true,
118 ),
119 Field::new(
120 format_state_name(args.name, "mean"),
121 DataType::Float64,
122 true,
123 ),
124 Field::new(format_state_name(args.name, "m2"), DataType::Float64, true),
125 ]
126 .into_iter()
127 .map(Arc::new)
128 .collect())
129 }
130
131 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
132 if acc_args.is_distinct {
133 return not_impl_err!("STDDEV_POP(DISTINCT) aggregations are not available");
134 }
135 Ok(Box::new(StddevAccumulator::try_new(StatsType::Sample)?))
136 }
137
138 fn aliases(&self) -> &[String] {
139 &self.alias
140 }
141
142 fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool {
143 !acc_args.is_distinct
144 }
145
146 fn create_groups_accumulator(
147 &self,
148 _args: AccumulatorArgs,
149 ) -> Result<Box<dyn GroupsAccumulator>> {
150 Ok(Box::new(StddevGroupsAccumulator::new(StatsType::Sample)))
151 }
152
153 fn documentation(&self) -> Option<&Documentation> {
154 self.doc()
155 }
156}
157
158make_udaf_expr_and_func!(
159 StddevPop,
160 stddev_pop,
161 expression,
162 "Compute the population standard deviation of a set of numbers",
163 stddev_pop_udaf
164);
165
166#[user_doc(
167 doc_section(label = "Statistical Functions"),
168 description = "Returns the population standard deviation of a set of numbers.",
169 syntax_example = "stddev_pop(expression)",
170 sql_example = r#"```sql
171> SELECT stddev_pop(column_name) FROM table_name;
172+--------------------------+
173| stddev_pop(column_name) |
174+--------------------------+
175| 10.56 |
176+--------------------------+
177```"#,
178 standard_argument(name = "expression",)
179)]
180pub struct StddevPop {
182 signature: Signature,
183}
184
185impl Debug for StddevPop {
186 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
187 f.debug_struct("StddevPop")
188 .field("name", &self.name())
189 .field("signature", &self.signature)
190 .finish()
191 }
192}
193
194impl Default for StddevPop {
195 fn default() -> Self {
196 Self::new()
197 }
198}
199
200impl StddevPop {
201 pub fn new() -> Self {
203 Self {
204 signature: Signature::numeric(1, Volatility::Immutable),
205 }
206 }
207}
208
209impl AggregateUDFImpl for StddevPop {
210 fn as_any(&self) -> &dyn Any {
212 self
213 }
214
215 fn name(&self) -> &str {
216 "stddev_pop"
217 }
218
219 fn signature(&self) -> &Signature {
220 &self.signature
221 }
222
223 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
224 Ok(vec![
225 Field::new(
226 format_state_name(args.name, "count"),
227 DataType::UInt64,
228 true,
229 ),
230 Field::new(
231 format_state_name(args.name, "mean"),
232 DataType::Float64,
233 true,
234 ),
235 Field::new(format_state_name(args.name, "m2"), DataType::Float64, true),
236 ]
237 .into_iter()
238 .map(Arc::new)
239 .collect())
240 }
241
242 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
243 if acc_args.is_distinct {
244 return not_impl_err!("STDDEV_POP(DISTINCT) aggregations are not available");
245 }
246 Ok(Box::new(StddevAccumulator::try_new(StatsType::Population)?))
247 }
248
249 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
250 if !arg_types[0].is_numeric() {
251 return plan_err!("StddevPop requires numeric input types");
252 }
253
254 Ok(DataType::Float64)
255 }
256
257 fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool {
258 !acc_args.is_distinct
259 }
260
261 fn create_groups_accumulator(
262 &self,
263 _args: AccumulatorArgs,
264 ) -> Result<Box<dyn GroupsAccumulator>> {
265 Ok(Box::new(StddevGroupsAccumulator::new(
266 StatsType::Population,
267 )))
268 }
269
270 fn documentation(&self) -> Option<&Documentation> {
271 self.doc()
272 }
273}
274
275#[derive(Debug)]
277pub struct StddevAccumulator {
278 variance: VarianceAccumulator,
279}
280
281impl StddevAccumulator {
282 pub fn try_new(s_type: StatsType) -> Result<Self> {
284 Ok(Self {
285 variance: VarianceAccumulator::try_new(s_type)?,
286 })
287 }
288
289 pub fn get_m2(&self) -> f64 {
290 self.variance.get_m2()
291 }
292}
293
294impl Accumulator for StddevAccumulator {
295 fn state(&mut self) -> Result<Vec<ScalarValue>> {
296 Ok(vec![
297 ScalarValue::from(self.variance.get_count()),
298 ScalarValue::from(self.variance.get_mean()),
299 ScalarValue::from(self.variance.get_m2()),
300 ])
301 }
302
303 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
304 self.variance.update_batch(values)
305 }
306
307 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
308 self.variance.retract_batch(values)
309 }
310
311 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
312 self.variance.merge_batch(states)
313 }
314
315 fn evaluate(&mut self) -> Result<ScalarValue> {
316 let variance = self.variance.evaluate()?;
317 match variance {
318 ScalarValue::Float64(e) => {
319 if e.is_none() {
320 Ok(ScalarValue::Float64(None))
321 } else {
322 Ok(ScalarValue::Float64(e.map(|f| f.sqrt())))
323 }
324 }
325 _ => internal_err!("Variance should be f64"),
326 }
327 }
328
329 fn size(&self) -> usize {
330 align_of_val(self) - align_of_val(&self.variance) + self.variance.size()
331 }
332
333 fn supports_retract_batch(&self) -> bool {
334 self.variance.supports_retract_batch()
335 }
336}
337
338#[derive(Debug)]
339pub struct StddevGroupsAccumulator {
340 variance: VarianceGroupsAccumulator,
341}
342
343impl StddevGroupsAccumulator {
344 pub fn new(s_type: StatsType) -> Self {
345 Self {
346 variance: VarianceGroupsAccumulator::new(s_type),
347 }
348 }
349}
350
351impl GroupsAccumulator for StddevGroupsAccumulator {
352 fn update_batch(
353 &mut self,
354 values: &[ArrayRef],
355 group_indices: &[usize],
356 opt_filter: Option<&arrow::array::BooleanArray>,
357 total_num_groups: usize,
358 ) -> Result<()> {
359 self.variance
360 .update_batch(values, group_indices, opt_filter, total_num_groups)
361 }
362
363 fn merge_batch(
364 &mut self,
365 values: &[ArrayRef],
366 group_indices: &[usize],
367 opt_filter: Option<&arrow::array::BooleanArray>,
368 total_num_groups: usize,
369 ) -> Result<()> {
370 self.variance
371 .merge_batch(values, group_indices, opt_filter, total_num_groups)
372 }
373
374 fn evaluate(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<ArrayRef> {
375 let (mut variances, nulls) = self.variance.variance(emit_to);
376 variances.iter_mut().for_each(|v| *v = v.sqrt());
377 Ok(Arc::new(Float64Array::new(variances.into(), Some(nulls))))
378 }
379
380 fn state(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<Vec<ArrayRef>> {
381 self.variance.state(emit_to)
382 }
383
384 fn size(&self) -> usize {
385 self.variance.size()
386 }
387}
388
389#[cfg(test)]
390mod tests {
391 use super::*;
392 use arrow::{array::*, datatypes::*};
393 use datafusion_expr::AggregateUDF;
394 use datafusion_functions_aggregate_common::utils::get_accum_scalar_values_as_arrays;
395 use datafusion_physical_expr::expressions::col;
396 use datafusion_physical_expr_common::sort_expr::LexOrdering;
397 use std::sync::Arc;
398
399 #[test]
400 fn stddev_f64_merge_1() -> Result<()> {
401 let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64]));
402 let b = Arc::new(Float64Array::from(vec![4_f64, 5_f64]));
403
404 let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]);
405
406 let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?;
407 let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![b])?;
408
409 let agg1 = stddev_pop_udaf();
410 let agg2 = stddev_pop_udaf();
411
412 let actual = merge(&batch1, &batch2, agg1, agg2, &schema)?;
413 assert_eq!(actual, ScalarValue::from(std::f64::consts::SQRT_2));
414
415 Ok(())
416 }
417
418 #[test]
419 fn stddev_f64_merge_2() -> Result<()> {
420 let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64]));
421 let b = Arc::new(Float64Array::from(vec![None]));
422
423 let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]);
424
425 let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?;
426 let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![b])?;
427
428 let agg1 = stddev_pop_udaf();
429 let agg2 = stddev_pop_udaf();
430
431 let actual = merge(&batch1, &batch2, agg1, agg2, &schema)?;
432 assert_eq!(actual, ScalarValue::from(std::f64::consts::SQRT_2));
433
434 Ok(())
435 }
436
437 fn merge(
438 batch1: &RecordBatch,
439 batch2: &RecordBatch,
440 agg1: Arc<AggregateUDF>,
441 agg2: Arc<AggregateUDF>,
442 schema: &Schema,
443 ) -> Result<ScalarValue> {
444 let args1 = AccumulatorArgs {
445 return_field: Field::new("f", DataType::Float64, true).into(),
446 schema,
447 ignore_nulls: false,
448 ordering_req: &LexOrdering::default(),
449 name: "a",
450 is_distinct: false,
451 is_reversed: false,
452 exprs: &[col("a", schema)?],
453 };
454
455 let args2 = AccumulatorArgs {
456 return_field: Field::new("f", DataType::Float64, true).into(),
457 schema,
458 ignore_nulls: false,
459 ordering_req: &LexOrdering::default(),
460 name: "a",
461 is_distinct: false,
462 is_reversed: false,
463 exprs: &[col("a", schema)?],
464 };
465
466 let mut accum1 = agg1.accumulator(args1)?;
467 let mut accum2 = agg2.accumulator(args2)?;
468
469 let value1 = vec![col("a", schema)?
470 .evaluate(batch1)
471 .and_then(|v| v.into_array(batch1.num_rows()))?];
472 let value2 = vec![col("a", schema)?
473 .evaluate(batch2)
474 .and_then(|v| v.into_array(batch2.num_rows()))?];
475
476 accum1.update_batch(&value1)?;
477 accum2.update_batch(&value2)?;
478 let state2 = get_accum_scalar_values_as_arrays(accum2.as_mut())?;
479 accum1.merge_batch(&state2)?;
480 let result = accum1.evaluate()?;
481 Ok(result)
482 }
483}