datafusion_comet_spark_expr/nondetermenistic_funcs/
rand.rs1use crate::hash_funcs::murmur3::spark_compatible_murmur3_hash;
19
20use crate::internal::{evaluate_batch_for_rand, StatefulSeedValueGenerator};
21use arrow::array::RecordBatch;
22use arrow::datatypes::{DataType, Schema};
23use datafusion::common::Result;
24use datafusion::logical_expr::ColumnarValue;
25use datafusion::physical_expr::PhysicalExpr;
26use std::any::Any;
27use std::fmt::{Display, Formatter};
28use std::hash::{Hash, Hasher};
29use std::sync::{Arc, Mutex};
30
31const DOUBLE_UNIT: f64 = 1.1102230246251565e-16;
37
38const SPARK_MURMUR_ARRAY_SEED: u32 = 0x3c074a61;
43
44#[derive(Debug, Clone)]
45pub(crate) struct XorShiftRandom {
46 pub(crate) seed: i64,
47}
48
49impl XorShiftRandom {
50 fn next(&mut self, bits: u8) -> i32 {
51 let mut next_seed = self.seed ^ (self.seed << 21);
52 next_seed ^= ((next_seed as u64) >> 35) as i64;
53 next_seed ^= next_seed << 4;
54 self.seed = next_seed;
55 (next_seed & ((1i64 << bits) - 1)) as i32
56 }
57
58 pub fn next_f64(&mut self) -> f64 {
59 let a = self.next(26) as i64;
60 let b = self.next(27) as i64;
61 ((a << 27) + b) as f64 * DOUBLE_UNIT
62 }
63}
64
65impl StatefulSeedValueGenerator<i64, f64> for XorShiftRandom {
66 fn from_init_seed(init_seed: i64) -> Self {
67 let bytes_repr = init_seed.to_be_bytes();
68 let low_bits = spark_compatible_murmur3_hash(bytes_repr, SPARK_MURMUR_ARRAY_SEED);
69 let high_bits = spark_compatible_murmur3_hash(bytes_repr, low_bits);
70 let init_seed = ((high_bits as i64) << 32) | (low_bits as i64 & 0xFFFFFFFFi64);
71 XorShiftRandom { seed: init_seed }
72 }
73
74 fn from_stored_state(stored_state: i64) -> Self {
75 XorShiftRandom { seed: stored_state }
76 }
77
78 fn next_value(&mut self) -> f64 {
79 self.next_f64()
80 }
81
82 fn get_current_state(&self) -> i64 {
83 self.seed
84 }
85}
86
87#[derive(Debug)]
88pub struct RandExpr {
89 seed: i64,
90 state_holder: Arc<Mutex<Option<i64>>>,
91}
92
93impl RandExpr {
94 pub fn new(seed: i64) -> Self {
95 Self {
96 seed,
97 state_holder: Arc::new(Mutex::new(None::<i64>)),
98 }
99 }
100}
101
102impl Display for RandExpr {
103 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
104 write!(f, "RAND({})", self.seed)
105 }
106}
107
108impl PartialEq for RandExpr {
109 fn eq(&self, other: &Self) -> bool {
110 self.seed.eq(&other.seed)
111 }
112}
113
114impl Eq for RandExpr {}
115
116impl Hash for RandExpr {
117 fn hash<H: Hasher>(&self, state: &mut H) {
118 self.children().hash(state);
119 }
120}
121
122impl PhysicalExpr for RandExpr {
123 fn as_any(&self) -> &dyn Any {
124 self
125 }
126
127 fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
128 Ok(DataType::Float64)
129 }
130
131 fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
132 Ok(false)
133 }
134
135 fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
136 evaluate_batch_for_rand::<XorShiftRandom, i64>(
137 &self.state_holder,
138 self.seed,
139 batch.num_rows(),
140 )
141 }
142
143 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
144 vec![]
145 }
146
147 fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result {
148 unimplemented!()
149 }
150
151 fn with_new_children(
152 self: Arc<Self>,
153 _children: Vec<Arc<dyn PhysicalExpr>>,
154 ) -> Result<Arc<dyn PhysicalExpr>> {
155 Ok(Arc::new(RandExpr::new(self.seed)))
156 }
157}
158
159pub fn rand(seed: i64) -> Arc<dyn PhysicalExpr> {
160 Arc::new(RandExpr::new(seed))
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166 use arrow::array::{Array, Float64Array, Int64Array};
167 use arrow::{array::StringArray, compute::concat, datatypes::*};
168 use datafusion::common::cast::as_float64_array;
169
170 const SPARK_SEED_42_FIRST_5: [f64; 5] = [
171 0.619189370225301,
172 0.5096018842446481,
173 0.8325259388871524,
174 0.26322809041172357,
175 0.6702867696264135,
176 ];
177
178 #[test]
179 fn test_rand_single_batch() -> Result<()> {
180 let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
181 let data = StringArray::from(vec![Some("foo"), None, None, Some("bar"), None]);
182 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(data)])?;
183 let rand_expr = rand(42);
184 let result = rand_expr.evaluate(&batch)?.into_array(batch.num_rows())?;
185 let result = as_float64_array(&result)?;
186 let expected = &Float64Array::from(Vec::from(SPARK_SEED_42_FIRST_5));
187 assert_eq!(result, expected);
188 Ok(())
189 }
190
191 #[test]
192 fn test_rand_multi_batch() -> Result<()> {
193 let first_batch_schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]);
194 let first_batch_data = Int64Array::from(vec![Some(42), None]);
195 let second_batch_schema = first_batch_schema.clone();
196 let second_batch_data = Int64Array::from(vec![None, Some(-42), None]);
197 let rand_expr = rand(42);
198 let first_batch = RecordBatch::try_new(
199 Arc::new(first_batch_schema),
200 vec![Arc::new(first_batch_data)],
201 )?;
202 let first_batch_result = rand_expr
203 .evaluate(&first_batch)?
204 .into_array(first_batch.num_rows())?;
205 let second_batch = RecordBatch::try_new(
206 Arc::new(second_batch_schema),
207 vec![Arc::new(second_batch_data)],
208 )?;
209 let second_batch_result = rand_expr
210 .evaluate(&second_batch)?
211 .into_array(second_batch.num_rows())?;
212 let result_arrays: Vec<&dyn Array> = vec![
213 as_float64_array(&first_batch_result)?,
214 as_float64_array(&second_batch_result)?,
215 ];
216 let result_arrays = &concat(&result_arrays)?;
217 let final_result = as_float64_array(result_arrays)?;
218 let expected = &Float64Array::from(Vec::from(SPARK_SEED_42_FIRST_5));
219 assert_eq!(final_result, expected);
220 Ok(())
221 }
222}