polars_plan/dsl/function_expr/
fused.rs1#[cfg(feature = "serde")]
2use serde::{Deserialize, Serialize};
3
4use super::*;
5
6#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
7#[derive(Copy, Clone, PartialEq, Debug, Hash)]
8pub enum FusedOperator {
9 MultiplyAdd,
10 SubMultiply,
11 MultiplySub,
12}
13
14impl Display for FusedOperator {
15 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
16 let s = match self {
17 FusedOperator::MultiplyAdd => "fma",
18 FusedOperator::SubMultiply => "fsm",
19 FusedOperator::MultiplySub => "fms",
20 };
21 write!(f, "{s}")
22 }
23}
24
25pub(super) fn fused(input: &[Column], op: FusedOperator) -> PolarsResult<Column> {
26 let s0 = &input[0];
27 let s1 = &input[1];
28 let s2 = &input[2];
29 match op {
30 FusedOperator::MultiplyAdd => Ok(fma_columns(s0, s1, s2)),
31 FusedOperator::SubMultiply => Ok(fsm_columns(s0, s1, s2)),
32 FusedOperator::MultiplySub => Ok(fms_columns(s0, s1, s2)),
33 }
34}