datafusion_comet_spark_expr/nondetermenistic_funcs/
randn.rs1use crate::nondetermenistic_funcs::rand::XorShiftRandom;
19
20use crate::internal::{evaluate_batch_for_rand, StatefulSeedValueGenerator};
21use arrow::array::RecordBatch;
22use arrow::datatypes::{DataType, Schema};
23use datafusion::logical_expr::ColumnarValue;
24use datafusion::physical_expr::PhysicalExpr;
25use std::any::Any;
26use std::fmt::{Display, Formatter};
27use std::hash::{Hash, Hasher};
28use std::sync::{Arc, Mutex};
29
30#[derive(Debug, Clone)]
44struct XorShiftRandomForGaussian {
45 base_generator: XorShiftRandom,
46 next_gaussian: Option<f64>,
47}
48
49impl XorShiftRandomForGaussian {
50 pub fn next_gaussian(&mut self) -> f64 {
51 if let Some(stored_value) = self.next_gaussian {
52 self.next_gaussian = None;
53 return stored_value;
54 }
55 let mut v1: f64;
56 let mut v2: f64;
57 let mut s: f64;
58 loop {
59 v1 = 2f64 * self.base_generator.next_f64() - 1f64;
60 v2 = 2f64 * self.base_generator.next_f64() - 1f64;
61 s = v1 * v1 + v2 * v2;
62 if s < 1f64 && s != 0f64 {
63 break;
64 }
65 }
66 let multiplier = (-2f64 * s.ln() / s).sqrt();
67 self.next_gaussian = Some(v2 * multiplier);
68 v1 * multiplier
69 }
70}
71
72type RandomGaussianState = (i64, Option<f64>);
73
74impl StatefulSeedValueGenerator<RandomGaussianState, f64> for XorShiftRandomForGaussian {
75 fn from_init_seed(init_value: i64) -> Self {
76 XorShiftRandomForGaussian {
77 base_generator: XorShiftRandom::from_init_seed(init_value),
78 next_gaussian: None,
79 }
80 }
81
82 fn from_stored_state(stored_state: RandomGaussianState) -> Self {
83 XorShiftRandomForGaussian {
84 base_generator: XorShiftRandom::from_stored_state(stored_state.0),
85 next_gaussian: stored_state.1,
86 }
87 }
88
89 fn next_value(&mut self) -> f64 {
90 self.next_gaussian()
91 }
92
93 fn get_current_state(&self) -> RandomGaussianState {
94 (self.base_generator.seed, self.next_gaussian)
95 }
96}
97
98#[derive(Debug, Clone)]
99pub struct RandnExpr {
100 seed: i64,
101 state_holder: Arc<Mutex<Option<RandomGaussianState>>>,
102}
103
104impl RandnExpr {
105 pub fn new(seed: i64) -> Self {
106 Self {
107 seed,
108 state_holder: Arc::new(Mutex::new(None)),
109 }
110 }
111}
112
113impl Display for RandnExpr {
114 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
115 write!(f, "RANDN({})", self.seed)
116 }
117}
118
119impl PartialEq for RandnExpr {
120 fn eq(&self, other: &Self) -> bool {
121 self.seed.eq(&other.seed)
122 }
123}
124
125impl Eq for RandnExpr {}
126
127impl Hash for RandnExpr {
128 fn hash<H: Hasher>(&self, state: &mut H) {
129 self.children().hash(state);
130 }
131}
132
133impl PhysicalExpr for RandnExpr {
134 fn as_any(&self) -> &dyn Any {
135 self
136 }
137
138 fn data_type(&self, _input_schema: &Schema) -> datafusion::common::Result<DataType> {
139 Ok(DataType::Float64)
140 }
141
142 fn nullable(&self, _input_schema: &Schema) -> datafusion::common::Result<bool> {
143 Ok(false)
144 }
145
146 fn evaluate(&self, batch: &RecordBatch) -> datafusion::common::Result<ColumnarValue> {
147 evaluate_batch_for_rand::<XorShiftRandomForGaussian, RandomGaussianState>(
148 &self.state_holder,
149 self.seed,
150 batch.num_rows(),
151 )
152 }
153
154 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
155 vec![]
156 }
157
158 fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result {
159 unimplemented!()
160 }
161
162 fn with_new_children(
163 self: Arc<Self>,
164 _children: Vec<Arc<dyn PhysicalExpr>>,
165 ) -> datafusion::common::Result<Arc<dyn PhysicalExpr>> {
166 Ok(Arc::new(RandnExpr::new(self.seed)))
167 }
168}
169
170pub fn randn(seed: i64) -> Arc<dyn PhysicalExpr> {
171 Arc::new(RandnExpr::new(seed))
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177 use arrow::array::{Array, Float64Array, Int64Array};
178 use arrow::{array::StringArray, compute::concat, datatypes::*};
179 use datafusion::common::cast::as_float64_array;
180
181 const PRECISION_TOLERANCE: f64 = 1e-6;
182
183 const SPARK_SEED_42_FIRST_5_GAUSSIAN: [f64; 5] = [
184 2.384479054241165,
185 0.1920934041293524,
186 0.7337336533286575,
187 -0.5224480195716871,
188 2.060084179317831,
189 ];
190
191 #[test]
192 fn test_rand_single_batch() -> datafusion::common::Result<()> {
193 let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
194 let data = StringArray::from(vec![Some("foo"), None, None, Some("bar"), None]);
195 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(data)])?;
196 let randn_expr = randn(42);
197 let result = randn_expr.evaluate(&batch)?.into_array(batch.num_rows())?;
198 let result = as_float64_array(&result)?;
199 let expected = &Float64Array::from(Vec::from(SPARK_SEED_42_FIRST_5_GAUSSIAN));
200 assert_eq_with_tolerance(result, expected);
201 Ok(())
202 }
203
204 #[test]
205 fn test_rand_multi_batch() -> datafusion::common::Result<()> {
206 let first_batch_schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]);
207 let first_batch_data = Int64Array::from(vec![Some(24), None, None]);
208 let second_batch_schema = first_batch_schema.clone();
209 let second_batch_data = Int64Array::from(vec![None, Some(22)]);
210 let randn_expr = randn(42);
211 let first_batch = RecordBatch::try_new(
212 Arc::new(first_batch_schema),
213 vec![Arc::new(first_batch_data)],
214 )?;
215 let first_batch_result = randn_expr
216 .evaluate(&first_batch)?
217 .into_array(first_batch.num_rows())?;
218 let second_batch = RecordBatch::try_new(
219 Arc::new(second_batch_schema),
220 vec![Arc::new(second_batch_data)],
221 )?;
222 let second_batch_result = randn_expr
223 .evaluate(&second_batch)?
224 .into_array(second_batch.num_rows())?;
225 let result_arrays: Vec<&dyn Array> = vec![
226 as_float64_array(&first_batch_result)?,
227 as_float64_array(&second_batch_result)?,
228 ];
229 let result_arrays = &concat(&result_arrays)?;
230 let final_result = as_float64_array(result_arrays)?;
231 let expected = &Float64Array::from(Vec::from(SPARK_SEED_42_FIRST_5_GAUSSIAN));
232 assert_eq_with_tolerance(final_result, expected);
233 Ok(())
234 }
235
236 fn assert_eq_with_tolerance(left: &Float64Array, right: &Float64Array) {
237 assert_eq!(left.len(), right.len());
238 left.iter().zip(right.iter()).for_each(|(l, r)| {
239 assert!(
240 (l.unwrap() - r.unwrap()).abs() < PRECISION_TOLERANCE,
241 "difference between {:?} and {:?} is larger than acceptable precision",
242 l.unwrap(),
243 r.unwrap()
244 )
245 })
246 }
247}