datafusion_comet_spark_expr/nondetermenistic_funcs/
rand.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::hash_funcs::murmur3::spark_compatible_murmur3_hash;
19use arrow::array::{Float64Array, Float64Builder, RecordBatch};
20use arrow::datatypes::{DataType, Schema};
21use datafusion::common::Result;
22use datafusion::common::ScalarValue;
23use datafusion::error::DataFusionError;
24use datafusion::logical_expr::ColumnarValue;
25use datafusion::physical_expr::PhysicalExpr;
26use std::any::Any;
27use std::fmt::{Display, Formatter};
28use std::hash::{Hash, Hasher};
29use std::sync::{Arc, Mutex};
30
31/// Adoption of the XOR-shift algorithm used in Apache Spark.
32/// See: https://github.com/apache/spark/blob/91f3fdd25852b43095dd5273358fc394ffd11b66/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala
33/// Normalization multiplier used in mapping from a random i64 value to the f64 interval [0.0, 1.0).
34/// Corresponds to the java implementation: https://github.com/openjdk/jdk/blob/07c9f7138affdf0d42ecdc30adcb854515569985/src/java.base/share/classes/java/util/Random.java#L302
35/// Due to the lack of hexadecimal float literals support in rust, the scientific notation is used instead.
36const DOUBLE_UNIT: f64 = 1.1102230246251565e-16;
37
38/// Spark-compatible initial seed which is actually a part of the scala standard library murmurhash3 implementation.
39/// The references:
40/// https://github.com/apache/spark/blob/91f3fdd25852b43095dd5273358fc394ffd11b66/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala#L63
41/// https://github.com/scala/scala/blob/360d5da544d84b821c40e4662ad08703b51a44e1/src/library/scala/util/hashing/MurmurHash3.scala#L331
42const SPARK_MURMUR_ARRAY_SEED: u32 = 0x3c074a61;
43
44#[derive(Debug, Clone)]
45struct XorShiftRandom {
46    seed: i64,
47}
48
49impl XorShiftRandom {
50    fn from_init_seed(init_seed: i64) -> Self {
51        XorShiftRandom {
52            seed: Self::init_seed(init_seed),
53        }
54    }
55
56    fn from_stored_seed(stored_seed: i64) -> Self {
57        XorShiftRandom { seed: stored_seed }
58    }
59
60    fn next(&mut self, bits: u8) -> i32 {
61        let mut next_seed = self.seed ^ (self.seed << 21);
62        next_seed ^= ((next_seed as u64) >> 35) as i64;
63        next_seed ^= next_seed << 4;
64        self.seed = next_seed;
65        (next_seed & ((1i64 << bits) - 1)) as i32
66    }
67
68    pub fn next_f64(&mut self) -> f64 {
69        let a = self.next(26) as i64;
70        let b = self.next(27) as i64;
71        ((a << 27) + b) as f64 * DOUBLE_UNIT
72    }
73
74    fn init_seed(init: i64) -> i64 {
75        let bytes_repr = init.to_be_bytes();
76        let low_bits = spark_compatible_murmur3_hash(bytes_repr, SPARK_MURMUR_ARRAY_SEED);
77        let high_bits = spark_compatible_murmur3_hash(bytes_repr, low_bits);
78        ((high_bits as i64) << 32) | (low_bits as i64 & 0xFFFFFFFFi64)
79    }
80}
81
82#[derive(Debug)]
83pub struct RandExpr {
84    seed: Arc<dyn PhysicalExpr>,
85    init_seed_shift: i32,
86    state_holder: Arc<Mutex<Option<i64>>>,
87}
88
89impl RandExpr {
90    pub fn new(seed: Arc<dyn PhysicalExpr>, init_seed_shift: i32) -> Self {
91        Self {
92            seed,
93            init_seed_shift,
94            state_holder: Arc::new(Mutex::new(None::<i64>)),
95        }
96    }
97
98    fn extract_init_state(seed: ScalarValue) -> Result<i64> {
99        if let ScalarValue::Int64(seed_opt) = seed.cast_to(&DataType::Int64)? {
100            Ok(seed_opt.unwrap_or(0))
101        } else {
102            Err(DataFusionError::Internal(
103                "unexpected execution branch".to_string(),
104            ))
105        }
106    }
107    fn evaluate_batch(&self, seed: ScalarValue, num_rows: usize) -> Result<ColumnarValue> {
108        let mut seed_state = self.state_holder.lock().unwrap();
109        let mut rnd = if seed_state.is_none() {
110            let init_seed = RandExpr::extract_init_state(seed)?;
111            let init_seed = init_seed.wrapping_add(self.init_seed_shift as i64);
112            *seed_state = Some(init_seed);
113            XorShiftRandom::from_init_seed(init_seed)
114        } else {
115            let stored_seed = seed_state.unwrap();
116            XorShiftRandom::from_stored_seed(stored_seed)
117        };
118
119        let mut arr_builder = Float64Builder::with_capacity(num_rows);
120        std::iter::repeat_with(|| rnd.next_f64())
121            .take(num_rows)
122            .for_each(|v| arr_builder.append_value(v));
123        let array_ref = Arc::new(Float64Array::from(arr_builder.finish()));
124        *seed_state = Some(rnd.seed);
125        Ok(ColumnarValue::Array(array_ref))
126    }
127}
128
129impl Display for RandExpr {
130    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
131        write!(f, "RAND({})", self.seed)
132    }
133}
134
135impl PartialEq for RandExpr {
136    fn eq(&self, other: &Self) -> bool {
137        self.seed.eq(&other.seed) && self.init_seed_shift == other.init_seed_shift
138    }
139}
140
141impl Eq for RandExpr {}
142
143impl Hash for RandExpr {
144    fn hash<H: Hasher>(&self, state: &mut H) {
145        self.children().hash(state);
146    }
147}
148
149impl PhysicalExpr for RandExpr {
150    fn as_any(&self) -> &dyn Any {
151        self
152    }
153
154    fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
155        Ok(DataType::Float64)
156    }
157
158    fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
159        Ok(false)
160    }
161
162    fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
163        match self.seed.evaluate(batch)? {
164            ColumnarValue::Scalar(seed) => self.evaluate_batch(seed, batch.num_rows()),
165            ColumnarValue::Array(_arr) => Err(DataFusionError::NotImplemented(format!(
166                "Only literal seeds are supported for {self}"
167            ))),
168        }
169    }
170
171    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
172        vec![&self.seed]
173    }
174
175    fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result {
176        unimplemented!()
177    }
178
179    fn with_new_children(
180        self: Arc<Self>,
181        children: Vec<Arc<dyn PhysicalExpr>>,
182    ) -> Result<Arc<dyn PhysicalExpr>> {
183        Ok(Arc::new(RandExpr::new(
184            Arc::clone(&children[0]),
185            self.init_seed_shift,
186        )))
187    }
188}
189
190pub fn rand(seed: Arc<dyn PhysicalExpr>, init_seed_shift: i32) -> Result<Arc<dyn PhysicalExpr>> {
191    Ok(Arc::new(RandExpr::new(seed, init_seed_shift)))
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197    use arrow::array::{Array, BooleanArray, Int64Array};
198    use arrow::{array::StringArray, compute::concat, datatypes::*};
199    use datafusion::common::cast::as_float64_array;
200    use datafusion::physical_expr::expressions::lit;
201
202    const SPARK_SEED_42_FIRST_5: [f64; 5] = [
203        0.619189370225301,
204        0.5096018842446481,
205        0.8325259388871524,
206        0.26322809041172357,
207        0.6702867696264135,
208    ];
209
210    #[test]
211    fn test_rand_single_batch() -> Result<()> {
212        let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
213        let data = StringArray::from(vec![Some("foo"), None, None, Some("bar"), None]);
214        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(data)])?;
215        let rand_expr = rand(lit(42), 0)?;
216        let result = rand_expr.evaluate(&batch)?.into_array(batch.num_rows())?;
217        let result = as_float64_array(&result)?;
218        let expected = &Float64Array::from(Vec::from(SPARK_SEED_42_FIRST_5));
219        assert_eq!(result, expected);
220        Ok(())
221    }
222
223    #[test]
224    fn test_rand_multi_batch() -> Result<()> {
225        let first_batch_schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]);
226        let first_batch_data = Int64Array::from(vec![Some(42), None]);
227        let second_batch_schema = first_batch_schema.clone();
228        let second_batch_data = Int64Array::from(vec![None, Some(-42), None]);
229        let rand_expr = rand(lit(42), 0)?;
230        let first_batch = RecordBatch::try_new(
231            Arc::new(first_batch_schema),
232            vec![Arc::new(first_batch_data)],
233        )?;
234        let first_batch_result = rand_expr
235            .evaluate(&first_batch)?
236            .into_array(first_batch.num_rows())?;
237        let second_batch = RecordBatch::try_new(
238            Arc::new(second_batch_schema),
239            vec![Arc::new(second_batch_data)],
240        )?;
241        let second_batch_result = rand_expr
242            .evaluate(&second_batch)?
243            .into_array(second_batch.num_rows())?;
244        let result_arrays: Vec<&dyn Array> = vec![
245            as_float64_array(&first_batch_result)?,
246            as_float64_array(&second_batch_result)?,
247        ];
248        let result_arrays = &concat(&result_arrays)?;
249        let final_result = as_float64_array(result_arrays)?;
250        let expected = &Float64Array::from(Vec::from(SPARK_SEED_42_FIRST_5));
251        assert_eq!(final_result, expected);
252        Ok(())
253    }
254
255    #[test]
256    fn test_overflow_shift_seed() -> Result<()> {
257        let schema = Schema::new(vec![Field::new("a", DataType::Boolean, false)]);
258        let data = BooleanArray::from(vec![Some(true), Some(false)]);
259        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(data)])?;
260        let max_seed_and_shift_expr = rand(lit(i64::MAX), 1)?;
261        let min_seed_no_shift_expr = rand(lit(i64::MIN), 0)?;
262        let first_expr_result = max_seed_and_shift_expr
263            .evaluate(&batch)?
264            .into_array(batch.num_rows())?;
265        let first_expr_result = as_float64_array(&first_expr_result)?;
266        let second_expr_result = min_seed_no_shift_expr
267            .evaluate(&batch)?
268            .into_array(batch.num_rows())?;
269        let second_expr_result = as_float64_array(&second_expr_result)?;
270        assert_eq!(first_expr_result, second_expr_result);
271        Ok(())
272    }
273}