polars_plan/dsl/function_expr/
random.rs

1use polars_core::prelude::DataType::Float64;
2#[cfg(feature = "serde")]
3use serde::{Deserialize, Serialize};
4use strum_macros::IntoStaticStr;
5
6use super::*;
7
8#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
9#[derive(Copy, Clone, PartialEq, Debug, IntoStaticStr)]
10#[strum(serialize_all = "snake_case")]
11pub enum RandomMethod {
12    Shuffle,
13    Sample {
14        is_fraction: bool,
15        with_replacement: bool,
16        shuffle: bool,
17    },
18}
19
20impl Hash for RandomMethod {
21    fn hash<H: Hasher>(&self, state: &mut H) {
22        std::mem::discriminant(self).hash(state)
23    }
24}
25
26pub(super) fn shuffle(s: &Column, seed: Option<u64>) -> PolarsResult<Column> {
27    Ok(s.shuffle(seed))
28}
29
30pub(super) fn sample_frac(
31    s: &[Column],
32    with_replacement: bool,
33    shuffle: bool,
34    seed: Option<u64>,
35) -> PolarsResult<Column> {
36    let src = &s[0];
37    let frac_s = &s[1];
38
39    polars_ensure!(
40        frac_s.len() == 1,
41        ComputeError: "Sample fraction must be a single value."
42    );
43
44    let frac_s = frac_s.cast(&Float64)?;
45    let frac = frac_s.f64()?;
46
47    match frac.get(0) {
48        Some(frac) => src.sample_frac(frac, with_replacement, shuffle, seed),
49        None => Ok(Column::new_empty(src.name().clone(), src.dtype())),
50    }
51}
52
53pub(super) fn sample_n(
54    s: &[Column],
55    with_replacement: bool,
56    shuffle: bool,
57    seed: Option<u64>,
58) -> PolarsResult<Column> {
59    let src = &s[0];
60    let n_s = &s[1];
61
62    polars_ensure!(
63        n_s.len() == 1,
64        ComputeError: "Sample size must be a single value."
65    );
66
67    let n_s = n_s.cast(&IDX_DTYPE)?;
68    let n = n_s.idx()?;
69
70    match n.get(0) {
71        Some(n) => src.sample_n(n as usize, with_replacement, shuffle, seed),
72        None => Ok(Column::new_empty(src.name().clone(), src.dtype())),
73    }
74}