polars_plan/dsl/function_expr/
random.rs1use polars_core::prelude::DataType::Float64;
2#[cfg(feature = "serde")]
3use serde::{Deserialize, Serialize};
4use strum_macros::IntoStaticStr;
5
6use super::*;
7
8#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
9#[derive(Copy, Clone, PartialEq, Debug, IntoStaticStr)]
10#[strum(serialize_all = "snake_case")]
11pub enum RandomMethod {
12 Shuffle,
13 Sample {
14 is_fraction: bool,
15 with_replacement: bool,
16 shuffle: bool,
17 },
18}
19
20impl Hash for RandomMethod {
21 fn hash<H: Hasher>(&self, state: &mut H) {
22 std::mem::discriminant(self).hash(state)
23 }
24}
25
26pub(super) fn shuffle(s: &Column, seed: Option<u64>) -> PolarsResult<Column> {
27 Ok(s.shuffle(seed))
28}
29
30pub(super) fn sample_frac(
31 s: &[Column],
32 with_replacement: bool,
33 shuffle: bool,
34 seed: Option<u64>,
35) -> PolarsResult<Column> {
36 let src = &s[0];
37 let frac_s = &s[1];
38
39 polars_ensure!(
40 frac_s.len() == 1,
41 ComputeError: "Sample fraction must be a single value."
42 );
43
44 let frac_s = frac_s.cast(&Float64)?;
45 let frac = frac_s.f64()?;
46
47 match frac.get(0) {
48 Some(frac) => src.sample_frac(frac, with_replacement, shuffle, seed),
49 None => Ok(Column::new_empty(src.name().clone(), src.dtype())),
50 }
51}
52
53pub(super) fn sample_n(
54 s: &[Column],
55 with_replacement: bool,
56 shuffle: bool,
57 seed: Option<u64>,
58) -> PolarsResult<Column> {
59 let src = &s[0];
60 let n_s = &s[1];
61
62 polars_ensure!(
63 n_s.len() == 1,
64 ComputeError: "Sample size must be a single value."
65 );
66
67 let n_s = n_s.cast(&IDX_DTYPE)?;
68 let n = n_s.idx()?;
69
70 match n.get(0) {
71 Some(n) => src.sample_n(n as usize, with_replacement, shuffle, seed),
72 None => Ok(Column::new_empty(src.name().clone(), src.dtype())),
73 }
74}