datafusion_comet_spark_expr/nondetermenistic_funcs/
randn.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::nondetermenistic_funcs::rand::XorShiftRandom;
19
20use crate::internal::{evaluate_batch_for_rand, StatefulSeedValueGenerator};
21use arrow::array::RecordBatch;
22use arrow::datatypes::{DataType, Schema};
23use datafusion::logical_expr::ColumnarValue;
24use datafusion::physical_expr::PhysicalExpr;
25use std::any::Any;
26use std::fmt::{Display, Formatter};
27use std::hash::{Hash, Hasher};
28use std::sync::{Arc, Mutex};
29
30/// Stateful extension of the Marsaglia polar method (https://en.wikipedia.org/wiki/Marsaglia_polar_method)
31/// to convert uniform distribution to the standard normal one used by Apache Spark.
32/// For correct processing of batches having odd number of elements, we need to keep not used yet generated value as a part of the state.
33/// Note about Comet <-> Spark equivalence:
34/// Under the hood, the spark algorithm refers to java.util.Random relying on a module StrictMath. The latter uses
35/// native implementations of floating-point operations (ln, exp, sin, cos) and ensures
36/// they are stable across different platforms.
37/// See: https://github.com/openjdk/jdk/blob/07c9f7138affdf0d42ecdc30adcb854515569985/src/java.base/share/classes/java/util/Random.java#L745
38/// Yet, for the Rust standard library this stability is not guaranteed (https://doc.rust-lang.org/std/primitive.f64.html#method.ln)
39/// Moreover, potential usage of external library like rug (https://docs.rs/rug/latest/rug/) doesn't help because still there is no
40/// guarantee it matches the StrictMath jvm implementation.
41/// So, we can ensure only equivalence with some error tolerance between rust and spark(jvm).
42
43#[derive(Debug, Clone)]
44struct XorShiftRandomForGaussian {
45    base_generator: XorShiftRandom,
46    next_gaussian: Option<f64>,
47}
48
49impl XorShiftRandomForGaussian {
50    pub fn next_gaussian(&mut self) -> f64 {
51        if let Some(stored_value) = self.next_gaussian {
52            self.next_gaussian = None;
53            return stored_value;
54        }
55        let mut v1: f64;
56        let mut v2: f64;
57        let mut s: f64;
58        loop {
59            v1 = 2f64 * self.base_generator.next_f64() - 1f64;
60            v2 = 2f64 * self.base_generator.next_f64() - 1f64;
61            s = v1 * v1 + v2 * v2;
62            if s < 1f64 && s != 0f64 {
63                break;
64            }
65        }
66        let multiplier = (-2f64 * s.ln() / s).sqrt();
67        self.next_gaussian = Some(v2 * multiplier);
68        v1 * multiplier
69    }
70}
71
72type RandomGaussianState = (i64, Option<f64>);
73
74impl StatefulSeedValueGenerator<RandomGaussianState, f64> for XorShiftRandomForGaussian {
75    fn from_init_seed(init_value: i64) -> Self {
76        XorShiftRandomForGaussian {
77            base_generator: XorShiftRandom::from_init_seed(init_value),
78            next_gaussian: None,
79        }
80    }
81
82    fn from_stored_state(stored_state: RandomGaussianState) -> Self {
83        XorShiftRandomForGaussian {
84            base_generator: XorShiftRandom::from_stored_state(stored_state.0),
85            next_gaussian: stored_state.1,
86        }
87    }
88
89    fn next_value(&mut self) -> f64 {
90        self.next_gaussian()
91    }
92
93    fn get_current_state(&self) -> RandomGaussianState {
94        (self.base_generator.seed, self.next_gaussian)
95    }
96}
97
98#[derive(Debug, Clone)]
99pub struct RandnExpr {
100    seed: i64,
101    state_holder: Arc<Mutex<Option<RandomGaussianState>>>,
102}
103
104impl RandnExpr {
105    pub fn new(seed: i64) -> Self {
106        Self {
107            seed,
108            state_holder: Arc::new(Mutex::new(None)),
109        }
110    }
111}
112
113impl Display for RandnExpr {
114    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
115        write!(f, "RANDN({})", self.seed)
116    }
117}
118
119impl PartialEq for RandnExpr {
120    fn eq(&self, other: &Self) -> bool {
121        self.seed.eq(&other.seed)
122    }
123}
124
125impl Eq for RandnExpr {}
126
127impl Hash for RandnExpr {
128    fn hash<H: Hasher>(&self, state: &mut H) {
129        self.children().hash(state);
130    }
131}
132
133impl PhysicalExpr for RandnExpr {
134    fn as_any(&self) -> &dyn Any {
135        self
136    }
137
138    fn data_type(&self, _input_schema: &Schema) -> datafusion::common::Result<DataType> {
139        Ok(DataType::Float64)
140    }
141
142    fn nullable(&self, _input_schema: &Schema) -> datafusion::common::Result<bool> {
143        Ok(false)
144    }
145
146    fn evaluate(&self, batch: &RecordBatch) -> datafusion::common::Result<ColumnarValue> {
147        evaluate_batch_for_rand::<XorShiftRandomForGaussian, RandomGaussianState>(
148            &self.state_holder,
149            self.seed,
150            batch.num_rows(),
151        )
152    }
153
154    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
155        vec![]
156    }
157
158    fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result {
159        unimplemented!()
160    }
161
162    fn with_new_children(
163        self: Arc<Self>,
164        _children: Vec<Arc<dyn PhysicalExpr>>,
165    ) -> datafusion::common::Result<Arc<dyn PhysicalExpr>> {
166        Ok(Arc::new(RandnExpr::new(self.seed)))
167    }
168}
169
170pub fn randn(seed: i64) -> Arc<dyn PhysicalExpr> {
171    Arc::new(RandnExpr::new(seed))
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177    use arrow::array::{Array, Float64Array, Int64Array};
178    use arrow::{array::StringArray, compute::concat, datatypes::*};
179    use datafusion::common::cast::as_float64_array;
180
181    const PRECISION_TOLERANCE: f64 = 1e-6;
182
183    const SPARK_SEED_42_FIRST_5_GAUSSIAN: [f64; 5] = [
184        2.384479054241165,
185        0.1920934041293524,
186        0.7337336533286575,
187        -0.5224480195716871,
188        2.060084179317831,
189    ];
190
191    #[test]
192    fn test_rand_single_batch() -> datafusion::common::Result<()> {
193        let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
194        let data = StringArray::from(vec![Some("foo"), None, None, Some("bar"), None]);
195        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(data)])?;
196        let randn_expr = randn(42);
197        let result = randn_expr.evaluate(&batch)?.into_array(batch.num_rows())?;
198        let result = as_float64_array(&result)?;
199        let expected = &Float64Array::from(Vec::from(SPARK_SEED_42_FIRST_5_GAUSSIAN));
200        assert_eq_with_tolerance(result, expected);
201        Ok(())
202    }
203
204    #[test]
205    fn test_rand_multi_batch() -> datafusion::common::Result<()> {
206        let first_batch_schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]);
207        let first_batch_data = Int64Array::from(vec![Some(24), None, None]);
208        let second_batch_schema = first_batch_schema.clone();
209        let second_batch_data = Int64Array::from(vec![None, Some(22)]);
210        let randn_expr = randn(42);
211        let first_batch = RecordBatch::try_new(
212            Arc::new(first_batch_schema),
213            vec![Arc::new(first_batch_data)],
214        )?;
215        let first_batch_result = randn_expr
216            .evaluate(&first_batch)?
217            .into_array(first_batch.num_rows())?;
218        let second_batch = RecordBatch::try_new(
219            Arc::new(second_batch_schema),
220            vec![Arc::new(second_batch_data)],
221        )?;
222        let second_batch_result = randn_expr
223            .evaluate(&second_batch)?
224            .into_array(second_batch.num_rows())?;
225        let result_arrays: Vec<&dyn Array> = vec![
226            as_float64_array(&first_batch_result)?,
227            as_float64_array(&second_batch_result)?,
228        ];
229        let result_arrays = &concat(&result_arrays)?;
230        let final_result = as_float64_array(result_arrays)?;
231        let expected = &Float64Array::from(Vec::from(SPARK_SEED_42_FIRST_5_GAUSSIAN));
232        assert_eq_with_tolerance(final_result, expected);
233        Ok(())
234    }
235
236    fn assert_eq_with_tolerance(left: &Float64Array, right: &Float64Array) {
237        assert_eq!(left.len(), right.len());
238        left.iter().zip(right.iter()).for_each(|(l, r)| {
239            assert!(
240                (l.unwrap() - r.unwrap()).abs() < PRECISION_TOLERANCE,
241                "difference between {:?} and {:?} is larger than acceptable precision",
242                l.unwrap(),
243                r.unwrap()
244            )
245        })
246    }
247}