1use std::any::Any;
21use std::fmt::{Debug, Formatter};
22use std::mem::align_of_val;
23use std::sync::Arc;
24
25use arrow::array::Float64Array;
26use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field};
27
28use datafusion_common::{internal_err, not_impl_err, Result};
29use datafusion_common::{plan_err, ScalarValue};
30use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
31use datafusion_expr::utils::format_state_name;
32use datafusion_expr::{
33 Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, Signature,
34 Volatility,
35};
36use datafusion_functions_aggregate_common::stats::StatsType;
37use datafusion_macros::user_doc;
38
39use crate::variance::{VarianceAccumulator, VarianceGroupsAccumulator};
40
41make_udaf_expr_and_func!(
42 Stddev,
43 stddev,
44 expression,
45 "Compute the standard deviation of a set of numbers",
46 stddev_udaf
47);
48
49#[user_doc(
50 doc_section(label = "Statistical Functions"),
51 description = "Returns the standard deviation of a set of numbers.",
52 syntax_example = "stddev(expression)",
53 sql_example = r#"```sql
54> SELECT stddev(column_name) FROM table_name;
55+----------------------+
56| stddev(column_name) |
57+----------------------+
58| 12.34 |
59+----------------------+
60```"#,
61 standard_argument(name = "expression",)
62)]
63pub struct Stddev {
65 signature: Signature,
66 alias: Vec<String>,
67}
68
69impl Debug for Stddev {
70 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
71 f.debug_struct("Stddev")
72 .field("name", &self.name())
73 .field("signature", &self.signature)
74 .finish()
75 }
76}
77
78impl Default for Stddev {
79 fn default() -> Self {
80 Self::new()
81 }
82}
83
84impl Stddev {
85 pub fn new() -> Self {
87 Self {
88 signature: Signature::numeric(1, Volatility::Immutable),
89 alias: vec!["stddev_samp".to_string()],
90 }
91 }
92}
93
94impl AggregateUDFImpl for Stddev {
95 fn as_any(&self) -> &dyn Any {
97 self
98 }
99
100 fn name(&self) -> &str {
101 "stddev"
102 }
103
104 fn signature(&self) -> &Signature {
105 &self.signature
106 }
107
108 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
109 Ok(DataType::Float64)
110 }
111
112 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
113 Ok(vec![
114 Field::new(
115 format_state_name(args.name, "count"),
116 DataType::UInt64,
117 true,
118 ),
119 Field::new(
120 format_state_name(args.name, "mean"),
121 DataType::Float64,
122 true,
123 ),
124 Field::new(format_state_name(args.name, "m2"), DataType::Float64, true),
125 ])
126 }
127
128 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
129 if acc_args.is_distinct {
130 return not_impl_err!("STDDEV_POP(DISTINCT) aggregations are not available");
131 }
132 Ok(Box::new(StddevAccumulator::try_new(StatsType::Sample)?))
133 }
134
135 fn aliases(&self) -> &[String] {
136 &self.alias
137 }
138
139 fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool {
140 !acc_args.is_distinct
141 }
142
143 fn create_groups_accumulator(
144 &self,
145 _args: AccumulatorArgs,
146 ) -> Result<Box<dyn GroupsAccumulator>> {
147 Ok(Box::new(StddevGroupsAccumulator::new(StatsType::Sample)))
148 }
149
150 fn documentation(&self) -> Option<&Documentation> {
151 self.doc()
152 }
153}
154
155make_udaf_expr_and_func!(
156 StddevPop,
157 stddev_pop,
158 expression,
159 "Compute the population standard deviation of a set of numbers",
160 stddev_pop_udaf
161);
162
163#[user_doc(
164 doc_section(label = "Statistical Functions"),
165 description = "Returns the population standard deviation of a set of numbers.",
166 syntax_example = "stddev_pop(expression)",
167 sql_example = r#"```sql
168> SELECT stddev_pop(column_name) FROM table_name;
169+--------------------------+
170| stddev_pop(column_name) |
171+--------------------------+
172| 10.56 |
173+--------------------------+
174```"#,
175 standard_argument(name = "expression",)
176)]
177pub struct StddevPop {
179 signature: Signature,
180}
181
182impl Debug for StddevPop {
183 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
184 f.debug_struct("StddevPop")
185 .field("name", &self.name())
186 .field("signature", &self.signature)
187 .finish()
188 }
189}
190
191impl Default for StddevPop {
192 fn default() -> Self {
193 Self::new()
194 }
195}
196
197impl StddevPop {
198 pub fn new() -> Self {
200 Self {
201 signature: Signature::numeric(1, Volatility::Immutable),
202 }
203 }
204}
205
206impl AggregateUDFImpl for StddevPop {
207 fn as_any(&self) -> &dyn Any {
209 self
210 }
211
212 fn name(&self) -> &str {
213 "stddev_pop"
214 }
215
216 fn signature(&self) -> &Signature {
217 &self.signature
218 }
219
220 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
221 Ok(vec![
222 Field::new(
223 format_state_name(args.name, "count"),
224 DataType::UInt64,
225 true,
226 ),
227 Field::new(
228 format_state_name(args.name, "mean"),
229 DataType::Float64,
230 true,
231 ),
232 Field::new(format_state_name(args.name, "m2"), DataType::Float64, true),
233 ])
234 }
235
236 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
237 if acc_args.is_distinct {
238 return not_impl_err!("STDDEV_POP(DISTINCT) aggregations are not available");
239 }
240 Ok(Box::new(StddevAccumulator::try_new(StatsType::Population)?))
241 }
242
243 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
244 if !arg_types[0].is_numeric() {
245 return plan_err!("StddevPop requires numeric input types");
246 }
247
248 Ok(DataType::Float64)
249 }
250
251 fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool {
252 !acc_args.is_distinct
253 }
254
255 fn create_groups_accumulator(
256 &self,
257 _args: AccumulatorArgs,
258 ) -> Result<Box<dyn GroupsAccumulator>> {
259 Ok(Box::new(StddevGroupsAccumulator::new(
260 StatsType::Population,
261 )))
262 }
263
264 fn documentation(&self) -> Option<&Documentation> {
265 self.doc()
266 }
267}
268
269#[derive(Debug)]
271pub struct StddevAccumulator {
272 variance: VarianceAccumulator,
273}
274
275impl StddevAccumulator {
276 pub fn try_new(s_type: StatsType) -> Result<Self> {
278 Ok(Self {
279 variance: VarianceAccumulator::try_new(s_type)?,
280 })
281 }
282
283 pub fn get_m2(&self) -> f64 {
284 self.variance.get_m2()
285 }
286}
287
288impl Accumulator for StddevAccumulator {
289 fn state(&mut self) -> Result<Vec<ScalarValue>> {
290 Ok(vec![
291 ScalarValue::from(self.variance.get_count()),
292 ScalarValue::from(self.variance.get_mean()),
293 ScalarValue::from(self.variance.get_m2()),
294 ])
295 }
296
297 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
298 self.variance.update_batch(values)
299 }
300
301 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
302 self.variance.retract_batch(values)
303 }
304
305 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
306 self.variance.merge_batch(states)
307 }
308
309 fn evaluate(&mut self) -> Result<ScalarValue> {
310 let variance = self.variance.evaluate()?;
311 match variance {
312 ScalarValue::Float64(e) => {
313 if e.is_none() {
314 Ok(ScalarValue::Float64(None))
315 } else {
316 Ok(ScalarValue::Float64(e.map(|f| f.sqrt())))
317 }
318 }
319 _ => internal_err!("Variance should be f64"),
320 }
321 }
322
323 fn size(&self) -> usize {
324 align_of_val(self) - align_of_val(&self.variance) + self.variance.size()
325 }
326
327 fn supports_retract_batch(&self) -> bool {
328 self.variance.supports_retract_batch()
329 }
330}
331
332#[derive(Debug)]
333pub struct StddevGroupsAccumulator {
334 variance: VarianceGroupsAccumulator,
335}
336
337impl StddevGroupsAccumulator {
338 pub fn new(s_type: StatsType) -> Self {
339 Self {
340 variance: VarianceGroupsAccumulator::new(s_type),
341 }
342 }
343}
344
345impl GroupsAccumulator for StddevGroupsAccumulator {
346 fn update_batch(
347 &mut self,
348 values: &[ArrayRef],
349 group_indices: &[usize],
350 opt_filter: Option<&arrow::array::BooleanArray>,
351 total_num_groups: usize,
352 ) -> Result<()> {
353 self.variance
354 .update_batch(values, group_indices, opt_filter, total_num_groups)
355 }
356
357 fn merge_batch(
358 &mut self,
359 values: &[ArrayRef],
360 group_indices: &[usize],
361 opt_filter: Option<&arrow::array::BooleanArray>,
362 total_num_groups: usize,
363 ) -> Result<()> {
364 self.variance
365 .merge_batch(values, group_indices, opt_filter, total_num_groups)
366 }
367
368 fn evaluate(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<ArrayRef> {
369 let (mut variances, nulls) = self.variance.variance(emit_to);
370 variances.iter_mut().for_each(|v| *v = v.sqrt());
371 Ok(Arc::new(Float64Array::new(variances.into(), Some(nulls))))
372 }
373
374 fn state(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<Vec<ArrayRef>> {
375 self.variance.state(emit_to)
376 }
377
378 fn size(&self) -> usize {
379 self.variance.size()
380 }
381}
382
383#[cfg(test)]
384mod tests {
385 use super::*;
386 use arrow::{array::*, datatypes::*};
387 use datafusion_expr::AggregateUDF;
388 use datafusion_functions_aggregate_common::utils::get_accum_scalar_values_as_arrays;
389 use datafusion_physical_expr::expressions::col;
390 use datafusion_physical_expr_common::sort_expr::LexOrdering;
391 use std::sync::Arc;
392
393 #[test]
394 fn stddev_f64_merge_1() -> Result<()> {
395 let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64]));
396 let b = Arc::new(Float64Array::from(vec![4_f64, 5_f64]));
397
398 let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]);
399
400 let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?;
401 let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![b])?;
402
403 let agg1 = stddev_pop_udaf();
404 let agg2 = stddev_pop_udaf();
405
406 let actual = merge(&batch1, &batch2, agg1, agg2, &schema)?;
407 assert_eq!(actual, ScalarValue::from(std::f64::consts::SQRT_2));
408
409 Ok(())
410 }
411
412 #[test]
413 fn stddev_f64_merge_2() -> Result<()> {
414 let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64]));
415 let b = Arc::new(Float64Array::from(vec![None]));
416
417 let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]);
418
419 let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?;
420 let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![b])?;
421
422 let agg1 = stddev_pop_udaf();
423 let agg2 = stddev_pop_udaf();
424
425 let actual = merge(&batch1, &batch2, agg1, agg2, &schema)?;
426 assert_eq!(actual, ScalarValue::from(std::f64::consts::SQRT_2));
427
428 Ok(())
429 }
430
431 fn merge(
432 batch1: &RecordBatch,
433 batch2: &RecordBatch,
434 agg1: Arc<AggregateUDF>,
435 agg2: Arc<AggregateUDF>,
436 schema: &Schema,
437 ) -> Result<ScalarValue> {
438 let args1 = AccumulatorArgs {
439 return_type: &DataType::Float64,
440 schema,
441 ignore_nulls: false,
442 ordering_req: &LexOrdering::default(),
443 name: "a",
444 is_distinct: false,
445 is_reversed: false,
446 exprs: &[col("a", schema)?],
447 };
448
449 let args2 = AccumulatorArgs {
450 return_type: &DataType::Float64,
451 schema,
452 ignore_nulls: false,
453 ordering_req: &LexOrdering::default(),
454 name: "a",
455 is_distinct: false,
456 is_reversed: false,
457 exprs: &[col("a", schema)?],
458 };
459
460 let mut accum1 = agg1.accumulator(args1)?;
461 let mut accum2 = agg2.accumulator(args2)?;
462
463 let value1 = vec![col("a", schema)?
464 .evaluate(batch1)
465 .and_then(|v| v.into_array(batch1.num_rows()))?];
466 let value2 = vec![col("a", schema)?
467 .evaluate(batch2)
468 .and_then(|v| v.into_array(batch2.num_rows()))?];
469
470 accum1.update_batch(&value1)?;
471 accum2.update_batch(&value2)?;
472 let state2 = get_accum_scalar_values_as_arrays(accum2.as_mut())?;
473 accum1.merge_batch(&state2)?;
474 let result = accum1.evaluate()?;
475 Ok(result)
476 }
477}