datafusion_comet_spark_expr/nondetermenistic_funcs/
rand.rs1use crate::hash_funcs::murmur3::spark_compatible_murmur3_hash;
19use arrow::array::{Float64Array, Float64Builder, RecordBatch};
20use arrow::datatypes::{DataType, Schema};
21use datafusion::common::Result;
22use datafusion::common::ScalarValue;
23use datafusion::error::DataFusionError;
24use datafusion::logical_expr::ColumnarValue;
25use datafusion::physical_expr::PhysicalExpr;
26use std::any::Any;
27use std::fmt::{Display, Formatter};
28use std::hash::{Hash, Hasher};
29use std::sync::{Arc, Mutex};
30
31const DOUBLE_UNIT: f64 = 1.1102230246251565e-16;
37
38const SPARK_MURMUR_ARRAY_SEED: u32 = 0x3c074a61;
43
44#[derive(Debug, Clone)]
45struct XorShiftRandom {
46 seed: i64,
47}
48
49impl XorShiftRandom {
50 fn from_init_seed(init_seed: i64) -> Self {
51 XorShiftRandom {
52 seed: Self::init_seed(init_seed),
53 }
54 }
55
56 fn from_stored_seed(stored_seed: i64) -> Self {
57 XorShiftRandom { seed: stored_seed }
58 }
59
60 fn next(&mut self, bits: u8) -> i32 {
61 let mut next_seed = self.seed ^ (self.seed << 21);
62 next_seed ^= ((next_seed as u64) >> 35) as i64;
63 next_seed ^= next_seed << 4;
64 self.seed = next_seed;
65 (next_seed & ((1i64 << bits) - 1)) as i32
66 }
67
68 pub fn next_f64(&mut self) -> f64 {
69 let a = self.next(26) as i64;
70 let b = self.next(27) as i64;
71 ((a << 27) + b) as f64 * DOUBLE_UNIT
72 }
73
74 fn init_seed(init: i64) -> i64 {
75 let bytes_repr = init.to_be_bytes();
76 let low_bits = spark_compatible_murmur3_hash(bytes_repr, SPARK_MURMUR_ARRAY_SEED);
77 let high_bits = spark_compatible_murmur3_hash(bytes_repr, low_bits);
78 ((high_bits as i64) << 32) | (low_bits as i64 & 0xFFFFFFFFi64)
79 }
80}
81
82#[derive(Debug)]
83pub struct RandExpr {
84 seed: Arc<dyn PhysicalExpr>,
85 init_seed_shift: i32,
86 state_holder: Arc<Mutex<Option<i64>>>,
87}
88
89impl RandExpr {
90 pub fn new(seed: Arc<dyn PhysicalExpr>, init_seed_shift: i32) -> Self {
91 Self {
92 seed,
93 init_seed_shift,
94 state_holder: Arc::new(Mutex::new(None::<i64>)),
95 }
96 }
97
98 fn extract_init_state(seed: ScalarValue) -> Result<i64> {
99 if let ScalarValue::Int64(seed_opt) = seed.cast_to(&DataType::Int64)? {
100 Ok(seed_opt.unwrap_or(0))
101 } else {
102 Err(DataFusionError::Internal(
103 "unexpected execution branch".to_string(),
104 ))
105 }
106 }
107 fn evaluate_batch(&self, seed: ScalarValue, num_rows: usize) -> Result<ColumnarValue> {
108 let mut seed_state = self.state_holder.lock().unwrap();
109 let mut rnd = if seed_state.is_none() {
110 let init_seed = RandExpr::extract_init_state(seed)?;
111 let init_seed = init_seed.wrapping_add(self.init_seed_shift as i64);
112 *seed_state = Some(init_seed);
113 XorShiftRandom::from_init_seed(init_seed)
114 } else {
115 let stored_seed = seed_state.unwrap();
116 XorShiftRandom::from_stored_seed(stored_seed)
117 };
118
119 let mut arr_builder = Float64Builder::with_capacity(num_rows);
120 std::iter::repeat_with(|| rnd.next_f64())
121 .take(num_rows)
122 .for_each(|v| arr_builder.append_value(v));
123 let array_ref = Arc::new(Float64Array::from(arr_builder.finish()));
124 *seed_state = Some(rnd.seed);
125 Ok(ColumnarValue::Array(array_ref))
126 }
127}
128
129impl Display for RandExpr {
130 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
131 write!(f, "RAND({})", self.seed)
132 }
133}
134
135impl PartialEq for RandExpr {
136 fn eq(&self, other: &Self) -> bool {
137 self.seed.eq(&other.seed) && self.init_seed_shift == other.init_seed_shift
138 }
139}
140
141impl Eq for RandExpr {}
142
143impl Hash for RandExpr {
144 fn hash<H: Hasher>(&self, state: &mut H) {
145 self.children().hash(state);
146 }
147}
148
149impl PhysicalExpr for RandExpr {
150 fn as_any(&self) -> &dyn Any {
151 self
152 }
153
154 fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
155 Ok(DataType::Float64)
156 }
157
158 fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
159 Ok(false)
160 }
161
162 fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
163 match self.seed.evaluate(batch)? {
164 ColumnarValue::Scalar(seed) => self.evaluate_batch(seed, batch.num_rows()),
165 ColumnarValue::Array(_arr) => Err(DataFusionError::NotImplemented(format!(
166 "Only literal seeds are supported for {self}"
167 ))),
168 }
169 }
170
171 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
172 vec![&self.seed]
173 }
174
175 fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result {
176 unimplemented!()
177 }
178
179 fn with_new_children(
180 self: Arc<Self>,
181 children: Vec<Arc<dyn PhysicalExpr>>,
182 ) -> Result<Arc<dyn PhysicalExpr>> {
183 Ok(Arc::new(RandExpr::new(
184 Arc::clone(&children[0]),
185 self.init_seed_shift,
186 )))
187 }
188}
189
190pub fn rand(seed: Arc<dyn PhysicalExpr>, init_seed_shift: i32) -> Result<Arc<dyn PhysicalExpr>> {
191 Ok(Arc::new(RandExpr::new(seed, init_seed_shift)))
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197 use arrow::array::{Array, BooleanArray, Int64Array};
198 use arrow::{array::StringArray, compute::concat, datatypes::*};
199 use datafusion::common::cast::as_float64_array;
200 use datafusion::physical_expr::expressions::lit;
201
202 const SPARK_SEED_42_FIRST_5: [f64; 5] = [
203 0.619189370225301,
204 0.5096018842446481,
205 0.8325259388871524,
206 0.26322809041172357,
207 0.6702867696264135,
208 ];
209
210 #[test]
211 fn test_rand_single_batch() -> Result<()> {
212 let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
213 let data = StringArray::from(vec![Some("foo"), None, None, Some("bar"), None]);
214 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(data)])?;
215 let rand_expr = rand(lit(42), 0)?;
216 let result = rand_expr.evaluate(&batch)?.into_array(batch.num_rows())?;
217 let result = as_float64_array(&result)?;
218 let expected = &Float64Array::from(Vec::from(SPARK_SEED_42_FIRST_5));
219 assert_eq!(result, expected);
220 Ok(())
221 }
222
223 #[test]
224 fn test_rand_multi_batch() -> Result<()> {
225 let first_batch_schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]);
226 let first_batch_data = Int64Array::from(vec![Some(42), None]);
227 let second_batch_schema = first_batch_schema.clone();
228 let second_batch_data = Int64Array::from(vec![None, Some(-42), None]);
229 let rand_expr = rand(lit(42), 0)?;
230 let first_batch = RecordBatch::try_new(
231 Arc::new(first_batch_schema),
232 vec![Arc::new(first_batch_data)],
233 )?;
234 let first_batch_result = rand_expr
235 .evaluate(&first_batch)?
236 .into_array(first_batch.num_rows())?;
237 let second_batch = RecordBatch::try_new(
238 Arc::new(second_batch_schema),
239 vec![Arc::new(second_batch_data)],
240 )?;
241 let second_batch_result = rand_expr
242 .evaluate(&second_batch)?
243 .into_array(second_batch.num_rows())?;
244 let result_arrays: Vec<&dyn Array> = vec![
245 as_float64_array(&first_batch_result)?,
246 as_float64_array(&second_batch_result)?,
247 ];
248 let result_arrays = &concat(&result_arrays)?;
249 let final_result = as_float64_array(result_arrays)?;
250 let expected = &Float64Array::from(Vec::from(SPARK_SEED_42_FIRST_5));
251 assert_eq!(final_result, expected);
252 Ok(())
253 }
254
255 #[test]
256 fn test_overflow_shift_seed() -> Result<()> {
257 let schema = Schema::new(vec![Field::new("a", DataType::Boolean, false)]);
258 let data = BooleanArray::from(vec![Some(true), Some(false)]);
259 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(data)])?;
260 let max_seed_and_shift_expr = rand(lit(i64::MAX), 1)?;
261 let min_seed_no_shift_expr = rand(lit(i64::MIN), 0)?;
262 let first_expr_result = max_seed_and_shift_expr
263 .evaluate(&batch)?
264 .into_array(batch.num_rows())?;
265 let first_expr_result = as_float64_array(&first_expr_result)?;
266 let second_expr_result = min_seed_no_shift_expr
267 .evaluate(&batch)?
268 .into_array(batch.num_rows())?;
269 let second_expr_result = as_float64_array(&second_expr_result)?;
270 assert_eq!(first_expr_result, second_expr_result);
271 Ok(())
272 }
273}