1use std::any::Any;
21use std::fmt::{Debug, Formatter};
22use std::hash::Hash;
23use std::mem::align_of_val;
24use std::sync::Arc;
25
26use arrow::array::Float64Array;
27use arrow::datatypes::FieldRef;
28use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field};
29use datafusion_common::{internal_err, not_impl_err, Result};
30use datafusion_common::{plan_err, ScalarValue};
31use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
32use datafusion_expr::utils::format_state_name;
33use datafusion_expr::{
34 Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, Signature,
35 Volatility,
36};
37use datafusion_functions_aggregate_common::stats::StatsType;
38use datafusion_macros::user_doc;
39
40use crate::variance::{VarianceAccumulator, VarianceGroupsAccumulator};
41
42make_udaf_expr_and_func!(
43 Stddev,
44 stddev,
45 expression,
46 "Compute the standard deviation of a set of numbers",
47 stddev_udaf
48);
49
50#[user_doc(
51 doc_section(label = "Statistical Functions"),
52 description = "Returns the standard deviation of a set of numbers.",
53 syntax_example = "stddev(expression)",
54 sql_example = r#"```sql
55> SELECT stddev(column_name) FROM table_name;
56+----------------------+
57| stddev(column_name) |
58+----------------------+
59| 12.34 |
60+----------------------+
61```"#,
62 standard_argument(name = "expression",)
63)]
64#[derive(PartialEq, Eq, Hash)]
66pub struct Stddev {
67 signature: Signature,
68 alias: Vec<String>,
69}
70
71impl Debug for Stddev {
72 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
73 f.debug_struct("Stddev")
74 .field("name", &self.name())
75 .field("signature", &self.signature)
76 .finish()
77 }
78}
79
80impl Default for Stddev {
81 fn default() -> Self {
82 Self::new()
83 }
84}
85
86impl Stddev {
87 pub fn new() -> Self {
89 Self {
90 signature: Signature::numeric(1, Volatility::Immutable),
91 alias: vec!["stddev_samp".to_string()],
92 }
93 }
94}
95
96impl AggregateUDFImpl for Stddev {
97 fn as_any(&self) -> &dyn Any {
99 self
100 }
101
102 fn name(&self) -> &str {
103 "stddev"
104 }
105
106 fn signature(&self) -> &Signature {
107 &self.signature
108 }
109
110 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
111 Ok(DataType::Float64)
112 }
113
114 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
115 Ok(vec![
116 Field::new(
117 format_state_name(args.name, "count"),
118 DataType::UInt64,
119 true,
120 ),
121 Field::new(
122 format_state_name(args.name, "mean"),
123 DataType::Float64,
124 true,
125 ),
126 Field::new(format_state_name(args.name, "m2"), DataType::Float64, true),
127 ]
128 .into_iter()
129 .map(Arc::new)
130 .collect())
131 }
132
133 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
134 if acc_args.is_distinct {
135 return not_impl_err!("STDDEV_POP(DISTINCT) aggregations are not available");
136 }
137 Ok(Box::new(StddevAccumulator::try_new(StatsType::Sample)?))
138 }
139
140 fn aliases(&self) -> &[String] {
141 &self.alias
142 }
143
144 fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool {
145 !acc_args.is_distinct
146 }
147
148 fn create_groups_accumulator(
149 &self,
150 _args: AccumulatorArgs,
151 ) -> Result<Box<dyn GroupsAccumulator>> {
152 Ok(Box::new(StddevGroupsAccumulator::new(StatsType::Sample)))
153 }
154
155 fn documentation(&self) -> Option<&Documentation> {
156 self.doc()
157 }
158}
159
160make_udaf_expr_and_func!(
161 StddevPop,
162 stddev_pop,
163 expression,
164 "Compute the population standard deviation of a set of numbers",
165 stddev_pop_udaf
166);
167
168#[user_doc(
169 doc_section(label = "Statistical Functions"),
170 description = "Returns the population standard deviation of a set of numbers.",
171 syntax_example = "stddev_pop(expression)",
172 sql_example = r#"```sql
173> SELECT stddev_pop(column_name) FROM table_name;
174+--------------------------+
175| stddev_pop(column_name) |
176+--------------------------+
177| 10.56 |
178+--------------------------+
179```"#,
180 standard_argument(name = "expression",)
181)]
182#[derive(PartialEq, Eq, Hash)]
184pub struct StddevPop {
185 signature: Signature,
186}
187
188impl Debug for StddevPop {
189 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
190 f.debug_struct("StddevPop")
191 .field("name", &self.name())
192 .field("signature", &self.signature)
193 .finish()
194 }
195}
196
197impl Default for StddevPop {
198 fn default() -> Self {
199 Self::new()
200 }
201}
202
203impl StddevPop {
204 pub fn new() -> Self {
206 Self {
207 signature: Signature::numeric(1, Volatility::Immutable),
208 }
209 }
210}
211
212impl AggregateUDFImpl for StddevPop {
213 fn as_any(&self) -> &dyn Any {
215 self
216 }
217
218 fn name(&self) -> &str {
219 "stddev_pop"
220 }
221
222 fn signature(&self) -> &Signature {
223 &self.signature
224 }
225
226 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
227 Ok(vec![
228 Field::new(
229 format_state_name(args.name, "count"),
230 DataType::UInt64,
231 true,
232 ),
233 Field::new(
234 format_state_name(args.name, "mean"),
235 DataType::Float64,
236 true,
237 ),
238 Field::new(format_state_name(args.name, "m2"), DataType::Float64, true),
239 ]
240 .into_iter()
241 .map(Arc::new)
242 .collect())
243 }
244
245 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
246 if acc_args.is_distinct {
247 return not_impl_err!("STDDEV_POP(DISTINCT) aggregations are not available");
248 }
249 Ok(Box::new(StddevAccumulator::try_new(StatsType::Population)?))
250 }
251
252 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
253 if !arg_types[0].is_numeric() {
254 return plan_err!("StddevPop requires numeric input types");
255 }
256
257 Ok(DataType::Float64)
258 }
259
260 fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool {
261 !acc_args.is_distinct
262 }
263
264 fn create_groups_accumulator(
265 &self,
266 _args: AccumulatorArgs,
267 ) -> Result<Box<dyn GroupsAccumulator>> {
268 Ok(Box::new(StddevGroupsAccumulator::new(
269 StatsType::Population,
270 )))
271 }
272
273 fn documentation(&self) -> Option<&Documentation> {
274 self.doc()
275 }
276}
277
278#[derive(Debug)]
280pub struct StddevAccumulator {
281 variance: VarianceAccumulator,
282}
283
284impl StddevAccumulator {
285 pub fn try_new(s_type: StatsType) -> Result<Self> {
287 Ok(Self {
288 variance: VarianceAccumulator::try_new(s_type)?,
289 })
290 }
291
292 pub fn get_m2(&self) -> f64 {
293 self.variance.get_m2()
294 }
295}
296
297impl Accumulator for StddevAccumulator {
298 fn state(&mut self) -> Result<Vec<ScalarValue>> {
299 Ok(vec![
300 ScalarValue::from(self.variance.get_count()),
301 ScalarValue::from(self.variance.get_mean()),
302 ScalarValue::from(self.variance.get_m2()),
303 ])
304 }
305
306 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
307 self.variance.update_batch(values)
308 }
309
310 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
311 self.variance.retract_batch(values)
312 }
313
314 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
315 self.variance.merge_batch(states)
316 }
317
318 fn evaluate(&mut self) -> Result<ScalarValue> {
319 let variance = self.variance.evaluate()?;
320 match variance {
321 ScalarValue::Float64(e) => {
322 if e.is_none() {
323 Ok(ScalarValue::Float64(None))
324 } else {
325 Ok(ScalarValue::Float64(e.map(|f| f.sqrt())))
326 }
327 }
328 _ => internal_err!("Variance should be f64"),
329 }
330 }
331
332 fn size(&self) -> usize {
333 align_of_val(self) - align_of_val(&self.variance) + self.variance.size()
334 }
335
336 fn supports_retract_batch(&self) -> bool {
337 self.variance.supports_retract_batch()
338 }
339}
340
341#[derive(Debug)]
342pub struct StddevGroupsAccumulator {
343 variance: VarianceGroupsAccumulator,
344}
345
346impl StddevGroupsAccumulator {
347 pub fn new(s_type: StatsType) -> Self {
348 Self {
349 variance: VarianceGroupsAccumulator::new(s_type),
350 }
351 }
352}
353
354impl GroupsAccumulator for StddevGroupsAccumulator {
355 fn update_batch(
356 &mut self,
357 values: &[ArrayRef],
358 group_indices: &[usize],
359 opt_filter: Option<&arrow::array::BooleanArray>,
360 total_num_groups: usize,
361 ) -> Result<()> {
362 self.variance
363 .update_batch(values, group_indices, opt_filter, total_num_groups)
364 }
365
366 fn merge_batch(
367 &mut self,
368 values: &[ArrayRef],
369 group_indices: &[usize],
370 opt_filter: Option<&arrow::array::BooleanArray>,
371 total_num_groups: usize,
372 ) -> Result<()> {
373 self.variance
374 .merge_batch(values, group_indices, opt_filter, total_num_groups)
375 }
376
377 fn evaluate(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<ArrayRef> {
378 let (mut variances, nulls) = self.variance.variance(emit_to);
379 variances.iter_mut().for_each(|v| *v = v.sqrt());
380 Ok(Arc::new(Float64Array::new(variances.into(), Some(nulls))))
381 }
382
383 fn state(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<Vec<ArrayRef>> {
384 self.variance.state(emit_to)
385 }
386
387 fn size(&self) -> usize {
388 self.variance.size()
389 }
390}
391
392#[cfg(test)]
393mod tests {
394 use super::*;
395 use arrow::{array::*, datatypes::*};
396 use datafusion_expr::AggregateUDF;
397 use datafusion_functions_aggregate_common::utils::get_accum_scalar_values_as_arrays;
398 use datafusion_physical_expr::expressions::col;
399 use std::sync::Arc;
400
401 #[test]
402 fn stddev_f64_merge_1() -> Result<()> {
403 let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64]));
404 let b = Arc::new(Float64Array::from(vec![4_f64, 5_f64]));
405
406 let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]);
407
408 let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?;
409 let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![b])?;
410
411 let agg1 = stddev_pop_udaf();
412 let agg2 = stddev_pop_udaf();
413
414 let actual = merge(&batch1, &batch2, agg1, agg2, &schema)?;
415 assert_eq!(actual, ScalarValue::from(std::f64::consts::SQRT_2));
416
417 Ok(())
418 }
419
420 #[test]
421 fn stddev_f64_merge_2() -> Result<()> {
422 let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64]));
423 let b = Arc::new(Float64Array::from(vec![None]));
424
425 let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]);
426
427 let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?;
428 let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![b])?;
429
430 let agg1 = stddev_pop_udaf();
431 let agg2 = stddev_pop_udaf();
432
433 let actual = merge(&batch1, &batch2, agg1, agg2, &schema)?;
434 assert_eq!(actual, ScalarValue::from(std::f64::consts::SQRT_2));
435
436 Ok(())
437 }
438
439 fn merge(
440 batch1: &RecordBatch,
441 batch2: &RecordBatch,
442 agg1: Arc<AggregateUDF>,
443 agg2: Arc<AggregateUDF>,
444 schema: &Schema,
445 ) -> Result<ScalarValue> {
446 let args1 = AccumulatorArgs {
447 return_field: Field::new("f", DataType::Float64, true).into(),
448 schema,
449 ignore_nulls: false,
450 order_bys: &[],
451 name: "a",
452 is_distinct: false,
453 is_reversed: false,
454 exprs: &[col("a", schema)?],
455 };
456
457 let args2 = AccumulatorArgs {
458 return_field: Field::new("f", DataType::Float64, true).into(),
459 schema,
460 ignore_nulls: false,
461 order_bys: &[],
462 name: "a",
463 is_distinct: false,
464 is_reversed: false,
465 exprs: &[col("a", schema)?],
466 };
467
468 let mut accum1 = agg1.accumulator(args1)?;
469 let mut accum2 = agg2.accumulator(args2)?;
470
471 let value1 = vec![col("a", schema)?
472 .evaluate(batch1)
473 .and_then(|v| v.into_array(batch1.num_rows()))?];
474 let value2 = vec![col("a", schema)?
475 .evaluate(batch2)
476 .and_then(|v| v.into_array(batch2.num_rows()))?];
477
478 accum1.update_batch(&value1)?;
479 accum2.update_batch(&value2)?;
480 let state2 = get_accum_scalar_values_as_arrays(accum2.as_mut())?;
481 accum1.merge_batch(&state2)?;
482 let result = accum1.evaluate()?;
483 Ok(result)
484 }
485}