polars_plan/dsl/function_expr/
fused.rs

1#[cfg(feature = "serde")]
2use serde::{Deserialize, Serialize};
3
4use super::*;
5
6#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
7#[derive(Copy, Clone, PartialEq, Debug, Hash)]
8pub enum FusedOperator {
9    MultiplyAdd,
10    SubMultiply,
11    MultiplySub,
12}
13
14impl Display for FusedOperator {
15    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
16        let s = match self {
17            FusedOperator::MultiplyAdd => "fma",
18            FusedOperator::SubMultiply => "fsm",
19            FusedOperator::MultiplySub => "fms",
20        };
21        write!(f, "{s}")
22    }
23}
24
25pub(super) fn fused(input: &[Column], op: FusedOperator) -> PolarsResult<Column> {
26    let s0 = &input[0];
27    let s1 = &input[1];
28    let s2 = &input[2];
29    match op {
30        FusedOperator::MultiplyAdd => Ok(fma_columns(s0, s1, s2)),
31        FusedOperator::SubMultiply => Ok(fsm_columns(s0, s1, s2)),
32        FusedOperator::MultiplySub => Ok(fms_columns(s0, s1, s2)),
33    }
34}