use crate::nondetermenistic_funcs::rand::XorShiftRandom;
use crate::internal::{evaluate_batch_for_rand, StatefulSeedValueGenerator};
use arrow::array::RecordBatch;
use arrow::datatypes::{DataType, Schema};
use datafusion::logical_expr::ColumnarValue;
use datafusion::physical_expr::PhysicalExpr;
use std::any::Any;
use std::fmt::{Display, Formatter};
use std::hash::{Hash, Hasher};
use std::sync::{Arc, Mutex};
#[derive(Debug, Clone)]
struct XorShiftRandomForGaussian {
base_generator: XorShiftRandom,
next_gaussian: Option<f64>,
}
impl XorShiftRandomForGaussian {
pub fn next_gaussian(&mut self) -> f64 {
if let Some(stored_value) = self.next_gaussian {
self.next_gaussian = None;
return stored_value;
}
let mut v1: f64;
let mut v2: f64;
let mut s: f64;
loop {
v1 = 2f64 * self.base_generator.next_f64() - 1f64;
v2 = 2f64 * self.base_generator.next_f64() - 1f64;
s = v1 * v1 + v2 * v2;
if s < 1f64 && s != 0f64 {
break;
}
}
let multiplier = (-2f64 * s.ln() / s).sqrt();
self.next_gaussian = Some(v2 * multiplier);
v1 * multiplier
}
}
type RandomGaussianState = (i64, Option<f64>);
impl StatefulSeedValueGenerator<RandomGaussianState, f64> for XorShiftRandomForGaussian {
fn from_init_seed(init_value: i64) -> Self {
XorShiftRandomForGaussian {
base_generator: XorShiftRandom::from_init_seed(init_value),
next_gaussian: None,
}
}
fn from_stored_state(stored_state: RandomGaussianState) -> Self {
XorShiftRandomForGaussian {
base_generator: XorShiftRandom::from_stored_state(stored_state.0),
next_gaussian: stored_state.1,
}
}
fn next_value(&mut self) -> f64 {
self.next_gaussian()
}
fn get_current_state(&self) -> RandomGaussianState {
(self.base_generator.seed, self.next_gaussian)
}
}
#[derive(Debug, Clone)]
pub struct RandnExpr {
seed: i64,
state_holder: Arc<Mutex<Option<RandomGaussianState>>>,
}
impl RandnExpr {
pub fn new(seed: i64) -> Self {
Self {
seed,
state_holder: Arc::new(Mutex::new(None)),
}
}
}
impl Display for RandnExpr {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "RANDN({})", self.seed)
}
}
impl PartialEq for RandnExpr {
fn eq(&self, other: &Self) -> bool {
self.seed.eq(&other.seed)
}
}
impl Eq for RandnExpr {}
impl Hash for RandnExpr {
fn hash<H: Hasher>(&self, state: &mut H) {
self.children().hash(state);
}
}
impl PhysicalExpr for RandnExpr {
fn as_any(&self) -> &dyn Any {
self
}
fn data_type(&self, _input_schema: &Schema) -> datafusion::common::Result<DataType> {
Ok(DataType::Float64)
}
fn nullable(&self, _input_schema: &Schema) -> datafusion::common::Result<bool> {
Ok(false)
}
fn evaluate(&self, batch: &RecordBatch) -> datafusion::common::Result<ColumnarValue> {
evaluate_batch_for_rand::<XorShiftRandomForGaussian, RandomGaussianState>(
&self.state_holder,
self.seed,
batch.num_rows(),
)
}
fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
vec![]
}
fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result {
unimplemented!()
}
fn with_new_children(
self: Arc<Self>,
_children: Vec<Arc<dyn PhysicalExpr>>,
) -> datafusion::common::Result<Arc<dyn PhysicalExpr>> {
Ok(Arc::new(RandnExpr::new(self.seed)))
}
}
pub fn randn(seed: i64) -> Arc<dyn PhysicalExpr> {
Arc::new(RandnExpr::new(seed))
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::{Array, Float64Array, Int64Array};
use arrow::{array::StringArray, compute::concat, datatypes::*};
use datafusion::common::cast::as_float64_array;
const PRECISION_TOLERANCE: f64 = 1e-6;
const SPARK_SEED_42_FIRST_5_GAUSSIAN: [f64; 5] = [
2.384479054241165,
0.1920934041293524,
0.7337336533286575,
-0.5224480195716871,
2.060084179317831,
];
#[test]
fn test_rand_single_batch() -> datafusion::common::Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
let data = StringArray::from(vec![Some("foo"), None, None, Some("bar"), None]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(data)])?;
let randn_expr = randn(42);
let result = randn_expr.evaluate(&batch)?.into_array(batch.num_rows())?;
let result = as_float64_array(&result)?;
let expected = &Float64Array::from(Vec::from(SPARK_SEED_42_FIRST_5_GAUSSIAN));
assert_eq_with_tolerance(result, expected);
Ok(())
}
#[test]
fn test_rand_multi_batch() -> datafusion::common::Result<()> {
let first_batch_schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]);
let first_batch_data = Int64Array::from(vec![Some(24), None, None]);
let second_batch_schema = first_batch_schema.clone();
let second_batch_data = Int64Array::from(vec![None, Some(22)]);
let randn_expr = randn(42);
let first_batch = RecordBatch::try_new(
Arc::new(first_batch_schema),
vec![Arc::new(first_batch_data)],
)?;
let first_batch_result = randn_expr
.evaluate(&first_batch)?
.into_array(first_batch.num_rows())?;
let second_batch = RecordBatch::try_new(
Arc::new(second_batch_schema),
vec![Arc::new(second_batch_data)],
)?;
let second_batch_result = randn_expr
.evaluate(&second_batch)?
.into_array(second_batch.num_rows())?;
let result_arrays: Vec<&dyn Array> = vec![
as_float64_array(&first_batch_result)?,
as_float64_array(&second_batch_result)?,
];
let result_arrays = &concat(&result_arrays)?;
let final_result = as_float64_array(result_arrays)?;
let expected = &Float64Array::from(Vec::from(SPARK_SEED_42_FIRST_5_GAUSSIAN));
assert_eq_with_tolerance(final_result, expected);
Ok(())
}
fn assert_eq_with_tolerance(left: &Float64Array, right: &Float64Array) {
assert_eq!(left.len(), right.len());
left.iter().zip(right.iter()).for_each(|(l, r)| {
assert!(
(l.unwrap() - r.unwrap()).abs() < PRECISION_TOLERANCE,
"difference between {:?} and {:?} is larger than acceptable precision",
l.unwrap(),
r.unwrap()
)
})
}
}