datafusion_comet_spark_expr/nondetermenistic_funcs/internal/
rand_utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{Float64Array, Float64Builder};
19use datafusion::logical_expr::ColumnarValue;
20use std::ops::Deref;
21use std::sync::{Arc, Mutex};
22
23pub fn evaluate_batch_for_rand<R, S>(
24    state_holder: &Arc<Mutex<Option<S>>>,
25    seed: i64,
26    num_rows: usize,
27) -> datafusion::common::Result<ColumnarValue>
28where
29    R: StatefulSeedValueGenerator<S, f64>,
30    S: Copy,
31{
32    let seed_state = state_holder.lock().unwrap();
33    let mut rnd = R::from_state_ref(seed_state, seed);
34    let mut arr_builder = Float64Builder::with_capacity(num_rows);
35    std::iter::repeat_with(|| rnd.next_value())
36        .take(num_rows)
37        .for_each(|v| arr_builder.append_value(v));
38    let array_ref = Arc::new(Float64Array::from(arr_builder.finish()));
39    let mut seed_state = state_holder.lock().unwrap();
40    seed_state.replace(rnd.get_current_state());
41    Ok(ColumnarValue::Array(array_ref))
42}
43
44pub trait StatefulSeedValueGenerator<State: Copy, Value>: Sized {
45    fn from_init_seed(init_seed: i64) -> Self;
46
47    fn from_stored_state(stored_state: State) -> Self;
48
49    fn next_value(&mut self) -> Value;
50
51    fn get_current_state(&self) -> State;
52
53    fn from_state_ref(state: impl Deref<Target = Option<State>>, init_value: i64) -> Self {
54        if state.is_none() {
55            Self::from_init_seed(init_value)
56        } else {
57            Self::from_stored_state(state.unwrap())
58        }
59    }
60}