datafusion_comet_spark_expr/nondetermenistic_funcs/
rand.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::hash_funcs::murmur3::spark_compatible_murmur3_hash;
19
20use crate::internal::{evaluate_batch_for_rand, StatefulSeedValueGenerator};
21use arrow::array::RecordBatch;
22use arrow::datatypes::{DataType, Schema};
23use datafusion::common::Result;
24use datafusion::logical_expr::ColumnarValue;
25use datafusion::physical_expr::PhysicalExpr;
26use std::any::Any;
27use std::fmt::{Display, Formatter};
28use std::hash::{Hash, Hasher};
29use std::sync::{Arc, Mutex};
30
31/// Adoption of the XOR-shift algorithm used in Apache Spark.
32/// See: https://github.com/apache/spark/blob/91f3fdd25852b43095dd5273358fc394ffd11b66/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala
33/// Normalization multiplier used in mapping from a random i64 value to the f64 interval [0.0, 1.0).
34/// Corresponds to the java implementation: https://github.com/openjdk/jdk/blob/07c9f7138affdf0d42ecdc30adcb854515569985/src/java.base/share/classes/java/util/Random.java#L302
35/// Due to the lack of hexadecimal float literals support in rust, the scientific notation is used instead.
36const DOUBLE_UNIT: f64 = 1.1102230246251565e-16;
37
38/// Spark-compatible initial seed which is actually a part of the scala standard library murmurhash3 implementation.
39/// The references:
40/// https://github.com/apache/spark/blob/91f3fdd25852b43095dd5273358fc394ffd11b66/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala#L63
41/// https://github.com/scala/scala/blob/360d5da544d84b821c40e4662ad08703b51a44e1/src/library/scala/util/hashing/MurmurHash3.scala#L331
42const SPARK_MURMUR_ARRAY_SEED: u32 = 0x3c074a61;
43
44#[derive(Debug, Clone)]
45pub(crate) struct XorShiftRandom {
46    pub(crate) seed: i64,
47}
48
49impl XorShiftRandom {
50    fn next(&mut self, bits: u8) -> i32 {
51        let mut next_seed = self.seed ^ (self.seed << 21);
52        next_seed ^= ((next_seed as u64) >> 35) as i64;
53        next_seed ^= next_seed << 4;
54        self.seed = next_seed;
55        (next_seed & ((1i64 << bits) - 1)) as i32
56    }
57
58    pub fn next_f64(&mut self) -> f64 {
59        let a = self.next(26) as i64;
60        let b = self.next(27) as i64;
61        ((a << 27) + b) as f64 * DOUBLE_UNIT
62    }
63}
64
65impl StatefulSeedValueGenerator<i64, f64> for XorShiftRandom {
66    fn from_init_seed(init_seed: i64) -> Self {
67        let bytes_repr = init_seed.to_be_bytes();
68        let low_bits = spark_compatible_murmur3_hash(bytes_repr, SPARK_MURMUR_ARRAY_SEED);
69        let high_bits = spark_compatible_murmur3_hash(bytes_repr, low_bits);
70        let init_seed = ((high_bits as i64) << 32) | (low_bits as i64 & 0xFFFFFFFFi64);
71        XorShiftRandom { seed: init_seed }
72    }
73
74    fn from_stored_state(stored_state: i64) -> Self {
75        XorShiftRandom { seed: stored_state }
76    }
77
78    fn next_value(&mut self) -> f64 {
79        self.next_f64()
80    }
81
82    fn get_current_state(&self) -> i64 {
83        self.seed
84    }
85}
86
87#[derive(Debug)]
88pub struct RandExpr {
89    seed: i64,
90    state_holder: Arc<Mutex<Option<i64>>>,
91}
92
93impl RandExpr {
94    pub fn new(seed: i64) -> Self {
95        Self {
96            seed,
97            state_holder: Arc::new(Mutex::new(None::<i64>)),
98        }
99    }
100}
101
102impl Display for RandExpr {
103    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
104        write!(f, "RAND({})", self.seed)
105    }
106}
107
108impl PartialEq for RandExpr {
109    fn eq(&self, other: &Self) -> bool {
110        self.seed.eq(&other.seed)
111    }
112}
113
114impl Eq for RandExpr {}
115
116impl Hash for RandExpr {
117    fn hash<H: Hasher>(&self, state: &mut H) {
118        self.children().hash(state);
119    }
120}
121
122impl PhysicalExpr for RandExpr {
123    fn as_any(&self) -> &dyn Any {
124        self
125    }
126
127    fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
128        Ok(DataType::Float64)
129    }
130
131    fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
132        Ok(false)
133    }
134
135    fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
136        evaluate_batch_for_rand::<XorShiftRandom, i64>(
137            &self.state_holder,
138            self.seed,
139            batch.num_rows(),
140        )
141    }
142
143    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
144        vec![]
145    }
146
147    fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result {
148        unimplemented!()
149    }
150
151    fn with_new_children(
152        self: Arc<Self>,
153        _children: Vec<Arc<dyn PhysicalExpr>>,
154    ) -> Result<Arc<dyn PhysicalExpr>> {
155        Ok(Arc::new(RandExpr::new(self.seed)))
156    }
157}
158
159pub fn rand(seed: i64) -> Arc<dyn PhysicalExpr> {
160    Arc::new(RandExpr::new(seed))
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166    use arrow::array::{Array, Float64Array, Int64Array};
167    use arrow::{array::StringArray, compute::concat, datatypes::*};
168    use datafusion::common::cast::as_float64_array;
169
170    const SPARK_SEED_42_FIRST_5: [f64; 5] = [
171        0.619189370225301,
172        0.5096018842446481,
173        0.8325259388871524,
174        0.26322809041172357,
175        0.6702867696264135,
176    ];
177
178    #[test]
179    fn test_rand_single_batch() -> Result<()> {
180        let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
181        let data = StringArray::from(vec![Some("foo"), None, None, Some("bar"), None]);
182        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(data)])?;
183        let rand_expr = rand(42);
184        let result = rand_expr.evaluate(&batch)?.into_array(batch.num_rows())?;
185        let result = as_float64_array(&result)?;
186        let expected = &Float64Array::from(Vec::from(SPARK_SEED_42_FIRST_5));
187        assert_eq!(result, expected);
188        Ok(())
189    }
190
191    #[test]
192    fn test_rand_multi_batch() -> Result<()> {
193        let first_batch_schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]);
194        let first_batch_data = Int64Array::from(vec![Some(42), None]);
195        let second_batch_schema = first_batch_schema.clone();
196        let second_batch_data = Int64Array::from(vec![None, Some(-42), None]);
197        let rand_expr = rand(42);
198        let first_batch = RecordBatch::try_new(
199            Arc::new(first_batch_schema),
200            vec![Arc::new(first_batch_data)],
201        )?;
202        let first_batch_result = rand_expr
203            .evaluate(&first_batch)?
204            .into_array(first_batch.num_rows())?;
205        let second_batch = RecordBatch::try_new(
206            Arc::new(second_batch_schema),
207            vec![Arc::new(second_batch_data)],
208        )?;
209        let second_batch_result = rand_expr
210            .evaluate(&second_batch)?
211            .into_array(second_batch.num_rows())?;
212        let result_arrays: Vec<&dyn Array> = vec![
213            as_float64_array(&first_batch_result)?,
214            as_float64_array(&second_batch_result)?,
215        ];
216        let result_arrays = &concat(&result_arrays)?;
217        let final_result = as_float64_array(result_arrays)?;
218        let expected = &Float64Array::from(Vec::from(SPARK_SEED_42_FIRST_5));
219        assert_eq!(final_result, expected);
220        Ok(())
221    }
222}