datafusion_comet_spark_expr/nondetermenistic_funcs/internal/
rand_utils.rs1use arrow::array::{Float64Array, Float64Builder};
19use datafusion::logical_expr::ColumnarValue;
20use std::ops::Deref;
21use std::sync::{Arc, Mutex};
22
23pub fn evaluate_batch_for_rand<R, S>(
24 state_holder: &Arc<Mutex<Option<S>>>,
25 seed: i64,
26 num_rows: usize,
27) -> datafusion::common::Result<ColumnarValue>
28where
29 R: StatefulSeedValueGenerator<S, f64>,
30 S: Copy,
31{
32 let seed_state = state_holder.lock().unwrap();
33 let mut rnd = R::from_state_ref(seed_state, seed);
34 let mut arr_builder = Float64Builder::with_capacity(num_rows);
35 std::iter::repeat_with(|| rnd.next_value())
36 .take(num_rows)
37 .for_each(|v| arr_builder.append_value(v));
38 let array_ref = Arc::new(Float64Array::from(arr_builder.finish()));
39 let mut seed_state = state_holder.lock().unwrap();
40 seed_state.replace(rnd.get_current_state());
41 Ok(ColumnarValue::Array(array_ref))
42}
43
44pub trait StatefulSeedValueGenerator<State: Copy, Value>: Sized {
45 fn from_init_seed(init_seed: i64) -> Self;
46
47 fn from_stored_state(stored_state: State) -> Self;
48
49 fn next_value(&mut self) -> Value;
50
51 fn get_current_state(&self) -> State;
52
53 fn from_state_ref(state: impl Deref<Target = Option<State>>, init_value: i64) -> Self {
54 if state.is_none() {
55 Self::from_init_seed(init_value)
56 } else {
57 Self::from_stored_state(state.unwrap())
58 }
59 }
60}