datafusion_comet_spark_expr/agg_funcs/
stddev.rs1use std::{any::Any, sync::Arc};
19
20use crate::agg_funcs::variance::VarianceAccumulator;
21use arrow::{
22 array::ArrayRef,
23 datatypes::{DataType, Field},
24};
25use datafusion::logical_expr::Accumulator;
26use datafusion_common::types::NativeType;
27use datafusion_common::{internal_err, Result, ScalarValue};
28use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
29use datafusion_expr::{AggregateUDFImpl, Signature, Volatility};
30use datafusion_expr_common::signature::Coercion;
31use datafusion_physical_expr::expressions::format_state_name;
32use datafusion_physical_expr::expressions::StatsType;
33
34#[derive(Debug)]
40pub struct Stddev {
41 name: String,
42 signature: Signature,
43 stats_type: StatsType,
44 null_on_divide_by_zero: bool,
45}
46
47impl Stddev {
48 pub fn new(
50 name: impl Into<String>,
51 data_type: DataType,
52 stats_type: StatsType,
53 null_on_divide_by_zero: bool,
54 ) -> Self {
55 assert!(matches!(data_type, DataType::Float64));
57 Self {
58 name: name.into(),
59 signature: Signature::coercible(
60 vec![Coercion::new_exact(
61 datafusion_expr_common::signature::TypeSignatureClass::Native(Arc::new(
62 NativeType::Float64,
63 )),
64 )],
65 Volatility::Immutable,
66 ),
67 stats_type,
68 null_on_divide_by_zero,
69 }
70 }
71}
72
73impl AggregateUDFImpl for Stddev {
74 fn as_any(&self) -> &dyn Any {
76 self
77 }
78
79 fn name(&self) -> &str {
80 &self.name
81 }
82
83 fn signature(&self) -> &Signature {
84 &self.signature
85 }
86
87 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
88 Ok(DataType::Float64)
89 }
90
91 fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
92 Ok(Box::new(StddevAccumulator::try_new(
93 self.stats_type,
94 self.null_on_divide_by_zero,
95 )?))
96 }
97
98 fn create_sliding_accumulator(
99 &self,
100 _acc_args: AccumulatorArgs,
101 ) -> Result<Box<dyn Accumulator>> {
102 Ok(Box::new(StddevAccumulator::try_new(
103 self.stats_type,
104 self.null_on_divide_by_zero,
105 )?))
106 }
107
108 fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
109 Ok(vec![
110 Field::new(
111 format_state_name(&self.name, "count"),
112 DataType::Float64,
113 true,
114 ),
115 Field::new(
116 format_state_name(&self.name, "mean"),
117 DataType::Float64,
118 true,
119 ),
120 Field::new(format_state_name(&self.name, "m2"), DataType::Float64, true),
121 ])
122 }
123
124 fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
125 Ok(ScalarValue::Float64(None))
126 }
127}
128
129#[derive(Debug)]
131pub struct StddevAccumulator {
132 variance: VarianceAccumulator,
133}
134
135impl StddevAccumulator {
136 pub fn try_new(s_type: StatsType, null_on_divide_by_zero: bool) -> Result<Self> {
138 Ok(Self {
139 variance: VarianceAccumulator::try_new(s_type, null_on_divide_by_zero)?,
140 })
141 }
142
143 pub fn get_m2(&self) -> f64 {
144 self.variance.get_m2()
145 }
146}
147
148impl Accumulator for StddevAccumulator {
149 fn state(&mut self) -> Result<Vec<ScalarValue>> {
150 Ok(vec![
151 ScalarValue::from(self.variance.get_count()),
152 ScalarValue::from(self.variance.get_mean()),
153 ScalarValue::from(self.variance.get_m2()),
154 ])
155 }
156
157 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
158 self.variance.update_batch(values)
159 }
160
161 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
162 self.variance.retract_batch(values)
163 }
164
165 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
166 self.variance.merge_batch(states)
167 }
168
169 fn evaluate(&mut self) -> Result<ScalarValue> {
170 let variance = self.variance.evaluate()?;
171 match variance {
172 ScalarValue::Float64(Some(e)) => Ok(ScalarValue::Float64(Some(e.sqrt()))),
173 ScalarValue::Float64(None) => Ok(ScalarValue::Float64(None)),
174 _ => internal_err!("Variance should be f64"),
175 }
176 }
177
178 fn size(&self) -> usize {
179 std::mem::align_of_val(self) - std::mem::align_of_val(&self.variance) + self.variance.size()
180 }
181}