use crate::hash_funcs::murmur3::spark_compatible_murmur3_hash;
use crate::internal::{evaluate_batch_for_rand, StatefulSeedValueGenerator};
use arrow::array::RecordBatch;
use arrow::datatypes::{DataType, Schema};
use datafusion::common::Result;
use datafusion::logical_expr::ColumnarValue;
use datafusion::physical_expr::PhysicalExpr;
use std::any::Any;
use std::fmt::{Display, Formatter};
use std::hash::{Hash, Hasher};
use std::sync::{Arc, Mutex};
const DOUBLE_UNIT: f64 = 1.1102230246251565e-16;
const SPARK_MURMUR_ARRAY_SEED: u32 = 0x3c074a61;
#[derive(Debug, Clone)]
pub(crate) struct XorShiftRandom {
pub(crate) seed: i64,
}
impl XorShiftRandom {
fn next(&mut self, bits: u8) -> i32 {
let mut next_seed = self.seed ^ (self.seed << 21);
next_seed ^= ((next_seed as u64) >> 35) as i64;
next_seed ^= next_seed << 4;
self.seed = next_seed;
(next_seed & ((1i64 << bits) - 1)) as i32
}
pub fn next_f64(&mut self) -> f64 {
let a = self.next(26) as i64;
let b = self.next(27) as i64;
((a << 27) + b) as f64 * DOUBLE_UNIT
}
}
impl StatefulSeedValueGenerator<i64, f64> for XorShiftRandom {
fn from_init_seed(init_seed: i64) -> Self {
let bytes_repr = init_seed.to_be_bytes();
let low_bits = spark_compatible_murmur3_hash(bytes_repr, SPARK_MURMUR_ARRAY_SEED);
let high_bits = spark_compatible_murmur3_hash(bytes_repr, low_bits);
let init_seed = ((high_bits as i64) << 32) | (low_bits as i64 & 0xFFFFFFFFi64);
XorShiftRandom { seed: init_seed }
}
fn from_stored_state(stored_state: i64) -> Self {
XorShiftRandom { seed: stored_state }
}
fn next_value(&mut self) -> f64 {
self.next_f64()
}
fn get_current_state(&self) -> i64 {
self.seed
}
}
#[derive(Debug)]
pub struct RandExpr {
seed: i64,
state_holder: Arc<Mutex<Option<i64>>>,
}
impl RandExpr {
pub fn new(seed: i64) -> Self {
Self {
seed,
state_holder: Arc::new(Mutex::new(None::<i64>)),
}
}
}
impl Display for RandExpr {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "RAND({})", self.seed)
}
}
impl PartialEq for RandExpr {
fn eq(&self, other: &Self) -> bool {
self.seed.eq(&other.seed)
}
}
impl Eq for RandExpr {}
impl Hash for RandExpr {
fn hash<H: Hasher>(&self, state: &mut H) {
self.children().hash(state);
}
}
impl PhysicalExpr for RandExpr {
fn as_any(&self) -> &dyn Any {
self
}
fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
Ok(DataType::Float64)
}
fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
Ok(false)
}
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
evaluate_batch_for_rand::<XorShiftRandom, i64>(
&self.state_holder,
self.seed,
batch.num_rows(),
)
}
fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
vec![]
}
fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result {
unimplemented!()
}
fn with_new_children(
self: Arc<Self>,
_children: Vec<Arc<dyn PhysicalExpr>>,
) -> Result<Arc<dyn PhysicalExpr>> {
Ok(Arc::new(RandExpr::new(self.seed)))
}
}
pub fn rand(seed: i64) -> Arc<dyn PhysicalExpr> {
Arc::new(RandExpr::new(seed))
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::{Array, Float64Array, Int64Array};
use arrow::{array::StringArray, compute::concat, datatypes::*};
use datafusion::common::cast::as_float64_array;
const SPARK_SEED_42_FIRST_5: [f64; 5] = [
0.619189370225301,
0.5096018842446481,
0.8325259388871524,
0.26322809041172357,
0.6702867696264135,
];
#[test]
fn test_rand_single_batch() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
let data = StringArray::from(vec![Some("foo"), None, None, Some("bar"), None]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(data)])?;
let rand_expr = rand(42);
let result = rand_expr.evaluate(&batch)?.into_array(batch.num_rows())?;
let result = as_float64_array(&result)?;
let expected = &Float64Array::from(Vec::from(SPARK_SEED_42_FIRST_5));
assert_eq!(result, expected);
Ok(())
}
#[test]
fn test_rand_multi_batch() -> Result<()> {
let first_batch_schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]);
let first_batch_data = Int64Array::from(vec![Some(42), None]);
let second_batch_schema = first_batch_schema.clone();
let second_batch_data = Int64Array::from(vec![None, Some(-42), None]);
let rand_expr = rand(42);
let first_batch = RecordBatch::try_new(
Arc::new(first_batch_schema),
vec![Arc::new(first_batch_data)],
)?;
let first_batch_result = rand_expr
.evaluate(&first_batch)?
.into_array(first_batch.num_rows())?;
let second_batch = RecordBatch::try_new(
Arc::new(second_batch_schema),
vec![Arc::new(second_batch_data)],
)?;
let second_batch_result = rand_expr
.evaluate(&second_batch)?
.into_array(second_batch.num_rows())?;
let result_arrays: Vec<&dyn Array> = vec![
as_float64_array(&first_batch_result)?,
as_float64_array(&second_batch_result)?,
];
let result_arrays = &concat(&result_arrays)?;
let final_result = as_float64_array(result_arrays)?;
let expected = &Float64Array::from(Vec::from(SPARK_SEED_42_FIRST_5));
assert_eq!(final_result, expected);
Ok(())
}
}