polars_plan/plans/functions/
dsl.rs

1use polars_compute::rolling::QuantileMethod;
2use strum_macros::IntoStaticStr;
3
4use super::*;
5
6#[cfg(feature = "python")]
7#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
8#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
9#[derive(Clone)]
10pub struct OpaquePythonUdf {
11    pub function: PythonFunction,
12    pub schema: Option<SchemaRef>,
13    ///  allow predicate pushdown optimizations
14    pub predicate_pd: bool,
15    ///  allow projection pushdown optimizations
16    pub projection_pd: bool,
17    pub streamable: bool,
18    pub validate_output: bool,
19}
20
21// Except for Opaque functions, this only has the DSL name of the function.
22#[derive(Clone, IntoStaticStr)]
23#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
24#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
25#[strum(serialize_all = "SCREAMING_SNAKE_CASE")]
26pub enum DslFunction {
27    RowIndex {
28        name: PlSmallStr,
29        offset: Option<IdxSize>,
30    },
31    // This is both in DSL and IR because we want to be able to serialize it.
32    #[cfg(feature = "python")]
33    OpaquePython(OpaquePythonUdf),
34    Explode {
35        columns: Vec<Selector>,
36        allow_empty: bool,
37    },
38    #[cfg(feature = "pivot")]
39    Unpivot {
40        args: UnpivotArgsDSL,
41    },
42    Rename {
43        existing: Arc<[PlSmallStr]>,
44        new: Arc<[PlSmallStr]>,
45        strict: bool,
46    },
47    Unnest(Vec<Selector>),
48    Stats(StatsFunction),
49    /// FillValue
50    FillNan(Expr),
51    Drop(DropFunction),
52    // Function that is already converted to IR.
53    #[cfg_attr(any(feature = "serde", feature = "dsl-schema"), serde(skip))]
54    FunctionIR(FunctionIR),
55}
56
57#[derive(Clone)]
58#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
59#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
60pub struct DropFunction {
61    /// Columns that are going to be dropped
62    pub(crate) to_drop: Vec<Selector>,
63    /// If `true`, performs a check for each item in `to_drop` against the schema. Returns an
64    /// `ColumnNotFound` error if the column does not exist in the schema.
65    pub(crate) strict: bool,
66}
67
68#[derive(Clone)]
69#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
70#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
71pub enum StatsFunction {
72    Var {
73        ddof: u8,
74    },
75    Std {
76        ddof: u8,
77    },
78    Quantile {
79        quantile: Expr,
80        method: QuantileMethod,
81    },
82    Median,
83    Mean,
84    Sum,
85    Min,
86    Max,
87}
88
89pub(crate) fn validate_columns_in_input<S: AsRef<str>, I: IntoIterator<Item = S>>(
90    columns: I,
91    input_schema: &Schema,
92    operation_name: &str,
93) -> PolarsResult<()> {
94    let columns = columns.into_iter();
95    for c in columns {
96        polars_ensure!(input_schema.contains(c.as_ref()), ColumnNotFound: "'{}' on column: '{}' is invalid\n\nSchema at this point: {:?}", operation_name, c.as_ref(), input_schema)
97    }
98    Ok(())
99}
100
101impl DslFunction {
102    pub(crate) fn into_function_ir(self, input_schema: &Schema) -> PolarsResult<FunctionIR> {
103        let function = match self {
104            #[cfg(feature = "pivot")]
105            DslFunction::Unpivot { args } => {
106                let on = expand_selectors(args.on, input_schema, &[])?;
107                let index = expand_selectors(args.index, input_schema, &[])?;
108                validate_columns_in_input(on.as_ref(), input_schema, "unpivot")?;
109                validate_columns_in_input(index.as_ref(), input_schema, "unpivot")?;
110
111                let args = UnpivotArgsIR {
112                    on: on.iter().cloned().collect(),
113                    index: index.iter().cloned().collect(),
114                    variable_name: args.variable_name.clone(),
115                    value_name: args.value_name.clone(),
116                };
117
118                FunctionIR::Unpivot {
119                    args: Arc::new(args),
120                    schema: Default::default(),
121                }
122            },
123            DslFunction::FunctionIR(func) => func,
124            DslFunction::RowIndex { name, offset } => FunctionIR::RowIndex {
125                name,
126                offset,
127                schema: Default::default(),
128            },
129            DslFunction::Unnest(selectors) => {
130                let columns = expand_selectors(selectors, input_schema, &[])?;
131                validate_columns_in_input(columns.as_ref(), input_schema, "unnest")?;
132                FunctionIR::Unnest { columns }
133            },
134            #[cfg(feature = "python")]
135            DslFunction::OpaquePython(inner) => FunctionIR::OpaquePython(inner),
136            DslFunction::Stats(_)
137            | DslFunction::FillNan(_)
138            | DslFunction::Drop(_)
139            | DslFunction::Rename { .. }
140            | DslFunction::Explode { .. } => {
141                // We should not reach this.
142                panic!("impl error")
143            },
144        };
145        Ok(function)
146    }
147}
148
149impl Debug for DslFunction {
150    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
151        write!(f, "{self}")
152    }
153}
154
155impl Display for DslFunction {
156    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
157        use DslFunction::*;
158        match self {
159            FunctionIR(inner) => write!(f, "{inner}"),
160            v => {
161                let s: &str = v.into();
162                write!(f, "{s}")
163            },
164        }
165    }
166}
167
168impl From<FunctionIR> for DslFunction {
169    fn from(value: FunctionIR) -> Self {
170        DslFunction::FunctionIR(value)
171    }
172}